View Javadoc
1   /*
2    * Copyright 2016 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelFuture;
23  import io.netty.channel.ChannelFutureListener;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelInboundHandlerAdapter;
26  import io.netty.channel.ChannelInitializer;
27  import io.netty.channel.ChannelOption;
28  import io.netty.channel.socket.SocketChannel;
29  import io.netty.util.concurrent.ImmediateEventExecutor;
30  import io.netty.util.concurrent.Promise;
31  import org.junit.jupiter.api.Test;
32  import org.junit.jupiter.api.TestInfo;
33  import org.junit.jupiter.api.Timeout;
34  
35  import java.io.ByteArrayOutputStream;
36  import java.net.InetSocketAddress;
37  import java.net.SocketAddress;
38  import java.util.concurrent.BlockingQueue;
39  import java.util.concurrent.LinkedBlockingQueue;
40  import java.util.concurrent.Semaphore;
41  import java.util.concurrent.TimeUnit;
42  
43  import static io.netty.buffer.ByteBufUtil.writeAscii;
44  import static io.netty.buffer.UnpooledByteBufAllocator.DEFAULT;
45  import static io.netty.util.CharsetUtil.US_ASCII;
46  import static org.junit.jupiter.api.Assertions.assertEquals;
47  import static org.junit.jupiter.api.Assertions.assertFalse;
48  import static org.junit.jupiter.api.Assertions.assertNotNull;
49  import static org.junit.jupiter.api.Assertions.assertNull;
50  import static org.junit.jupiter.api.Assertions.assertTrue;
51  
52  public class SocketConnectTest extends AbstractSocketTest {
53  
54      @Test
55      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
56      public void testCloseTwice(TestInfo testInfo) throws Throwable {
57          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
58              @Override
59              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
60                  testCloseTwice(serverBootstrap, bootstrap);
61              }
62          });
63      }
64  
65      public void testCloseTwice(ServerBootstrap sb, Bootstrap cb) throws Throwable {
66          Channel serverChannel = null;
67          Channel clientChannel = null;
68          try {
69              serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter()).bind().syncUninterruptibly().channel();
70              final BlockingQueue<ChannelFuture> futures = new LinkedBlockingQueue<>();
71              clientChannel = cb.handler(new ChannelInboundHandlerAdapter() {
72                          @Override
73                          public void userEventTriggered(ChannelHandlerContext ctx, Object evt)  {
74                              futures.add(ctx.close());
75                          }
76                      })
77                      .connect(serverChannel.localAddress()).syncUninterruptibly().channel();
78              clientChannel.pipeline().fireUserEventTriggered("test");
79              clientChannel.close().syncUninterruptibly();
80              futures.take().sync();
81              clientChannel = null;
82  
83              serverChannel.close().syncUninterruptibly();
84              serverChannel.close().syncUninterruptibly();
85              serverChannel = null;
86          } finally {
87              if (clientChannel != null) {
88                  clientChannel.close().syncUninterruptibly();
89              }
90              if (serverChannel != null) {
91                  serverChannel.close().syncUninterruptibly();
92              }
93          }
94      }
95  
96      @Test
97      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
98      public void testLocalAddressAfterConnect(TestInfo testInfo) throws Throwable {
99          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
100             @Override
101             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
102                 testLocalAddressAfterConnect(serverBootstrap, bootstrap);
103             }
104         });
105     }
106 
107     public void testLocalAddressAfterConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
108         Channel serverChannel = null;
109         Channel clientChannel = null;
110         try {
111             final Promise<InetSocketAddress> localAddressPromise = ImmediateEventExecutor.INSTANCE.newPromise();
112             serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter() {
113                         @Override
114                         public void channelActive(ChannelHandlerContext ctx) throws Exception {
115                             localAddressPromise.setSuccess((InetSocketAddress) ctx.channel().localAddress());
116                         }
117                     }).bind().syncUninterruptibly().channel();
118 
119             clientChannel = cb.handler(new ChannelInboundHandlerAdapter()).register().syncUninterruptibly().channel();
120 
121             assertNull(clientChannel.localAddress());
122             assertNull(clientChannel.remoteAddress());
123 
124             clientChannel.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
125             assertLocalAddress((InetSocketAddress) clientChannel.localAddress());
126             assertNotNull(clientChannel.remoteAddress());
127 
128             assertLocalAddress(localAddressPromise.get());
129         } finally {
130             if (clientChannel != null) {
131                 clientChannel.close().syncUninterruptibly();
132             }
133             if (serverChannel != null) {
134                 serverChannel.close().syncUninterruptibly();
135             }
136         }
137     }
138 
139     @Test
140     @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
141     public void testChannelEventsFiredWhenClosedDirectly(TestInfo testInfo) throws Throwable {
142         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
143             @Override
144             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
145                 testChannelEventsFiredWhenClosedDirectly(serverBootstrap, bootstrap);
146             }
147         });
148     }
149 
150     public void testChannelEventsFiredWhenClosedDirectly(ServerBootstrap sb, Bootstrap cb) throws Throwable {
151         final BlockingQueue<Integer> events = new LinkedBlockingQueue<Integer>();
152 
153         Channel sc = null;
154         Channel cc = null;
155         try {
156             sb.childHandler(new ChannelInboundHandlerAdapter());
157             sc = sb.bind().syncUninterruptibly().channel();
158 
159             cb.handler(new ChannelInboundHandlerAdapter() {
160                 @Override
161                 public void channelActive(ChannelHandlerContext ctx) throws Exception {
162                     events.add(0);
163                 }
164 
165                 @Override
166                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
167                     events.add(1);
168                 }
169             });
170             // Connect and directly close again.
171             cc = cb.connect(sc.localAddress()).addListener(ChannelFutureListener.CLOSE).
172                     syncUninterruptibly().channel();
173             assertEquals(0, events.take().intValue());
174             assertEquals(1, events.take().intValue());
175         } finally {
176             if (cc != null) {
177                 cc.close();
178             }
179             if (sc != null) {
180                 sc.close();
181             }
182         }
183     }
184 
185     @Test
186     @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
187     public void testWriteWithFastOpenBeforeConnect(TestInfo testInfo) throws Throwable {
188         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
189             @Override
190             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
191                 testWriteWithFastOpenBeforeConnect(serverBootstrap, bootstrap);
192             }
193         });
194     }
195 
196     public void testWriteWithFastOpenBeforeConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
197         enableTcpFastOpen(sb, cb);
198         sb.childOption(ChannelOption.AUTO_READ, true);
199         cb.option(ChannelOption.AUTO_READ, true);
200 
201         sb.childHandler(new ChannelInitializer<SocketChannel>() {
202             @Override
203             protected void initChannel(SocketChannel ch) throws Exception {
204                 ch.pipeline().addLast(new EchoServerHandler());
205             }
206         });
207 
208         Channel sc = sb.bind().sync().channel();
209         connectAndVerifyDataTransfer(cb, sc);
210         connectAndVerifyDataTransfer(cb, sc);
211     }
212 
213     private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
214             throws InterruptedException {
215         BufferingClientHandler handler = new BufferingClientHandler();
216         cb.handler(handler);
217         ChannelFuture register = cb.register();
218         Channel channel = register.sync().channel();
219         ChannelFuture write = channel.write(writeAscii(DEFAULT, "[fastopen]"));
220         SocketAddress remoteAddress = sc.localAddress();
221         ChannelFuture connectFuture = channel.connect(remoteAddress);
222         Channel cc = connectFuture.sync().channel();
223         cc.writeAndFlush(writeAscii(DEFAULT, "[normal data]")).sync();
224         write.sync();
225         String expectedString = "[fastopen][normal data]";
226         String result = handler.collectBuffer(expectedString.getBytes(US_ASCII).length);
227         cc.disconnect().sync();
228         assertEquals(expectedString, result);
229     }
230 
231     protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
232         // TFO is an almost-pure optimisation and should not change any observable behaviour in our tests.
233         sb.option(ChannelOption.TCP_FASTOPEN, 5);
234         cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
235     }
236 
237     private static void assertLocalAddress(InetSocketAddress address) {
238         assertTrue(address.getPort() > 0);
239         assertFalse(address.getAddress().isAnyLocalAddress());
240     }
241 
242     private static class BufferingClientHandler extends ChannelInboundHandlerAdapter {
243         private final Semaphore semaphore = new Semaphore(0);
244         private final ByteArrayOutputStream streamBuffer = new ByteArrayOutputStream();
245 
246         @Override
247         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
248             if (msg instanceof ByteBuf) {
249                 ByteBuf buf = (ByteBuf) msg;
250                 int readableBytes = buf.readableBytes();
251                 buf.readBytes(streamBuffer, readableBytes);
252                 semaphore.release(readableBytes);
253                 buf.release();
254             } else {
255                 throw new IllegalArgumentException("Unexpected message type: " + msg);
256             }
257         }
258 
259         String collectBuffer(int expectedBytes) throws InterruptedException {
260             semaphore.acquire(expectedBytes);
261             byte[] bytes = streamBuffer.toByteArray();
262             streamBuffer.reset();
263             return new String(bytes, US_ASCII);
264         }
265     }
266 
267     private static final class EchoServerHandler extends ChannelInboundHandlerAdapter {
268         @Override
269         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
270             if (msg instanceof ByteBuf) {
271                 ByteBuf buffer = ctx.alloc().buffer();
272                 ByteBuf buf = (ByteBuf) msg;
273                 buffer.writeBytes(buf);
274                 buf.release();
275                 ctx.channel().writeAndFlush(buffer);
276             } else {
277                 throw new IllegalArgumentException("Unexpected message type: " + msg);
278             }
279         }
280     }
281 }