View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.buffer.ByteBuf;
20  import io.netty.buffer.Unpooled;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelInboundHandlerAdapter;
24  import io.netty.channel.ChannelOption;
25  import io.netty.channel.SimpleChannelInboundHandler;
26  import io.netty.channel.WriteBufferWaterMark;
27  import io.netty.channel.socket.SocketChannel;
28  import io.netty.channel.socket.oio.OioSocketChannel;
29  import org.junit.jupiter.api.Disabled;
30  import org.junit.jupiter.api.Test;
31  import org.junit.jupiter.api.TestInfo;
32  import org.junit.jupiter.api.Timeout;
33  
34  import java.net.ServerSocket;
35  import java.net.Socket;
36  import java.net.SocketException;
37  import java.nio.channels.ClosedChannelException;
38  import java.util.concurrent.BlockingDeque;
39  import java.util.concurrent.BlockingQueue;
40  import java.util.concurrent.LinkedBlockingDeque;
41  import java.util.concurrent.LinkedBlockingQueue;
42  import java.util.concurrent.TimeUnit;
43  
44  import static org.junit.jupiter.api.Assertions.assertEquals;
45  import static org.junit.jupiter.api.Assertions.assertFalse;
46  import static org.junit.jupiter.api.Assertions.assertNull;
47  import static org.junit.jupiter.api.Assertions.assertTrue;
48  import static org.junit.jupiter.api.Assertions.fail;
49  import static org.junit.jupiter.api.Assumptions.assumeFalse;
50  
51  public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
52  
53      @Test
54      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
55      public void testShutdownOutput(TestInfo testInfo) throws Throwable {
56          run(testInfo, new Runner<Bootstrap>() {
57              @Override
58              public void run(Bootstrap bootstrap) throws Throwable {
59                  testShutdownOutput(bootstrap);
60              }
61          });
62      }
63  
64      public void testShutdownOutput(Bootstrap cb) throws Throwable {
65          TestHandler h = new TestHandler();
66          ServerSocket ss = new ServerSocket();
67          Socket s = null;
68          SocketChannel ch = null;
69          try {
70              ss.bind(newSocketAddress());
71              ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
72              assertTrue(ch.isActive());
73              assertFalse(ch.isOutputShutdown());
74  
75              s = ss.accept();
76              ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[] { 1 })).sync();
77              assertEquals(1, s.getInputStream().read());
78  
79              assertTrue(h.ch.isOpen());
80              assertTrue(h.ch.isActive());
81              assertFalse(h.ch.isInputShutdown());
82              assertFalse(h.ch.isOutputShutdown());
83  
84              // Make the connection half-closed and ensure read() returns -1.
85              ch.shutdownOutput().sync();
86              assertEquals(-1, s.getInputStream().read());
87  
88              assertTrue(h.ch.isOpen());
89              assertTrue(h.ch.isActive());
90              assertFalse(h.ch.isInputShutdown());
91              assertTrue(h.ch.isOutputShutdown());
92  
93              // If half-closed, the peer should be able to write something.
94              s.getOutputStream().write(new byte[] { 1 });
95              assertEquals(1, (int) h.queue.take());
96          } finally {
97              if (s != null) {
98                  s.close();
99              }
100             if (ch != null) {
101                 ch.close();
102             }
103             ss.close();
104         }
105     }
106 
107     @Test
108     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
109     public void testShutdownOutputAfterClosed(TestInfo testInfo) throws Throwable {
110         run(testInfo, new Runner<Bootstrap>() {
111             @Override
112             public void run(Bootstrap bootstrap) throws Throwable {
113                 testShutdownOutputAfterClosed(bootstrap);
114             }
115         });
116     }
117 
118     public void testShutdownOutputAfterClosed(Bootstrap cb) throws Throwable {
119         TestHandler h = new TestHandler();
120         ServerSocket ss = new ServerSocket();
121         Socket s = null;
122         try {
123             ss.bind(newSocketAddress());
124             SocketChannel ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
125             assertTrue(ch.isActive());
126             s = ss.accept();
127 
128             ch.close().syncUninterruptibly();
129             try {
130                 ch.shutdownInput().syncUninterruptibly();
131                 fail();
132             } catch (Throwable cause) {
133                 checkThrowable(cause);
134             }
135             try {
136                 ch.shutdownOutput().syncUninterruptibly();
137                 fail();
138             } catch (Throwable cause) {
139                 checkThrowable(cause);
140             }
141         } finally {
142             if (s != null) {
143                 s.close();
144             }
145             ss.close();
146         }
147     }
148 
149     @Disabled
150     @Test
151     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
152     public void testWriteAfterShutdownOutputNoWritabilityChange(TestInfo testInfo) throws Throwable {
153         run(testInfo, new Runner<Bootstrap>() {
154             @Override
155             public void run(Bootstrap bootstrap) throws Throwable {
156                 testWriteAfterShutdownOutputNoWritabilityChange(bootstrap);
157             }
158         });
159     }
160 
161     public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable {
162         final TestHandler h = new TestHandler();
163         ServerSocket ss = new ServerSocket();
164         Socket s = null;
165         SocketChannel ch = null;
166         try {
167             ss.bind(newSocketAddress());
168             cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4));
169             ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
170             assumeFalse(ch instanceof OioSocketChannel);
171             assertTrue(ch.isActive());
172             assertFalse(ch.isOutputShutdown());
173 
174             s = ss.accept();
175 
176             byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 };
177             ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes));
178             h.assertWritability(false);
179             ch.flush();
180             writeFuture.sync();
181             h.assertWritability(true);
182             for (int i = 0; i < expectedBytes.length; ++i) {
183                 assertEquals(expectedBytes[i], s.getInputStream().read());
184             }
185 
186             assertTrue(h.ch.isOpen());
187             assertTrue(h.ch.isActive());
188             assertFalse(h.ch.isInputShutdown());
189             assertFalse(h.ch.isOutputShutdown());
190 
191             // Make the connection half-closed and ensure read() returns -1.
192             ch.shutdownOutput().sync();
193             assertEquals(-1, s.getInputStream().read());
194 
195             assertTrue(h.ch.isOpen());
196             assertTrue(h.ch.isActive());
197             assertFalse(h.ch.isInputShutdown());
198             assertTrue(h.ch.isOutputShutdown());
199 
200             try {
201                 // If half-closed, the local endpoint shouldn't be able to write
202                 ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{ 2 })).sync();
203                 fail();
204             } catch (Throwable cause) {
205                 checkThrowable(cause);
206             }
207             assertNull(h.writabilityQueue.poll());
208         } finally {
209             if (s != null) {
210                 s.close();
211             }
212             if (ch != null) {
213                 ch.close();
214             }
215             ss.close();
216         }
217     }
218 
219     @Test
220     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
221     public void testShutdownOutputSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
222         run(testInfo, new Runner<Bootstrap>() {
223             @Override
224             public void run(Bootstrap bootstrap) throws Throwable {
225                 testShutdownOutputSoLingerNoAssertError(bootstrap);
226             }
227         });
228     }
229 
230     public void testShutdownOutputSoLingerNoAssertError(Bootstrap cb) throws Throwable {
231         testShutdownSoLingerNoAssertError0(cb, true);
232     }
233 
234     @Test
235     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
236     public void testShutdownSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
237         run(testInfo, new Runner<Bootstrap>() {
238             @Override
239             public void run(Bootstrap bootstrap) throws Throwable {
240                 testShutdownSoLingerNoAssertError(bootstrap);
241             }
242         });
243     }
244 
245     public void testShutdownSoLingerNoAssertError(Bootstrap cb) throws Throwable {
246         testShutdownSoLingerNoAssertError0(cb, false);
247     }
248 
249     private void testShutdownSoLingerNoAssertError0(Bootstrap cb, boolean output) throws Throwable {
250         ServerSocket ss = new ServerSocket();
251         Socket s = null;
252 
253         ChannelFuture cf = null;
254         try {
255             ss.bind(newSocketAddress());
256             cf = cb.option(ChannelOption.SO_LINGER, 1).handler(new ChannelInboundHandlerAdapter())
257                     .connect(ss.getLocalSocketAddress()).sync();
258             s = ss.accept();
259 
260             cf.sync();
261 
262             if (output) {
263                 ((SocketChannel) cf.channel()).shutdownOutput().sync();
264             } else {
265                 ((SocketChannel) cf.channel()).shutdown().sync();
266             }
267         } finally {
268             if (s != null) {
269                 s.close();
270             }
271             if (cf != null) {
272                 cf.channel().close();
273             }
274             ss.close();
275         }
276     }
277     private static void checkThrowable(Throwable cause) throws Throwable {
278         // Depending on OIO / NIO both are ok
279         if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) {
280             throw cause;
281         }
282     }
283 
284     private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
285         volatile SocketChannel ch;
286         final BlockingQueue<Byte> queue = new LinkedBlockingQueue<Byte>();
287         final BlockingDeque<Boolean> writabilityQueue = new LinkedBlockingDeque<Boolean>();
288 
289         @Override
290         public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
291             writabilityQueue.add(ctx.channel().isWritable());
292         }
293 
294         @Override
295         public void channelActive(ChannelHandlerContext ctx) throws Exception {
296             ch = (SocketChannel) ctx.channel();
297         }
298 
299         @Override
300         public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
301             queue.offer(msg.readByte());
302         }
303 
304         private void drainWritabilityQueue() throws InterruptedException {
305             while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
306                 // Just drain the queue.
307             }
308         }
309 
310         void assertWritability(boolean isWritable) throws InterruptedException {
311             try {
312                 Boolean writability = writabilityQueue.takeLast();
313                 assertEquals(isWritable, writability);
314                 // TODO(scott): why do we get multiple writability changes here ... race condition?
315                 drainWritabilityQueue();
316             } catch (Throwable c) {
317                 c.printStackTrace();
318             }
319         }
320     }
321 }