View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufUtil;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelHandler.Sharable;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.SimpleChannelInboundHandler;
29  import io.netty.channel.socket.SocketChannel;
30  import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslContextBuilder;
33  import io.netty.handler.ssl.SslHandler;
34  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35  import io.netty.handler.ssl.SslProvider;
36  import io.netty.handler.ssl.util.SelfSignedCertificate;
37  import io.netty.util.internal.SuppressJava6Requirement;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  
41  import org.junit.jupiter.api.TestInfo;
42  import org.junit.jupiter.api.Timeout;
43  import org.junit.jupiter.params.ParameterizedTest;
44  import org.junit.jupiter.params.provider.MethodSource;
45  
46  import javax.net.ssl.SSLEngine;
47  import javax.net.ssl.SSLSession;
48  import javax.net.ssl.SSLSessionContext;
49  import javax.net.ssl.TrustManager;
50  import javax.net.ssl.X509ExtendedTrustManager;
51  
52  import java.io.File;
53  import java.io.IOException;
54  import java.net.InetSocketAddress;
55  import java.net.Socket;
56  import java.nio.channels.ClosedChannelException;
57  import java.security.cert.CertificateException;
58  import java.security.cert.X509Certificate;
59  import java.util.Arrays;
60  import java.util.Collection;
61  import java.util.Collections;
62  import java.util.Enumeration;
63  import java.util.HashSet;
64  import java.util.Set;
65  import java.util.concurrent.BlockingQueue;
66  import java.util.concurrent.ExecutionException;
67  import java.util.concurrent.LinkedBlockingQueue;
68  import java.util.concurrent.TimeUnit;
69  import java.util.concurrent.atomic.AtomicReference;
70  
71  import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
72  import static org.junit.jupiter.api.Assertions.assertEquals;
73  import static org.junit.jupiter.api.Assertions.assertTrue;
74  
75  public class SocketSslSessionReuseTest extends AbstractSocketTest {
76  
77      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
78  
79      private static final File CERT_FILE;
80      private static final File KEY_FILE;
81  
82      static {
83          SelfSignedCertificate ssc;
84          try {
85              ssc = new SelfSignedCertificate();
86          } catch (CertificateException e) {
87              throw new Error(e);
88          }
89          CERT_FILE = ssc.certificate();
90          KEY_FILE = ssc.privateKey();
91      }
92  
93      public static Collection<Object[]> jdkOnly() throws Exception {
94          return Collections.singleton(new Object[]{
95                  SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
96                  SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
97          });
98      }
99  
100     public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
101         return Arrays.asList(new Object[]{
102                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
103                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
104                 },
105                 new Object[]{
106                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
107                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
108                 });
109     }
110 
111     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
112     @MethodSource("jdkOnly")
113     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
114     public void testSslSessionReuse(
115             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
116         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
117             @Override
118             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
119                 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
120             }
121         });
122     }
123 
124     public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
125                                     final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
126         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
127         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
128         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
129 
130         sb.childHandler(new ChannelInitializer<SocketChannel>() {
131             @Override
132             protected void initChannel(SocketChannel sch) throws Exception {
133                 SSLEngine engine = serverCtx.newEngine(sch.alloc());
134                 engine.setUseClientMode(false);
135                 engine.setEnabledProtocols(protocols);
136 
137                 sch.pipeline().addLast(new SslHandler(engine));
138                 sch.pipeline().addLast(sh);
139             }
140         });
141         final Channel sc = sb.bind().sync().channel();
142 
143         cb.handler(new ChannelInitializer<SocketChannel>() {
144             @Override
145             protected void initChannel(SocketChannel sch) throws Exception {
146                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
147                 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
148                 engine.setUseClientMode(true);
149                 engine.setEnabledProtocols(protocols);
150 
151                 sch.pipeline().addLast(new SslHandler(engine));
152                 sch.pipeline().addLast(ch);
153             }
154         });
155 
156         try {
157             SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
158             Channel cc = cb.connect(sc.localAddress()).sync().channel();
159             ByteBuf msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
160 
161             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
162             cc.closeFuture().sync();
163             rethrowHandlerExceptions(sh, ch);
164             Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
165 
166             cc = cb.connect(sc.localAddress()).sync().channel();
167             msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
168             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
169             cc.closeFuture().sync();
170             assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
171             rethrowHandlerExceptions(sh, ch);
172         } finally {
173             sc.close().awaitUninterruptibly();
174         }
175     }
176 
177     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
178     @MethodSource("jdkAndOpenSSL")
179     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
180     public void testSslSessionTrustManagerResumption(
181             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
182         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
183             @Override
184             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
185                 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
186             }
187         });
188     }
189 
190     public void testSslSessionTrustManagerResumption(
191             ServerBootstrap sb, Bootstrap cb,
192             SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
193         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
194         serverCtxBldr.protocols(protocols);
195         clientCtxBldr.protocols(protocols);
196         TrustManager clientTrustManager = new SessionSettingTrustManager();
197         clientCtxBldr.trustManager(clientTrustManager);
198         final SslContext serverContext = serverCtxBldr.build();
199         final SslContext clientContext = clientCtxBldr.build();
200 
201         final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
202         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
203         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
204             @Override
205             public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
206                 if (evt instanceof SslHandshakeCompletionEvent) {
207                     SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
208                     if (handshakeCompletionEvent.isSuccess()) {
209                         SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
210                         assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
211                     } else {
212                         logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
213                     }
214                 }
215                 super.userEventTriggered(ctx, evt);
216             }
217         };
218 
219         sb.childHandler(new ChannelInitializer<SocketChannel>() {
220             @Override
221             protected void initChannel(SocketChannel sch) throws Exception {
222                 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
223                 sch.pipeline().addLast(sh);
224             }
225         });
226         final Channel sc = sb.bind().sync().channel();
227 
228         cb.handler(new ChannelInitializer<SocketChannel>() {
229             @Override
230             protected void initChannel(SocketChannel sch) throws Exception {
231                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
232                 SslHandler sslHandler = clientContext.newHandler(
233                         sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
234 
235                 sch.pipeline().addLast(sslHandler);
236                 sch.pipeline().addLast(ch);
237             }
238         });
239 
240         try {
241             Channel cc = cb.connect(sc.localAddress()).sync().channel();
242             ByteBuf msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
243             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
244             cc.closeFuture().sync();
245             rethrowHandlerExceptions(sh, ch);
246             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
247 
248             cc = cb.connect(sc.localAddress()).sync().channel();
249             msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
250             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
251             cc.closeFuture().sync();
252             rethrowHandlerExceptions(sh, ch);
253             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
254         } finally {
255             sc.close().awaitUninterruptibly();
256         }
257     }
258 
259     private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
260         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
261             throw new ExecutionException(sh.exception.get());
262         }
263         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
264             throw new ExecutionException(ch.exception.get());
265         }
266         if (sh.exception.get() != null) {
267             throw new ExecutionException(sh.exception.get());
268         }
269         if (ch.exception.get() != null) {
270             throw new ExecutionException(ch.exception.get());
271         }
272     }
273 
274     private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
275         Set<String> idSet = new HashSet<String>();
276         byte[] id;
277         while (sessionIds.hasMoreElements()) {
278             id = sessionIds.nextElement();
279             idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
280         }
281         return idSet;
282     }
283 
284     @Sharable
285     private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
286         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
287         private final boolean server;
288         private final boolean autoRead;
289 
290         ReadAndDiscardHandler(boolean server, boolean autoRead) {
291             this.server = server;
292             this.autoRead = autoRead;
293         }
294 
295         @Override
296         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
297             in.skipBytes(in.readableBytes());
298         }
299 
300         @Override
301         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
302             try {
303                 ctx.flush();
304             } finally {
305                 if (!autoRead) {
306                     ctx.read();
307                 }
308             }
309         }
310 
311         @Override
312         public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
313             if (logger.isWarnEnabled()) {
314                 logger.warn(
315                         "Unexpected exception from the " +
316                         (server? "server" : "client") + " side", cause);
317             }
318 
319             exception.compareAndSet(null, cause);
320             ctx.close();
321         }
322     }
323 
324     @SuppressJava6Requirement(reason = "Test only")
325     private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
326             implements ResumableX509ExtendedTrustManager {
327         @Override
328         public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
329             engine.getSession().putValue("key", "value");
330         }
331 
332         @SuppressJava6Requirement(reason = "Test only")
333         @Override
334         public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
335                 throws CertificateException {
336             engine.getHandshakeSession().putValue("key", "value");
337         }
338 
339         @Override
340         public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
341             throw new CertificateException("Unsupported operation");
342         }
343 
344         @Override
345         public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
346                 throws CertificateException {
347             throw new CertificateException("Unsupported operation");
348         }
349 
350         @Override
351         public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
352                 throws CertificateException {
353             throw new CertificateException("Unsupported operation");
354         }
355 
356         @Override
357         public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
358                 throws CertificateException {
359             throw new CertificateException("Unsupported operation");
360         }
361 
362         @Override
363         public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
364             throw new CertificateException("Unsupported operation");
365         }
366 
367         @Override
368         public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
369             throw new CertificateException("Unsupported operation");
370         }
371 
372         @Override
373         public X509Certificate[] getAcceptedIssuers() {
374             return new X509Certificate[0];
375         }
376     }
377 }