1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelConfig;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.IoEventLoopGroup;
30 import io.netty.channel.RecvByteBufAllocator;
31 import io.netty.channel.SimpleChannelInboundHandler;
32 import io.netty.channel.nio.NioIoHandler;
33 import io.netty.channel.oio.OioEventLoopGroup;
34 import io.netty.channel.socket.ChannelInputShutdownEvent;
35 import io.netty.channel.socket.ChannelInputShutdownReadComplete;
36 import io.netty.channel.socket.ChannelOutputShutdownEvent;
37 import io.netty.channel.socket.DuplexChannel;
38 import io.netty.channel.socket.SocketChannel;
39 import io.netty.util.ReferenceCountUtil;
40 import io.netty.util.UncheckedBooleanSupplier;
41 import io.netty.util.internal.PlatformDependent;
42 import org.junit.jupiter.api.Test;
43 import org.junit.jupiter.api.TestInfo;
44 import org.junit.jupiter.api.Timeout;
45
46 import java.util.concurrent.CountDownLatch;
47 import java.util.concurrent.atomic.AtomicInteger;
48 import java.util.concurrent.atomic.AtomicReference;
49
50 import static java.util.concurrent.TimeUnit.MILLISECONDS;
51 import static org.junit.jupiter.api.Assertions.assertEquals;
52 import static org.junit.jupiter.api.Assertions.assertNull;
53 import static org.junit.jupiter.api.Assertions.assertTrue;
54 import static org.junit.jupiter.api.Assumptions.assumeFalse;
55
56 @Timeout(value = 20000, unit = MILLISECONDS)
57 public class SocketHalfClosedTest extends AbstractSocketTest {
58
59 protected int maxReadCompleteWithNoDataAfterInputShutdown() {
60 return 2;
61 }
62
63 @Test
64 public void testAllDataReadEventTriggeredAfterHalfClosure(TestInfo testInfo) throws Throwable {
65 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
66 @Override
67 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
68 if (bootstrap.config().group() instanceof OioEventLoopGroup) {
69 logger.debug("Ignoring test for incompatible OIO event system");
70 return;
71 } else if (bootstrap.config().group() instanceof IoEventLoopGroup) {
72 IoEventLoopGroup group = (IoEventLoopGroup) bootstrap.config().group();
73 if (group.isIoType(NioIoHandler.class)) {
74 logger.debug("Ignoring test for incompatible NioHandler");
75 return;
76 }
77 }
78 allDataReadEventTriggeredAfterHalfClosure(serverBootstrap, bootstrap);
79 }
80 });
81 }
82
83 private void allDataReadEventTriggeredAfterHalfClosure(ServerBootstrap sb, Bootstrap cb) throws Throwable {
84 final int totalServerBytesWritten = 1;
85 final CountDownLatch clientReadAllDataLatch = new CountDownLatch(1);
86 final CountDownLatch clientHalfClosedLatch = new CountDownLatch(1);
87 final CountDownLatch clientHalfClosedAllBytesRead = new CountDownLatch(1);
88 final AtomicInteger clientReadCompletes = new AtomicInteger();
89 final AtomicInteger clientZeroDataReadCompletes = new AtomicInteger();
90 Channel serverChannel = null;
91 Channel clientChannel = null;
92 AtomicReference<Channel> serverChildChannel = new AtomicReference<>();
93 try {
94 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
95 .option(ChannelOption.AUTO_CLOSE, false)
96 .option(ChannelOption.AUTO_READ, false);
97
98 sb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
99 .option(ChannelOption.AUTO_CLOSE, false)
100 .childOption(ChannelOption.TCP_NODELAY, true);
101
102 sb.childHandler(new ChannelInitializer<Channel>() {
103 @Override
104 protected void initChannel(Channel ch) throws Exception {
105 serverChildChannel.set(ch);
106 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
107 @Override
108 public void channelActive(ChannelHandlerContext ctx) throws Exception {
109 ByteBuf buf = ctx.alloc().buffer(totalServerBytesWritten);
110 buf.writerIndex(buf.capacity());
111 ctx.writeAndFlush(buf);
112 }
113
114 @Override
115 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
116 ctx.close();
117 }
118 });
119 }
120 });
121
122
123 cb.handler(new ChannelInitializer<Channel>() {
124 @Override
125 protected void initChannel(Channel ch) {
126 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
127 private int bytesRead;
128 private int bytesSinceReadComplete;
129
130 @Override
131 public void channelRead(ChannelHandlerContext ctx, Object msg) {
132 ByteBuf buf = (ByteBuf) msg;
133 bytesRead += buf.readableBytes();
134 bytesSinceReadComplete += buf.readableBytes();
135 buf.release();
136 }
137
138 @Override
139 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
140 if (evt == ChannelInputShutdownEvent.INSTANCE) {
141 clientHalfClosedLatch.countDown();
142 } else if (evt == ChannelInputShutdownReadComplete.INSTANCE) {
143 clientHalfClosedAllBytesRead.countDown();
144 ctx.close();
145 }
146 }
147
148 @Override
149 public void channelReadComplete(ChannelHandlerContext ctx) {
150 if (bytesSinceReadComplete == 0) {
151 clientZeroDataReadCompletes.incrementAndGet();
152 } else {
153 bytesSinceReadComplete = 0;
154 }
155 clientReadCompletes.incrementAndGet();
156 if (bytesRead == totalServerBytesWritten) {
157
158
159 ch.eventLoop().execute(new Runnable() {
160 @Override
161 public void run() {
162 clientReadAllDataLatch.countDown();
163 }
164 });
165 } else {
166 ctx.read();
167 }
168 }
169
170 @Override
171 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
172 ctx.fireExceptionCaught(cause);
173 ctx.close();
174 }
175 });
176 ch.read();
177 }
178 });
179
180 serverChannel = sb.bind().sync().channel();
181 clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
182 clientChannel.read();
183
184 clientReadAllDataLatch.await();
185
186
187 ((DuplexChannel) serverChildChannel.get()).shutdownOutput();
188
189 clientHalfClosedLatch.await();
190 clientHalfClosedAllBytesRead.await();
191 } finally {
192 if (clientChannel != null) {
193 clientChannel.close().sync();
194 }
195 if (serverChannel != null) {
196 serverChannel.close().sync();
197 }
198 }
199 }
200
201 @Test
202 public void testHalfClosureReceiveDataOnFinalWait2StateWhenSoLingerSet(TestInfo testInfo) throws Throwable {
203 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
204 @Override
205 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
206 testHalfClosureReceiveDataOnFinalWait2StateWhenSoLingerSet(serverBootstrap, bootstrap);
207 }
208 });
209 }
210
211 private void testHalfClosureReceiveDataOnFinalWait2StateWhenSoLingerSet(ServerBootstrap sb, Bootstrap cb)
212 throws Throwable {
213 Channel serverChannel = null;
214 Channel clientChannel = null;
215
216 final CountDownLatch waitHalfClosureDone = new CountDownLatch(1);
217 try {
218 sb.childOption(ChannelOption.SO_LINGER, 1)
219 .childHandler(new ChannelInitializer<Channel>() {
220
221 @Override
222 protected void initChannel(Channel ch) throws Exception {
223 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
224
225 @Override
226 public void channelActive(final ChannelHandlerContext ctx) {
227 SocketChannel channel = (SocketChannel) ctx.channel();
228 channel.shutdownOutput();
229 }
230
231 @Override
232 public void channelRead(ChannelHandlerContext ctx, Object msg) {
233 ReferenceCountUtil.release(msg);
234 waitHalfClosureDone.countDown();
235 }
236 });
237 }
238 });
239
240 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
241 .handler(new ChannelInitializer<Channel>() {
242 @Override
243 protected void initChannel(Channel ch) throws Exception {
244 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
245
246 @Override
247 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
248 if (ChannelInputShutdownEvent.INSTANCE == evt) {
249 ctx.writeAndFlush(ctx.alloc().buffer().writeZero(16));
250 }
251
252 if (ChannelInputShutdownReadComplete.INSTANCE == evt) {
253 ctx.close();
254 }
255 }
256 });
257 }
258 });
259
260 serverChannel = sb.bind().sync().channel();
261 clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
262 waitHalfClosureDone.await();
263 } finally {
264 if (clientChannel != null) {
265 clientChannel.close().sync();
266 }
267
268 if (serverChannel != null) {
269 serverChannel.close().sync();
270 }
271 }
272 }
273
274 @Test
275 public void testHalfClosureOnlyOneEventWhenAutoRead(TestInfo testInfo) throws Throwable {
276 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
277 @Override
278 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
279 testHalfClosureOnlyOneEventWhenAutoRead(serverBootstrap, bootstrap);
280 }
281 });
282 }
283
284 public void testHalfClosureOnlyOneEventWhenAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
285 Channel serverChannel = null;
286 try {
287 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
288 .option(ChannelOption.AUTO_READ, true);
289 sb.childHandler(new ChannelInitializer<Channel>() {
290 @Override
291 protected void initChannel(Channel ch) {
292 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
293 @Override
294 public void channelActive(ChannelHandlerContext ctx) {
295 ((DuplexChannel) ctx).shutdownOutput();
296 }
297
298 @Override
299 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
300 ctx.close();
301 }
302 });
303 }
304 });
305
306 final AtomicInteger shutdownEventReceivedCounter = new AtomicInteger();
307 final AtomicInteger shutdownReadCompleteEventReceivedCounter = new AtomicInteger();
308
309 cb.handler(new ChannelInitializer<Channel>() {
310 @Override
311 protected void initChannel(Channel ch) {
312 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
313
314 @Override
315 public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) {
316 if (evt == ChannelInputShutdownEvent.INSTANCE) {
317 shutdownEventReceivedCounter.incrementAndGet();
318 } else if (evt == ChannelInputShutdownReadComplete.INSTANCE) {
319 shutdownReadCompleteEventReceivedCounter.incrementAndGet();
320 ctx.executor().schedule(new Runnable() {
321 @Override
322 public void run() {
323 ctx.close();
324 }
325 }, 100, MILLISECONDS);
326 }
327 }
328
329 @Override
330 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
331 ctx.close();
332 }
333 });
334 }
335 });
336
337 serverChannel = sb.bind().sync().channel();
338 Channel clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
339 clientChannel.closeFuture().await();
340 assertEquals(1, shutdownEventReceivedCounter.get());
341 assertEquals(1, shutdownReadCompleteEventReceivedCounter.get());
342 } finally {
343 if (serverChannel != null) {
344 serverChannel.close().sync();
345 }
346 }
347 }
348
349 @Test
350 public void testAllDataReadAfterHalfClosure(TestInfo testInfo) throws Throwable {
351 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
352 @Override
353 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
354 testAllDataReadAfterHalfClosure(serverBootstrap, bootstrap);
355 }
356 });
357 }
358
359 public void testAllDataReadAfterHalfClosure(ServerBootstrap sb, Bootstrap cb) throws Throwable {
360 testAllDataReadAfterHalfClosure(true, sb, cb);
361 testAllDataReadAfterHalfClosure(false, sb, cb);
362 }
363
364 private void testAllDataReadAfterHalfClosure(final boolean autoRead,
365 ServerBootstrap sb, Bootstrap cb) throws Throwable {
366 final int totalServerBytesWritten = 1024 * 16;
367 final int numReadsPerReadLoop = 2;
368 final CountDownLatch serverInitializedLatch = new CountDownLatch(1);
369 final CountDownLatch clientReadAllDataLatch = new CountDownLatch(1);
370 final CountDownLatch clientHalfClosedLatch = new CountDownLatch(1);
371 final AtomicInteger clientReadCompletes = new AtomicInteger();
372 final AtomicInteger clientZeroDataReadCompletes = new AtomicInteger();
373 Channel serverChannel = null;
374 Channel clientChannel = null;
375 try {
376 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
377 .option(ChannelOption.AUTO_READ, autoRead)
378 .option(ChannelOption.RECVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(numReadsPerReadLoop));
379
380 sb.childHandler(new ChannelInitializer<Channel>() {
381 @Override
382 protected void initChannel(Channel ch) throws Exception {
383 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
384 @Override
385 public void channelActive(ChannelHandlerContext ctx) throws Exception {
386 ByteBuf buf = ctx.alloc().buffer(totalServerBytesWritten);
387 buf.writerIndex(buf.capacity());
388 ctx.writeAndFlush(buf).addListener((ChannelFutureListener) future ->
389 ((DuplexChannel) future.channel()).shutdownOutput());
390
391 serverInitializedLatch.countDown();
392 }
393
394 @Override
395 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
396 ctx.close();
397 }
398 });
399 }
400 });
401
402 cb.handler(new ChannelInitializer<Channel>() {
403 @Override
404 protected void initChannel(Channel ch) throws Exception {
405 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
406 private int bytesRead;
407 private int bytesSinceReadComplete;
408
409 @Override
410 public void channelRead(ChannelHandlerContext ctx, Object msg) {
411 ByteBuf buf = (ByteBuf) msg;
412 bytesRead += buf.readableBytes();
413 bytesSinceReadComplete += buf.readableBytes();
414 buf.release();
415 }
416
417 @Override
418 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
419 if (evt == ChannelInputShutdownEvent.INSTANCE) {
420 clientHalfClosedLatch.countDown();
421 } else if (evt == ChannelInputShutdownReadComplete.INSTANCE) {
422 ctx.close();
423 }
424 }
425
426 @Override
427 public void channelReadComplete(ChannelHandlerContext ctx) {
428 if (bytesSinceReadComplete == 0) {
429 clientZeroDataReadCompletes.incrementAndGet();
430 } else {
431 bytesSinceReadComplete = 0;
432 }
433 clientReadCompletes.incrementAndGet();
434 if (bytesRead == totalServerBytesWritten) {
435 clientReadAllDataLatch.countDown();
436 }
437 if (!autoRead) {
438 ctx.read();
439 }
440 }
441
442 @Override
443 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
444 ctx.close();
445 }
446 });
447 }
448 });
449
450 serverChannel = sb.bind().sync().channel();
451 clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
452 clientChannel.read();
453
454 serverInitializedLatch.await();
455 clientReadAllDataLatch.await();
456 clientHalfClosedLatch.await();
457
458
459
460 assertTrue(totalServerBytesWritten > clientReadCompletes.get(),
461 "too many read complete events: " + clientReadCompletes.get());
462 assertTrue(clientZeroDataReadCompletes.get() <= maxReadCompleteWithNoDataAfterInputShutdown(),
463 "too many readComplete with no data: " + clientZeroDataReadCompletes.get() + " readComplete: " +
464 clientReadCompletes.get());
465 } finally {
466 if (clientChannel != null) {
467 clientChannel.close().sync();
468 }
469 if (serverChannel != null) {
470 serverChannel.close().sync();
471 }
472 }
473 }
474
475 @Test
476 public void testAutoCloseFalseDoesShutdownOutput(TestInfo testInfo) throws Throwable {
477
478 assumeFalse(PlatformDependent.isWindows());
479 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
480 @Override
481 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
482 testAutoCloseFalseDoesShutdownOutput(serverBootstrap, bootstrap);
483 }
484 });
485 }
486
487 public void testAutoCloseFalseDoesShutdownOutput(ServerBootstrap sb, Bootstrap cb) throws Throwable {
488 testAutoCloseFalseDoesShutdownOutput(false, false, sb, cb);
489 testAutoCloseFalseDoesShutdownOutput(false, true, sb, cb);
490 testAutoCloseFalseDoesShutdownOutput(true, false, sb, cb);
491 testAutoCloseFalseDoesShutdownOutput(true, true, sb, cb);
492 }
493
494 private static void testAutoCloseFalseDoesShutdownOutput(boolean allowHalfClosed,
495 final boolean clientIsLeader,
496 ServerBootstrap sb,
497 Bootstrap cb) throws InterruptedException {
498 final int expectedBytes = 100;
499 final CountDownLatch serverReadExpectedLatch = new CountDownLatch(1);
500 final CountDownLatch doneLatch = new CountDownLatch(2);
501 final AtomicReference<Throwable> causeRef = new AtomicReference<Throwable>();
502 Channel serverChannel = null;
503 Channel clientChannel = null;
504 try {
505 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed)
506 .option(ChannelOption.AUTO_CLOSE, false)
507 .option(ChannelOption.SO_LINGER, 0);
508 sb.childOption(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed)
509 .childOption(ChannelOption.AUTO_CLOSE, false)
510 .childOption(ChannelOption.SO_LINGER, 0);
511
512 final AutoCloseFalseLeader leaderHandler = new AutoCloseFalseLeader(expectedBytes,
513 serverReadExpectedLatch, doneLatch, causeRef);
514 final AutoCloseFalseFollower followerHandler = new AutoCloseFalseFollower(expectedBytes,
515 serverReadExpectedLatch, doneLatch, causeRef);
516 sb.childHandler(new ChannelInitializer<Channel>() {
517 @Override
518 protected void initChannel(Channel ch) throws Exception {
519 ch.pipeline().addLast(clientIsLeader ? followerHandler :leaderHandler);
520 }
521 });
522
523 cb.handler(new ChannelInitializer<Channel>() {
524 @Override
525 protected void initChannel(Channel ch) throws Exception {
526 ch.pipeline().addLast(clientIsLeader ? leaderHandler : followerHandler);
527 }
528 });
529
530 serverChannel = sb.bind().sync().channel();
531 clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
532
533 doneLatch.await();
534 assertNull(causeRef.get());
535 assertTrue(leaderHandler.seenOutputShutdown);
536 } finally {
537 if (clientChannel != null) {
538 clientChannel.close().sync();
539 }
540 if (serverChannel != null) {
541 serverChannel.close().sync();
542 }
543 }
544 }
545
546 private static final class AutoCloseFalseFollower extends SimpleChannelInboundHandler<ByteBuf> {
547 private final int expectedBytes;
548 private final CountDownLatch followerCloseLatch;
549 private final CountDownLatch doneLatch;
550 private final AtomicReference<Throwable> causeRef;
551 private int bytesRead;
552
553 AutoCloseFalseFollower(int expectedBytes, CountDownLatch followerCloseLatch, CountDownLatch doneLatch,
554 AtomicReference<Throwable> causeRef) {
555 this.expectedBytes = expectedBytes;
556 this.followerCloseLatch = followerCloseLatch;
557 this.doneLatch = doneLatch;
558 this.causeRef = causeRef;
559 }
560
561 @Override
562 public void channelInactive(ChannelHandlerContext ctx) {
563 checkPrematureClose();
564 }
565
566 @Override
567 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
568 ctx.close();
569 checkPrematureClose();
570 }
571
572 @Override
573 protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
574 bytesRead += msg.readableBytes();
575 if (bytesRead >= expectedBytes) {
576
577 ByteBuf buf = ctx.alloc().buffer(expectedBytes);
578 buf.writerIndex(buf.writerIndex() + expectedBytes);
579 ctx.writeAndFlush(buf).addListener(future ->
580 ctx.close().addListener(f -> {
581
582
583
584
585
586 ctx.executor().schedule(new Runnable() {
587 @Override
588 public void run() {
589 followerCloseLatch.countDown();
590 }
591 }, 200, MILLISECONDS);
592 }));
593 }
594 }
595
596 private void checkPrematureClose() {
597 if (bytesRead < expectedBytes) {
598 causeRef.set(new IllegalStateException("follower premature close"));
599 doneLatch.countDown();
600 }
601 }
602 }
603
604 private static final class AutoCloseFalseLeader extends SimpleChannelInboundHandler<ByteBuf> {
605 private final int expectedBytes;
606 private final CountDownLatch followerCloseLatch;
607 private final CountDownLatch doneLatch;
608 private final AtomicReference<Throwable> causeRef;
609 private int bytesRead;
610 boolean seenOutputShutdown;
611
612 AutoCloseFalseLeader(int expectedBytes, CountDownLatch followerCloseLatch, CountDownLatch doneLatch,
613 AtomicReference<Throwable> causeRef) {
614 this.expectedBytes = expectedBytes;
615 this.followerCloseLatch = followerCloseLatch;
616 this.doneLatch = doneLatch;
617 this.causeRef = causeRef;
618 }
619
620 @Override
621 public void channelActive(ChannelHandlerContext ctx) throws Exception {
622 ByteBuf buf = ctx.alloc().buffer(expectedBytes);
623 buf.writerIndex(buf.writerIndex() + expectedBytes);
624 ctx.writeAndFlush(buf.retainedDuplicate());
625
626
627
628 followerCloseLatch.await();
629
630
631 ctx.writeAndFlush(buf).addListener(future -> {
632 if (future.cause() == null) {
633 causeRef.set(new IllegalStateException("second write should have failed!"));
634 doneLatch.countDown();
635 }
636 });
637 }
638
639 @Override
640 protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
641 bytesRead += msg.readableBytes();
642 if (bytesRead >= expectedBytes) {
643 doneLatch.countDown();
644 }
645 }
646
647 @Override
648 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
649 if (evt instanceof ChannelOutputShutdownEvent) {
650 seenOutputShutdown = true;
651 doneLatch.countDown();
652 }
653 }
654
655 @Override
656 public void channelInactive(ChannelHandlerContext ctx) {
657 checkPrematureClose();
658 }
659
660 @Override
661 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
662 ctx.close();
663 checkPrematureClose();
664 }
665
666 private void checkPrematureClose() {
667 if (bytesRead < expectedBytes || !seenOutputShutdown) {
668 causeRef.set(new IllegalStateException("leader premature close"));
669 doneLatch.countDown();
670 }
671 }
672 }
673
674 @Test
675 public void testAllDataReadClosure(TestInfo testInfo) throws Throwable {
676 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
677 @Override
678 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
679 testAllDataReadClosure(serverBootstrap, bootstrap);
680 }
681 });
682 }
683
684 public void testAllDataReadClosure(ServerBootstrap sb, Bootstrap cb) throws Throwable {
685 testAllDataReadClosure(true, false, sb, cb);
686 testAllDataReadClosure(true, true, sb, cb);
687 testAllDataReadClosure(false, false, sb, cb);
688 testAllDataReadClosure(false, true, sb, cb);
689 }
690
691 private static void testAllDataReadClosure(final boolean autoRead, final boolean allowHalfClosed,
692 ServerBootstrap sb, Bootstrap cb) throws Throwable {
693 final int totalServerBytesWritten = 1024 * 16;
694 final int numReadsPerReadLoop = 2;
695 final CountDownLatch serverInitializedLatch = new CountDownLatch(1);
696 final CountDownLatch clientReadAllDataLatch = new CountDownLatch(1);
697 final CountDownLatch clientHalfClosedLatch = new CountDownLatch(1);
698 final AtomicInteger clientReadCompletes = new AtomicInteger();
699 Channel serverChannel = null;
700 Channel clientChannel = null;
701 try {
702 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed)
703 .option(ChannelOption.AUTO_READ, autoRead)
704 .option(ChannelOption.RECVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(numReadsPerReadLoop));
705
706 sb.childHandler(new ChannelInitializer<Channel>() {
707 @Override
708 protected void initChannel(Channel ch) throws Exception {
709 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
710 @Override
711 public void channelActive(ChannelHandlerContext ctx) throws Exception {
712 ByteBuf buf = ctx.alloc().buffer(totalServerBytesWritten);
713 buf.writerIndex(buf.capacity());
714 ctx.writeAndFlush(buf).addListener(ChannelFutureListener.CLOSE);
715 serverInitializedLatch.countDown();
716 }
717
718 @Override
719 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
720 ctx.close();
721 }
722 });
723 }
724 });
725
726 cb.handler(new ChannelInitializer<Channel>() {
727 @Override
728 protected void initChannel(Channel ch) throws Exception {
729 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
730 private int bytesRead;
731
732 @Override
733 public void channelRead(ChannelHandlerContext ctx, Object msg) {
734 ByteBuf buf = (ByteBuf) msg;
735 bytesRead += buf.readableBytes();
736 buf.release();
737 }
738
739 @Override
740 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
741 if (evt == ChannelInputShutdownEvent.INSTANCE && allowHalfClosed) {
742 clientHalfClosedLatch.countDown();
743 } else if (evt == ChannelInputShutdownReadComplete.INSTANCE) {
744 ctx.close();
745 }
746 }
747
748 @Override
749 public void channelInactive(ChannelHandlerContext ctx) {
750 if (!allowHalfClosed) {
751 clientHalfClosedLatch.countDown();
752 }
753 }
754
755 @Override
756 public void channelReadComplete(ChannelHandlerContext ctx) {
757 clientReadCompletes.incrementAndGet();
758 if (bytesRead == totalServerBytesWritten) {
759 clientReadAllDataLatch.countDown();
760 }
761 if (!autoRead) {
762 ctx.read();
763 }
764 }
765
766 @Override
767 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
768 ctx.close();
769 }
770 });
771 }
772 });
773
774 serverChannel = sb.bind().sync().channel();
775 clientChannel = cb.connect(serverChannel.localAddress()).sync().channel();
776 clientChannel.read();
777
778 serverInitializedLatch.await();
779 clientReadAllDataLatch.await();
780 clientHalfClosedLatch.await();
781 assertTrue(totalServerBytesWritten / numReadsPerReadLoop + 10 > clientReadCompletes.get(),
782 "too many read complete events: " + clientReadCompletes.get());
783 } finally {
784 if (clientChannel != null) {
785 clientChannel.close().sync();
786 }
787 if (serverChannel != null) {
788 serverChannel.close().sync();
789 }
790 }
791 }
792
793
794
795
796 private static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator {
797 private final int numReads;
798 TestNumReadsRecvByteBufAllocator(int numReads) {
799 this.numReads = numReads;
800 }
801
802 @Override
803 public ExtendedHandle newHandle() {
804 return new ExtendedHandle() {
805 private int attemptedBytesRead;
806 private int lastBytesRead;
807 private int numMessagesRead;
808 @Override
809 public ByteBuf allocate(ByteBufAllocator alloc) {
810 return alloc.ioBuffer(guess(), guess());
811 }
812
813 @Override
814 public int guess() {
815 return 1;
816 }
817
818 @Override
819 public void reset(ChannelConfig config) {
820 numMessagesRead = 0;
821 }
822
823 @Override
824 public void incMessagesRead(int numMessages) {
825 numMessagesRead += numMessages;
826 }
827
828 @Override
829 public void lastBytesRead(int bytes) {
830 lastBytesRead = bytes;
831 }
832
833 @Override
834 public int lastBytesRead() {
835 return lastBytesRead;
836 }
837
838 @Override
839 public void attemptedBytesRead(int bytes) {
840 attemptedBytesRead = bytes;
841 }
842
843 @Override
844 public int attemptedBytesRead() {
845 return attemptedBytesRead;
846 }
847
848 @Override
849 public boolean continueReading() {
850 return numMessagesRead < numReads;
851 }
852
853 @Override
854 public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) {
855 return continueReading() && maybeMoreDataSupplier.get();
856 }
857
858 @Override
859 public void readComplete() {
860
861 }
862 };
863 }
864 }
865 }