View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.PooledByteBufAllocator;
21  import io.netty.channel.Channel;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelInitializer;
24  import io.netty.channel.ChannelOption;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.codec.LineBasedFrameDecoder;
28  import io.netty.handler.codec.string.StringDecoder;
29  import io.netty.handler.codec.string.StringEncoder;
30  import io.netty.handler.logging.LogLevel;
31  import io.netty.handler.logging.LoggingHandler;
32  import io.netty.handler.ssl.OpenSsl;
33  import io.netty.handler.ssl.SslContext;
34  import io.netty.handler.ssl.SslContextBuilder;
35  import io.netty.handler.ssl.SslHandler;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.handler.ssl.util.SelfSignedCertificate;
38  import io.netty.util.concurrent.DefaultEventExecutorGroup;
39  import io.netty.util.concurrent.EventExecutorGroup;
40  import io.netty.util.concurrent.Future;
41  import io.netty.util.internal.logging.InternalLogger;
42  import io.netty.util.internal.logging.InternalLoggerFactory;
43  
44  import org.junit.jupiter.api.AfterAll;
45  import org.junit.jupiter.api.BeforeAll;
46  import org.junit.jupiter.api.TestInfo;
47  import org.junit.jupiter.api.Timeout;
48  import org.junit.jupiter.params.ParameterizedTest;
49  import org.junit.jupiter.params.provider.MethodSource;
50  
51  import javax.net.ssl.SSLEngine;
52  import java.io.File;
53  import java.io.IOException;
54  import java.security.cert.CertificateException;
55  import java.util.ArrayList;
56  import java.util.Collection;
57  import java.util.List;
58  import java.util.concurrent.TimeUnit;
59  import java.util.concurrent.atomic.AtomicReference;
60  
61  import static org.junit.jupiter.api.Assertions.assertEquals;
62  import static org.junit.jupiter.api.Assertions.assertNotNull;
63  import static org.junit.jupiter.api.Assertions.assertTrue;
64  
65  public class SocketStartTlsTest extends AbstractSocketTest {
66      private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67  
68      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69  
70      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71      private static final File CERT_FILE;
72      private static final File KEY_FILE;
73      private static EventExecutorGroup executor;
74  
75      static {
76          SelfSignedCertificate ssc;
77          try {
78              ssc = new SelfSignedCertificate();
79          } catch (CertificateException e) {
80              throw new Error(e);
81          }
82          CERT_FILE = ssc.certificate();
83          KEY_FILE = ssc.privateKey();
84      }
85  
86      public static Collection<Object[]> data() throws Exception {
87          List<SslContext> serverContexts = new ArrayList<SslContext>();
88          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
89  
90          List<SslContext> clientContexts = new ArrayList<SslContext>();
91          clientContexts.add(SslContextBuilder.forClient()
92                  .sslProvider(SslProvider.JDK)
93                  .trustManager(CERT_FILE)
94                  .endpointIdentificationAlgorithm(null)
95                  .build());
96  
97          boolean hasOpenSsl = OpenSsl.isAvailable();
98          if (hasOpenSsl) {
99              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
100                                                 .sslProvider(SslProvider.OPENSSL).build());
101             clientContexts.add(SslContextBuilder.forClient()
102                     .sslProvider(SslProvider.OPENSSL)
103                     .trustManager(CERT_FILE)
104                     .endpointIdentificationAlgorithm(null)
105                     .build());
106         } else {
107             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
108         }
109 
110         List<Object[]> params = new ArrayList<Object[]>();
111         for (SslContext sc: serverContexts) {
112             for (SslContext cc: clientContexts) {
113                 params.add(new Object[] { sc, cc });
114             }
115         }
116         return params;
117     }
118 
119     @BeforeAll
120     public static void createExecutor() {
121         executor = new DefaultEventExecutorGroup(2);
122     }
123 
124     @AfterAll
125     public static void shutdownExecutor() throws Exception {
126         executor.shutdownGracefully().sync();
127     }
128 
129     @ParameterizedTest(name = PARAMETERIZED_NAME)
130     @MethodSource("data")
131     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
132     public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
133             throws Throwable {
134         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
135             @Override
136             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
137                 testStartTls(sb, cb, serverCtx, clientCtx);
138             }
139         });
140     }
141 
142     public void testStartTls(ServerBootstrap sb, Bootstrap cb,
143                              SslContext serverCtx, SslContext clientCtx) throws Throwable {
144         testStartTls(sb, cb, serverCtx, clientCtx, true);
145     }
146 
147     @ParameterizedTest(name = PARAMETERIZED_NAME)
148     @MethodSource("data")
149     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
150     public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
151                                         TestInfo testInfo) throws Throwable {
152         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
153             @Override
154             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
155                 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
156             }
157         });
158     }
159 
160     public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
161                                         SslContext serverCtx, SslContext clientCtx) throws Throwable {
162         testStartTls(sb, cb, serverCtx, clientCtx, false);
163     }
164 
165     private void testStartTls(ServerBootstrap sb, Bootstrap cb,
166                               SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
167         sb.childOption(ChannelOption.AUTO_READ, autoRead);
168         cb.option(ChannelOption.AUTO_READ, autoRead);
169 
170         final EventExecutorGroup executor = SocketStartTlsTest.executor;
171         SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
172         SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
173 
174         final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
175         final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
176 
177         sb.childHandler(new ChannelInitializer<Channel>() {
178             @Override
179             public void initChannel(Channel sch) throws Exception {
180                 ChannelPipeline p = sch.pipeline();
181                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
182                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
183                 p.addLast(executor, sh);
184             }
185         });
186 
187         cb.handler(new ChannelInitializer<Channel>() {
188             @Override
189             public void initChannel(Channel sch) throws Exception {
190                 ChannelPipeline p = sch.pipeline();
191                 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
192                 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
193                 p.addLast(executor, ch);
194             }
195         });
196 
197         Channel sc = sb.bind().sync().channel();
198         Channel cc = cb.connect(sc.localAddress()).sync().channel();
199 
200         while (cc.isActive()) {
201             if (sh.exception.get() != null) {
202                 break;
203             }
204             if (ch.exception.get() != null) {
205                 break;
206             }
207 
208             Thread.sleep(50);
209         }
210 
211         while (sh.channel.isActive()) {
212             if (sh.exception.get() != null) {
213                 break;
214             }
215             if (ch.exception.get() != null) {
216                 break;
217             }
218 
219             Thread.sleep(50);
220         }
221 
222         sh.channel.close().awaitUninterruptibly();
223         cc.close().awaitUninterruptibly();
224         sc.close().awaitUninterruptibly();
225 
226         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
227             throw sh.exception.get();
228         }
229         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
230             throw ch.exception.get();
231         }
232         if (sh.exception.get() != null) {
233             throw sh.exception.get();
234         }
235         if (ch.exception.get() != null) {
236             throw ch.exception.get();
237         }
238     }
239 
240     private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
241         private final SslHandler sslHandler;
242         private final boolean autoRead;
243         private Future<Channel> handshakeFuture;
244         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
245 
246         StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
247             engine.setUseClientMode(true);
248             sslHandler = new SslHandler(engine);
249             this.autoRead = autoRead;
250         }
251 
252         @Override
253         public void channelActive(ChannelHandlerContext ctx)
254                 throws Exception {
255             if (!autoRead) {
256                 ctx.read();
257             }
258             ctx.writeAndFlush("StartTlsRequest\n");
259         }
260 
261         @Override
262         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
263             if ("StartTlsResponse".equals(msg)) {
264                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
265                 handshakeFuture = sslHandler.handshakeFuture();
266                 ctx.writeAndFlush("EncryptedRequest\n");
267                 return;
268             }
269 
270             assertEquals("EncryptedResponse", msg);
271             assertNotNull(handshakeFuture);
272             assertTrue(handshakeFuture.isSuccess());
273             ctx.close();
274         }
275 
276         @Override
277         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
278             if (!autoRead) {
279                 ctx.read();
280             }
281         }
282 
283         @Override
284         public void exceptionCaught(ChannelHandlerContext ctx,
285                 Throwable cause) throws Exception {
286             if (logger.isWarnEnabled()) {
287                 logger.warn("Unexpected exception from the client side", cause);
288             }
289 
290             exception.compareAndSet(null, cause);
291             ctx.close();
292         }
293     }
294 
295     private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
296         private final SslHandler sslHandler;
297         private final boolean autoRead;
298         volatile Channel channel;
299         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
300 
301         StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
302             engine.setUseClientMode(false);
303             sslHandler = new SslHandler(engine, true);
304             this.autoRead = autoRead;
305         }
306 
307         @Override
308         public void channelActive(ChannelHandlerContext ctx) throws Exception {
309             channel = ctx.channel();
310             if (!autoRead) {
311                 ctx.read();
312             }
313         }
314 
315         @Override
316         public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
317             if ("StartTlsRequest".equals(msg)) {
318                 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
319                 ctx.writeAndFlush("StartTlsResponse\n");
320                 return;
321             }
322 
323             assertEquals("EncryptedRequest", msg);
324             ctx.writeAndFlush("EncryptedResponse\n");
325         }
326 
327         @Override
328         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
329             if (!autoRead) {
330                 ctx.read();
331             }
332         }
333 
334         @Override
335         public void exceptionCaught(ChannelHandlerContext ctx,
336                                     Throwable cause) throws Exception {
337             if (logger.isWarnEnabled()) {
338                 logger.warn("Unexpected exception from the server side", cause);
339             }
340 
341             exception.compareAndSet(null, cause);
342             ctx.close();
343         }
344     }
345 }