1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.buffer.api.Buffer;
20 import io.netty5.channel.Channel;
21 import io.netty5.channel.ChannelHandler;
22 import io.netty5.channel.ChannelHandlerContext;
23 import io.netty5.channel.ChannelOption;
24 import io.netty5.channel.ChannelShutdownDirection;
25 import io.netty5.channel.SimpleChannelInboundHandler;
26 import io.netty5.channel.WriteBufferWaterMark;
27 import io.netty5.channel.socket.SocketChannel;
28 import io.netty5.util.concurrent.Future;
29 import org.junit.jupiter.api.Disabled;
30 import org.junit.jupiter.api.Test;
31 import org.junit.jupiter.api.TestInfo;
32 import org.junit.jupiter.api.Timeout;
33
34 import java.net.ServerSocket;
35 import java.net.Socket;
36 import java.net.SocketException;
37 import java.nio.channels.ClosedChannelException;
38 import java.util.concurrent.BlockingDeque;
39 import java.util.concurrent.BlockingQueue;
40 import java.util.concurrent.LinkedBlockingDeque;
41 import java.util.concurrent.LinkedBlockingQueue;
42 import java.util.concurrent.TimeUnit;
43
44 import static io.netty5.buffer.api.DefaultBufferAllocators.onHeapAllocator;
45 import static org.junit.jupiter.api.Assertions.assertEquals;
46 import static org.junit.jupiter.api.Assertions.assertFalse;
47 import static org.junit.jupiter.api.Assertions.assertNull;
48 import static org.junit.jupiter.api.Assertions.assertTrue;
49 import static org.junit.jupiter.api.Assertions.fail;
50
51 public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
52
53 @Test
54 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
55 public void testShutdownOutput(TestInfo testInfo) throws Throwable {
56 run(testInfo, this::testShutdownOutput);
57 }
58
59 public void testShutdownOutput(Bootstrap cb) throws Throwable {
60 TestHandler h = new TestHandler();
61 ServerSocket ss = new ServerSocket();
62 Socket s = null;
63 SocketChannel ch = null;
64 try {
65 ss.bind(newSocketAddress());
66 ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).asStage().get();
67 assertTrue(ch.isActive());
68 assertFalse(ch.isShutdown(ChannelShutdownDirection.Outbound));
69
70 s = ss.accept();
71 ch.writeAndFlush(onHeapAllocator().copyOf(new byte[] { 1 })).asStage().sync();
72 assertEquals(1, s.getInputStream().read());
73
74 assertTrue(h.ch.isOpen());
75 assertTrue(h.ch.isActive());
76 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
77 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
78
79
80 ch.shutdown(ChannelShutdownDirection.Outbound).asStage().sync();
81 assertEquals(-1, s.getInputStream().read());
82
83 assertTrue(h.ch.isOpen());
84 assertTrue(h.ch.isActive());
85 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
86 assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
87
88
89 s.getOutputStream().write(new byte[] { 1 });
90 assertEquals(1, (int) h.queue.take());
91 } finally {
92 if (s != null) {
93 s.close();
94 }
95 if (ch != null) {
96 ch.close();
97 }
98 ss.close();
99 }
100 }
101
102 @Test
103 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
104 public void testShutdownOutputAfterClosed(TestInfo testInfo) throws Throwable {
105 run(testInfo, this::testShutdownOutputAfterClosed);
106 }
107
108 public void testShutdownOutputAfterClosed(Bootstrap cb) throws Throwable {
109 TestHandler h = new TestHandler();
110 ServerSocket ss = new ServerSocket();
111 Socket s = null;
112 try {
113 ss.bind(newSocketAddress());
114 SocketChannel ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).asStage().get();
115 assertTrue(ch.isActive());
116 s = ss.accept();
117
118 ch.close().asStage().sync();
119 try {
120 ch.shutdown(ChannelShutdownDirection.Inbound).asStage().sync();
121 fail();
122 } catch (Throwable cause) {
123 checkThrowable(cause.getCause());
124 }
125 try {
126 ch.shutdown(ChannelShutdownDirection.Outbound).asStage().sync();
127 fail();
128 } catch (Throwable cause) {
129 checkThrowable(cause.getCause());
130 }
131 } finally {
132 if (s != null) {
133 s.close();
134 }
135 ss.close();
136 }
137 }
138
139 @Disabled
140 @Test
141 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
142 public void testWriteAfterShutdownOutputNoWritabilityChange(TestInfo testInfo) throws Throwable {
143 run(testInfo, this::testWriteAfterShutdownOutputNoWritabilityChange);
144 }
145
146 public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable {
147 final TestHandler h = new TestHandler();
148 ServerSocket ss = new ServerSocket();
149 Socket s = null;
150 SocketChannel ch = null;
151 try {
152 ss.bind(newSocketAddress());
153 cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4));
154 ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).asStage().get();
155 assertTrue(ch.isActive());
156 assertFalse(ch.isShutdown(ChannelShutdownDirection.Outbound));
157
158 s = ss.accept();
159
160 byte[] expectedBytes = { 1, 2, 3, 4, 5, 6 };
161 Future<Void> writeFuture = ch.write(onHeapAllocator().copyOf(expectedBytes));
162 h.assertWritability(false);
163 ch.flush();
164 writeFuture.asStage().sync();
165 h.assertWritability(true);
166 for (byte expectedByte : expectedBytes) {
167 assertEquals(expectedByte, s.getInputStream().read());
168 }
169
170 assertTrue(h.ch.isOpen());
171 assertTrue(h.ch.isActive());
172 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
173 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
174
175
176 ch.shutdown(ChannelShutdownDirection.Outbound).asStage().sync();
177 assertEquals(-1, s.getInputStream().read());
178
179 assertTrue(h.ch.isOpen());
180 assertTrue(h.ch.isActive());
181 assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
182 assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
183
184 try {
185
186 ch.writeAndFlush(onHeapAllocator().copyOf(new byte[]{ 2 })).asStage().sync();
187 fail();
188 } catch (Throwable cause) {
189 checkThrowable(cause.getCause());
190 }
191 assertNull(h.writabilityQueue.poll());
192 } finally {
193 if (s != null) {
194 s.close();
195 }
196 if (ch != null) {
197 ch.close();
198 }
199 ss.close();
200 }
201 }
202
203 @Test
204 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
205 public void testShutdownOutputSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
206 run(testInfo, this::testShutdownOutputSoLingerNoAssertError);
207 }
208
209 public void testShutdownOutputSoLingerNoAssertError(Bootstrap cb) throws Throwable {
210 testShutdownOutputSoLingerNoAssertError0(cb, false);
211 }
212
213 @Test
214 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
215 public void testShutdownOutputAndInputSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
216 run(testInfo, this::testShutdownOutputSoLingerNoAssertError);
217 }
218
219 public void testShutdownOutputAndInputSoLingerNoAssertError(Bootstrap cb) throws Throwable {
220 testShutdownOutputSoLingerNoAssertError0(cb, true);
221 }
222
223 public void testShutdownOutputSoLingerNoAssertError0(Bootstrap cb, boolean shutdownInputAsWell) throws Throwable {
224 ServerSocket ss = new ServerSocket();
225 Socket s = null;
226
227 Channel client = null;
228 try {
229 ss.bind(newSocketAddress());
230 client = cb.option(ChannelOption.SO_LINGER, 1).handler(new ChannelHandler() { })
231 .connect(ss.getLocalSocketAddress()).asStage().get();
232 s = ss.accept();
233
234 client.shutdown(ChannelShutdownDirection.Outbound).asStage().sync();
235 if (shutdownInputAsWell) {
236 client.shutdown(ChannelShutdownDirection.Inbound).asStage().sync();
237 }
238 } finally {
239 if (s != null) {
240 s.close();
241 }
242 if (client != null) {
243 client.close();
244 }
245 ss.close();
246 }
247 }
248 private static void checkThrowable(Throwable cause) throws Throwable {
249
250 if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) {
251 throw cause;
252 }
253 }
254
255 private static final class TestHandler extends SimpleChannelInboundHandler<Buffer> {
256 volatile SocketChannel ch;
257 final BlockingQueue<Byte> queue = new LinkedBlockingQueue<>();
258 final BlockingDeque<Boolean> writabilityQueue = new LinkedBlockingDeque<>();
259
260 @Override
261 public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
262 writabilityQueue.add(ctx.channel().isWritable());
263 }
264
265 @Override
266 public void channelActive(ChannelHandlerContext ctx) throws Exception {
267 ch = (SocketChannel) ctx.channel();
268 }
269
270 @Override
271 public void messageReceived(ChannelHandlerContext ctx, Buffer msg) throws Exception {
272 queue.offer(msg.readByte());
273 }
274
275 private void drainWritabilityQueue() throws InterruptedException {
276 while (writabilityQueue.poll(100, TimeUnit.MILLISECONDS) != null) {
277
278 }
279 }
280
281 void assertWritability(boolean isWritable) throws InterruptedException {
282 try {
283 Boolean writability = writabilityQueue.takeLast();
284 assertEquals(isWritable, writability);
285
286 drainWritabilityQueue();
287 } catch (Throwable c) {
288 c.printStackTrace();
289 }
290 }
291 }
292 }