View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.ServerBootstrap;
19  import io.netty5.buffer.api.Buffer;
20  import io.netty5.channel.Channel;
21  import io.netty5.channel.ChannelHandlerContext;
22  import io.netty5.channel.ChannelOption;
23  import io.netty5.channel.ChannelShutdownDirection;
24  import io.netty5.channel.SimpleChannelInboundHandler;
25  import org.junit.jupiter.api.Test;
26  import org.junit.jupiter.api.TestInfo;
27  import org.junit.jupiter.api.Timeout;
28  
29  import java.io.IOException;
30  import java.net.SocketAddress;
31  import java.util.concurrent.BlockingQueue;
32  import java.util.concurrent.CountDownLatch;
33  import java.util.concurrent.LinkedBlockingQueue;
34  import java.util.concurrent.TimeUnit;
35  import java.util.concurrent.atomic.AtomicInteger;
36  
37  import static org.junit.jupiter.api.Assertions.assertEquals;
38  import static org.junit.jupiter.api.Assertions.assertFalse;
39  import static org.junit.jupiter.api.Assertions.assertTrue;
40  
41  public abstract class AbstractSocketShutdownOutputByPeerTest<Socket> extends AbstractServerSocketTest {
42  
43      @Test
44      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
45      public void testShutdownOutput(TestInfo testInfo) throws Throwable {
46          run(testInfo, this::testShutdownOutput);
47      }
48  
49      public void testShutdownOutput(ServerBootstrap sb) throws Throwable {
50          TestHandler h = new TestHandler();
51          Socket s = newSocket();
52          Channel sc = null;
53          try {
54              sc = sb.childHandler(h).childOption(ChannelOption.ALLOW_HALF_CLOSURE, true).bind().asStage().get();
55  
56              connect(s, sc.localAddress());
57              write(s, 1);
58  
59              assertEquals(1, (int) h.queue.take());
60  
61              assertTrue(h.ch.isOpen());
62              assertTrue(h.ch.isActive());
63              assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
64              assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
65  
66              shutdownOutput(s);
67  
68              h.halfClosure.await();
69  
70              assertTrue(h.ch.isOpen());
71              assertTrue(h.ch.isActive());
72              assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
73              assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
74  
75              while (h.closure.getCount() != 1 && h.halfClosureCount.intValue() != 1) {
76                  Thread.sleep(100);
77              }
78          } finally {
79              if (sc != null) {
80                  sc.close();
81              }
82              close(s);
83          }
84      }
85  
86      @Test
87      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
88      public void testShutdownOutputWithoutOption(TestInfo testInfo) throws Throwable {
89          run(testInfo, this::testShutdownOutputWithoutOption);
90      }
91  
92      public void testShutdownOutputWithoutOption(ServerBootstrap sb) throws Throwable {
93          TestHandler h = new TestHandler();
94          Socket s = newSocket();
95          Channel sc = null;
96          try {
97              sc = sb.childHandler(h).bind().asStage().get();
98  
99              connect(s, sc.localAddress());
100             write(s, 1);
101 
102             assertEquals(1, (int) h.queue.take());
103 
104             assertTrue(h.ch.isOpen());
105             assertTrue(h.ch.isActive());
106             assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
107             assertFalse(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
108 
109             shutdownOutput(s);
110 
111             h.closure.await();
112 
113             assertFalse(h.ch.isOpen());
114             assertFalse(h.ch.isActive());
115             assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Inbound));
116             assertTrue(h.ch.isShutdown(ChannelShutdownDirection.Outbound));
117 
118             while (h.halfClosure.getCount() != 1 && h.halfClosureCount.intValue() != 0) {
119                 Thread.sleep(100);
120             }
121         } finally {
122             if (sc != null) {
123                 sc.close();
124             }
125             close(s);
126         }
127     }
128 
129     protected abstract void shutdownOutput(Socket s) throws IOException;
130 
131     protected abstract void connect(Socket s, SocketAddress address) throws IOException;
132 
133     protected abstract void close(Socket s) throws IOException;
134 
135     protected abstract void write(Socket s, int data) throws IOException;
136 
137     protected abstract Socket newSocket();
138 
139     private static class TestHandler extends SimpleChannelInboundHandler<Buffer> {
140         volatile Channel ch;
141         final BlockingQueue<Byte> queue = new LinkedBlockingQueue<>();
142         final CountDownLatch halfClosure = new CountDownLatch(1);
143         final CountDownLatch closure = new CountDownLatch(1);
144         final AtomicInteger halfClosureCount = new AtomicInteger();
145 
146         @Override
147         public void channelActive(ChannelHandlerContext ctx) throws Exception {
148             ch = ctx.channel();
149         }
150 
151         @Override
152         public void channelInactive(ChannelHandlerContext ctx) throws Exception {
153             closure.countDown();
154         }
155 
156         @Override
157         public void messageReceived(ChannelHandlerContext ctx, Buffer msg) throws Exception {
158             queue.offer(msg.readByte());
159         }
160 
161         @Override
162         public void channelShutdown(ChannelHandlerContext ctx, ChannelShutdownDirection direction) throws Exception {
163             if (direction == ChannelShutdownDirection.Inbound) {
164                 halfClosureCount.incrementAndGet();
165                 halfClosure.countDown();
166             }
167         }
168     }
169 }