View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufUtil;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelHandler.Sharable;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.SimpleChannelInboundHandler;
29  import io.netty.channel.socket.SocketChannel;
30  import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslContextBuilder;
33  import io.netty.handler.ssl.SslHandler;
34  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35  import io.netty.handler.ssl.SslProvider;
36  import io.netty.pkitesting.CertificateBuilder;
37  import io.netty.pkitesting.X509Bundle;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  
41  import org.junit.jupiter.api.TestInfo;
42  import org.junit.jupiter.api.Timeout;
43  import org.junit.jupiter.params.ParameterizedTest;
44  import org.junit.jupiter.params.provider.MethodSource;
45  
46  import javax.net.ssl.SSLEngine;
47  import javax.net.ssl.SSLSession;
48  import javax.net.ssl.SSLSessionContext;
49  import javax.net.ssl.TrustManager;
50  import javax.net.ssl.X509ExtendedTrustManager;
51  
52  import java.io.File;
53  import java.io.IOException;
54  import java.net.InetSocketAddress;
55  import java.net.Socket;
56  import java.security.cert.CertificateException;
57  import java.security.cert.X509Certificate;
58  import java.util.Arrays;
59  import java.util.Collection;
60  import java.util.Collections;
61  import java.util.Enumeration;
62  import java.util.HashSet;
63  import java.util.Set;
64  import java.util.concurrent.BlockingQueue;
65  import java.util.concurrent.ExecutionException;
66  import java.util.concurrent.LinkedBlockingQueue;
67  import java.util.concurrent.TimeUnit;
68  import java.util.concurrent.atomic.AtomicReference;
69  
70  import static org.junit.jupiter.api.Assertions.assertEquals;
71  import static org.junit.jupiter.api.Assertions.assertTrue;
72  
73  public class SocketSslSessionReuseTest extends AbstractSocketTest {
74  
75      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
76  
77      private static final File CERT_FILE;
78      private static final File KEY_FILE;
79  
80      static {
81          try {
82              X509Bundle cert = new CertificateBuilder()
83                      .subject("cn=localhost")
84                      .setIsCertificateAuthority(true)
85                      .buildSelfSigned();
86              CERT_FILE = cert.toTempCertChainPem();
87              KEY_FILE = cert.toTempPrivateKeyPem();
88          } catch (Exception e) {
89              throw new ExceptionInInitializerError(e);
90          }
91      }
92  
93      public static Collection<Object[]> jdkOnly() throws Exception {
94          return Collections.singleton(new Object[]{
95                  SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
96                  SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
97                          .endpointIdentificationAlgorithm(null)
98          });
99      }
100 
101     public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
102         return Arrays.asList(new Object[]{
103                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
104                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
105                                 .endpointIdentificationAlgorithm(null)
106                 },
107                 new Object[]{
108                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
109                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
110                                 .endpointIdentificationAlgorithm(null)
111                 });
112     }
113 
114     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
115     @MethodSource("jdkOnly")
116     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
117     public void testSslSessionReuse(
118             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
119         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
120             @Override
121             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
122                 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
123             }
124         });
125     }
126 
127     public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
128                                     final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
129         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
130         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
131         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
132 
133         sb.childHandler(new ChannelInitializer<SocketChannel>() {
134             @Override
135             protected void initChannel(SocketChannel sch) throws Exception {
136                 SSLEngine engine = serverCtx.newEngine(sch.alloc());
137                 engine.setUseClientMode(false);
138                 engine.setEnabledProtocols(protocols);
139 
140                 sch.pipeline().addLast(new SslHandler(engine));
141                 sch.pipeline().addLast(sh);
142             }
143         });
144         final Channel sc = sb.bind().sync().channel();
145 
146         cb.handler(new ChannelInitializer<SocketChannel>() {
147             @Override
148             protected void initChannel(SocketChannel sch) throws Exception {
149                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
150                 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
151                 engine.setUseClientMode(true);
152                 engine.setEnabledProtocols(protocols);
153 
154                 sch.pipeline().addLast(new SslHandler(engine));
155                 sch.pipeline().addLast(ch);
156             }
157         });
158 
159         try {
160             SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
161             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
162             Channel cc = cb.connect(sc.localAddress()).sync().channel();
163             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
164             cc.closeFuture().sync();
165             rethrowHandlerExceptions(sh, ch);
166             Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
167 
168             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
169             cc = cb.connect(sc.localAddress()).sync().channel();
170             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
171             cc.closeFuture().sync();
172             assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
173             rethrowHandlerExceptions(sh, ch);
174         } finally {
175             sc.close().awaitUninterruptibly();
176         }
177     }
178 
179     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
180     @MethodSource("jdkAndOpenSSL")
181     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
182     public void testSslSessionTrustManagerResumption(
183             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
184         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
185             @Override
186             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
187                 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
188             }
189         });
190     }
191 
192     public void testSslSessionTrustManagerResumption(
193             ServerBootstrap sb, Bootstrap cb,
194             SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
195         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
196         serverCtxBldr.protocols(protocols);
197         clientCtxBldr.protocols(protocols);
198         TrustManager clientTrustManager = new SessionSettingTrustManager();
199         clientCtxBldr.trustManager(clientTrustManager);
200         final SslContext serverContext = serverCtxBldr.build();
201         final SslContext clientContext = clientCtxBldr.build();
202 
203         final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
204         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
205         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
206             @Override
207             public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
208                 if (evt instanceof SslHandshakeCompletionEvent) {
209                     SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
210                     if (handshakeCompletionEvent.isSuccess()) {
211                         SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
212                         assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
213                     } else {
214                         logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
215                     }
216                 }
217                 super.userEventTriggered(ctx, evt);
218             }
219         };
220 
221         sb.childHandler(new ChannelInitializer<SocketChannel>() {
222             @Override
223             protected void initChannel(SocketChannel sch) throws Exception {
224                 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
225                 sch.pipeline().addLast(sh);
226             }
227         });
228         final Channel sc = sb.bind().sync().channel();
229 
230         cb.handler(new ChannelInitializer<SocketChannel>() {
231             @Override
232             protected void initChannel(SocketChannel sch) throws Exception {
233                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
234                 SslHandler sslHandler = clientContext.newHandler(
235                         sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
236 
237                 sch.pipeline().addLast(sslHandler);
238                 sch.pipeline().addLast(ch);
239             }
240         });
241 
242         try {
243             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
244             Channel cc = cb.connect(sc.localAddress()).sync().channel();
245             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
246             cc.closeFuture().sync();
247             rethrowHandlerExceptions(sh, ch);
248             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
249 
250             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
251             cc = cb.connect(sc.localAddress()).sync().channel();
252             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
253             cc.closeFuture().sync();
254             rethrowHandlerExceptions(sh, ch);
255             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
256         } finally {
257             sc.close().awaitUninterruptibly();
258         }
259     }
260 
261     private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
262         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
263             throw new ExecutionException(sh.exception.get());
264         }
265         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
266             throw new ExecutionException(ch.exception.get());
267         }
268         if (sh.exception.get() != null) {
269             throw new ExecutionException(sh.exception.get());
270         }
271         if (ch.exception.get() != null) {
272             throw new ExecutionException(ch.exception.get());
273         }
274     }
275 
276     private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
277         Set<String> idSet = new HashSet<String>();
278         byte[] id;
279         while (sessionIds.hasMoreElements()) {
280             id = sessionIds.nextElement();
281             idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
282         }
283         return idSet;
284     }
285 
286     @Sharable
287     private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
288         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
289         private final boolean server;
290         private final boolean autoRead;
291 
292         ReadAndDiscardHandler(boolean server, boolean autoRead) {
293             this.server = server;
294             this.autoRead = autoRead;
295         }
296 
297         @Override
298         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
299             in.skipBytes(in.readableBytes());
300         }
301 
302         @Override
303         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
304             try {
305                 ctx.flush();
306             } finally {
307                 if (!autoRead) {
308                     ctx.read();
309                 }
310             }
311         }
312 
313         @Override
314         public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
315             if (logger.isWarnEnabled()) {
316                 logger.warn(
317                         "Unexpected exception from the " +
318                         (server? "server" : "client") + " side", cause);
319             }
320 
321             exception.compareAndSet(null, cause);
322             ctx.close();
323         }
324     }
325 
326     private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
327             implements ResumableX509ExtendedTrustManager {
328         @Override
329         public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
330             engine.getSession().putValue("key", "value");
331         }
332 
333         @Override
334         public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
335                 throws CertificateException {
336             engine.getHandshakeSession().putValue("key", "value");
337         }
338 
339         @Override
340         public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
341             throw new CertificateException("Unsupported operation");
342         }
343 
344         @Override
345         public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
346                 throws CertificateException {
347             throw new CertificateException("Unsupported operation");
348         }
349 
350         @Override
351         public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
352                 throws CertificateException {
353             throw new CertificateException("Unsupported operation");
354         }
355 
356         @Override
357         public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
358                 throws CertificateException {
359             throw new CertificateException("Unsupported operation");
360         }
361 
362         @Override
363         public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
364             throw new CertificateException("Unsupported operation");
365         }
366 
367         @Override
368         public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
369             throw new CertificateException("Unsupported operation");
370         }
371 
372         @Override
373         public X509Certificate[] getAcceptedIssuers() {
374             return new X509Certificate[0];
375         }
376     }
377 }