View Javadoc
1   /*
2    * Copyright 2018 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInboundHandlerAdapter;
25  import io.netty.channel.ChannelInitializer;
26  import io.netty.channel.ChannelOption;
27  import io.netty.channel.WriteBufferWaterMark;
28  import io.netty.util.ReferenceCountUtil;
29  import org.junit.jupiter.api.Test;
30  import org.junit.jupiter.api.TestInfo;
31  import org.junit.jupiter.api.Timeout;
32  
33  import java.util.concurrent.CountDownLatch;
34  import java.util.concurrent.TimeUnit;
35  
36  public class SocketConditionalWritabilityTest extends AbstractSocketTest {
37      @Test
38      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
39      public void testConditionalWritability(TestInfo testInfo) throws Throwable {
40          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
41              @Override
42              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
43                  testConditionalWritability(serverBootstrap, bootstrap);
44              }
45          });
46      }
47  
48      public void testConditionalWritability(ServerBootstrap sb, Bootstrap cb) throws Throwable {
49          Channel serverChannel = null;
50          Channel clientChannel = null;
51          try {
52              final int expectedBytes = 100 * 1024 * 1024;
53              final int maxWriteChunkSize = 16 * 1024;
54              final CountDownLatch latch = new CountDownLatch(1);
55              sb.childOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(8 * 1024, 16 * 1024));
56              sb.childHandler(new ChannelInitializer<Channel>() {
57                  @Override
58                  protected void initChannel(Channel ch) {
59                      ch.pipeline().addLast(new ChannelDuplexHandler() {
60                          private int bytesWritten;
61  
62                          @Override
63                          public void channelRead(ChannelHandlerContext ctx, Object msg) {
64                              ReferenceCountUtil.release(msg);
65                              writeRemainingBytes(ctx);
66                          }
67  
68                          @Override
69                          public void flush(ChannelHandlerContext ctx) {
70                              if (ctx.channel().isWritable()) {
71                                  writeRemainingBytes(ctx);
72                              } else {
73                                  ctx.flush();
74                              }
75                          }
76  
77                          @Override
78                          public void channelWritabilityChanged(ChannelHandlerContext ctx) {
79                              if (ctx.channel().isWritable()) {
80                                  writeRemainingBytes(ctx);
81                              }
82                              ctx.fireChannelWritabilityChanged();
83                          }
84  
85                          private void writeRemainingBytes(ChannelHandlerContext ctx) {
86                              while (ctx.channel().isWritable() && bytesWritten < expectedBytes) {
87                                  int chunkSize = Math.min(expectedBytes - bytesWritten, maxWriteChunkSize);
88                                  bytesWritten += chunkSize;
89                                  ctx.write(ctx.alloc().buffer(chunkSize).writeZero(chunkSize));
90                              }
91                              ctx.flush();
92                          }
93                      });
94                  }
95              });
96  
97              serverChannel = sb.bind().syncUninterruptibly().channel();
98  
99              cb.handler(new ChannelInitializer<Channel>() {
100                 @Override
101                 protected void initChannel(Channel ch) {
102                     ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
103                         private int totalRead;
104                         @Override
105                         public void channelActive(ChannelHandlerContext ctx) {
106                             ctx.writeAndFlush(ctx.alloc().buffer(1).writeByte(0));
107                         }
108 
109                         @Override
110                         public void channelRead(ChannelHandlerContext ctx, Object msg) {
111                             if (msg instanceof ByteBuf) {
112                                 totalRead += ((ByteBuf) msg).readableBytes();
113                                 if (totalRead == expectedBytes) {
114                                     latch.countDown();
115                                 }
116                             }
117                             ReferenceCountUtil.release(msg);
118                         }
119                     });
120                 }
121             });
122             clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
123             latch.await();
124         } finally {
125             if (serverChannel != null) {
126                 serverChannel.close();
127             }
128             if (clientChannel != null) {
129                 clientChannel.close();
130             }
131         }
132     }
133 }