View Javadoc
1   /*
2    * Copyright 2013 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.CompositeByteBuf;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelFuture;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelInitializer;
27  import io.netty.channel.ChannelOption;
28  import io.netty.channel.ChannelPromise;
29  import io.netty.channel.SimpleChannelInboundHandler;
30  import io.netty.testsuite.util.TestUtils;
31  import io.netty.util.concurrent.ImmediateEventExecutor;
32  import io.netty.util.concurrent.Promise;
33  import io.netty.util.internal.StringUtil;
34  import org.junit.jupiter.api.AfterAll;
35  import org.junit.jupiter.api.Test;
36  import org.junit.jupiter.api.TestInfo;
37  import org.junit.jupiter.api.Timeout;
38  import org.opentest4j.TestAbortedException;
39  
40  import java.io.IOException;
41  import java.util.Random;
42  import java.util.concurrent.TimeUnit;
43  import java.util.concurrent.atomic.AtomicInteger;
44  import java.util.concurrent.atomic.AtomicReference;
45  
46  import static io.netty.buffer.Unpooled.compositeBuffer;
47  import static io.netty.buffer.Unpooled.wrappedBuffer;
48  import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
49  import static org.junit.jupiter.api.Assertions.assertEquals;
50  import static org.junit.jupiter.api.Assertions.assertNotEquals;
51  import static org.junit.jupiter.api.Assertions.assertSame;
52  import static org.junit.jupiter.api.Assertions.assertTrue;
53  
54  public class SocketGatheringWriteTest extends AbstractSocketTest {
55      private static final long TIMEOUT = 120000;
56  
57      private static final Random random = new Random();
58      static final byte[] data = new byte[1048576];
59  
60      static {
61          random.nextBytes(data);
62      }
63  
64      @AfterAll
65      public static void compressHeapDumps() throws Exception {
66          TestUtils.compressHeapDumps();
67      }
68  
69      @Test
70      @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
71      public void testGatheringWrite(TestInfo testInfo) throws Throwable {
72          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
73              @Override
74              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
75                  testGatheringWrite(serverBootstrap, bootstrap);
76              }
77          });
78      }
79  
80      public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
81          testGatheringWrite0(sb, cb, data, false, true);
82      }
83  
84      @Test
85      @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
86      public void testGatheringWriteNotAutoRead(TestInfo testInfo) throws Throwable {
87          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
88              @Override
89              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
90                  testGatheringWriteNotAutoRead(serverBootstrap, bootstrap);
91              }
92          });
93      }
94  
95      public void testGatheringWriteNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
96          testGatheringWrite0(sb, cb, data, false, false);
97      }
98  
99      @Test
100     @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
101     public void testGatheringWriteWithComposite(TestInfo testInfo) throws Throwable {
102         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
103             @Override
104             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
105                 testGatheringWriteWithComposite(serverBootstrap, bootstrap);
106             }
107         });
108     }
109 
110     public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
111         testGatheringWrite0(sb, cb, data, true, true);
112     }
113 
114     @Test
115     @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
116     public void testGatheringWriteWithCompositeNotAutoRead(TestInfo testInfo) throws Throwable {
117         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
118             @Override
119             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
120                 testGatheringWriteWithCompositeNotAutoRead(serverBootstrap, bootstrap);
121             }
122         });
123     }
124 
125     public void testGatheringWriteWithCompositeNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
126         testGatheringWrite0(sb, cb, data, true, false);
127     }
128 
129     // Test for https://github.com/netty/netty/issues/2647
130     @Test
131     @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
132     public void testGatheringWriteBig(TestInfo testInfo) throws Throwable {
133         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
134             @Override
135             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
136                 testGatheringWriteBig(serverBootstrap, bootstrap);
137             }
138         });
139     }
140 
141     public void testGatheringWriteBig(ServerBootstrap sb, Bootstrap cb) throws Throwable {
142         byte[] bigData = new byte[1024 * 1024 * 50];
143         random.nextBytes(bigData);
144         testGatheringWrite0(sb, cb, bigData, false, true);
145     }
146 
147     private void testGatheringWrite0(
148             ServerBootstrap sb, Bootstrap cb, byte[] data, boolean composite, boolean autoRead) throws Throwable {
149         sb.childOption(ChannelOption.AUTO_READ, autoRead);
150         cb.option(ChannelOption.AUTO_READ, autoRead);
151 
152         Promise<Void> serverDonePromise = ImmediateEventExecutor.INSTANCE.newPromise();
153         final TestServerHandler sh = new TestServerHandler(autoRead, serverDonePromise, data.length);
154         final TestHandler ch = new TestHandler(autoRead);
155 
156         cb.handler(ch);
157         sb.childHandler(sh);
158 
159         Channel sc = sb.bind().sync().channel();
160         Channel cc = cb.connect(sc.localAddress()).sync().channel();
161 
162         for (int i = 0; i < data.length;) {
163             int length = Math.min(random.nextInt(1024 * 8), data.length - i);
164             if (composite && i % 2 == 0) {
165                 int firstBufLength = length / 2;
166                 CompositeByteBuf comp = compositeBuffer();
167                 comp.addComponent(true,
168                                 randomBufferType(cc.alloc(), data, i, firstBufLength))
169                     .addComponent(true,
170                             randomBufferType(cc.alloc(), data, i + firstBufLength, length - firstBufLength));
171                 cc.write(comp);
172             } else {
173                 cc.write(randomBufferType(cc.alloc(), data, i, length));
174             }
175             i += length;
176         }
177 
178         ChannelFuture cf = cc.writeAndFlush(Unpooled.EMPTY_BUFFER);
179         assertNotEquals(cc.voidPromise(), cf);
180         try {
181             assertTrue(cf.await(60000));
182             cf.sync();
183         } catch (Throwable t) {
184             // TODO: Remove this once we fix this test.
185             TestUtils.dump(StringUtil.simpleClassName(this));
186             throw t;
187         }
188 
189         serverDonePromise.sync();
190         sh.channel.close().sync();
191         ch.channel.close().sync();
192         sc.close().sync();
193 
194         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
195             throw sh.exception.get();
196         }
197         if (sh.exception.get() != null) {
198             throw sh.exception.get();
199         }
200         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
201             throw ch.exception.get();
202         }
203         if (ch.exception.get() != null) {
204             throw ch.exception.get();
205         }
206         ByteBuf expected = wrappedBuffer(data);
207         assertEquals(expected, sh.received);
208         expected.release();
209         sh.received.release();
210     }
211 
212     @Test
213     @Timeout(value = 30, unit = TimeUnit.SECONDS)
214     public void testGatheringWriteSameEventLoop(TestInfo testInfo) throws Throwable {
215         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
216             @Override
217             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
218                 testGatheringWriteSameEventLoop(serverBootstrap, bootstrap);
219             }
220         });
221     }
222 
223     private void testGatheringWriteSameEventLoop(ServerBootstrap sb, Bootstrap cb) throws Throwable {
224         // Ensure all clients are on the same EventLoop.
225         try {
226             cb = cb.clone(cb.group().next());
227         } catch (UnsupportedOperationException e) {
228             throw new TestAbortedException("Not supported by this EventLoopGroup: " + cb.group(), e);
229         }
230 
231         AtomicInteger sHandlersIdx = new AtomicInteger(0);
232         AtomicInteger cHandlersIdx = new AtomicInteger(0);
233         final TestServerHandler[] sHandlers = new TestServerHandler[] {
234                 new TestServerHandler(true, ImmediateEventExecutor.INSTANCE.newPromise(), data.length),
235                 new TestServerHandler(true, ImmediateEventExecutor.INSTANCE.newPromise(), data.length)
236         };
237         final TestHandler[] cHandlers = new TestHandler[] {
238                 new TestHandler(true),
239                 new TestHandler(true)
240         };
241 
242         cb.handler(new ChannelInitializer<Channel>() {
243             @Override
244             protected void initChannel(Channel ch) throws Exception {
245                 ch.pipeline().addLast(cHandlers[cHandlersIdx.getAndIncrement()]);
246             }
247         });
248 
249         sb.childHandler(new ChannelInitializer<Channel>() {
250             @Override
251             protected void initChannel(Channel ch) throws Exception {
252                 ch.pipeline().addLast(sHandlers[sHandlersIdx.getAndIncrement()]);
253             }
254         });
255 
256         Channel sc = sb.bind().sync().channel();
257         Channel cc1 = cb.connect(sc.localAddress()).sync().channel();
258         Channel cc2 = cb.connect(sc.localAddress()).sync().channel();
259 
260         assertSame(cc1.eventLoop(), cc2.eventLoop());
261         ChannelPromise p1 = cc1.newPromise();
262         ChannelPromise p2 = cc2.newPromise();
263         cc1.eventLoop().execute(() -> {
264             for (int i = 0; i < data.length;) {
265                 int length = Math.min(random.nextInt(1024 * 8), data.length - i);
266                 cc1.write(randomBufferType(cc1.alloc(), data, i, length));
267                 cc2.write(randomBufferType(cc2.alloc(), data, i, length));
268                 i += length;
269             }
270             cc1.writeAndFlush(Unpooled.EMPTY_BUFFER, p1);
271             cc2.writeAndFlush(Unpooled.EMPTY_BUFFER, p2);
272         });
273 
274         assertTrue(p1.await(60000));
275         p1.sync();
276         assertTrue(p2.await(60000));
277         p2.sync();
278 
279         for (int i = 0; i < sHandlers.length; i++) {
280             TestServerHandler sh = sHandlers[i];
281             TestHandler ch = cHandlers[i];
282             sh.doneReadingPromise.sync();
283             sh.channel.close().sync();
284             ch.channel.close().sync();
285 
286             if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
287                 throw sh.exception.get();
288             }
289             if (sh.exception.get() != null) {
290                 throw sh.exception.get();
291             }
292             if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
293                 throw ch.exception.get();
294             }
295             if (ch.exception.get() != null) {
296                 throw ch.exception.get();
297             }
298             ByteBuf expected = wrappedBuffer(data);
299             assertEquals(expected, sh.received);
300             expected.release();
301             sh.received.release();
302         }
303         sc.close().sync();
304     }
305 
306     private static final class TestServerHandler extends TestHandler {
307         private final int expectedBytes;
308         final Promise<Void> doneReadingPromise;
309         final ByteBuf received = Unpooled.buffer();
310 
311         TestServerHandler(boolean autoRead, Promise<Void> doneReadingPromise, int expectedBytes) {
312             super(autoRead);
313             this.doneReadingPromise = doneReadingPromise;
314             this.expectedBytes = expectedBytes;
315         }
316 
317         @Override
318         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
319             received.writeBytes(in);
320             if (received.readableBytes() >= expectedBytes) {
321                 doneReadingPromise.setSuccess(null);
322             }
323         }
324 
325         @Override
326         void handleException(ChannelHandlerContext ctx, Throwable cause) {
327             doneReadingPromise.tryFailure(cause);
328             super.handleException(ctx, cause);
329         }
330 
331         @Override
332         public void channelInactive(ChannelHandlerContext ctx) throws Exception {
333             doneReadingPromise.tryFailure(new IllegalStateException("server closed!"));
334             super.channelInactive(ctx);
335         }
336     }
337 
338     private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
339         private final boolean autoRead;
340         volatile Channel channel;
341         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
342 
343         TestHandler(boolean autoRead) {
344             this.autoRead = autoRead;
345         }
346 
347         @Override
348         public final void channelActive(ChannelHandlerContext ctx) throws Exception {
349             channel = ctx.channel();
350             if (!autoRead) {
351                 ctx.read();
352             }
353             super.channelActive(ctx);
354         }
355 
356         @Override
357         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
358         }
359 
360         @Override
361         public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
362             if (!autoRead) {
363                 ctx.read();
364             }
365             super.channelReadComplete(ctx);
366         }
367 
368         @Override
369         public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
370             if (exception.compareAndSet(null, cause)) {
371                 handleException(ctx, cause);
372             }
373             super.exceptionCaught(ctx, cause);
374         }
375 
376         void handleException(ChannelHandlerContext ctx, Throwable cause) {
377             ctx.close();
378         }
379     }
380 }