View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.channel.Channel;
21  import io.netty5.channel.ChannelHandlerContext;
22  import io.netty5.channel.ChannelInitializer;
23  import io.netty5.channel.ChannelOption;
24  import io.netty5.channel.ChannelPipeline;
25  import io.netty5.channel.SimpleChannelInboundHandler;
26  import io.netty5.handler.codec.LineBasedFrameDecoder;
27  import io.netty5.handler.codec.string.StringDecoder;
28  import io.netty5.handler.codec.string.StringEncoder;
29  import io.netty5.handler.logging.LogLevel;
30  import io.netty5.handler.logging.LoggingHandler;
31  import io.netty5.handler.ssl.OpenSsl;
32  import io.netty5.handler.ssl.SslContext;
33  import io.netty5.handler.ssl.SslContextBuilder;
34  import io.netty5.handler.ssl.SslHandler;
35  import io.netty5.handler.ssl.SslProvider;
36  import io.netty5.handler.ssl.util.SelfSignedCertificate;
37  import io.netty5.util.concurrent.Future;
38  import io.netty5.util.internal.logging.InternalLogger;
39  import io.netty5.util.internal.logging.InternalLoggerFactory;
40  import org.junit.jupiter.api.TestInfo;
41  import org.junit.jupiter.api.Timeout;
42  import org.junit.jupiter.params.ParameterizedTest;
43  import org.junit.jupiter.params.provider.MethodSource;
44  
45  import javax.net.ssl.SSLEngine;
46  import java.io.File;
47  import java.io.IOException;
48  import java.security.cert.CertificateException;
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.List;
52  import java.util.concurrent.TimeUnit;
53  import java.util.concurrent.atomic.AtomicReference;
54  
55  import static io.netty5.buffer.api.DefaultBufferAllocators.offHeapAllocator;
56  import static org.junit.jupiter.api.Assertions.assertEquals;
57  import static org.junit.jupiter.api.Assertions.assertNotNull;
58  import static org.junit.jupiter.api.Assertions.assertTrue;
59  
60  public class SocketStartTlsTest extends AbstractSocketTest {
61      private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
62  
63      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
64  
65      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
66      private static final File CERT_FILE;
67      private static final File KEY_FILE;
68  
69      static {
70          SelfSignedCertificate ssc;
71          try {
72              ssc = new SelfSignedCertificate();
73          } catch (CertificateException e) {
74              throw new Error(e);
75          }
76          CERT_FILE = ssc.certificate();
77          KEY_FILE = ssc.privateKey();
78      }
79  
80      public static Collection<Object[]> data() throws Exception {
81          List<SslContext> serverContexts = new ArrayList<>();
82          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
83  
84          List<SslContext> clientContexts = new ArrayList<>();
85          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
86  
87          boolean hasOpenSsl = OpenSsl.isAvailable();
88          if (hasOpenSsl) {
89              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
90                                                  .sslProvider(SslProvider.OPENSSL).build());
91              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
92                                                  .trustManager(CERT_FILE).build());
93          } else {
94              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
95          }
96  
97          List<Object[]> params = new ArrayList<>();
98          for (SslContext sc: serverContexts) {
99              for (SslContext cc: clientContexts) {
100                 params.add(new Object[] { sc, cc });
101             }
102         }
103         return params;
104     }
105 
106     @ParameterizedTest(name = PARAMETERIZED_NAME)
107     @MethodSource("data")
108     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
109     public void testStartTls(SslContext serverCtx, SslContext clientCtx, TestInfo testInfo) throws Throwable {
110         run(testInfo, (sb, cb) -> testStartTls(sb, cb, serverCtx, clientCtx));
111     }
112 
113     public void testStartTls(ServerBootstrap sb, Bootstrap cb,
114                                     SslContext serverCtx, SslContext clientCtx) throws Throwable {
115         testStartTls(sb, cb, serverCtx, clientCtx, true);
116     }
117 
118     @ParameterizedTest(name = PARAMETERIZED_NAME)
119     @MethodSource("data")
120     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
121     public void testStartTlsNotAutoRead(SslContext serverCtx, SslContext clientCtx,
122                                         TestInfo testInfo) throws Throwable {
123         run(testInfo, (sb, cb) -> testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx));
124     }
125 
126     public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
127                                                SslContext serverCtx, SslContext clientCtx) throws Throwable {
128         testStartTls(sb, cb, serverCtx, clientCtx, false);
129     }
130 
131     private void testStartTls(ServerBootstrap sb, Bootstrap cb,
132                                      SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
133         sb.childOption(ChannelOption.AUTO_READ, autoRead);
134         cb.option(ChannelOption.AUTO_READ, autoRead);
135 
136         SSLEngine sse = serverCtx.newEngine(offHeapAllocator());
137         SSLEngine cse = clientCtx.newEngine(offHeapAllocator());
138 
139         final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
140         final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
141 
142         sb.childHandler(new ChannelInitializer<>() {
143             @Override
144             public void initChannel(Channel sch) {
145                 ChannelPipeline p = sch.pipeline();
146                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
147                 p.addLast(new LineBasedFrameDecoder(64));
148                 p.addLast(new StringDecoder(), new StringEncoder());
149                 p.addLast(sh);
150             }
151         });
152 
153         cb.handler(new ChannelInitializer<>() {
154             @Override
155             public void initChannel(Channel sch) {
156                 ChannelPipeline p = sch.pipeline();
157                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
158                 p.addLast(new LineBasedFrameDecoder(64));
159                 p.addLast(new StringDecoder(), new StringEncoder());
160                 p.addLast(ch);
161             }
162         });
163 
164         Channel sc = sb.bind().asStage().get();
165         Channel cc = cb.connect(sc.localAddress()).asStage().get();
166 
167         while (cc.isActive()) {
168             if (sh.exception.get() != null) {
169                 break;
170             }
171             if (ch.exception.get() != null) {
172                 break;
173             }
174 
175             try {
176                 Thread.sleep(50);
177             } catch (InterruptedException e) {
178                 // Ignore.
179             }
180         }
181 
182         while (sh.channel.isActive()) {
183             if (sh.exception.get() != null) {
184                 break;
185             }
186             if (ch.exception.get() != null) {
187                 break;
188             }
189 
190             try {
191                 Thread.sleep(50);
192             } catch (InterruptedException e) {
193                 // Ignore.
194             }
195         }
196 
197         sh.channel.close().asStage().await();
198         cc.close().asStage().await();
199         sc.close().asStage().await();
200 
201         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
202             throw sh.exception.get();
203         }
204         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
205             throw ch.exception.get();
206         }
207         if (sh.exception.get() != null) {
208             throw sh.exception.get();
209         }
210         if (ch.exception.get() != null) {
211             throw ch.exception.get();
212         }
213     }
214 
215     private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
216         private final SslHandler sslHandler;
217         private final boolean autoRead;
218         private Future<Channel> handshakeFuture;
219         final AtomicReference<Throwable> exception = new AtomicReference<>();
220 
221         StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
222             engine.setUseClientMode(true);
223             sslHandler = new SslHandler(engine);
224             this.autoRead = autoRead;
225         }
226 
227         @Override
228         public void channelActive(ChannelHandlerContext ctx)
229                 throws Exception {
230             if (!autoRead) {
231                 ctx.read();
232             }
233             ctx.writeAndFlush("StartTlsRequest\n");
234         }
235 
236         @Override
237         public void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
238             if ("StartTlsResponse".equals(msg)) {
239                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
240                 handshakeFuture = sslHandler.handshakeFuture();
241                 ctx.writeAndFlush("EncryptedRequest\n");
242                 return;
243             }
244 
245             assertEquals("EncryptedResponse", msg);
246             assertNotNull(handshakeFuture);
247             assertTrue(handshakeFuture.isSuccess());
248             ctx.close();
249         }
250 
251         @Override
252         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
253             if (!autoRead) {
254                 ctx.read();
255             }
256         }
257 
258         @Override
259         public void channelExceptionCaught(ChannelHandlerContext ctx,
260                                            Throwable cause) throws Exception {
261             if (logger.isWarnEnabled()) {
262                 logger.warn("Unexpected exception from the client side", cause);
263             }
264 
265             exception.compareAndSet(null, cause);
266             ctx.close();
267         }
268     }
269 
270     private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
271         private final SslHandler sslHandler;
272         private final boolean autoRead;
273         volatile Channel channel;
274         final AtomicReference<Throwable> exception = new AtomicReference<>();
275 
276         StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
277             engine.setUseClientMode(false);
278             sslHandler = new SslHandler(engine, true);
279             this.autoRead = autoRead;
280         }
281 
282         @Override
283         public void channelActive(ChannelHandlerContext ctx) throws Exception {
284             channel = ctx.channel();
285             if (!autoRead) {
286                 ctx.read();
287             }
288         }
289 
290         @Override
291         public void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
292             if ("StartTlsRequest".equals(msg)) {
293                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
294                 ctx.writeAndFlush("StartTlsResponse\n");
295                 return;
296             }
297 
298             assertEquals("EncryptedRequest", msg);
299             ctx.writeAndFlush("EncryptedResponse\n");
300         }
301 
302         @Override
303         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
304             if (!autoRead) {
305                 ctx.read();
306             }
307         }
308 
309         @Override
310         public void channelExceptionCaught(ChannelHandlerContext ctx,
311                                            Throwable cause) throws Exception {
312             if (logger.isWarnEnabled()) {
313                 logger.warn("Unexpected exception from the server side", cause);
314             }
315 
316             exception.compareAndSet(null, cause);
317             ctx.close();
318         }
319     }
320 }