View Javadoc
1   /*
2    * Copyright 2016 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelConfig;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelInboundHandlerAdapter;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.ChannelOption;
29  import io.netty.channel.RecvByteBufAllocator;
30  import io.netty.util.ReferenceCountUtil;
31  import io.netty.util.UncheckedBooleanSupplier;
32  import org.junit.jupiter.api.Test;
33  import org.junit.jupiter.api.TestInfo;
34  import org.junit.jupiter.api.Timeout;
35  
36  import java.util.concurrent.CountDownLatch;
37  import java.util.concurrent.TimeUnit;
38  import java.util.concurrent.atomic.AtomicInteger;
39  
40  import static org.junit.jupiter.api.Assertions.assertEquals;
41  import static org.junit.jupiter.api.Assertions.assertFalse;
42  import static org.junit.jupiter.api.Assertions.assertTrue;
43  
44  public class SocketReadPendingTest extends AbstractSocketTest {
45      @Test
46      @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
47      public void testReadPendingIsResetAfterEachRead(TestInfo testInfo) throws Throwable {
48          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
49              @Override
50              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
51                  testReadPendingIsResetAfterEachRead(serverBootstrap, bootstrap);
52              }
53          });
54      }
55  
56      public void testReadPendingIsResetAfterEachRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
57          Channel serverChannel = null;
58          Channel clientChannel = null;
59          try {
60              ReadPendingInitializer serverInitializer = new ReadPendingInitializer();
61              ReadPendingInitializer clientInitializer = new ReadPendingInitializer();
62              sb.option(ChannelOption.SO_BACKLOG, 1024)
63                .option(ChannelOption.AUTO_READ, true)
64                .childOption(ChannelOption.AUTO_READ, false)
65                // We intend to do 2 reads per read loop wakeup
66                .childOption(ChannelOption.RCVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(2))
67                .childHandler(serverInitializer);
68  
69              serverChannel = sb.bind().syncUninterruptibly().channel();
70  
71              cb.option(ChannelOption.AUTO_READ, false)
72                // We intend to do 2 reads per read loop wakeup
73                .option(ChannelOption.RCVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(2))
74                .handler(clientInitializer);
75              clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
76  
77              // 4 bytes means 2 read loops for TestNumReadsRecvByteBufAllocator
78              clientChannel.writeAndFlush(Unpooled.wrappedBuffer(new byte[4]));
79  
80              // 4 bytes means 2 read loops for TestNumReadsRecvByteBufAllocator
81              assertTrue(serverInitializer.channelInitLatch.await(5, TimeUnit.SECONDS));
82              serverInitializer.channel.writeAndFlush(Unpooled.wrappedBuffer(new byte[4]));
83  
84              serverInitializer.channel.read();
85              serverInitializer.readPendingHandler.assertAllRead();
86  
87              clientChannel.read();
88              clientInitializer.readPendingHandler.assertAllRead();
89          } finally {
90              if (serverChannel != null) {
91                  serverChannel.close().syncUninterruptibly();
92              }
93              if (clientChannel != null) {
94                  clientChannel.close().syncUninterruptibly();
95              }
96          }
97      }
98  
99      private static class ReadPendingInitializer extends ChannelInitializer<Channel> {
100         final ReadPendingReadHandler readPendingHandler = new ReadPendingReadHandler();
101         final CountDownLatch channelInitLatch = new CountDownLatch(1);
102         volatile Channel channel;
103 
104         @Override
105         protected void initChannel(Channel ch) throws Exception {
106             channel = ch;
107             ch.pipeline().addLast(readPendingHandler);
108             channelInitLatch.countDown();
109         }
110     }
111 
112     private static final class ReadPendingReadHandler extends ChannelInboundHandlerAdapter {
113         private final AtomicInteger count = new AtomicInteger();
114         private final CountDownLatch latch = new CountDownLatch(1);
115         private final CountDownLatch latch2 = new CountDownLatch(2);
116 
117         @Override
118         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
119             ReferenceCountUtil.release(msg);
120             if (count.incrementAndGet() == 1) {
121                 // Call read the first time, to ensure it is not reset the second time.
122                 ctx.read();
123             }
124         }
125 
126         @Override
127         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
128             latch.countDown();
129             latch2.countDown();
130         }
131 
132         void assertAllRead() throws InterruptedException {
133             assertTrue(latch.await(5, TimeUnit.SECONDS));
134             // We should only do 1 read loop, because we only called read() on the first channelRead.
135             assertFalse(latch2.await(1, TimeUnit.SECONDS));
136             assertEquals(2, count.get());
137         }
138     }
139 
140     /**
141      * Designed to read a single byte at a time to control the number of reads done at a fine granularity.
142      */
143     private static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator {
144         private final int numReads;
145         TestNumReadsRecvByteBufAllocator(int numReads) {
146             this.numReads = numReads;
147         }
148 
149         @Override
150         public ExtendedHandle newHandle() {
151             return new ExtendedHandle() {
152                 private int attemptedBytesRead;
153                 private int lastBytesRead;
154                 private int numMessagesRead;
155                 @Override
156                 public ByteBuf allocate(ByteBufAllocator alloc) {
157                     return alloc.ioBuffer(guess(), guess());
158                 }
159 
160                 @Override
161                 public int guess() {
162                     return 1; // only ever allocate buffers of size 1 to ensure the number of reads is controlled.
163                 }
164 
165                 @Override
166                 public void reset(ChannelConfig config) {
167                     numMessagesRead = 0;
168                 }
169 
170                 @Override
171                 public void incMessagesRead(int numMessages) {
172                     numMessagesRead += numMessages;
173                 }
174 
175                 @Override
176                 public void lastBytesRead(int bytes) {
177                     lastBytesRead = bytes;
178                 }
179 
180                 @Override
181                 public int lastBytesRead() {
182                     return lastBytesRead;
183                 }
184 
185                 @Override
186                 public void attemptedBytesRead(int bytes) {
187                     attemptedBytesRead = bytes;
188                 }
189 
190                 @Override
191                 public int attemptedBytesRead() {
192                     return attemptedBytesRead;
193                 }
194 
195                 @Override
196                 public boolean continueReading() {
197                     return numMessagesRead < numReads;
198                 }
199 
200                 @Override
201                 public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) {
202                     return continueReading();
203                 }
204 
205                 @Override
206                 public void readComplete() {
207                     // Nothing needs to be done or adjusted after each read cycle is completed.
208                 }
209             };
210         }
211     }
212 }