1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.DefaultBufferAllocators;
21 import io.netty5.util.Resource;
22 import io.netty5.channel.Channel;
23 import io.netty5.channel.ChannelHandler;
24 import io.netty5.channel.ChannelHandlerContext;
25 import io.netty5.channel.ChannelInitializer;
26 import io.netty5.channel.ChannelOption;
27 import io.netty5.channel.ChannelPipeline;
28 import org.junit.jupiter.api.Test;
29 import org.junit.jupiter.api.TestInfo;
30
31 import java.util.concurrent.CountDownLatch;
32 import java.util.concurrent.TimeUnit;
33 import java.util.concurrent.atomic.AtomicLong;
34
35 import static org.junit.jupiter.api.Assertions.assertFalse;
36 import static org.junit.jupiter.api.Assertions.assertTrue;
37
38 public class SocketExceptionHandlingTest extends AbstractSocketTest {
39 @Test
40 public void testReadPendingIsResetAfterEachRead(TestInfo testInfo) throws Throwable {
41 run(testInfo, this::testReadPendingIsResetAfterEachRead);
42 }
43
44 public void testReadPendingIsResetAfterEachRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
45 Channel serverChannel = null;
46 Channel clientChannel = null;
47 try {
48 MyInitializer serverInitializer = new MyInitializer();
49 sb.option(ChannelOption.SO_BACKLOG, 1024);
50 sb.childHandler(serverInitializer);
51
52 serverChannel = sb.bind().asStage().get();
53
54 cb.handler(new MyInitializer());
55 clientChannel = cb.connect(serverChannel.localAddress()).asStage().get();
56
57 clientChannel.writeAndFlush(DefaultBufferAllocators.preferredAllocator().copyOf(new byte[1024]));
58
59
60 assertTrue(serverInitializer.exceptionHandler.latch1.await(5, TimeUnit.SECONDS));
61
62
63 assertFalse(serverInitializer.exceptionHandler.latch2.await(1, TimeUnit.SECONDS),
64 "Encountered " + serverInitializer.exceptionHandler.count.get() +
65 " exceptions when 1 was expected");
66 } finally {
67 if (serverChannel != null) {
68 serverChannel.close().asStage().sync();
69 }
70 if (clientChannel != null) {
71 clientChannel.close().asStage().sync();
72 }
73 }
74 }
75
76 private static class MyInitializer extends ChannelInitializer<Channel> {
77 final ExceptionHandler exceptionHandler = new ExceptionHandler();
78 @Override
79 protected void initChannel(Channel ch) throws Exception {
80 ChannelPipeline pipeline = ch.pipeline();
81
82 pipeline.addLast(new BuggyChannelHandler());
83 pipeline.addLast(exceptionHandler);
84 }
85 }
86
87 private static class BuggyChannelHandler implements ChannelHandler {
88 @Override
89 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
90 if (msg instanceof Resource<?>) {
91 ((Resource<?>) msg).close();
92 } else {
93 Resource.dispose(msg);
94 }
95 throw new NullPointerException("I am a bug!");
96 }
97 }
98
99 private static class ExceptionHandler implements ChannelHandler {
100 final AtomicLong count = new AtomicLong();
101
102
103
104 final CountDownLatch latch1 = new CountDownLatch(1);
105 final CountDownLatch latch2 = new CountDownLatch(1);
106
107 @Override
108 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
109 if (count.incrementAndGet() <= 2) {
110 latch1.countDown();
111 } else {
112 latch2.countDown();
113 }
114
115 ctx.close();
116 }
117 }
118 }