View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.PooledByteBufAllocator;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelInitializer;
24  import io.netty.channel.ChannelOption;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.codec.LineBasedFrameDecoder;
28  import io.netty.handler.codec.string.StringDecoder;
29  import io.netty.handler.codec.string.StringEncoder;
30  import io.netty.handler.logging.LogLevel;
31  import io.netty.handler.logging.LoggingHandler;
32  import io.netty.handler.ssl.OpenSsl;
33  import io.netty.handler.ssl.SslContext;
34  import io.netty.handler.ssl.SslContextBuilder;
35  import io.netty.handler.ssl.SslHandler;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.handler.ssl.util.SelfSignedCertificate;
38  import io.netty.util.concurrent.DefaultEventExecutorGroup;
39  import io.netty.util.concurrent.EventExecutorGroup;
40  import io.netty.util.concurrent.Future;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import org.junit.jupiter.api.AfterAll;
45  import org.junit.jupiter.api.BeforeAll;
46  import org.junit.jupiter.api.TestInfo;
47  import org.junit.jupiter.api.Timeout;
48  import org.junit.jupiter.params.ParameterizedTest;
49  import org.junit.jupiter.params.provider.MethodSource;
50  
51  import javax.net.ssl.SSLEngine;
52  import java.io.File;
53  import java.io.IOException;
54  import java.security.cert.CertificateException;
55  import java.util.ArrayList;
56  import java.util.Collection;
57  import java.util.List;
58  import java.util.concurrent.TimeUnit;
59  import java.util.concurrent.atomic.AtomicReference;
60  
61  import static org.junit.jupiter.api.Assertions.assertEquals;
62  import static org.junit.jupiter.api.Assertions.assertNotNull;
63  import static org.junit.jupiter.api.Assertions.assertTrue;
64  
65  public class SocketStartTlsTest extends AbstractSocketTest {
66      private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67  
68      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69  
70      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71      private static final File CERT_FILE;
72      private static final File KEY_FILE;
73      private static EventExecutorGroup executor;
74  
75      static {
76          SelfSignedCertificate ssc;
77          try {
78              ssc = new SelfSignedCertificate();
79          } catch (CertificateException e) {
80              throw new Error(e);
81          }
82          CERT_FILE = ssc.certificate();
83          KEY_FILE = ssc.privateKey();
84      }
85  
86      public static Collection<Object[]> data() throws Exception {
87          List<SslContext> serverContexts = new ArrayList<SslContext>();
88          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
89  
90          List<SslContext> clientContexts = new ArrayList<SslContext>();
91          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
92  
93          boolean hasOpenSsl = OpenSsl.isAvailable();
94          if (hasOpenSsl) {
95              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
96                                                  .sslProvider(SslProvider.OPENSSL).build());
97              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
98                                                  .trustManager(CERT_FILE).build());
99          } else {
100             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
101         }
102 
103         List<Object[]> params = new ArrayList<Object[]>();
104         for (SslContext sc: serverContexts) {
105             for (SslContext cc: clientContexts) {
106                 params.add(new Object[] { sc, cc });
107             }
108         }
109         return params;
110     }
111 
112     @BeforeAll
113     public static void createExecutor() {
114         executor = new DefaultEventExecutorGroup(2);
115     }
116 
117     @AfterAll
118     public static void shutdownExecutor() throws Exception {
119         executor.shutdownGracefully().sync();
120     }
121 
122     @ParameterizedTest(name = PARAMETERIZED_NAME)
123     @MethodSource("data")
124     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
125     public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
126             throws Throwable {
127         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
128             @Override
129             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
130                 testStartTls(sb, cb, serverCtx, clientCtx);
131             }
132         });
133     }
134 
135     public void testStartTls(ServerBootstrap sb, Bootstrap cb,
136                              SslContext serverCtx, SslContext clientCtx) throws Throwable {
137         testStartTls(sb, cb, serverCtx, clientCtx, true);
138     }
139 
140     @ParameterizedTest(name = PARAMETERIZED_NAME)
141     @MethodSource("data")
142     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
143     public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
144                                         TestInfo testInfo) throws Throwable {
145         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
146             @Override
147             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
148                 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
149             }
150         });
151     }
152 
153     public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
154                                         SslContext serverCtx, SslContext clientCtx) throws Throwable {
155         testStartTls(sb, cb, serverCtx, clientCtx, false);
156     }
157 
158     private void testStartTls(ServerBootstrap sb, Bootstrap cb,
159                               SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
160         sb.childOption(ChannelOption.AUTO_READ, autoRead);
161         cb.option(ChannelOption.AUTO_READ, autoRead);
162 
163         final EventExecutorGroup executor = SocketStartTlsTest.executor;
164         SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
165         SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
166 
167         final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
168         final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
169 
170         sb.childHandler(new ChannelInitializer<Channel>() {
171             @Override
172             public void initChannel(Channel sch) throws Exception {
173                 ChannelPipeline p = sch.pipeline();
174                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
175                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
176                 p.addLast(executor, sh);
177             }
178         });
179 
180         cb.handler(new ChannelInitializer<Channel>() {
181             @Override
182             public void initChannel(Channel sch) throws Exception {
183                 ChannelPipeline p = sch.pipeline();
184                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
185                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
186                 p.addLast(executor, ch);
187             }
188         });
189 
190         Channel sc = sb.bind().sync().channel();
191         Channel cc = cb.connect(sc.localAddress()).sync().channel();
192 
193         while (cc.isActive()) {
194             if (sh.exception.get() != null) {
195                 break;
196             }
197             if (ch.exception.get() != null) {
198                 break;
199             }
200 
201             try {
202                 Thread.sleep(50);
203             } catch (InterruptedException e) {
204                 // Ignore.
205             }
206         }
207 
208         while (sh.channel.isActive()) {
209             if (sh.exception.get() != null) {
210                 break;
211             }
212             if (ch.exception.get() != null) {
213                 break;
214             }
215 
216             try {
217                 Thread.sleep(50);
218             } catch (InterruptedException e) {
219                 // Ignore.
220             }
221         }
222 
223         sh.channel.close().awaitUninterruptibly();
224         cc.close().awaitUninterruptibly();
225         sc.close().awaitUninterruptibly();
226 
227         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
228             throw sh.exception.get();
229         }
230         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
231             throw ch.exception.get();
232         }
233         if (sh.exception.get() != null) {
234             throw sh.exception.get();
235         }
236         if (ch.exception.get() != null) {
237             throw ch.exception.get();
238         }
239     }
240 
241     private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
242         private final SslHandler sslHandler;
243         private final boolean autoRead;
244         private Future<Channel> handshakeFuture;
245         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
246 
247         StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
248             engine.setUseClientMode(true);
249             sslHandler = new SslHandler(engine);
250             this.autoRead = autoRead;
251         }
252 
253         @Override
254         public void channelActive(ChannelHandlerContext ctx)
255                 throws Exception {
256             if (!autoRead) {
257                 ctx.read();
258             }
259             ctx.writeAndFlush("StartTlsRequest\n");
260         }
261 
262         @Override
263         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
264             if ("StartTlsResponse".equals(msg)) {
265                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
266                 handshakeFuture = sslHandler.handshakeFuture();
267                 ctx.writeAndFlush("EncryptedRequest\n");
268                 return;
269             }
270 
271             assertEquals("EncryptedResponse", msg);
272             assertNotNull(handshakeFuture);
273             assertTrue(handshakeFuture.isSuccess());
274             ctx.close();
275         }
276 
277         @Override
278         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
279             if (!autoRead) {
280                 ctx.read();
281             }
282         }
283 
284         @Override
285         public void exceptionCaught(ChannelHandlerContext ctx,
286                 Throwable cause) throws Exception {
287             if (logger.isWarnEnabled()) {
288                 logger.warn("Unexpected exception from the client side", cause);
289             }
290 
291             exception.compareAndSet(null, cause);
292             ctx.close();
293         }
294     }
295 
296     private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
297         private final SslHandler sslHandler;
298         private final boolean autoRead;
299         volatile Channel channel;
300         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
301 
302         StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
303             engine.setUseClientMode(false);
304             sslHandler = new SslHandler(engine, true);
305             this.autoRead = autoRead;
306         }
307 
308         @Override
309         public void channelActive(ChannelHandlerContext ctx) throws Exception {
310             channel = ctx.channel();
311             if (!autoRead) {
312                 ctx.read();
313             }
314         }
315 
316         @Override
317         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
318             if ("StartTlsRequest".equals(msg)) {
319                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
320                 ctx.writeAndFlush("StartTlsResponse\n");
321                 return;
322             }
323 
324             assertEquals("EncryptedRequest", msg);
325             ctx.writeAndFlush("EncryptedResponse\n");
326         }
327 
328         @Override
329         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
330             if (!autoRead) {
331                 ctx.read();
332             }
333         }
334 
335         @Override
336         public void exceptionCaught(ChannelHandlerContext ctx,
337                                     Throwable cause) throws Exception {
338             if (logger.isWarnEnabled()) {
339                 logger.warn("Unexpected exception from the server side", cause);
340             }
341 
342             exception.compareAndSet(null, cause);
343             ctx.close();
344         }
345     }
346 }