View Javadoc
1   /*
2    * Copyright 2016 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelFuture;
23  import io.netty.channel.ChannelFutureListener;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelInboundHandlerAdapter;
26  import io.netty.channel.ChannelInitializer;
27  import io.netty.channel.ChannelOption;
28  import io.netty.channel.socket.SocketChannel;
29  import io.netty.util.concurrent.ImmediateEventExecutor;
30  import io.netty.util.concurrent.Promise;
31  import org.junit.jupiter.api.Test;
32  import org.junit.jupiter.api.TestInfo;
33  import org.junit.jupiter.api.Timeout;
34  
35  import java.io.ByteArrayOutputStream;
36  import java.net.InetSocketAddress;
37  import java.net.SocketAddress;
38  import java.util.concurrent.BlockingQueue;
39  import java.util.concurrent.LinkedBlockingQueue;
40  import java.util.concurrent.Semaphore;
41  import java.util.concurrent.TimeUnit;
42  
43  import static io.netty.buffer.ByteBufUtil.writeAscii;
44  import static io.netty.buffer.UnpooledByteBufAllocator.DEFAULT;
45  import static io.netty.util.CharsetUtil.US_ASCII;
46  import static org.junit.jupiter.api.Assertions.assertEquals;
47  import static org.junit.jupiter.api.Assertions.assertFalse;
48  import static org.junit.jupiter.api.Assertions.assertNotNull;
49  import static org.junit.jupiter.api.Assertions.assertNull;
50  import static org.junit.jupiter.api.Assertions.assertTrue;
51  
52  public class SocketConnectTest extends AbstractSocketTest {
53  
54      @Test
55      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
56      public void testLocalAddressAfterConnect(TestInfo testInfo) throws Throwable {
57          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
58              @Override
59              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
60                  testLocalAddressAfterConnect(serverBootstrap, bootstrap);
61              }
62          });
63      }
64  
65      public void testLocalAddressAfterConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
66          Channel serverChannel = null;
67          Channel clientChannel = null;
68          try {
69              final Promise<InetSocketAddress> localAddressPromise = ImmediateEventExecutor.INSTANCE.newPromise();
70              serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter() {
71                          @Override
72                          public void channelActive(ChannelHandlerContext ctx) throws Exception {
73                              localAddressPromise.setSuccess((InetSocketAddress) ctx.channel().localAddress());
74                          }
75                      }).bind().syncUninterruptibly().channel();
76  
77              clientChannel = cb.handler(new ChannelInboundHandlerAdapter()).register().syncUninterruptibly().channel();
78  
79              assertNull(clientChannel.localAddress());
80              assertNull(clientChannel.remoteAddress());
81  
82              clientChannel.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
83              assertLocalAddress((InetSocketAddress) clientChannel.localAddress());
84              assertNotNull(clientChannel.remoteAddress());
85  
86              assertLocalAddress(localAddressPromise.get());
87          } finally {
88              if (clientChannel != null) {
89                  clientChannel.close().syncUninterruptibly();
90              }
91              if (serverChannel != null) {
92                  serverChannel.close().syncUninterruptibly();
93              }
94          }
95      }
96  
97      @Test
98      @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
99      public void testChannelEventsFiredWhenClosedDirectly(TestInfo testInfo) throws Throwable {
100         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
101             @Override
102             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
103                 testChannelEventsFiredWhenClosedDirectly(serverBootstrap, bootstrap);
104             }
105         });
106     }
107 
108     public void testChannelEventsFiredWhenClosedDirectly(ServerBootstrap sb, Bootstrap cb) throws Throwable {
109         final BlockingQueue<Integer> events = new LinkedBlockingQueue<Integer>();
110 
111         Channel sc = null;
112         Channel cc = null;
113         try {
114             sb.childHandler(new ChannelInboundHandlerAdapter());
115             sc = sb.bind().syncUninterruptibly().channel();
116 
117             cb.handler(new ChannelInboundHandlerAdapter() {
118                 @Override
119                 public void channelActive(ChannelHandlerContext ctx) throws Exception {
120                     events.add(0);
121                 }
122 
123                 @Override
124                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
125                     events.add(1);
126                 }
127             });
128             // Connect and directly close again.
129             cc = cb.connect(sc.localAddress()).addListener(ChannelFutureListener.CLOSE).
130                     syncUninterruptibly().channel();
131             assertEquals(0, events.take().intValue());
132             assertEquals(1, events.take().intValue());
133         } finally {
134             if (cc != null) {
135                 cc.close();
136             }
137             if (sc != null) {
138                 sc.close();
139             }
140         }
141     }
142 
143     @Test
144     @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
145     public void testWriteWithFastOpenBeforeConnect(TestInfo testInfo) throws Throwable {
146         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
147             @Override
148             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
149                 testWriteWithFastOpenBeforeConnect(serverBootstrap, bootstrap);
150             }
151         });
152     }
153 
154     public void testWriteWithFastOpenBeforeConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
155         enableTcpFastOpen(sb, cb);
156         sb.childOption(ChannelOption.AUTO_READ, true);
157         cb.option(ChannelOption.AUTO_READ, true);
158 
159         sb.childHandler(new ChannelInitializer<SocketChannel>() {
160             @Override
161             protected void initChannel(SocketChannel ch) throws Exception {
162                 ch.pipeline().addLast(new EchoServerHandler());
163             }
164         });
165 
166         Channel sc = sb.bind().sync().channel();
167         connectAndVerifyDataTransfer(cb, sc);
168         connectAndVerifyDataTransfer(cb, sc);
169     }
170 
171     private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
172             throws InterruptedException {
173         BufferingClientHandler handler = new BufferingClientHandler();
174         cb.handler(handler);
175         ChannelFuture register = cb.register();
176         Channel channel = register.sync().channel();
177         ChannelFuture write = channel.write(writeAscii(DEFAULT, "[fastopen]"));
178         SocketAddress remoteAddress = sc.localAddress();
179         ChannelFuture connectFuture = channel.connect(remoteAddress);
180         Channel cc = connectFuture.sync().channel();
181         cc.writeAndFlush(writeAscii(DEFAULT, "[normal data]")).sync();
182         write.sync();
183         String expectedString = "[fastopen][normal data]";
184         String result = handler.collectBuffer(expectedString.getBytes(US_ASCII).length);
185         cc.disconnect().sync();
186         assertEquals(expectedString, result);
187     }
188 
189     protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
190         // TFO is an almost-pure optimisation and should not change any observable behaviour in our tests.
191         sb.option(ChannelOption.TCP_FASTOPEN, 5);
192         cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
193     }
194 
195     private static void assertLocalAddress(InetSocketAddress address) {
196         assertTrue(address.getPort() > 0);
197         assertFalse(address.getAddress().isAnyLocalAddress());
198     }
199 
200     private static class BufferingClientHandler extends ChannelInboundHandlerAdapter {
201         private final Semaphore semaphore = new Semaphore(0);
202         private final ByteArrayOutputStream streamBuffer = new ByteArrayOutputStream();
203 
204         @Override
205         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
206             if (msg instanceof ByteBuf) {
207                 ByteBuf buf = (ByteBuf) msg;
208                 int readableBytes = buf.readableBytes();
209                 buf.readBytes(streamBuffer, readableBytes);
210                 semaphore.release(readableBytes);
211                 buf.release();
212             } else {
213                 throw new IllegalArgumentException("Unexpected message type: " + msg);
214             }
215         }
216 
217         String collectBuffer(int expectedBytes) throws InterruptedException {
218             semaphore.acquire(expectedBytes);
219             byte[] bytes = streamBuffer.toByteArray();
220             streamBuffer.reset();
221             return new String(bytes, US_ASCII);
222         }
223     }
224 
225     private static final class EchoServerHandler extends ChannelInboundHandlerAdapter {
226         @Override
227         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
228             if (msg instanceof ByteBuf) {
229                 ByteBuf buffer = ctx.alloc().buffer();
230                 ByteBuf buf = (ByteBuf) msg;
231                 buffer.writeBytes(buf);
232                 buf.release();
233                 ctx.channel().writeAndFlush(buffer);
234             } else {
235                 throw new IllegalArgumentException("Unexpected message type: " + msg);
236             }
237         }
238     }
239 }