1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.ServerBootstrap;
19 import io.netty5.buffer.api.Buffer;
20 import io.netty5.channel.Channel;
21 import io.netty5.channel.ChannelHandlerContext;
22 import io.netty5.channel.ChannelOption;
23 import io.netty5.channel.ChannelShutdownDirection;
24 import io.netty5.channel.SimpleChannelInboundHandler;
25 import org.junit.jupiter.api.Test;
26 import org.junit.jupiter.api.TestInfo;
27 import org.junit.jupiter.api.Timeout;
28
29 import java.io.IOException;
30 import java.net.SocketAddress;
31 import java.util.concurrent.BlockingQueue;
32 import java.util.concurrent.CountDownLatch;
33 import java.util.concurrent.LinkedBlockingQueue;
34 import java.util.concurrent.TimeUnit;
35 import java.util.concurrent.atomic.AtomicInteger;
36
37 import static org.junit.jupiter.api.Assertions.assertEquals;
38 import static org.junit.jupiter.api.Assertions.assertFalse;
39 import static org.junit.jupiter.api.Assertions.assertTrue;
40
41 public abstract class AbstractSocketShutdownOutputByPeerTest<Socket> extends AbstractServerSocketTest {
42
43 @Test
44 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
45 public void testShutdownOutput(TestInfo testInfo) throws Throwable {
46 run(testInfo, this::testShutdownOutput);
47 }
48
49 public void testShutdownOutput(ServerBootstrap sb) throws Throwable {
50 TestHandler h = new TestHandler();
51 Socket s = newSocket();
52 Channel sc = null;
53 try {
54 sc = sb.childHandler(h).childOption(ChannelOption.ALLOW_HALF_CLOSURE, true).bind().asStage().get();
55
56 connect(s, sc.localAddress());
57 write(s, 1);
58
59 assertEquals(1, (int) h.queue.take());
60
61 assertTrue(h.ch.isOpen());
62 assertTrue(h.ch.isActive());
63 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
64 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
65
66 shutdownOutput(s);
67
68 h.halfClosure.await();
69
70 assertTrue(h.ch.isOpen());
71 assertTrue(h.ch.isActive());
72 assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
73 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
74
75 while (h.closure.getCount() != 1 && h.halfClosureCount.intValue() != 1) {
76 Thread.sleep(100);
77 }
78 } finally {
79 if (sc != null) {
80 sc.close();
81 }
82 close(s);
83 }
84 }
85
86 @Test
87 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
88 public void testShutdownOutputWithoutOption(TestInfo testInfo) throws Throwable {
89 run(testInfo, this::testShutdownOutputWithoutOption);
90 }
91
92 public void testShutdownOutputWithoutOption(ServerBootstrap sb) throws Throwable {
93 TestHandler h = new TestHandler();
94 Socket s = newSocket();
95 Channel sc = null;
96 try {
97 sc = sb.childHandler(h).bind().asStage().get();
98
99 connect(s, sc.localAddress());
100 write(s, 1);
101
102 assertEquals(1, (int) h.queue.take());
103
104 assertTrue(h.ch.isOpen());
105 assertTrue(h.ch.isActive());
106 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
107 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
108
109 shutdownOutput(s);
110
111 h.closure.await();
112
113 assertFalse(h.ch.isOpen());
114 assertFalse(h.ch.isActive());
115 assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
116 assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
117
118 while (h.halfClosure.getCount() != 1 && h.halfClosureCount.intValue() != 0) {
119 Thread.sleep(100);
120 }
121 } finally {
122 if (sc != null) {
123 sc.close();
124 }
125 close(s);
126 }
127 }
128
129 protected abstract void shutdownOutput(Socket s) throws IOException;
130
131 protected abstract void connect(Socket s, SocketAddress address) throws IOException;
132
133 protected abstract void close(Socket s) throws IOException;
134
135 protected abstract void write(Socket s, int data) throws IOException;
136
137 protected abstract Socket newSocket();
138
139 private static class TestHandler extends SimpleChannelInboundHandler<Buffer> {
140 volatile Channel ch;
141 final BlockingQueue<Byte> queue = new LinkedBlockingQueue<>();
142 final CountDownLatch halfClosure = new CountDownLatch(1);
143 final CountDownLatch closure = new CountDownLatch(1);
144 final AtomicInteger halfClosureCount = new AtomicInteger();
145
146 @Override
147 public void channelActive(ChannelHandlerContext ctx) throws Exception {
148 ch = ctx.channel();
149 }
150
151 @Override
152 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
153 closure.countDown();
154 }
155
156 @Override
157 public void messageReceived(ChannelHandlerContext ctx, Buffer msg) throws Exception {
158 queue.offer(msg.readByte());
159 }
160
161 @Override
162 public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) throws Exception {
163 if (direction == ChannelShutdownDirection.Inbound) {
164 halfClosureCount.incrementAndGet();
165 halfClosure.countDown();
166 }
167 }
168 }
169 }