1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.Buffer;
21 import io.netty5.buffer.api.DefaultBufferAllocators;
22 import io.netty5.channel.Channel;
23 import io.netty5.channel.ChannelHandler;
24 import io.netty5.channel.ChannelHandlerAdapter;
25 import io.netty5.channel.ChannelHandlerContext;
26 import io.netty5.channel.ChannelInitializer;
27 import io.netty5.channel.ChannelOption;
28 import io.netty5.channel.socket.SocketChannel;
29 import io.netty5.util.concurrent.Future;
30 import io.netty5.util.concurrent.ImmediateEventExecutor;
31 import io.netty5.util.concurrent.Promise;
32 import org.junit.jupiter.api.Test;
33 import org.junit.jupiter.api.TestInfo;
34 import org.junit.jupiter.api.Timeout;
35
36 import java.io.ByteArrayOutputStream;
37 import java.net.InetSocketAddress;
38 import java.net.SocketAddress;
39 import java.util.concurrent.BlockingQueue;
40 import java.util.concurrent.LinkedBlockingQueue;
41 import java.util.concurrent.Semaphore;
42 import java.util.concurrent.TimeUnit;
43
44 import static io.netty5.util.CharsetUtil.US_ASCII;
45 import static org.junit.jupiter.api.Assertions.assertEquals;
46 import static org.junit.jupiter.api.Assertions.assertFalse;
47 import static org.junit.jupiter.api.Assertions.assertNotNull;
48 import static org.junit.jupiter.api.Assertions.assertNull;
49 import static org.junit.jupiter.api.Assertions.assertTrue;
50
51 public class SocketConnectTest extends AbstractSocketTest {
52
53 @Test
54 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
55 public void testLocalAddressAfterConnect(TestInfo testInfo) throws Throwable {
56 run(testInfo, this::testLocalAddressAfterConnect);
57 }
58
59 public void testLocalAddressAfterConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
60 Channel serverChannel = null;
61 Channel clientChannel = null;
62 try {
63 final Promise<InetSocketAddress> localAddressPromise = ImmediateEventExecutor.INSTANCE.newPromise();
64 serverChannel = sb.childHandler(new ChannelHandler() {
65 @Override
66 public void channelActive(ChannelHandlerContext ctx) throws Exception {
67 localAddressPromise.setSuccess((InetSocketAddress) ctx.channel().localAddress());
68 }
69 }).bind().asStage().get();
70
71 clientChannel = cb.handler(new ChannelHandler() { }).register().asStage().get();
72
73 assertNull(clientChannel.localAddress());
74 assertNull(clientChannel.remoteAddress());
75
76 clientChannel.connect(serverChannel.localAddress()).asStage().get();
77 assertLocalAddress((InetSocketAddress) clientChannel.localAddress());
78 assertNotNull(clientChannel.remoteAddress());
79
80 assertLocalAddress(localAddressPromise.asFuture().asStage().get());
81 } finally {
82 if (clientChannel != null) {
83 clientChannel.close().asStage().sync();
84 }
85 if (serverChannel != null) {
86 serverChannel.close().asStage().sync();
87 }
88 }
89 }
90
91 @Test
92 @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
93 public void testChannelEventsFiredWhenClosedDirectly(TestInfo testInfo) throws Throwable {
94 run(testInfo, this::testChannelEventsFiredWhenClosedDirectly);
95 }
96
97 public void testChannelEventsFiredWhenClosedDirectly(ServerBootstrap sb, Bootstrap cb) throws Throwable {
98 final BlockingQueue<Integer> events = new LinkedBlockingQueue<>();
99
100 Channel sc = null;
101 Channel cc = null;
102 try {
103 sb.childHandler(new ChannelHandler() { });
104 sc = sb.bind().asStage().get();
105
106 cb.handler(new ChannelHandler() {
107 @Override
108 public void channelActive(ChannelHandlerContext ctx) throws Exception {
109 events.add(0);
110 }
111
112 @Override
113 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
114 events.add(1);
115 }
116 });
117
118 cc = cb.connect(sc.localAddress()).addListener(future -> future.getNow().close()).asStage().get();
119 assertEquals(0, events.take().intValue());
120 assertEquals(1, events.take().intValue());
121 } finally {
122 if (cc != null) {
123 cc.close();
124 }
125 if (sc != null) {
126 sc.close();
127 }
128 }
129 }
130
131 @Test
132 @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
133 public void testWriteWithFastOpenBeforeConnect(TestInfo testInfo) throws Throwable {
134 run(testInfo, this::testWriteWithFastOpenBeforeConnect);
135 }
136
137 public void testWriteWithFastOpenBeforeConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
138 enableTcpFastOpen(sb, cb);
139 sb.childOption(ChannelOption.AUTO_READ, true);
140 cb.option(ChannelOption.AUTO_READ, true);
141
142 sb.childHandler(new ChannelInitializer<SocketChannel>() {
143 @Override
144 protected void initChannel(SocketChannel ch) throws Exception {
145 ch.pipeline().addLast(new EchoServerHandler());
146 }
147 });
148
149 Channel sc = sb.bind().asStage().get();
150 connectAndVerifyDataTransfer(cb, sc);
151 connectAndVerifyDataTransfer(cb, sc);
152 }
153
154 private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
155 throws Exception {
156 BufferingClientHandler handler = new BufferingClientHandler();
157 cb.handler(handler);
158 Future<Channel> register = cb.register();
159 Channel channel = register.asStage().get();
160 Future<Void> write = channel.write(writeAsciiBuffer(sc, "[fastopen]"));
161 SocketAddress remoteAddress = sc.localAddress();
162 Future<Void> connectFuture = channel.connect(remoteAddress);
163 connectFuture.asStage().sync();
164 channel.writeAndFlush(writeAsciiBuffer(sc, "[normal data]")).asStage().sync();
165 write.asStage().sync();
166 String expectedString = "[fastopen][normal data]";
167 String result = handler.collectBuffer(expectedString.getBytes(US_ASCII).length);
168 channel.disconnect().asStage().sync();
169 assertEquals(expectedString, result);
170 }
171
172 private static Object writeAsciiBuffer(Channel sc, String seq) {
173 return DefaultBufferAllocators.preferredAllocator().copyOf(seq, US_ASCII);
174 }
175
176 protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
177
178 sb.option(ChannelOption.TCP_FASTOPEN, 5);
179 cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
180 }
181
182 private static void assertLocalAddress(InetSocketAddress address) {
183 assertTrue(address.getPort() > 0);
184 assertFalse(address.getAddress().isAnyLocalAddress());
185 }
186
187 private static class BufferingClientHandler extends ChannelHandlerAdapter {
188 private final Semaphore semaphore = new Semaphore(0);
189 private final ByteArrayOutputStream streamBuffer = new ByteArrayOutputStream();
190
191 @Override
192 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
193 if (msg instanceof Buffer) {
194 try (Buffer buf = (Buffer) msg) {
195 int readableBytes = buf.readableBytes();
196 byte[] array = new byte[readableBytes];
197 buf.readBytes(array, 0, array.length);
198 streamBuffer.write(array);
199 semaphore.release(readableBytes);
200 }
201 } else {
202 throw new IllegalArgumentException("Unexpected message type: " + msg);
203 }
204 }
205
206 String collectBuffer(int expectedBytes) throws InterruptedException {
207 semaphore.acquire(expectedBytes);
208 String result = streamBuffer.toString(US_ASCII);
209 streamBuffer.reset();
210 return result;
211 }
212 }
213
214 private static final class EchoServerHandler extends ChannelHandlerAdapter {
215 @Override
216 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
217 if (msg instanceof Buffer) {
218 try (Buffer buf = (Buffer) msg) {
219 Buffer buffer = ctx.bufferAllocator().allocate(buf.readableBytes());
220 buffer.writeBytes(buf);
221 ctx.channel().writeAndFlush(buffer);
222 }
223 } else {
224 throw new IllegalArgumentException("Unexpected message type: " + msg);
225 }
226 }
227 }
228 }