View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.buffer.ByteBuf;
20  import io.netty.buffer.Unpooled;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelInboundHandlerAdapter;
24  import io.netty.channel.ChannelOption;
25  import io.netty.channel.SimpleChannelInboundHandler;
26  import io.netty.channel.WriteBufferWaterMark;
27  import io.netty.channel.socket.SocketChannel;
28  import io.netty.channel.socket.oio.OioSocketChannel;
29  import org.junit.jupiter.api.Disabled;
30  import org.junit.jupiter.api.Test;
31  import org.junit.jupiter.api.TestInfo;
32  import org.junit.jupiter.api.Timeout;
33  
34  import java.net.ServerSocket;
35  import java.net.Socket;
36  import java.net.SocketException;
37  import java.nio.channels.ClosedChannelException;
38  import java.util.concurrent.BlockingDeque;
39  import java.util.concurrent.BlockingQueue;
40  import java.util.concurrent.LinkedBlockingDeque;
41  import java.util.concurrent.LinkedBlockingQueue;
42  import java.util.concurrent.TimeUnit;
43  
44  import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
45  import static org.junit.jupiter.api.Assertions.assertEquals;
46  import static org.junit.jupiter.api.Assertions.assertFalse;
47  import static org.junit.jupiter.api.Assertions.assertNull;
48  import static org.junit.jupiter.api.Assertions.assertTrue;
49  import static org.junit.jupiter.api.Assertions.fail;
50  import static org.junit.jupiter.api.Assumptions.assumeFalse;
51  
52  public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
53  
54      @Test
55      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
56      public void testShutdownOutput(TestInfo testInfo) throws Throwable {
57          run(testInfo, new Runner<Bootstrap>() {
58              @Override
59              public void run(Bootstrap bootstrap) throws Throwable {
60                  testShutdownOutput(bootstrap);
61              }
62          });
63      }
64  
65      public void testShutdownOutput(Bootstrap cb) throws Throwable {
66          TestHandler h = new TestHandler();
67          ServerSocket ss = new ServerSocket();
68          Socket s = null;
69          SocketChannel ch = null;
70          try {
71              ss.bind(newSocketAddress());
72              ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
73              assertTrue(ch.isActive());
74              assertFalse(ch.isOutputShutdown());
75  
76              s = ss.accept();
77              ch.writeAndFlush(randomBufferType(ch.alloc(), new byte[] { 1 }, 0, 1)).sync();
78              assertEquals(1, s.getInputStream().read());
79  
80              assertTrue(h.ch.isOpen());
81              assertTrue(h.ch.isActive());
82              assertFalse(h.ch.isInputShutdown());
83              assertFalse(h.ch.isOutputShutdown());
84  
85              // Make the connection half-closed and ensure read() returns -1.
86              ch.shutdownOutput().sync();
87              assertEquals(-1, s.getInputStream().read());
88  
89              assertTrue(h.ch.isOpen());
90              assertTrue(h.ch.isActive());
91              assertFalse(h.ch.isInputShutdown());
92              assertTrue(h.ch.isOutputShutdown());
93  
94              // If half-closed, the peer should be able to write something.
95              s.getOutputStream().write(new byte[] { 1 });
96              assertEquals(1, (int) h.queue.take());
97          } finally {
98              if (s != null) {
99                  s.close();
100             }
101             if (ch != null) {
102                 ch.close();
103             }
104             ss.close();
105         }
106     }
107 
108     @Test
109     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
110     public void testShutdownOutputAfterClosed(TestInfo testInfo) throws Throwable {
111         run(testInfo, new Runner<Bootstrap>() {
112             @Override
113             public void run(Bootstrap bootstrap) throws Throwable {
114                 testShutdownOutputAfterClosed(bootstrap);
115             }
116         });
117     }
118 
119     public void testShutdownOutputAfterClosed(Bootstrap cb) throws Throwable {
120         TestHandler h = new TestHandler();
121         ServerSocket ss = new ServerSocket();
122         Socket s = null;
123         try {
124             ss.bind(newSocketAddress());
125             SocketChannel ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
126             assertTrue(ch.isActive());
127             s = ss.accept();
128 
129             ch.close().syncUninterruptibly();
130             try {
131                 ch.shutdownInput().syncUninterruptibly();
132                 fail();
133             } catch (Throwable cause) {
134                 checkThrowable(cause);
135             }
136             try {
137                 ch.shutdownOutput().syncUninterruptibly();
138                 fail();
139             } catch (Throwable cause) {
140                 checkThrowable(cause);
141             }
142         } finally {
143             if (s != null) {
144                 s.close();
145             }
146             ss.close();
147         }
148     }
149 
150     @Disabled
151     @Test
152     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
153     public void testWriteAfterShutdownOutputNoWritabilityChange(TestInfo testInfo) throws Throwable {
154         run(testInfo, new Runner<Bootstrap>() {
155             @Override
156             public void run(Bootstrap bootstrap) throws Throwable {
157                 testWriteAfterShutdownOutputNoWritabilityChange(bootstrap);
158             }
159         });
160     }
161 
162     public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable {
163         final TestHandler h = new TestHandler();
164         ServerSocket ss = new ServerSocket();
165         Socket s = null;
166         SocketChannel ch = null;
167         try {
168             ss.bind(newSocketAddress());
169             cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4));
170             ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
171             assumeFalse(ch instanceof OioSocketChannel);
172             assertTrue(ch.isActive());
173             assertFalse(ch.isOutputShutdown());
174 
175             s = ss.accept();
176 
177             byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 };
178             ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes));
179             h.assertWritability(false);
180             ch.flush();
181             writeFuture.sync();
182             h.assertWritability(true);
183             for (int i = 0; i < expectedBytes.length; ++i) {
184                 assertEquals(expectedBytes[i], s.getInputStream().read());
185             }
186 
187             assertTrue(h.ch.isOpen());
188             assertTrue(h.ch.isActive());
189             assertFalse(h.ch.isInputShutdown());
190             assertFalse(h.ch.isOutputShutdown());
191 
192             // Make the connection half-closed and ensure read() returns -1.
193             ch.shutdownOutput().sync();
194             assertEquals(-1, s.getInputStream().read());
195 
196             assertTrue(h.ch.isOpen());
197             assertTrue(h.ch.isActive());
198             assertFalse(h.ch.isInputShutdown());
199             assertTrue(h.ch.isOutputShutdown());
200 
201             try {
202                 // If half-closed, the local endpoint shouldn't be able to write
203                 ch.writeAndFlush(randomBufferType(ch.alloc(), new byte[]{ 2 }, 0 , 2)).sync();
204                 fail();
205             } catch (Throwable cause) {
206                 checkThrowable(cause);
207             }
208             assertNull(h.writabilityQueue.poll());
209         } finally {
210             if (s != null) {
211                 s.close();
212             }
213             if (ch != null) {
214                 ch.close();
215             }
216             ss.close();
217         }
218     }
219 
220     @Test
221     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
222     public void testShutdownOutputSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
223         run(testInfo, new Runner<Bootstrap>() {
224             @Override
225             public void run(Bootstrap bootstrap) throws Throwable {
226                 testShutdownOutputSoLingerNoAssertError(bootstrap);
227             }
228         });
229     }
230 
231     public void testShutdownOutputSoLingerNoAssertError(Bootstrap cb) throws Throwable {
232         testShutdownSoLingerNoAssertError0(cb, true);
233     }
234 
235     @Test
236     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
237     public void testShutdownSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
238         run(testInfo, new Runner<Bootstrap>() {
239             @Override
240             public void run(Bootstrap bootstrap) throws Throwable {
241                 testShutdownSoLingerNoAssertError(bootstrap);
242             }
243         });
244     }
245 
246     public void testShutdownSoLingerNoAssertError(Bootstrap cb) throws Throwable {
247         testShutdownSoLingerNoAssertError0(cb, false);
248     }
249 
250     private void testShutdownSoLingerNoAssertError0(Bootstrap cb, boolean output) throws Throwable {
251         ServerSocket ss = new ServerSocket();
252         Socket s = null;
253 
254         ChannelFuture cf = null;
255         try {
256             ss.bind(newSocketAddress());
257             cf = cb.option(ChannelOption.SO_LINGER, 1).handler(new ChannelInboundHandlerAdapter())
258                     .connect(ss.getLocalSocketAddress()).sync();
259             s = ss.accept();
260 
261             cf.sync();
262 
263             if (output) {
264                 ((SocketChannel) cf.channel()).shutdownOutput().sync();
265             } else {
266                 ((SocketChannel) cf.channel()).shutdown().sync();
267             }
268         } finally {
269             if (s != null) {
270                 s.close();
271             }
272             if (cf != null) {
273                 cf.channel().close();
274             }
275             ss.close();
276         }
277     }
278     private static void checkThrowable(Throwable cause) throws Throwable {
279         // Depending on OIO / NIO both are ok
280         if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) {
281             throw cause;
282         }
283     }
284 
285     private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
286         volatile SocketChannel ch;
287         final BlockingQueue<Byte> queue = new LinkedBlockingQueue<Byte>();
288         final BlockingDeque<Boolean> writabilityQueue = new LinkedBlockingDeque<Boolean>();
289 
290         @Override
291         public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
292             writabilityQueue.add(ctx.channel().isWritable());
293         }
294 
295         @Override
296         public void channelActive(ChannelHandlerContext ctx) throws Exception {
297             ch = (SocketChannel) ctx.channel();
298         }
299 
300         @Override
301         public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
302             queue.offer(msg.readByte());
303         }
304 
305         private void drainWritabilityQueue() throws InterruptedException {
306             while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
307                 // Just drain the queue.
308             }
309         }
310 
311         void assertWritability(boolean isWritable) throws InterruptedException {
312             try {
313                 Boolean writability = writabilityQueue.takeLast();
314                 assertEquals(isWritable, writability);
315                 // TODO(scott): why do we get multiple writability changes here ... race condition?
316                 drainWritabilityQueue();
317             } catch (Throwable c) {
318                 c.printStackTrace();
319             }
320         }
321     }
322 }