View Javadoc
1   /*
2    * Copyright 2016 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelConfig;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelInboundHandlerAdapter;
26  import io.netty.channel.ChannelInitializer;
27  import io.netty.channel.ChannelOption;
28  import io.netty.channel.RecvByteBufAllocator;
29  import io.netty.util.ReferenceCountUtil;
30  import io.netty.util.UncheckedBooleanSupplier;
31  import org.junit.jupiter.api.Test;
32  import org.junit.jupiter.api.TestInfo;
33  import org.junit.jupiter.api.Timeout;
34  
35  import java.util.concurrent.CountDownLatch;
36  import java.util.concurrent.TimeUnit;
37  import java.util.concurrent.atomic.AtomicInteger;
38  
39  import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
40  import static org.junit.jupiter.api.Assertions.assertEquals;
41  import static org.junit.jupiter.api.Assertions.assertFalse;
42  import static org.junit.jupiter.api.Assertions.assertTrue;
43  
44  public class SocketReadPendingTest extends AbstractSocketTest {
45      @Test
46      @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
47      public void testReadPendingIsResetAfterEachRead(TestInfo testInfo) throws Throwable {
48          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
49              @Override
50              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
51                  testReadPendingIsResetAfterEachRead(serverBootstrap, bootstrap);
52              }
53          });
54      }
55  
56      public void testReadPendingIsResetAfterEachRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
57          Channel serverChannel = null;
58          Channel clientChannel = null;
59          try {
60              ReadPendingInitializer serverInitializer = new ReadPendingInitializer();
61              ReadPendingInitializer clientInitializer = new ReadPendingInitializer();
62              sb.option(ChannelOption.SO_BACKLOG, 1024)
63                .option(ChannelOption.AUTO_READ, true)
64                .childOption(ChannelOption.AUTO_READ, false)
65                // We intend to do 2 reads per read loop wakeup
66                .childOption(ChannelOption.RECVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(2))
67                .childHandler(serverInitializer);
68  
69              serverChannel = sb.bind().syncUninterruptibly().channel();
70  
71              cb.option(ChannelOption.AUTO_READ, false)
72                // We intend to do 2 reads per read loop wakeup
73                .option(ChannelOption.RECVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(2))
74                .handler(clientInitializer);
75              clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
76  
77              // 4 bytes means 2 read loops for TestNumReadsRecvByteBufAllocator
78              clientChannel.writeAndFlush(randomBufferType(clientChannel.alloc(), new byte[4], 0, 4));
79  
80              // 4 bytes means 2 read loops for TestNumReadsRecvByteBufAllocator
81              assertTrue(serverInitializer.channelInitLatch.await(5, TimeUnit.SECONDS));
82              serverInitializer.channel.writeAndFlush(
83                      randomBufferType(serverInitializer.channel.alloc(), new byte[4], 0 , 4));
84  
85              serverInitializer.channel.read();
86              serverInitializer.readPendingHandler.assertAllRead();
87  
88              clientChannel.read();
89              clientInitializer.readPendingHandler.assertAllRead();
90          } finally {
91              if (serverChannel != null) {
92                  serverChannel.close().syncUninterruptibly();
93              }
94              if (clientChannel != null) {
95                  clientChannel.close().syncUninterruptibly();
96              }
97          }
98      }
99  
100     private static class ReadPendingInitializer extends ChannelInitializer<Channel> {
101         final ReadPendingReadHandler readPendingHandler = new ReadPendingReadHandler();
102         final CountDownLatch channelInitLatch = new CountDownLatch(1);
103         volatile Channel channel;
104 
105         @Override
106         protected void initChannel(Channel ch) throws Exception {
107             channel = ch;
108             ch.pipeline().addLast(readPendingHandler);
109             channelInitLatch.countDown();
110         }
111     }
112 
113     private static final class ReadPendingReadHandler extends ChannelInboundHandlerAdapter {
114         private final AtomicInteger count = new AtomicInteger();
115         private final CountDownLatch latch = new CountDownLatch(1);
116         private final CountDownLatch latch2 = new CountDownLatch(2);
117 
118         @Override
119         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
120             ReferenceCountUtil.release(msg);
121             if (count.incrementAndGet() == 1) {
122                 // Call read the first time, to ensure it is not reset the second time.
123                 ctx.read();
124             }
125         }
126 
127         @Override
128         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
129             latch.countDown();
130             latch2.countDown();
131         }
132 
133         void assertAllRead() throws InterruptedException {
134             assertTrue(latch.await(5, TimeUnit.SECONDS));
135             // We should only do 1 read loop, because we only called read() on the first channelRead.
136             assertFalse(latch2.await(1, TimeUnit.SECONDS));
137             assertEquals(2, count.get());
138         }
139     }
140 
141     /**
142      * Designed to read a single byte at a time to control the number of reads done at a fine granularity.
143      */
144     private static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator {
145         private final int numReads;
146         TestNumReadsRecvByteBufAllocator(int numReads) {
147             this.numReads = numReads;
148         }
149 
150         @Override
151         public ExtendedHandle newHandle() {
152             return new ExtendedHandle() {
153                 private int attemptedBytesRead;
154                 private int lastBytesRead;
155                 private int numMessagesRead;
156                 @Override
157                 public ByteBuf allocate(ByteBufAllocator alloc) {
158                     return alloc.ioBuffer(guess(), guess());
159                 }
160 
161                 @Override
162                 public int guess() {
163                     return 1; // only ever allocate buffers of size 1 to ensure the number of reads is controlled.
164                 }
165 
166                 @Override
167                 public void reset(ChannelConfig config) {
168                     numMessagesRead = 0;
169                 }
170 
171                 @Override
172                 public void incMessagesRead(int numMessages) {
173                     numMessagesRead += numMessages;
174                 }
175 
176                 @Override
177                 public void lastBytesRead(int bytes) {
178                     lastBytesRead = bytes;
179                 }
180 
181                 @Override
182                 public int lastBytesRead() {
183                     return lastBytesRead;
184                 }
185 
186                 @Override
187                 public void attemptedBytesRead(int bytes) {
188                     attemptedBytesRead = bytes;
189                 }
190 
191                 @Override
192                 public int attemptedBytesRead() {
193                     return attemptedBytesRead;
194                 }
195 
196                 @Override
197                 public boolean continueReading() {
198                     return numMessagesRead < numReads;
199                 }
200 
201                 @Override
202                 public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) {
203                     return continueReading();
204                 }
205 
206                 @Override
207                 public void readComplete() {
208                     // Nothing needs to be done or adjusted after each read cycle is completed.
209                 }
210             };
211         }
212     }
213 }