1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelFuture;
23 import io.netty.channel.ChannelFutureListener;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelInboundHandlerAdapter;
26 import io.netty.channel.ChannelInitializer;
27 import io.netty.channel.ChannelOption;
28 import io.netty.channel.socket.SocketChannel;
29 import io.netty.util.concurrent.ImmediateEventExecutor;
30 import io.netty.util.concurrent.Promise;
31 import org.junit.jupiter.api.Test;
32 import org.junit.jupiter.api.TestInfo;
33 import org.junit.jupiter.api.Timeout;
34
35 import java.io.ByteArrayOutputStream;
36 import java.net.InetSocketAddress;
37 import java.net.SocketAddress;
38 import java.util.concurrent.BlockingQueue;
39 import java.util.concurrent.LinkedBlockingQueue;
40 import java.util.concurrent.Semaphore;
41 import java.util.concurrent.TimeUnit;
42
43 import static io.netty.buffer.ByteBufUtil.writeAscii;
44 import static io.netty.buffer.UnpooledByteBufAllocator.DEFAULT;
45 import static io.netty.util.CharsetUtil.US_ASCII;
46 import static org.junit.jupiter.api.Assertions.assertEquals;
47 import static org.junit.jupiter.api.Assertions.assertFalse;
48 import static org.junit.jupiter.api.Assertions.assertNotNull;
49 import static org.junit.jupiter.api.Assertions.assertNull;
50 import static org.junit.jupiter.api.Assertions.assertTrue;
51
52 public class SocketConnectTest extends AbstractSocketTest {
53
54 @Test
55 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
56 public void testCloseTwice(TestInfo testInfo) throws Throwable {
57 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
58 @Override
59 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
60 testCloseTwice(serverBootstrap, bootstrap);
61 }
62 });
63 }
64
65 public void testCloseTwice(ServerBootstrap sb, Bootstrap cb) throws Throwable {
66 Channel serverChannel = null;
67 Channel clientChannel = null;
68 try {
69 serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter()).bind().syncUninterruptibly().channel();
70 final BlockingQueue<ChannelFuture> futures = new LinkedBlockingQueue<>();
71 clientChannel = cb.handler(new ChannelInboundHandlerAdapter() {
72 @Override
73 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
74 futures.add(ctx.close());
75 }
76 })
77 .connect(serverChannel.localAddress()).syncUninterruptibly().channel();
78 clientChannel.pipeline().fireUserEventTriggered("test");
79 clientChannel.close().syncUninterruptibly();
80 futures.take().sync();
81 clientChannel = null;
82
83 serverChannel.close().syncUninterruptibly();
84 serverChannel.close().syncUninterruptibly();
85 serverChannel = null;
86 } finally {
87 if (clientChannel != null) {
88 clientChannel.close().syncUninterruptibly();
89 }
90 if (serverChannel != null) {
91 serverChannel.close().syncUninterruptibly();
92 }
93 }
94 }
95
96 @Test
97 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
98 public void testLocalAddressAfterConnect(TestInfo testInfo) throws Throwable {
99 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
100 @Override
101 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
102 testLocalAddressAfterConnect(serverBootstrap, bootstrap);
103 }
104 });
105 }
106
107 public void testLocalAddressAfterConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
108 Channel serverChannel = null;
109 Channel clientChannel = null;
110 try {
111 final Promise<InetSocketAddress> localAddressPromise = ImmediateEventExecutor.INSTANCE.newPromise();
112 serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter() {
113 @Override
114 public void channelActive(ChannelHandlerContext ctx) throws Exception {
115 localAddressPromise.setSuccess((InetSocketAddress) ctx.channel().localAddress());
116 }
117 }).bind().syncUninterruptibly().channel();
118
119 clientChannel = cb.handler(new ChannelInboundHandlerAdapter()).register().syncUninterruptibly().channel();
120
121 assertNull(clientChannel.localAddress());
122 assertNull(clientChannel.remoteAddress());
123
124 clientChannel.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
125 assertLocalAddress((InetSocketAddress) clientChannel.localAddress());
126 assertNotNull(clientChannel.remoteAddress());
127
128 assertLocalAddress(localAddressPromise.get());
129 } finally {
130 if (clientChannel != null) {
131 clientChannel.close().syncUninterruptibly();
132 }
133 if (serverChannel != null) {
134 serverChannel.close().syncUninterruptibly();
135 }
136 }
137 }
138
139 @Test
140 @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
141 public void testChannelEventsFiredWhenClosedDirectly(TestInfo testInfo) throws Throwable {
142 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
143 @Override
144 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
145 testChannelEventsFiredWhenClosedDirectly(serverBootstrap, bootstrap);
146 }
147 });
148 }
149
150 public void testChannelEventsFiredWhenClosedDirectly(ServerBootstrap sb, Bootstrap cb) throws Throwable {
151 final BlockingQueue<Integer> events = new LinkedBlockingQueue<Integer>();
152
153 Channel sc = null;
154 Channel cc = null;
155 try {
156 sb.childHandler(new ChannelInboundHandlerAdapter());
157 sc = sb.bind().syncUninterruptibly().channel();
158
159 cb.handler(new ChannelInboundHandlerAdapter() {
160 @Override
161 public void channelActive(ChannelHandlerContext ctx) throws Exception {
162 events.add(0);
163 }
164
165 @Override
166 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
167 events.add(1);
168 }
169 });
170
171 cc = cb.connect(sc.localAddress()).addListener(ChannelFutureListener.CLOSE).
172 syncUninterruptibly().channel();
173 assertEquals(0, events.take().intValue());
174 assertEquals(1, events.take().intValue());
175 } finally {
176 if (cc != null) {
177 cc.close();
178 }
179 if (sc != null) {
180 sc.close();
181 }
182 }
183 }
184
185 @Test
186 @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
187 public void testWriteWithFastOpenBeforeConnect(TestInfo testInfo) throws Throwable {
188 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
189 @Override
190 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
191 testWriteWithFastOpenBeforeConnect(serverBootstrap, bootstrap);
192 }
193 });
194 }
195
196 public void testWriteWithFastOpenBeforeConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
197 enableTcpFastOpen(sb, cb);
198 sb.childOption(ChannelOption.AUTO_READ, true);
199 cb.option(ChannelOption.AUTO_READ, true);
200
201 sb.childHandler(new ChannelInitializer<SocketChannel>() {
202 @Override
203 protected void initChannel(SocketChannel ch) throws Exception {
204 ch.pipeline().addLast(new EchoServerHandler());
205 }
206 });
207
208 Channel sc = sb.bind().sync().channel();
209 connectAndVerifyDataTransfer(cb, sc);
210 connectAndVerifyDataTransfer(cb, sc);
211 }
212
213 private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
214 throws InterruptedException {
215 BufferingClientHandler handler = new BufferingClientHandler();
216 cb.handler(handler);
217 ChannelFuture register = cb.register();
218 Channel channel = register.sync().channel();
219 ChannelFuture write = channel.write(writeAscii(DEFAULT, "[fastopen]"));
220 SocketAddress remoteAddress = sc.localAddress();
221 ChannelFuture connectFuture = channel.connect(remoteAddress);
222 Channel cc = connectFuture.sync().channel();
223 cc.writeAndFlush(writeAscii(DEFAULT, "[normal data]")).sync();
224 write.sync();
225 String expectedString = "[fastopen][normal data]";
226 String result = handler.collectBuffer(expectedString.getBytes(US_ASCII).length);
227 cc.disconnect().sync();
228 assertEquals(expectedString, result);
229 }
230
231 protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
232
233 sb.option(ChannelOption.TCP_FASTOPEN, 5);
234 cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
235 }
236
237 private static void assertLocalAddress(InetSocketAddress address) {
238 assertTrue(address.getPort() > 0);
239 assertFalse(address.getAddress().isAnyLocalAddress());
240 }
241
242 private static class BufferingClientHandler extends ChannelInboundHandlerAdapter {
243 private final Semaphore semaphore = new Semaphore(0);
244 private final ByteArrayOutputStream streamBuffer = new ByteArrayOutputStream();
245
246 @Override
247 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
248 if (msg instanceof ByteBuf) {
249 ByteBuf buf = (ByteBuf) msg;
250 int readableBytes = buf.readableBytes();
251 buf.readBytes(streamBuffer, readableBytes);
252 semaphore.release(readableBytes);
253 buf.release();
254 } else {
255 throw new IllegalArgumentException("Unexpected message type: " + msg);
256 }
257 }
258
259 String collectBuffer(int expectedBytes) throws InterruptedException {
260 semaphore.acquire(expectedBytes);
261 byte[] bytes = streamBuffer.toByteArray();
262 streamBuffer.reset();
263 return new String(bytes, US_ASCII);
264 }
265 }
266
267 private static final class EchoServerHandler extends ChannelInboundHandlerAdapter {
268 @Override
269 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
270 if (msg instanceof ByteBuf) {
271 ByteBuf buffer = ctx.alloc().buffer();
272 ByteBuf buf = (ByteBuf) msg;
273 buffer.writeBytes(buf);
274 buf.release();
275 ctx.channel().writeAndFlush(buffer);
276 } else {
277 throw new IllegalArgumentException("Unexpected message type: " + msg);
278 }
279 }
280 }
281 }