1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.Buffer;
21 import io.netty5.channel.Channel;
22 import io.netty5.channel.ChannelHandler;
23 import io.netty5.channel.ChannelHandlerContext;
24 import io.netty5.channel.ChannelInitializer;
25 import io.netty5.channel.ChannelOption;
26 import io.netty5.testsuite.transport.TestsuitePermutation;
27 import io.netty5.util.concurrent.FutureListener;
28 import io.netty5.util.concurrent.ImmediateEventExecutor;
29 import io.netty5.util.concurrent.Promise;
30 import org.junit.jupiter.api.Test;
31 import org.junit.jupiter.api.TestInfo;
32 import org.junit.jupiter.api.Timeout;
33
34 import java.io.IOException;
35 import java.net.SocketAddress;
36 import java.util.List;
37 import java.util.concurrent.TimeUnit;
38 import java.util.concurrent.atomic.AtomicInteger;
39 import java.util.concurrent.atomic.AtomicReference;
40
41 import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator;
42 import static io.netty5.util.CharsetUtil.US_ASCII;
43
44 public abstract class AbstractSocketReuseFdTest extends AbstractSocketTest {
45 @Override
46 protected abstract SocketAddress newSocketAddress();
47
48 @Override
49 protected abstract List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> newFactories();
50
51 @Test
52 @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
53 public void testReuseFd(TestInfo testInfo) throws Throwable {
54 run(testInfo, this::testReuseFd);
55 }
56
57 public void testReuseFd(ServerBootstrap sb, Bootstrap cb) throws Throwable {
58 sb.childOption(ChannelOption.AUTO_READ, true);
59 cb.option(ChannelOption.AUTO_READ, true);
60
61
62
63 int numChannels = 100;
64 final AtomicReference<Throwable> globalException = new AtomicReference<>();
65 final AtomicInteger serverRemaining = new AtomicInteger(numChannels);
66 final AtomicInteger clientRemaining = new AtomicInteger(numChannels);
67 final Promise<Void> serverDonePromise = ImmediateEventExecutor.INSTANCE.newPromise();
68 final Promise<Void> clientDonePromise = ImmediateEventExecutor.INSTANCE.newPromise();
69
70 sb.childHandler(new ChannelInitializer<>() {
71 @Override
72 public void initChannel(Channel sch) {
73 ReuseFdHandler sh = new ReuseFdHandler(
74 false,
75 globalException,
76 serverRemaining,
77 serverDonePromise);
78 sch.pipeline().addLast("handler", sh);
79 }
80 });
81
82 cb.handler(new ChannelInitializer<>() {
83 @Override
84 public void initChannel(Channel sch) {
85 ReuseFdHandler ch = new ReuseFdHandler(
86 true,
87 globalException,
88 clientRemaining,
89 clientDonePromise);
90 sch.pipeline().addLast("handler", ch);
91 }
92 });
93
94 FutureListener<Channel> listener = future -> {
95 if (future.isFailed()) {
96 clientDonePromise.tryFailure(future.cause());
97 }
98 };
99
100 Channel sc = sb.bind().asStage().get();
101 for (int i = 0; i < numChannels; i++) {
102 cb.connect(sc.localAddress()).addListener(listener);
103 }
104
105 clientDonePromise.asFuture().asStage().sync();
106 serverDonePromise.asFuture().asStage().sync();
107 sc.close().asStage().sync();
108
109 if (globalException.get() != null && !(globalException.get() instanceof IOException)) {
110 throw globalException.get();
111 }
112 }
113
114 static class ReuseFdHandler implements ChannelHandler {
115 private static final String EXPECTED_PAYLOAD = "payload";
116
117 private final Promise<Void> donePromise;
118 private final AtomicInteger remaining;
119 private final boolean client;
120 volatile Channel channel;
121 final AtomicReference<Throwable> globalException;
122 final AtomicReference<Throwable> exception = new AtomicReference<>();
123 final StringBuilder received = new StringBuilder();
124
125 ReuseFdHandler(
126 boolean client,
127 AtomicReference<Throwable> globalException,
128 AtomicInteger remaining,
129 Promise<Void> donePromise) {
130 this.client = client;
131 this.globalException = globalException;
132 this.remaining = remaining;
133 this.donePromise = donePromise;
134 }
135
136 @Override
137 public void channelActive(ChannelHandlerContext ctx) {
138 channel = ctx.channel();
139 if (client) {
140 ctx.writeAndFlush(preferredAllocator().copyOf(EXPECTED_PAYLOAD, US_ASCII));
141 }
142 }
143
144 @Override
145 public void channelRead(ChannelHandlerContext ctx, Object msg) {
146 Buffer buf = (Buffer) msg;
147 received.append(buf.toString(US_ASCII));
148 buf.close();
149
150 if (received.toString().equals(EXPECTED_PAYLOAD)) {
151 if (client) {
152 ctx.close();
153 } else {
154 ctx.writeAndFlush(preferredAllocator().copyOf(EXPECTED_PAYLOAD, US_ASCII));
155 }
156 }
157 }
158
159 @Override
160 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
161 if (exception.compareAndSet(null, cause)) {
162 donePromise.tryFailure(new IllegalStateException("exceptionCaught: " + ctx.channel(), cause));
163 ctx.close();
164 }
165 globalException.compareAndSet(null, cause);
166 }
167
168 @Override
169 public void channelInactive(ChannelHandlerContext ctx) {
170 if (remaining.decrementAndGet() == 0) {
171 if (received.toString().equals(EXPECTED_PAYLOAD)) {
172 donePromise.setSuccess(null);
173 } else {
174 donePromise.tryFailure(new Exception("Unexpected payload:" + received));
175 }
176 }
177 }
178 }
179 }