1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.buffer.api.Buffer;
21  import io.netty5.buffer.api.BufferAllocator;
22  import io.netty5.util.Resource;
23  import io.netty5.channel.Channel;
24  import io.netty5.channel.ChannelHandler;
25  import io.netty5.channel.ChannelHandlerContext;
26  import io.netty5.channel.ChannelInitializer;
27  import io.netty5.channel.ChannelOption;
28  import io.netty5.channel.RecvBufferAllocator;
29  import org.junit.jupiter.api.Test;
30  import org.junit.jupiter.api.TestInfo;
31  import org.junit.jupiter.api.Timeout;
32  
33  import java.util.concurrent.CountDownLatch;
34  import java.util.concurrent.TimeUnit;
35  import java.util.concurrent.atomic.AtomicInteger;
36  import java.util.function.Predicate;
37  
38  import static io.netty5.buffer.api.DefaultBufferAllocators.preferredAllocator;
39  import static org.junit.jupiter.api.Assertions.assertEquals;
40  import static org.junit.jupiter.api.Assertions.assertFalse;
41  import static org.junit.jupiter.api.Assertions.assertTrue;
42  
43  public class SocketReadPendingTest extends AbstractSocketTest {
44      @Test
45      @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
46      public void testReadPendingIsResetAfterEachRead(TestInfo testInfo) throws Throwable {
47          run(testInfo, this::testReadPendingIsResetAfterEachRead);
48      }
49  
50      public void testReadPendingIsResetAfterEachRead(ServerBootstrap sb, Bootstrap cb)
51              throws Throwable {
52          Channel serverChannel = null;
53          Channel clientChannel = null;
54          try {
55              ReadPendingInitializer serverInitializer = new ReadPendingInitializer();
56              ReadPendingInitializer clientInitializer = new ReadPendingInitializer();
57              sb.option(ChannelOption.SO_BACKLOG, 1024)
58                .option(ChannelOption.AUTO_READ, true)
59                .childOption(ChannelOption.AUTO_READ, false)
60                
61                .childOption(ChannelOption.RCVBUFFER_ALLOCATOR, new TestNumReadsRecvBufferAllocator(2))
62                .childHandler(serverInitializer);
63  
64              serverChannel = sb.bind().asStage().get();
65  
66              cb.option(ChannelOption.AUTO_READ, false)
67                
68                .option(ChannelOption.RCVBUFFER_ALLOCATOR, new TestNumReadsRecvBufferAllocator(2))
69                .handler(clientInitializer);
70              clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
71  
72              
73              clientChannel.writeAndFlush(preferredAllocator().copyOf(new byte[4]));
74  
75              
76              assertTrue(serverInitializer.channelInitLatch.await(5, TimeUnit.SECONDS));
77              serverInitializer.channel.writeAndFlush(preferredAllocator().copyOf(new byte[4]));
78  
79              serverInitializer.channel.read();
80              serverInitializer.readPendingHandler.assertAllRead();
81  
82              clientChannel.read();
83              clientInitializer.readPendingHandler.assertAllRead();
84          } finally {
85              if (serverChannel != null) {
86                  serverChannel.close().asStage().sync();
87              }
88              if (clientChannel != null) {
89                  clientChannel.close().asStage().sync();
90              }
91          }
92      }
93  
94      private static class ReadPendingInitializer extends ChannelInitializer<Channel> {
95          final ReadPendingReadHandler readPendingHandler = new ReadPendingReadHandler();
96          final CountDownLatch channelInitLatch = new CountDownLatch(1);
97          volatile Channel channel;
98  
99          @Override
100         protected void initChannel(Channel ch) throws Exception {
101             channel = ch;
102             ch.pipeline().addLast(readPendingHandler);
103             channelInitLatch.countDown();
104         }
105     }
106 
107     private static final class ReadPendingReadHandler implements ChannelHandler {
108         private final AtomicInteger count = new AtomicInteger();
109         private final CountDownLatch latch = new CountDownLatch(1);
110         private final CountDownLatch latch2 = new CountDownLatch(2);
111 
112         @Override
113         public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
114             Resource.dispose(msg);
115             if (count.incrementAndGet() == 1) {
116                 
117                 ctx.read();
118             }
119         }
120 
121         @Override
122         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
123             latch.countDown();
124             latch2.countDown();
125         }
126 
127         void assertAllRead() throws InterruptedException {
128             assertTrue(latch.await(5, TimeUnit.SECONDS));
129             
130             assertFalse(latch2.await(1, TimeUnit.SECONDS));
131             assertEquals(2, count.get());
132         }
133     }
134 
135     
136 
137 
138     private static final class TestNumReadsRecvBufferAllocator implements RecvBufferAllocator {
139         private final int numReads;
140         TestNumReadsRecvBufferAllocator(int numReads) {
141             this.numReads = numReads;
142         }
143 
144         @Override
145         public Handle newHandle() {
146             return new Handle() {
147                 private int attemptedBytesRead;
148                 private int lastBytesRead;
149                 private int numMessagesRead;
150 
151                 @Override
152                 public Buffer allocate(BufferAllocator alloc) {
153                     return alloc.allocate(guess());
154                 }
155 
156                 @Override
157                 public int guess() {
158                     return 1; 
159                 }
160 
161                 @Override
162                 public void reset() {
163                     numMessagesRead = 0;
164                 }
165 
166                 @Override
167                 public void incMessagesRead(int numMessages) {
168                     numMessagesRead += numMessages;
169                 }
170 
171                 @Override
172                 public void lastBytesRead(int bytes) {
173                     lastBytesRead = bytes;
174                 }
175 
176                 @Override
177                 public int lastBytesRead() {
178                     return lastBytesRead;
179                 }
180 
181                 @Override
182                 public void attemptedBytesRead(int bytes) {
183                     attemptedBytesRead = bytes;
184                 }
185 
186                 @Override
187                 public int attemptedBytesRead() {
188                     return attemptedBytesRead;
189                 }
190 
191                 @Override
192                 public boolean continueReading(boolean autoRead) {
193                     return numMessagesRead < numReads;
194                 }
195 
196                 @Override
197                 public boolean continueReading(boolean autoRead, Predicate<Handle> maybeMoreDataSupplier) {
198                     return continueReading(autoRead);
199                 }
200 
201                 @Override
202                 public void readComplete() {
203                     
204                 }
205             };
206         }
207     }
208 }