View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.PooledByteBufAllocator;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelInitializer;
24  import io.netty.channel.ChannelOption;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.codec.LineBasedFrameDecoder;
28  import io.netty.handler.codec.string.StringDecoder;
29  import io.netty.handler.codec.string.StringEncoder;
30  import io.netty.handler.logging.LogLevel;
31  import io.netty.handler.logging.LoggingHandler;
32  import io.netty.handler.ssl.OpenSsl;
33  import io.netty.handler.ssl.SslContext;
34  import io.netty.handler.ssl.SslContextBuilder;
35  import io.netty.handler.ssl.SslHandler;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.handler.ssl.util.SelfSignedCertificate;
38  import io.netty.util.concurrent.DefaultEventExecutorGroup;
39  import io.netty.util.concurrent.EventExecutorGroup;
40  import io.netty.util.concurrent.Future;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import org.junit.jupiter.api.AfterAll;
45  import org.junit.jupiter.api.BeforeAll;
46  import org.junit.jupiter.api.TestInfo;
47  import org.junit.jupiter.api.Timeout;
48  import org.junit.jupiter.params.ParameterizedTest;
49  import org.junit.jupiter.params.provider.MethodSource;
50  
51  import javax.net.ssl.SSLEngine;
52  import java.io.File;
53  import java.io.IOException;
54  import java.security.cert.CertificateException;
55  import java.util.ArrayList;
56  import java.util.Collection;
57  import java.util.List;
58  import java.util.concurrent.TimeUnit;
59  import java.util.concurrent.atomic.AtomicReference;
60  
61  import static org.junit.jupiter.api.Assertions.assertEquals;
62  import static org.junit.jupiter.api.Assertions.assertNotNull;
63  import static org.junit.jupiter.api.Assertions.assertTrue;
64  
65  public class SocketStartTlsTest extends AbstractSocketTest {
66      private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67  
68      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69  
70      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71      private static final File CERT_FILE;
72      private static final File KEY_FILE;
73      private static EventExecutorGroup executor;
74  
75      static {
76          SelfSignedCertificate ssc;
77          try {
78              ssc = new SelfSignedCertificate();
79          } catch (CertificateException e) {
80              throw new Error(e);
81          }
82          CERT_FILE = ssc.certificate();
83          KEY_FILE = ssc.privateKey();
84      }
85  
86      public static Collection<Object[]> data() throws Exception {
87          List<SslContext> serverContexts = new ArrayList<SslContext>();
88          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
89  
90          List<SslContext> clientContexts = new ArrayList<SslContext>();
91          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
92  
93          boolean hasOpenSsl = OpenSsl.isAvailable();
94          if (hasOpenSsl) {
95              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
96                                                  .sslProvider(SslProvider.OPENSSL).build());
97              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
98                                                  .trustManager(CERT_FILE).build());
99          } else {
100             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
101         }
102 
103         List<Object[]> params = new ArrayList<Object[]>();
104         for (SslContext sc: serverContexts) {
105             for (SslContext cc: clientContexts) {
106                 params.add(new Object[] { sc, cc });
107             }
108         }
109         return params;
110     }
111 
112     @BeforeAll
113     public static void createExecutor() {
114         executor = new DefaultEventExecutorGroup(2);
115     }
116 
117     @AfterAll
118     public static void shutdownExecutor() throws Exception {
119         executor.shutdownGracefully().sync();
120     }
121 
122     @ParameterizedTest(name = PARAMETERIZED_NAME)
123     @MethodSource("data")
124     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
125     public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
126             throws Throwable {
127         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
128             @Override
129             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
130                 testStartTls(sb, cb, serverCtx, clientCtx);
131             }
132         });
133     }
134 
135     public void testStartTls(ServerBootstrap sb, Bootstrap cb,
136                              SslContext serverCtx, SslContext clientCtx) throws Throwable {
137         testStartTls(sb, cb, serverCtx, clientCtx, true);
138     }
139 
140     @ParameterizedTest(name = PARAMETERIZED_NAME)
141     @MethodSource("data")
142     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
143     public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
144                                         TestInfo testInfo) throws Throwable {
145         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
146             @Override
147             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
148                 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
149             }
150         });
151     }
152 
153     public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
154                                         SslContext serverCtx, SslContext clientCtx) throws Throwable {
155         testStartTls(sb, cb, serverCtx, clientCtx, false);
156     }
157 
158     private void testStartTls(ServerBootstrap sb, Bootstrap cb,
159                               SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
160         sb.childOption(ChannelOption.AUTO_READ, autoRead);
161         cb.option(ChannelOption.AUTO_READ, autoRead);
162 
163         final EventExecutorGroup executor = SocketStartTlsTest.executor;
164         SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
165         SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
166 
167         final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
168         final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
169 
170         sb.childHandler(new ChannelInitializer<Channel>() {
171             @Override
172             public void initChannel(Channel sch) throws Exception {
173                 ChannelPipeline p = sch.pipeline();
174                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
175                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
176                 p.addLast(executor, sh);
177             }
178         });
179 
180         cb.handler(new ChannelInitializer<Channel>() {
181             @Override
182             public void initChannel(Channel sch) throws Exception {
183                 ChannelPipeline p = sch.pipeline();
184                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
185                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
186                 p.addLast(executor, ch);
187             }
188         });
189 
190         Channel sc = sb.bind().sync().channel();
191         Channel cc = cb.connect(sc.localAddress()).sync().channel();
192 
193         while (cc.isActive()) {
194             if (sh.exception.get() != null) {
195                 break;
196             }
197             if (ch.exception.get() != null) {
198                 break;
199             }
200 
201             Thread.sleep(50);
202         }
203 
204         while (sh.channel.isActive()) {
205             if (sh.exception.get() != null) {
206                 break;
207             }
208             if (ch.exception.get() != null) {
209                 break;
210             }
211 
212             Thread.sleep(50);
213         }
214 
215         sh.channel.close().awaitUninterruptibly();
216         cc.close().awaitUninterruptibly();
217         sc.close().awaitUninterruptibly();
218 
219         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
220             throw sh.exception.get();
221         }
222         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
223             throw ch.exception.get();
224         }
225         if (sh.exception.get() != null) {
226             throw sh.exception.get();
227         }
228         if (ch.exception.get() != null) {
229             throw ch.exception.get();
230         }
231     }
232 
233     private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
234         private final SslHandler sslHandler;
235         private final boolean autoRead;
236         private Future<Channel> handshakeFuture;
237         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
238 
239         StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
240             engine.setUseClientMode(true);
241             sslHandler = new SslHandler(engine);
242             this.autoRead = autoRead;
243         }
244 
245         @Override
246         public void channelActive(ChannelHandlerContext ctx)
247                 throws Exception {
248             if (!autoRead) {
249                 ctx.read();
250             }
251             ctx.writeAndFlush("StartTlsRequest\n");
252         }
253 
254         @Override
255         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
256             if ("StartTlsResponse".equals(msg)) {
257                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
258                 handshakeFuture = sslHandler.handshakeFuture();
259                 ctx.writeAndFlush("EncryptedRequest\n");
260                 return;
261             }
262 
263             assertEquals("EncryptedResponse", msg);
264             assertNotNull(handshakeFuture);
265             assertTrue(handshakeFuture.isSuccess());
266             ctx.close();
267         }
268 
269         @Override
270         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
271             if (!autoRead) {
272                 ctx.read();
273             }
274         }
275 
276         @Override
277         public void exceptionCaught(ChannelHandlerContext ctx,
278                 Throwable cause) throws Exception {
279             if (logger.isWarnEnabled()) {
280                 logger.warn("Unexpected exception from the client side", cause);
281             }
282 
283             exception.compareAndSet(null, cause);
284             ctx.close();
285         }
286     }
287 
288     private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
289         private final SslHandler sslHandler;
290         private final boolean autoRead;
291         volatile Channel channel;
292         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
293 
294         StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
295             engine.setUseClientMode(false);
296             sslHandler = new SslHandler(engine, true);
297             this.autoRead = autoRead;
298         }
299 
300         @Override
301         public void channelActive(ChannelHandlerContext ctx) throws Exception {
302             channel = ctx.channel();
303             if (!autoRead) {
304                 ctx.read();
305             }
306         }
307 
308         @Override
309         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
310             if ("StartTlsRequest".equals(msg)) {
311                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
312                 ctx.writeAndFlush("StartTlsResponse\n");
313                 return;
314             }
315 
316             assertEquals("EncryptedRequest", msg);
317             ctx.writeAndFlush("EncryptedResponse\n");
318         }
319 
320         @Override
321         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
322             if (!autoRead) {
323                 ctx.read();
324             }
325         }
326 
327         @Override
328         public void exceptionCaught(ChannelHandlerContext ctx,
329                                     Throwable cause) throws Exception {
330             if (logger.isWarnEnabled()) {
331                 logger.warn("Unexpected exception from the server side", cause);
332             }
333 
334             exception.compareAndSet(null, cause);
335             ctx.close();
336         }
337     }
338 }