View Javadoc
1   /*
2    * Copyright 2013 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.buffer.api.Buffer;
21  import io.netty5.buffer.api.BufferAllocator;
22  import io.netty5.buffer.api.CompositeBuffer;
23  import io.netty5.buffer.api.MemoryManager;
24  import io.netty5.util.Resource;
25  import io.netty5.channel.Channel;
26  import io.netty5.channel.ChannelHandlerContext;
27  import io.netty5.channel.ChannelOption;
28  import io.netty5.channel.SimpleChannelInboundHandler;
29  import io.netty5.testsuite.util.TestUtils;
30  import io.netty5.util.concurrent.Future;
31  import io.netty5.util.concurrent.ImmediateEventExecutor;
32  import io.netty5.util.concurrent.Promise;
33  import io.netty5.util.internal.StringUtil;
34  import org.junit.jupiter.api.AfterAll;
35  import org.junit.jupiter.api.Test;
36  import org.junit.jupiter.api.TestInfo;
37  import org.junit.jupiter.api.Timeout;
38  
39  import java.io.IOException;
40  import java.util.SplittableRandom;
41  import java.util.concurrent.TimeUnit;
42  import java.util.concurrent.atomic.AtomicReference;
43  
44  import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator;
45  import static java.util.Arrays.asList;
46  import static org.junit.jupiter.api.Assertions.assertEquals;
47  import static org.junit.jupiter.api.Assertions.assertTrue;
48  
49  public class SocketGatheringWriteTest extends AbstractSocketTest {
50      private static final long TIMEOUT = 120000;
51  
52      private static final SplittableRandom random = new SplittableRandom();
53      static final byte[] data = new byte[1048576];
54  
55      static {
56          random.nextBytes(data);
57      }
58  
59      @AfterAll
60      public static void compressHeapDumps() throws Exception {
61          TestUtils.compressHeapDumps();
62      }
63  
64      @Test
65      @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
66      public void testGatheringWrite(TestInfo testInfo) throws Throwable {
67          run(testInfo, this::testGatheringWrite);
68      }
69  
70      public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
71          testGatheringWrite0(sb, cb, data, false, true);
72      }
73  
74      @Test
75      @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
76      public void testGatheringWriteNotAutoRead(TestInfo testInfo) throws Throwable {
77          run(testInfo, this::testGatheringWriteNotAutoRead);
78      }
79  
80      public void testGatheringWriteNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
81          testGatheringWrite0(sb, cb, data, false, false);
82      }
83  
84      @Test
85      @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
86      public void testGatheringWriteWithComposite(TestInfo testInfo) throws Throwable {
87          run(testInfo, this::testGatheringWriteWithComposite);
88      }
89  
90      public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
91          testGatheringWrite0(sb, cb, data, true, true);
92      }
93  
94      @Test
95      @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
96      public void testGatheringWriteWithCompositeNotAutoRead(TestInfo testInfo) throws Throwable {
97          run(testInfo, this::testGatheringWriteWithCompositeNotAutoRead);
98      }
99  
100     public void testGatheringWriteWithCompositeNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
101         testGatheringWrite0(sb, cb, data, true, false);
102     }
103 
104     // Test for https://github.com/netty/netty/issues/2647
105     @Test
106     @Timeout(value = TIMEOUT, unit = TimeUnit.MILLISECONDS)
107     public void testGatheringWriteBig(TestInfo testInfo) throws Throwable {
108         run(testInfo, this::testGatheringWriteBig);
109     }
110 
111     public void testGatheringWriteBig(ServerBootstrap sb, Bootstrap cb) throws Throwable {
112         byte[] bigData = new byte[1024 * 1024 * 50];
113         random.nextBytes(bigData);
114         testGatheringWrite0(sb, cb, bigData, false, true);
115     }
116 
117     private void testGatheringWrite0(
118             ServerBootstrap sb, Bootstrap cb, byte[] data, boolean composite, boolean autoRead)
119             throws Throwable {
120         sb.childOption(ChannelOption.AUTO_READ, autoRead);
121         cb.option(ChannelOption.AUTO_READ, autoRead);
122 
123         Promise<Void> serverDonePromise = ImmediateEventExecutor.INSTANCE.newPromise();
124         final TestServerHandler sh = new TestServerHandler(autoRead, serverDonePromise, data.length);
125         final TestHandler ch = new TestHandler(autoRead);
126 
127         cb.handler(ch);
128         sb.childHandler(sh);
129 
130         Channel sc = sb.bind().asStage().get();
131         Channel cc = cb.connect(sc.localAddress()).asStage().get();
132 
133         BufferAllocator alloc = preferredAllocator();
134         try (Buffer src = MemoryManager.unsafeWrap(data)) {
135             for (int i = 0; i < data.length;) {
136                 int length = Math.min(random.nextInt(1024 * 8), data.length - i);
137                 if (composite && i % 2 == 0) {
138                     int firstBufLength = length / 2;
139                     CompositeBuffer comp =
140                             alloc.compose(asList(
141                             src.readSplit(firstBufLength).send(),
142                             src.readSplit(length - firstBufLength).send()));
143                     cc.write(comp);
144                 } else {
145                     cc.write(src.readSplit(length));
146                 }
147                 i += length;
148             }
149         }
150 
151         Future<Void> cf = cc.writeAndFlush(preferredAllocator().allocate(0));
152         try {
153             assertTrue(cf.asStage().await(60000, TimeUnit.MILLISECONDS));
154             cf.asStage().sync();
155         } catch (Throwable t) {
156             // TODO: Remove this once we fix this test.
157             TestUtils.dump(StringUtil.simpleClassName(this));
158             throw t;
159         }
160 
161         serverDonePromise.asFuture().asStage().sync();
162         sh.channel.close().asStage().sync();
163         ch.channel.close().asStage().sync();
164         sc.close().asStage().sync();
165 
166         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
167             throw sh.exception.get();
168         }
169         if (sh.exception.get() != null) {
170             throw sh.exception.get();
171         }
172         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
173             throw ch.exception.get();
174         }
175         if (ch.exception.get() != null) {
176             throw ch.exception.get();
177         }
178         Object expected = MemoryManager.unsafeWrap(data);
179         assertEquals(expected, sh.received);
180         Resource.dispose(sh.received);
181         Resource.dispose(expected);
182     }
183 
184     private static final class TestServerHandler extends TestHandler {
185         private final int expectedBytes;
186         private final Promise<Void> doneReadingPromise;
187         Object received;
188 
189         TestServerHandler(boolean autoRead, Promise<Void> doneReadingPromise, int expectedBytes) {
190             super(autoRead);
191             this.doneReadingPromise = doneReadingPromise;
192             this.expectedBytes = expectedBytes;
193         }
194 
195         @Override
196         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
197             Buffer recv = (Buffer) received;
198             if (recv == null) {
199                 received = recv = ctx.bufferAllocator().allocate(256);
200             }
201             recv.ensureWritable(in.readableBytes(), recv.capacity(), true);
202             recv.writeBytes(in);
203             if (recv.readableBytes() >= expectedBytes) {
204                 doneReadingPromise.setSuccess(null);
205             }
206         }
207 
208         @Override
209         void handleException(ChannelHandlerContext ctx, Throwable cause) {
210             doneReadingPromise.tryFailure(cause);
211             super.handleException(ctx, cause);
212         }
213 
214         @Override
215         public void channelInactive(ChannelHandlerContext ctx) throws Exception {
216             doneReadingPromise.tryFailure(new IllegalStateException("server closed!"));
217             super.channelInactive(ctx);
218         }
219     }
220 
221     private static class TestHandler extends SimpleChannelInboundHandler<Buffer> {
222         private final boolean autoRead;
223         volatile Channel channel;
224         final AtomicReference<Throwable> exception = new AtomicReference<>();
225 
226         TestHandler(boolean autoRead) {
227             this.autoRead = autoRead;
228         }
229 
230         @Override
231         public final void channelActive(ChannelHandlerContext ctx) throws Exception {
232             channel = ctx.channel();
233             if (!autoRead) {
234                 ctx.read();
235             }
236             super.channelActive(ctx);
237         }
238 
239         @Override
240         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
241         }
242 
243         @Override
244         public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
245             if (!autoRead) {
246                 ctx.read();
247             }
248             super.channelReadComplete(ctx);
249         }
250 
251         @Override
252         public final void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
253             if (exception.compareAndSet(null, cause)) {
254                 handleException(ctx, cause);
255             }
256             super.channelExceptionCaught(ctx, cause);
257         }
258 
259         void handleException(ChannelHandlerContext ctx, Throwable cause) {
260             ctx.close();
261         }
262     }
263 }