1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.CompositeByteBuf;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFuture;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelOption;
27 import io.netty.channel.SimpleChannelInboundHandler;
28 import io.netty.testsuite.util.TestUtils;
29 import io.netty.util.concurrent.ImmediateEventExecutor;
30 import io.netty.util.concurrent.Promise;
31 import io.netty.util.internal.StringUtil;
32 import org.junit.jupiter.api.AfterAll;
33 import org.junit.jupiter.api.Test;
34 import org.junit.jupiter.api.TestInfo;
35 import org.junit.jupiter.api.Timeout;
36
37 import java.io.IOException;
38 import java.util.Random;
39 import java.util.concurrent.TimeUnit;
40 import java.util.concurrent.atomic.AtomicReference;
41
42 import static io.netty.buffer.Unpooled.compositeBuffer;
43 import static io.netty.buffer.Unpooled.wrappedBuffer;
44 import static org.junit.jupiter.api.Assertions.assertEquals;
45 import static org.junit.jupiter.api.Assertions.assertNotEquals;
46 import static org.junit.jupiter.api.Assertions.assertTrue;
47
48 public class SocketGatheringWriteTest extends AbstractSocketTest {
49 private static final long TIMEOUT = 120000;
50
51 private static final Random random = new Random();
52 static final byte[] data = new byte[1048576];
53
54 static {
55 random.nextBytes(data);
56 }
57
58 @AfterAll
59 public static void compressHeapDumps() throws Exception {
60 TestUtils.compressHeapDumps();
61 }
62
63 @Test
64 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
65 public void testGatheringWrite(TestInfo testInfo) throws Throwable {
66 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
67 @Override
68 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
69 testGatheringWrite(serverBootstrap, bootstrap);
70 }
71 });
72 }
73
74 public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
75 testGatheringWrite0(sb, cb, data, false, true);
76 }
77
78 @Test
79 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
80 public void testGatheringWriteNotAutoRead(TestInfo testInfo) throws Throwable {
81 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
82 @Override
83 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
84 testGatheringWriteNotAutoRead(serverBootstrap, bootstrap);
85 }
86 });
87 }
88
89 public void testGatheringWriteNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
90 testGatheringWrite0(sb, cb, data, false, false);
91 }
92
93 @Test
94 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
95 public void testGatheringWriteWithComposite(TestInfo testInfo) throws Throwable {
96 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
97 @Override
98 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
99 testGatheringWriteWithComposite(serverBootstrap, bootstrap);
100 }
101 });
102 }
103
104 public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
105 testGatheringWrite0(sb, cb, data, true, true);
106 }
107
108 @Test
109 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
110 public void testGatheringWriteWithCompositeNotAutoRead(TestInfo testInfo) throws Throwable {
111 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
112 @Override
113 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
114 testGatheringWriteWithCompositeNotAutoRead(serverBootstrap, bootstrap);
115 }
116 });
117 }
118
119 public void testGatheringWriteWithCompositeNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
120 testGatheringWrite0(sb, cb, data, true, false);
121 }
122
123
124 @Test
125 @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
126 public void testGatheringWriteBig(TestInfo testInfo) throws Throwable {
127 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
128 @Override
129 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
130 testGatheringWriteBig(serverBootstrap, bootstrap);
131 }
132 });
133 }
134
135 public void testGatheringWriteBig(ServerBootstrap sb, Bootstrap cb) throws Throwable {
136 byte[] bigData = new byte[1024 * 1024 * 50];
137 random.nextBytes(bigData);
138 testGatheringWrite0(sb, cb, bigData, false, true);
139 }
140
141 private void testGatheringWrite0(
142 ServerBootstrap sb, Bootstrap cb, byte[] data, boolean composite, boolean autoRead) throws Throwable {
143 sb.childOption(ChannelOption.AUTO_READ, autoRead);
144 cb.option(ChannelOption.AUTO_READ, autoRead);
145
146 Promise<Void> serverDonePromise = ImmediateEventExecutor.INSTANCE.newPromise();
147 final TestServerHandler sh = new TestServerHandler(autoRead, serverDonePromise, data.length);
148 final TestHandler ch = new TestHandler(autoRead);
149
150 cb.handler(ch);
151 sb.childHandler(sh);
152
153 Channel sc = sb.bind().sync().channel();
154 Channel cc = cb.connect(sc.localAddress()).sync().channel();
155
156 for (int i = 0; i < data.length;) {
157 int length = Math.min(random.nextInt(1024 * 8), data.length - i);
158 if (composite && i % 2 == 0) {
159 int firstBufLength = length / 2;
160 CompositeByteBuf comp = compositeBuffer();
161 comp.addComponent(true, wrappedBuffer(data, i, firstBufLength))
162 .addComponent(true, wrappedBuffer(data, i + firstBufLength, length - firstBufLength));
163 cc.write(comp);
164 } else {
165 cc.write(wrappedBuffer(data, i, length));
166 }
167 i += length;
168 }
169
170 ChannelFuture cf = cc.writeAndFlush(Unpooled.EMPTY_BUFFER);
171 assertNotEquals(cc.voidPromise(), cf);
172 try {
173 assertTrue(cf.await(60000));
174 cf.sync();
175 } catch (Throwable t) {
176
177 TestUtils.dump(StringUtil.simpleClassName(this));
178 throw t;
179 }
180
181 serverDonePromise.sync();
182 sh.channel.close().sync();
183 ch.channel.close().sync();
184 sc.close().sync();
185
186 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
187 throw sh.exception.get();
188 }
189 if (sh.exception.get() != null) {
190 throw sh.exception.get();
191 }
192 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
193 throw ch.exception.get();
194 }
195 if (ch.exception.get() != null) {
196 throw ch.exception.get();
197 }
198 ByteBuf expected = wrappedBuffer(data);
199 assertEquals(expected, sh.received);
200 expected.release();
201 sh.received.release();
202 }
203
204 private static final class TestServerHandler extends TestHandler {
205 private final int expectedBytes;
206 private final Promise<Void> doneReadingPromise;
207 final ByteBuf received = Unpooled.buffer();
208
209 TestServerHandler(boolean autoRead, Promise<Void> doneReadingPromise, int expectedBytes) {
210 super(autoRead);
211 this.doneReadingPromise = doneReadingPromise;
212 this.expectedBytes = expectedBytes;
213 }
214
215 @Override
216 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
217 received.writeBytes(in);
218 if (received.readableBytes() >= expectedBytes) {
219 doneReadingPromise.setSuccess(null);
220 }
221 }
222
223 @Override
224 void handleException(ChannelHandlerContext ctx, Throwable cause) {
225 doneReadingPromise.tryFailure(cause);
226 super.handleException(ctx, cause);
227 }
228
229 @Override
230 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
231 doneReadingPromise.tryFailure(new IllegalStateException("server closed!"));
232 super.channelInactive(ctx);
233 }
234 }
235
236 private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
237 private final boolean autoRead;
238 volatile Channel channel;
239 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
240
241 TestHandler(boolean autoRead) {
242 this.autoRead = autoRead;
243 }
244
245 @Override
246 public final void channelActive(ChannelHandlerContext ctx) throws Exception {
247 channel = ctx.channel();
248 if (!autoRead) {
249 ctx.read();
250 }
251 super.channelActive(ctx);
252 }
253
254 @Override
255 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
256 }
257
258 @Override
259 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
260 if (!autoRead) {
261 ctx.read();
262 }
263 super.channelReadComplete(ctx);
264 }
265
266 @Override
267 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
268 if (exception.compareAndSet(null, cause)) {
269 handleException(ctx, cause);
270 }
271 super.exceptionCaught(ctx, cause);
272 }
273
274 void handleException(ChannelHandlerContext ctx, Throwable cause) {
275 ctx.close();
276 }
277 }
278 }