View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.PooledByteBufAllocator;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelInitializer;
24  import io.netty.channel.ChannelOption;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.codec.LineBasedFrameDecoder;
28  import io.netty.handler.codec.string.StringDecoder;
29  import io.netty.handler.codec.string.StringEncoder;
30  import io.netty.handler.logging.LogLevel;
31  import io.netty.handler.logging.LoggingHandler;
32  import io.netty.handler.ssl.OpenSsl;
33  import io.netty.handler.ssl.SslContext;
34  import io.netty.handler.ssl.SslContextBuilder;
35  import io.netty.handler.ssl.SslHandler;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.pkitesting.CertificateBuilder;
38  import io.netty.pkitesting.X509Bundle;
39  import io.netty.util.concurrent.DefaultEventExecutorGroup;
40  import io.netty.util.concurrent.EventExecutorGroup;
41  import io.netty.util.concurrent.Future;
42  import io.netty.util.internal.logging.InternalLogger;
43  import io.netty.util.internal.logging.InternalLoggerFactory;
44  
45  import org.junit.jupiter.api.AfterAll;
46  import org.junit.jupiter.api.BeforeAll;
47  import org.junit.jupiter.api.TestInfo;
48  import org.junit.jupiter.api.Timeout;
49  import org.junit.jupiter.params.ParameterizedTest;
50  import org.junit.jupiter.params.provider.MethodSource;
51  
52  import javax.net.ssl.SSLEngine;
53  import java.io.File;
54  import java.io.IOException;
55  import java.util.ArrayList;
56  import java.util.Collection;
57  import java.util.List;
58  import java.util.concurrent.TimeUnit;
59  import java.util.concurrent.atomic.AtomicReference;
60  
61  import static org.junit.jupiter.api.Assertions.assertEquals;
62  import static org.junit.jupiter.api.Assertions.assertNotNull;
63  import static org.junit.jupiter.api.Assertions.assertTrue;
64  
65  public class SocketStartTlsTest extends AbstractSocketTest {
66      private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67  
68      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69  
70      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71      private static final File CERT_FILE;
72      private static final File KEY_FILE;
73      private static EventExecutorGroup executor;
74  
75      static {
76          try {
77              X509Bundle cert = new CertificateBuilder()
78                      .subject("cn=localhost")
79                      .setIsCertificateAuthority(true)
80                      .buildSelfSigned();
81              CERT_FILE = cert.toTempCertChainPem();
82              KEY_FILE = cert.toTempPrivateKeyPem();
83          } catch (Exception e) {
84              throw new ExceptionInInitializerError(e);
85          }
86      }
87  
88      public static Collection<Object[]> data() throws Exception {
89          List<SslContext> serverContexts = new ArrayList<SslContext>();
90          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
91  
92          List<SslContext> clientContexts = new ArrayList<SslContext>();
93          clientContexts.add(SslContextBuilder.forClient()
94                  .sslProvider(SslProvider.JDK)
95                  .trustManager(CERT_FILE)
96                  .endpointIdentificationAlgorithm(null)
97                  .build());
98  
99          boolean hasOpenSsl = OpenSsl.isAvailable();
100         if (hasOpenSsl) {
101             serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
102                                                 .sslProvider(SslProvider.OPENSSL).build());
103             clientContexts.add(SslContextBuilder.forClient()
104                     .sslProvider(SslProvider.OPENSSL)
105                     .trustManager(CERT_FILE)
106                     .endpointIdentificationAlgorithm(null)
107                     .build());
108         } else {
109             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
110         }
111 
112         List<Object[]> params = new ArrayList<Object[]>();
113         for (SslContext sc: serverContexts) {
114             for (SslContext cc: clientContexts) {
115                 params.add(new Object[] { sc, cc });
116             }
117         }
118         return params;
119     }
120 
121     @BeforeAll
122     public static void createExecutor() {
123         executor = new DefaultEventExecutorGroup(2);
124     }
125 
126     @AfterAll
127     public static void shutdownExecutor() throws Exception {
128         executor.shutdownGracefully().sync();
129     }
130 
131     @ParameterizedTest(name = PARAMETERIZED_NAME)
132     @MethodSource("data")
133     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
134     public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
135             throws Throwable {
136         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
137             @Override
138             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
139                 testStartTls(sb, cb, serverCtx, clientCtx);
140             }
141         });
142     }
143 
144     public void testStartTls(ServerBootstrap sb, Bootstrap cb,
145                              SslContext serverCtx, SslContext clientCtx) throws Throwable {
146         testStartTls(sb, cb, serverCtx, clientCtx, true);
147     }
148 
149     @ParameterizedTest(name = PARAMETERIZED_NAME)
150     @MethodSource("data")
151     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
152     public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
153                                         TestInfo testInfo) throws Throwable {
154         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
155             @Override
156             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
157                 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
158             }
159         });
160     }
161 
162     public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
163                                         SslContext serverCtx, SslContext clientCtx) throws Throwable {
164         testStartTls(sb, cb, serverCtx, clientCtx, false);
165     }
166 
167     private void testStartTls(ServerBootstrap sb, Bootstrap cb,
168                               SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
169         sb.childOption(ChannelOption.AUTO_READ, autoRead);
170         cb.option(ChannelOption.AUTO_READ, autoRead);
171 
172         final EventExecutorGroup executor = SocketStartTlsTest.executor;
173         SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
174         SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
175 
176         final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
177         final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
178 
179         sb.childHandler(new ChannelInitializer<Channel>() {
180             @Override
181             public void initChannel(Channel sch) throws Exception {
182                 ChannelPipeline p = sch.pipeline();
183                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
184                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
185                 p.addLast(executor, sh);
186             }
187         });
188 
189         cb.handler(new ChannelInitializer<Channel>() {
190             @Override
191             public void initChannel(Channel sch) throws Exception {
192                 ChannelPipeline p = sch.pipeline();
193                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
194                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
195                 p.addLast(executor, ch);
196             }
197         });
198 
199         Channel sc = sb.bind().sync().channel();
200         Channel cc = cb.connect(sc.localAddress()).sync().channel();
201 
202         while (cc.isActive()) {
203             if (sh.exception.get() != null) {
204                 break;
205             }
206             if (ch.exception.get() != null) {
207                 break;
208             }
209 
210             Thread.sleep(50);
211         }
212 
213         while (sh.channel.isActive()) {
214             if (sh.exception.get() != null) {
215                 break;
216             }
217             if (ch.exception.get() != null) {
218                 break;
219             }
220 
221             Thread.sleep(50);
222         }
223 
224         sh.channel.close().awaitUninterruptibly();
225         cc.close().awaitUninterruptibly();
226         sc.close().awaitUninterruptibly();
227 
228         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
229             throw sh.exception.get();
230         }
231         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
232             throw ch.exception.get();
233         }
234         if (sh.exception.get() != null) {
235             throw sh.exception.get();
236         }
237         if (ch.exception.get() != null) {
238             throw ch.exception.get();
239         }
240     }
241 
242     private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
243         private final SslHandler sslHandler;
244         private final boolean autoRead;
245         private Future<Channel> handshakeFuture;
246         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
247 
248         StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
249             engine.setUseClientMode(true);
250             sslHandler = new SslHandler(engine);
251             this.autoRead = autoRead;
252         }
253 
254         @Override
255         public void channelActive(ChannelHandlerContext ctx)
256                 throws Exception {
257             if (!autoRead) {
258                 ctx.read();
259             }
260             ctx.writeAndFlush("StartTlsRequest\n");
261         }
262 
263         @Override
264         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
265             if ("StartTlsResponse".equals(msg)) {
266                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
267                 handshakeFuture = sslHandler.handshakeFuture();
268                 ctx.writeAndFlush("EncryptedRequest\n");
269                 return;
270             }
271 
272             assertEquals("EncryptedResponse", msg);
273             assertNotNull(handshakeFuture);
274             assertTrue(handshakeFuture.isSuccess());
275             ctx.close();
276         }
277 
278         @Override
279         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
280             if (!autoRead) {
281                 ctx.read();
282             }
283         }
284 
285         @Override
286         public void exceptionCaught(ChannelHandlerContext ctx,
287                 Throwable cause) throws Exception {
288             if (logger.isWarnEnabled()) {
289                 logger.warn("Unexpected exception from the client side", cause);
290             }
291 
292             exception.compareAndSet(null, cause);
293             ctx.close();
294         }
295     }
296 
297     private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
298         private final SslHandler sslHandler;
299         private final boolean autoRead;
300         volatile Channel channel;
301         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
302 
303         StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
304             engine.setUseClientMode(false);
305             sslHandler = new SslHandler(engine, true);
306             this.autoRead = autoRead;
307         }
308 
309         @Override
310         public void channelActive(ChannelHandlerContext ctx) throws Exception {
311             channel = ctx.channel();
312             if (!autoRead) {
313                 ctx.read();
314             }
315         }
316 
317         @Override
318         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
319             if ("StartTlsRequest".equals(msg)) {
320                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
321                 ctx.writeAndFlush("StartTlsResponse\n");
322                 return;
323             }
324 
325             assertEquals("EncryptedRequest", msg);
326             ctx.writeAndFlush("EncryptedResponse\n");
327         }
328 
329         @Override
330         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
331             if (!autoRead) {
332                 ctx.read();
333             }
334         }
335 
336         @Override
337         public void exceptionCaught(ChannelHandlerContext ctx,
338                                     Throwable cause) throws Exception {
339             if (logger.isWarnEnabled()) {
340                 logger.warn("Unexpected exception from the server side", cause);
341             }
342 
343             exception.compareAndSet(null, cause);
344             ctx.close();
345         }
346     }
347 }