1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.Buffer;
21 import io.netty5.buffer.api.BufferAllocator;
22 import io.netty5.channel.ChannelShutdownDirection;
23 import io.netty5.util.Resource;
24 import io.netty5.channel.Channel;
25 import io.netty5.channel.ChannelFutureListeners;
26 import io.netty5.channel.ChannelHandler;
27 import io.netty5.channel.ChannelHandlerContext;
28 import io.netty5.channel.ChannelInitializer;
29 import io.netty5.channel.ChannelOption;
30 import io.netty5.channel.RecvBufferAllocator;
31 import io.netty5.channel.SimpleChannelInboundHandler;
32 import io.netty5.util.internal.PlatformDependent;
33 import org.junit.jupiter.api.Test;
34 import org.junit.jupiter.api.TestInfo;
35 import org.junit.jupiter.api.Timeout;
36
37 import java.util.concurrent.CountDownLatch;
38 import java.util.concurrent.atomic.AtomicInteger;
39 import java.util.concurrent.atomic.AtomicReference;
40 import java.util.function.Predicate;
41
42 import static java.util.concurrent.TimeUnit.MILLISECONDS;
43 import static org.junit.jupiter.api.Assertions.assertEquals;
44 import static org.junit.jupiter.api.Assertions.assertNull;
45 import static org.junit.jupiter.api.Assertions.assertTrue;
46 import static org.junit.jupiter.api.Assumptions.assumeFalse;
47
48 public class SocketHalfClosedTest extends AbstractSocketTest {
49 @Test
50 @Timeout(value = 5000, unit = MILLISECONDS)
51 public void testHalfClosureReceiveDataOnFinalWait2StateWhenSoLingerSet(TestInfo testInfo) throws Throwable {
52 run(testInfo, this::testHalfClosureReceiveDataOnFinalWait2StateWhenSoLingerSet);
53 }
54
55 private void testHalfClosureReceiveDataOnFinalWait2StateWhenSoLingerSet(
56 ServerBootstrap sb, Bootstrap cb)
57 throws Throwable {
58 Channel serverChannel = null;
59 Channel clientChannel = null;
60
61 final CountDownLatch waitHalfClosureDone = new CountDownLatch(1);
62 try {
63 sb.childOption(ChannelOption.SO_LINGER, 1)
64 .childHandler(new ChannelInitializer<>() {
65
66 @Override
67 protected void initChannel(Channel ch) throws Exception {
68 ch.pipeline().addLast(new ChannelHandler() {
69
70 @Override
71 public void channelActive(final ChannelHandlerContext ctx) {
72 ctx.shutdown(ChannelShutdownDirection.Outbound);
73 }
74
75 @Override
76 public void channelRead(ChannelHandlerContext ctx, Object msg) {
77 Resource.dispose(msg);
78 waitHalfClosureDone.countDown();
79 }
80 });
81 }
82 });
83
84 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
85 .handler(new ChannelInitializer<>() {
86 @Override
87 protected void initChannel(Channel ch) throws Exception {
88 ch.pipeline().addLast(new ChannelHandler() {
89
90 @Override
91 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) {
92 if (direction == ChannelShutdownDirection.Inbound) {
93 ctx.writeAndFlush(ctx.bufferAllocator().copyOf(new byte[16]))
94 .addListener(ctx, ChannelFutureListeners.CLOSE);
95 }
96 }
97 });
98 }
99 });
100
101 serverChannel = sb.bind().asStage().get();
102 clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
103 waitHalfClosureDone.await();
104 } finally {
105 if (clientChannel != null) {
106 clientChannel.close().asStage().sync();
107 }
108
109 if (serverChannel != null) {
110 serverChannel.close().asStage().sync();
111 }
112 }
113 }
114
115 @Test
116 @Timeout(value = 10000, unit = MILLISECONDS)
117 public void testHalfClosureOnlyOneEventWhenAutoRead(TestInfo testInfo) throws Throwable {
118 run(testInfo, this::testHalfClosureOnlyOneEventWhenAutoRead);
119 }
120
121 public void testHalfClosureOnlyOneEventWhenAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
122 Channel serverChannel = null;
123 try {
124 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
125 .option(ChannelOption.AUTO_READ, true);
126 sb.childHandler(new ChannelInitializer<>() {
127 @Override
128 protected void initChannel(Channel ch) {
129 ch.pipeline().addLast(new ChannelHandler() {
130 @Override
131 public void channelActive(ChannelHandlerContext ctx) {
132 ctx.shutdown(ChannelShutdownDirection.Outbound);
133 }
134
135 @Override
136 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
137 ctx.close();
138 }
139 });
140 }
141 });
142
143 final AtomicInteger shutdownEventReceivedCounter = new AtomicInteger();
144
145 cb.handler(new ChannelInitializer<>() {
146 @Override
147 protected void initChannel(Channel ch) {
148 ch.pipeline().addLast(new ChannelHandler() {
149
150 @Override
151 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) {
152 if (direction == ChannelShutdownDirection.Inbound) {
153 shutdownEventReceivedCounter.incrementAndGet();
154 ctx.executor().schedule((Runnable) ctx::close, 100, MILLISECONDS);
155 }
156 }
157
158 @Override
159 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
160 ctx.close();
161 }
162 });
163 }
164 });
165
166 serverChannel = sb.bind().asStage().get();
167 Channel clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
168 clientChannel.closeFuture().asStage().await();
169 assertEquals(1, shutdownEventReceivedCounter.get());
170 } finally {
171 if (serverChannel != null) {
172 serverChannel.close().asStage().sync();
173 }
174 }
175 }
176
177 @Test
178 public void testAllDataReadAfterHalfClosure(TestInfo testInfo) throws Throwable {
179 run(testInfo, this::testAllDataReadAfterHalfClosure);
180 }
181
182 public void testAllDataReadAfterHalfClosure(ServerBootstrap sb, Bootstrap cb) throws Throwable {
183 testAllDataReadAfterHalfClosure(true, sb, cb);
184 testAllDataReadAfterHalfClosure(false, sb, cb);
185 }
186
187 private static void testAllDataReadAfterHalfClosure(final boolean autoRead,
188 ServerBootstrap sb, Bootstrap cb) throws Throwable {
189 final int totalServerBytesWritten = 1024 * 16;
190 final int numReadsPerReadLoop = 2;
191 final CountDownLatch serverInitializedLatch = new CountDownLatch(1);
192 final CountDownLatch clientReadAllDataLatch = new CountDownLatch(1);
193 final CountDownLatch clientHalfClosedLatch = new CountDownLatch(1);
194 final AtomicInteger clientReadCompletes = new AtomicInteger();
195 Channel serverChannel = null;
196 Channel clientChannel = null;
197 try {
198 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true)
199 .option(ChannelOption.AUTO_READ, autoRead)
200 .option(ChannelOption.RCVBUFFER_ALLOCATOR, new TestNumReadsRecvBufferAllocator(numReadsPerReadLoop));
201
202 sb.childHandler(new ChannelInitializer<>() {
203 @Override
204 protected void initChannel(Channel ch) throws Exception {
205 ch.pipeline().addLast(new ChannelHandler() {
206 @Override
207 public void channelActive(ChannelHandlerContext ctx) throws Exception {
208 final Buffer buf = ctx.bufferAllocator().allocate(totalServerBytesWritten);
209 buf.writerOffset(buf.capacity());
210 ctx.writeAndFlush(buf).addListener(ctx, (c, f) ->
211 c.shutdown(ChannelShutdownDirection.Outbound));
212 serverInitializedLatch.countDown();
213 }
214
215 @Override
216 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
217 ctx.close();
218 }
219 });
220 }
221 });
222
223 cb.handler(new ChannelInitializer<>() {
224 @Override
225 protected void initChannel(Channel ch) throws Exception {
226 ch.pipeline().addLast(new ChannelHandler() {
227 private int bytesRead;
228
229 @Override
230 public void channelRead(ChannelHandlerContext ctx, Object msg) {
231 try (Buffer buf = (Buffer) msg) {
232 bytesRead += buf.readableBytes();
233 }
234 }
235
236 @Override
237 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) {
238 if (direction == ChannelShutdownDirection.Inbound) {
239 clientHalfClosedLatch.countDown();
240 ctx.close();
241 }
242 }
243
244 @Override
245 public void channelReadComplete(ChannelHandlerContext ctx) {
246 clientReadCompletes.incrementAndGet();
247 if (bytesRead == totalServerBytesWritten) {
248 clientReadAllDataLatch.countDown();
249 }
250 if (!autoRead) {
251 ctx.read();
252 }
253 }
254
255 @Override
256 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
257 ctx.close();
258 }
259 });
260 }
261 });
262
263 serverChannel = sb.bind().asStage().get();
264 clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
265 clientChannel.read();
266
267 serverInitializedLatch.await();
268 clientReadAllDataLatch.await();
269 clientHalfClosedLatch.await();
270 assertTrue(totalServerBytesWritten / numReadsPerReadLoop + 10 > clientReadCompletes.get(),
271 "too many read complete events: " + clientReadCompletes.get());
272 } finally {
273 if (clientChannel != null) {
274 clientChannel.close().asStage().sync();
275 }
276 if (serverChannel != null) {
277 serverChannel.close().asStage().sync();
278 }
279 }
280 }
281
282 @Test
283 public void testAutoCloseFalseDoesShutdownOutput(TestInfo testInfo) throws Throwable {
284
285 assumeFalse(PlatformDependent.isWindows());
286 run(testInfo, this::testAutoCloseFalseDoesShutdownOutput);
287 }
288
289 public void testAutoCloseFalseDoesShutdownOutput(ServerBootstrap sb, Bootstrap cb) throws Throwable {
290 testAutoCloseFalseDoesShutdownOutput(false, false, sb, cb);
291 testAutoCloseFalseDoesShutdownOutput(false, true, sb, cb);
292 testAutoCloseFalseDoesShutdownOutput(true, false, sb, cb);
293 testAutoCloseFalseDoesShutdownOutput(true, true, sb, cb);
294 }
295
296 private static void testAutoCloseFalseDoesShutdownOutput(boolean allowHalfClosed,
297 final boolean clientIsLeader,
298 ServerBootstrap sb,
299 Bootstrap cb) throws Exception {
300 final int expectedBytes = 100;
301 final CountDownLatch serverReadExpectedLatch = new CountDownLatch(1);
302 final CountDownLatch doneLatch = new CountDownLatch(1);
303 final AtomicReference<Throwable> causeRef = new AtomicReference<>();
304 Channel serverChannel = null;
305 Channel clientChannel = null;
306 try {
307 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed)
308 .option(ChannelOption.AUTO_CLOSE, false)
309 .option(ChannelOption.SO_LINGER, 0);
310 sb.childOption(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed)
311 .childOption(ChannelOption.AUTO_CLOSE, false)
312 .childOption(ChannelOption.SO_LINGER, 0);
313
314 final SimpleChannelInboundHandler<?> leaderHandler = new AutoCloseFalseLeader(
315 expectedBytes, serverReadExpectedLatch, doneLatch, causeRef);
316 final SimpleChannelInboundHandler<?> followerHandler = new AutoCloseFalseFollower(expectedBytes,
317 serverReadExpectedLatch, doneLatch, causeRef);
318 sb.childHandler(new ChannelInitializer<>() {
319 @Override
320 protected void initChannel(Channel ch) throws Exception {
321 ch.pipeline().addLast(clientIsLeader ? followerHandler : leaderHandler);
322 }
323 });
324
325 cb.handler(new ChannelInitializer<>() {
326 @Override
327 protected void initChannel(Channel ch) throws Exception {
328 ch.pipeline().addLast(clientIsLeader ? leaderHandler : followerHandler);
329 }
330 });
331
332 serverChannel = sb.bind().asStage().get();
333 clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
334
335 doneLatch.await();
336 assertNull(causeRef.get());
337 } finally {
338 if (clientChannel != null) {
339 clientChannel.close().asStage().sync();
340 }
341 if (serverChannel != null) {
342 serverChannel.close().asStage().sync();
343 }
344 }
345 }
346
347 private static final class AutoCloseFalseFollower extends SimpleChannelInboundHandler<Object> {
348 private final int expectedBytes;
349 private final CountDownLatch followerCloseLatch;
350 private final CountDownLatch doneLatch;
351 private final AtomicReference<Throwable> causeRef;
352 private int bytesRead;
353
354 AutoCloseFalseFollower(int expectedBytes, CountDownLatch followerCloseLatch, CountDownLatch doneLatch,
355 AtomicReference<Throwable> causeRef) {
356 this.expectedBytes = expectedBytes;
357 this.followerCloseLatch = followerCloseLatch;
358 this.doneLatch = doneLatch;
359 this.causeRef = causeRef;
360 }
361
362 @Override
363 public void channelInactive(ChannelHandlerContext ctx) {
364 checkPrematureClose();
365 }
366
367 @Override
368 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
369 ctx.close();
370 checkPrematureClose();
371 }
372
373 @Override
374 protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
375 bytesRead += ((Buffer) msg).readableBytes();
376 if (bytesRead >= expectedBytes) {
377
378 Buffer buf = ctx.bufferAllocator().allocate(expectedBytes);
379 buf.skipWritableBytes(expectedBytes);
380 ctx.writeAndFlush(buf).addListener(ctx.channel(), (c, f) ->
381 c.close().addListener(c, (channel, future) -> {
382
383
384
385
386
387 channel.executor().schedule(followerCloseLatch::countDown, 200, MILLISECONDS);
388 }));
389 }
390 }
391
392 private void checkPrematureClose() {
393 if (bytesRead < expectedBytes) {
394 causeRef.set(new IllegalStateException("follower premature close"));
395 doneLatch.countDown();
396 }
397 }
398 }
399
400 private static final class AutoCloseFalseLeader extends SimpleChannelInboundHandler<Object> {
401 private final int expectedBytes;
402 private final CountDownLatch followerCloseLatch;
403 private final CountDownLatch doneLatch;
404 private final AtomicReference<Throwable> causeRef;
405 private int bytesRead;
406 private boolean seenOutputShutdown;
407
408 AutoCloseFalseLeader(int expectedBytes, CountDownLatch followerCloseLatch, CountDownLatch doneLatch,
409 AtomicReference<Throwable> causeRef) {
410 this.expectedBytes = expectedBytes;
411 this.followerCloseLatch = followerCloseLatch;
412 this.doneLatch = doneLatch;
413 this.causeRef = causeRef;
414 }
415
416 @Override
417 public void channelActive(ChannelHandlerContext ctx) throws Exception {
418 Buffer buf = ctx.bufferAllocator().allocate(expectedBytes);
419 buf.skipWritableBytes(expectedBytes);
420 Buffer msg = buf.copy();
421 ctx.writeAndFlush(buf);
422
423
424
425 followerCloseLatch.await();
426
427
428 ctx.writeAndFlush(msg).addListener(future -> {
429 if (future.cause() == null) {
430 causeRef.set(new IllegalStateException("second write should have failed!"));
431 doneLatch.countDown();
432 }
433 });
434 }
435
436 @Override
437 protected void messageReceived(ChannelHandlerContext ctx, Object msg) throws Exception {
438 bytesRead += ((Buffer) msg).readableBytes();
439 if (bytesRead >= expectedBytes) {
440 if (!seenOutputShutdown) {
441 causeRef.set(new IllegalStateException(
442 ChannelShutdownDirection.Outbound.name() + " event was not seen"));
443 }
444 doneLatch.countDown();
445 }
446 }
447
448 @Override
449 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) throws Exception {
450 if (direction == ChannelShutdownDirection.Outbound) {
451 seenOutputShutdown = true;
452 }
453 super.channelShutdown(ctx, direction);
454 }
455
456 @Override
457 public void channelInactive(ChannelHandlerContext ctx) {
458 checkPrematureClose();
459 }
460
461 @Override
462 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
463 ctx.close();
464 checkPrematureClose();
465 }
466
467 private void checkPrematureClose() {
468 if (bytesRead < expectedBytes || !seenOutputShutdown) {
469 causeRef.set(new IllegalStateException("leader premature close"));
470 doneLatch.countDown();
471 }
472 }
473 }
474
475 @Test
476 public void testAllDataReadClosure(TestInfo testInfo) throws Throwable {
477 run(testInfo, this::testAllDataReadClosure);
478 }
479
480 public void testAllDataReadClosure(ServerBootstrap sb, Bootstrap cb) throws Throwable {
481 testAllDataReadClosure(true, false, sb, cb);
482 testAllDataReadClosure(true, true, sb, cb);
483 testAllDataReadClosure(false, false, sb, cb);
484 testAllDataReadClosure(false, true, sb, cb);
485 }
486
487 private static void testAllDataReadClosure(final boolean autoRead, final boolean allowHalfClosed,
488 ServerBootstrap sb, Bootstrap cb)
489 throws Throwable {
490 final int totalServerBytesWritten = 1024 * 16;
491 final int numReadsPerReadLoop = 2;
492 final CountDownLatch serverInitializedLatch = new CountDownLatch(1);
493 final CountDownLatch clientReadAllDataLatch = new CountDownLatch(1);
494 final CountDownLatch clientHalfClosedLatch = new CountDownLatch(1);
495 final AtomicInteger clientReadCompletes = new AtomicInteger();
496 Channel serverChannel = null;
497 Channel clientChannel = null;
498 try {
499 cb.option(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed)
500 .option(ChannelOption.AUTO_READ, autoRead)
501 .option(ChannelOption.RCVBUFFER_ALLOCATOR,
502 new TestNumReadsRecvBufferAllocator(numReadsPerReadLoop));
503
504 sb.childHandler(new ChannelInitializer<>() {
505 @Override
506 protected void initChannel(Channel ch) throws Exception {
507 ch.pipeline().addLast(new ChannelHandler() {
508 @Override
509 public void channelActive(ChannelHandlerContext ctx) throws Exception {
510 Buffer buf = ctx.bufferAllocator().allocate(totalServerBytesWritten);
511 buf.writerOffset(buf.capacity());
512 ctx.writeAndFlush(buf).addListener(ctx, ChannelFutureListeners.CLOSE);
513 serverInitializedLatch.countDown();
514 }
515
516 @Override
517 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
518 ctx.close();
519 }
520 });
521 }
522 });
523
524 cb.handler(new ChannelInitializer<>() {
525 @Override
526 protected void initChannel(Channel ch) throws Exception {
527 ch.pipeline().addLast(new ChannelHandler() {
528 private int bytesRead;
529
530 @Override
531 public void channelRead(ChannelHandlerContext ctx, Object msg) {
532 try (Buffer buf = (Buffer) msg) {
533 bytesRead += buf.readableBytes();
534 }
535 }
536
537 @Override
538 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) {
539 if (direction == ChannelShutdownDirection.Inbound && allowHalfClosed) {
540 clientHalfClosedLatch.countDown();
541 ctx.close();
542 }
543 }
544
545 @Override
546 public void channelInactive(ChannelHandlerContext ctx) {
547 if (!allowHalfClosed) {
548 clientHalfClosedLatch.countDown();
549 }
550 }
551
552 @Override
553 public void channelReadComplete(ChannelHandlerContext ctx) {
554 clientReadCompletes.incrementAndGet();
555 if (bytesRead == totalServerBytesWritten) {
556 clientReadAllDataLatch.countDown();
557 }
558 if (!autoRead) {
559 ctx.read();
560 }
561 }
562
563 @Override
564 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
565 ctx.close();
566 }
567 });
568 }
569 });
570
571 serverChannel = sb.bind().asStage().get();
572 clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
573 clientChannel.read();
574
575 serverInitializedLatch.await();
576 clientReadAllDataLatch.await();
577 clientHalfClosedLatch.await();
578 assertTrue(totalServerBytesWritten / numReadsPerReadLoop + 10 > clientReadCompletes.get(),
579 "too many read complete events: " + clientReadCompletes.get());
580 } finally {
581 if (clientChannel != null) {
582 clientChannel.close().asStage().sync();
583 }
584 if (serverChannel != null) {
585 serverChannel.close().asStage().sync();
586 }
587 }
588 }
589
590
591
592
593 private static final class TestNumReadsRecvBufferAllocator implements RecvBufferAllocator {
594 private final int numReads;
595 TestNumReadsRecvBufferAllocator(int numReads) {
596 this.numReads = numReads;
597 }
598
599 @Override
600 public Handle newHandle() {
601 return new Handle() {
602 private int attemptedBytesRead;
603 private int lastBytesRead;
604 private int numMessagesRead;
605
606 @Override
607 public Buffer allocate(BufferAllocator alloc) {
608 return alloc.allocate(guess());
609 }
610
611 @Override
612 public int guess() {
613 return 1;
614 }
615
616 @Override
617 public void reset() {
618 numMessagesRead = 0;
619 }
620
621 @Override
622 public void incMessagesRead(int numMessages) {
623 numMessagesRead += numMessages;
624 }
625
626 @Override
627 public void lastBytesRead(int bytes) {
628 lastBytesRead = bytes;
629 }
630
631 @Override
632 public int lastBytesRead() {
633 return lastBytesRead;
634 }
635
636 @Override
637 public void attemptedBytesRead(int bytes) {
638 attemptedBytesRead = bytes;
639 }
640
641 @Override
642 public int attemptedBytesRead() {
643 return attemptedBytesRead;
644 }
645
646 @Override
647 public boolean continueReading(boolean autoRead) {
648 return numMessagesRead < numReads;
649 }
650
651 @Override
652 public boolean continueReading(boolean autoRead, Predicate<Handle> maybeMoreDataSupplier) {
653 return continueReading(autoRead) && maybeMoreDataSupplier.test(this);
654 }
655
656 @Override
657 public void readComplete() {
658
659 }
660 };
661 }
662 }
663 }