1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.CompositeByteBuf;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFuture;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelOption;
27 import io.netty.channel.SimpleChannelInboundHandler;
28 import io.netty.testsuite.util.TestUtils;
29 import io.netty.util.concurrent.ImmediateEventExecutor;
30 import io.netty.util.concurrent.Promise;
31 import io.netty.util.internal.StringUtil;
32 import org.junit.jupiter.api.AfterAll;
33 import org.junit.jupiter.api.Test;
34 import org.junit.jupiter.api.TestInfo;
35 import org.junit.jupiter.api.Timeout;
36
37 import java.io.IOException;
38 import java.util.Random;
39 import java.util.concurrent.TimeUnit;
40 import java.util.concurrent.atomic.AtomicReference;
41
42 import static io.netty.buffer.Unpooled.compositeBuffer;
43 import static io.netty.buffer.Unpooled.wrappedBuffer;
44 import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
45 import static org.junit.jupiter.api.Assertions.assertEquals;
46 import static org.junit.jupiter.api.Assertions.assertNotEquals;
47 import static org.junit.jupiter.api.Assertions.assertTrue;
48
49 public class SocketGatheringWriteTest extends AbstractSocketTest {
50 private static final long TIMEOUT = 120000;
51
52 private static final Random random = new Random();
53 static final byte[] data = new byte[1048576];
54
55 static {
56 random.nextBytes(data);
57 }
58
59 @AfterAll
60 public static void compressHeapDumps() throws Exception {
61 TestUtils.compressHeapDumps();
62 }
63
64 @Test
65 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
66 public void testGatheringWrite(TestInfo testInfo) throws Throwable {
67 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
68 @Override
69 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
70 testGatheringWrite(serverBootstrap, bootstrap);
71 }
72 });
73 }
74
75 public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
76 testGatheringWrite0(sb, cb, data, false, true);
77 }
78
79 @Test
80 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
81 public void testGatheringWriteNotAutoRead(TestInfo testInfo) throws Throwable {
82 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
83 @Override
84 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
85 testGatheringWriteNotAutoRead(serverBootstrap, bootstrap);
86 }
87 });
88 }
89
90 public void testGatheringWriteNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
91 testGatheringWrite0(sb, cb, data, false, false);
92 }
93
94 @Test
95 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
96 public void testGatheringWriteWithComposite(TestInfo testInfo) throws Throwable {
97 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
98 @Override
99 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
100 testGatheringWriteWithComposite(serverBootstrap, bootstrap);
101 }
102 });
103 }
104
105 public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
106 testGatheringWrite0(sb, cb, data, true, true);
107 }
108
109 @Test
110 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
111 public void testGatheringWriteWithCompositeNotAutoRead(TestInfo testInfo) throws Throwable {
112 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
113 @Override
114 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
115 testGatheringWriteWithCompositeNotAutoRead(serverBootstrap, bootstrap);
116 }
117 });
118 }
119
120 public void testGatheringWriteWithCompositeNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
121 testGatheringWrite0(sb, cb, data, true, false);
122 }
123
124
125 @Test
126 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
127 public void testGatheringWriteBig(TestInfo testInfo) throws Throwable {
128 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
129 @Override
130 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
131 testGatheringWriteBig(serverBootstrap, bootstrap);
132 }
133 });
134 }
135
136 public void testGatheringWriteBig(ServerBootstrap sb, Bootstrap cb) throws Throwable {
137 byte[] bigData = new byte[1024 * 1024 * 50];
138 random.nextBytes(bigData);
139 testGatheringWrite0(sb, cb, bigData, false, true);
140 }
141
142 private void testGatheringWrite0(
143 ServerBootstrap sb, Bootstrap cb, byte[] data, boolean composite, boolean autoRead) throws Throwable {
144 sb.childOption(ChannelOption.AUTO_READ, autoRead);
145 cb.option(ChannelOption.AUTO_READ, autoRead);
146
147 Promise<Void> serverDonePromise = ImmediateEventExecutor.INSTANCE.newPromise();
148 final TestServerHandler sh = new TestServerHandler(autoRead, serverDonePromise, data.length);
149 final TestHandler ch = new TestHandler(autoRead);
150
151 cb.handler(ch);
152 sb.childHandler(sh);
153
154 Channel sc = sb.bind().sync().channel();
155 Channel cc = cb.connect(sc.localAddress()).sync().channel();
156
157 for (int i = 0; i < data.length;) {
158 int length = Math.min(random.nextInt(1024 * 8), data.length - i);
159 if (composite && i % 2 == 0) {
160 int firstBufLength = length / 2;
161 CompositeByteBuf comp = compositeBuffer();
162 comp.addComponent(true,
163 randomBufferType(cc.alloc(), data, i, firstBufLength))
164 .addComponent(true,
165 randomBufferType(cc.alloc(), data, i + firstBufLength, length - firstBufLength));
166 cc.write(comp);
167 } else {
168 cc.write(randomBufferType(cc.alloc(), data, i, length));
169 }
170 i += length;
171 }
172
173 ChannelFuture cf = cc.writeAndFlush(Unpooled.EMPTY_BUFFER);
174 assertNotEquals(cc.voidPromise(), cf);
175 try {
176 assertTrue(cf.await(60000));
177 cf.sync();
178 } catch (Throwable t) {
179
180 TestUtils.dump(StringUtil.simpleClassName(this));
181 throw t;
182 }
183
184 serverDonePromise.sync();
185 sh.channel.close().sync();
186 ch.channel.close().sync();
187 sc.close().sync();
188
189 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
190 throw sh.exception.get();
191 }
192 if (sh.exception.get() != null) {
193 throw sh.exception.get();
194 }
195 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
196 throw ch.exception.get();
197 }
198 if (ch.exception.get() != null) {
199 throw ch.exception.get();
200 }
201 ByteBuf expected = wrappedBuffer(data);
202 assertEquals(expected, sh.received);
203 expected.release();
204 sh.received.release();
205 }
206
207 private static final class TestServerHandler extends TestHandler {
208 private final int expectedBytes;
209 private final Promise<Void> doneReadingPromise;
210 final ByteBuf received = Unpooled.buffer();
211
212 TestServerHandler(boolean autoRead, Promise<Void> doneReadingPromise, int expectedBytes) {
213 super(autoRead);
214 this.doneReadingPromise = doneReadingPromise;
215 this.expectedBytes = expectedBytes;
216 }
217
218 @Override
219 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
220 received.writeBytes(in);
221 if (received.readableBytes() >= expectedBytes) {
222 doneReadingPromise.setSuccess(null);
223 }
224 }
225
226 @Override
227 void handleException(ChannelHandlerContext ctx, Throwable cause) {
228 doneReadingPromise.tryFailure(cause);
229 super.handleException(ctx, cause);
230 }
231
232 @Override
233 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
234 doneReadingPromise.tryFailure(new IllegalStateException("server closed!"));
235 super.channelInactive(ctx);
236 }
237 }
238
239 private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
240 private final boolean autoRead;
241 volatile Channel channel;
242 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
243
244 TestHandler(boolean autoRead) {
245 this.autoRead = autoRead;
246 }
247
248 @Override
249 public final void channelActive(ChannelHandlerContext ctx) throws Exception {
250 channel = ctx.channel();
251 if (!autoRead) {
252 ctx.read();
253 }
254 super.channelActive(ctx);
255 }
256
257 @Override
258 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
259 }
260
261 @Override
262 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
263 if (!autoRead) {
264 ctx.read();
265 }
266 super.channelReadComplete(ctx);
267 }
268
269 @Override
270 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
271 if (exception.compareAndSet(null, cause)) {
272 handleException(ctx, cause);
273 }
274 super.exceptionCaught(ctx, cause);
275 }
276
277 void handleException(ChannelHandlerContext ctx, Throwable cause) {
278 ctx.close();
279 }
280 }
281 }