View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufUtil;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelHandler.Sharable;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.SimpleChannelInboundHandler;
29  import io.netty.channel.socket.SocketChannel;
30  import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslContextBuilder;
33  import io.netty.handler.ssl.SslHandler;
34  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35  import io.netty.handler.ssl.SslProvider;
36  import io.netty.pkitesting.CertificateBuilder;
37  import io.netty.pkitesting.X509Bundle;
38  import io.netty.util.internal.SuppressJava6Requirement;
39  import io.netty.util.internal.logging.InternalLogger;
40  import io.netty.util.internal.logging.InternalLoggerFactory;
41  
42  import org.junit.jupiter.api.TestInfo;
43  import org.junit.jupiter.api.Timeout;
44  import org.junit.jupiter.params.ParameterizedTest;
45  import org.junit.jupiter.params.provider.MethodSource;
46  
47  import javax.net.ssl.SSLEngine;
48  import javax.net.ssl.SSLSession;
49  import javax.net.ssl.SSLSessionContext;
50  import javax.net.ssl.TrustManager;
51  import javax.net.ssl.X509ExtendedTrustManager;
52  
53  import java.io.File;
54  import java.io.IOException;
55  import java.net.InetSocketAddress;
56  import java.net.Socket;
57  import java.security.cert.CertificateException;
58  import java.security.cert.X509Certificate;
59  import java.util.Arrays;
60  import java.util.Collection;
61  import java.util.Collections;
62  import java.util.Enumeration;
63  import java.util.HashSet;
64  import java.util.Set;
65  import java.util.concurrent.BlockingQueue;
66  import java.util.concurrent.ExecutionException;
67  import java.util.concurrent.LinkedBlockingQueue;
68  import java.util.concurrent.TimeUnit;
69  import java.util.concurrent.atomic.AtomicReference;
70  
71  import static org.junit.jupiter.api.Assertions.assertEquals;
72  import static org.junit.jupiter.api.Assertions.assertTrue;
73  
74  public class SocketSslSessionReuseTest extends AbstractSocketTest {
75  
76      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
77  
78      private static final File CERT_FILE;
79      private static final File KEY_FILE;
80  
81      static {
82          try {
83              X509Bundle cert = new CertificateBuilder()
84                      .subject("cn=localhost")
85                      .setIsCertificateAuthority(true)
86                      .buildSelfSigned();
87              CERT_FILE = cert.toTempCertChainPem();
88              KEY_FILE = cert.toTempPrivateKeyPem();
89          } catch (Exception e) {
90              throw new ExceptionInInitializerError(e);
91          }
92      }
93  
94      public static Collection<Object[]> jdkOnly() throws Exception {
95          return Collections.singleton(new Object[]{
96                  SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
97                  SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
98                          .endpointIdentificationAlgorithm(null)
99          });
100     }
101 
102     public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
103         return Arrays.asList(new Object[]{
104                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
105                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
106                                 .endpointIdentificationAlgorithm(null)
107                 },
108                 new Object[]{
109                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
110                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
111                                 .endpointIdentificationAlgorithm(null)
112                 });
113     }
114 
115     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
116     @MethodSource("jdkOnly")
117     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
118     public void testSslSessionReuse(
119             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
120         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
121             @Override
122             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
123                 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
124             }
125         });
126     }
127 
128     public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
129                                     final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
130         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
131         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
132         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
133 
134         sb.childHandler(new ChannelInitializer<SocketChannel>() {
135             @Override
136             protected void initChannel(SocketChannel sch) throws Exception {
137                 SSLEngine engine = serverCtx.newEngine(sch.alloc());
138                 engine.setUseClientMode(false);
139                 engine.setEnabledProtocols(protocols);
140 
141                 sch.pipeline().addLast(new SslHandler(engine));
142                 sch.pipeline().addLast(sh);
143             }
144         });
145         final Channel sc = sb.bind().sync().channel();
146 
147         cb.handler(new ChannelInitializer<SocketChannel>() {
148             @Override
149             protected void initChannel(SocketChannel sch) throws Exception {
150                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
151                 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
152                 engine.setUseClientMode(true);
153                 engine.setEnabledProtocols(protocols);
154 
155                 sch.pipeline().addLast(new SslHandler(engine));
156                 sch.pipeline().addLast(ch);
157             }
158         });
159 
160         try {
161             SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
162             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
163             Channel cc = cb.connect(sc.localAddress()).sync().channel();
164             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
165             cc.closeFuture().sync();
166             rethrowHandlerExceptions(sh, ch);
167             Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
168 
169             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
170             cc = cb.connect(sc.localAddress()).sync().channel();
171             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
172             cc.closeFuture().sync();
173             assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
174             rethrowHandlerExceptions(sh, ch);
175         } finally {
176             sc.close().awaitUninterruptibly();
177         }
178     }
179 
180     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
181     @MethodSource("jdkAndOpenSSL")
182     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
183     public void testSslSessionTrustManagerResumption(
184             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
185         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
186             @Override
187             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
188                 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
189             }
190         });
191     }
192 
193     public void testSslSessionTrustManagerResumption(
194             ServerBootstrap sb, Bootstrap cb,
195             SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
196         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
197         serverCtxBldr.protocols(protocols);
198         clientCtxBldr.protocols(protocols);
199         TrustManager clientTrustManager = new SessionSettingTrustManager();
200         clientCtxBldr.trustManager(clientTrustManager);
201         final SslContext serverContext = serverCtxBldr.build();
202         final SslContext clientContext = clientCtxBldr.build();
203 
204         final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
205         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
206         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
207             @Override
208             public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
209                 if (evt instanceof SslHandshakeCompletionEvent) {
210                     SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
211                     if (handshakeCompletionEvent.isSuccess()) {
212                         SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
213                         assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
214                     } else {
215                         logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
216                     }
217                 }
218                 super.userEventTriggered(ctx, evt);
219             }
220         };
221 
222         sb.childHandler(new ChannelInitializer<SocketChannel>() {
223             @Override
224             protected void initChannel(SocketChannel sch) throws Exception {
225                 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
226                 sch.pipeline().addLast(sh);
227             }
228         });
229         final Channel sc = sb.bind().sync().channel();
230 
231         cb.handler(new ChannelInitializer<SocketChannel>() {
232             @Override
233             protected void initChannel(SocketChannel sch) throws Exception {
234                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
235                 SslHandler sslHandler = clientContext.newHandler(
236                         sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
237 
238                 sch.pipeline().addLast(sslHandler);
239                 sch.pipeline().addLast(ch);
240             }
241         });
242 
243         try {
244             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
245             Channel cc = cb.connect(sc.localAddress()).sync().channel();
246             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
247             cc.closeFuture().sync();
248             rethrowHandlerExceptions(sh, ch);
249             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
250 
251             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
252             cc = cb.connect(sc.localAddress()).sync().channel();
253             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
254             cc.closeFuture().sync();
255             rethrowHandlerExceptions(sh, ch);
256             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
257         } finally {
258             sc.close().awaitUninterruptibly();
259         }
260     }
261 
262     private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
263         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
264             throw new ExecutionException(sh.exception.get());
265         }
266         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
267             throw new ExecutionException(ch.exception.get());
268         }
269         if (sh.exception.get() != null) {
270             throw new ExecutionException(sh.exception.get());
271         }
272         if (ch.exception.get() != null) {
273             throw new ExecutionException(ch.exception.get());
274         }
275     }
276 
277     private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
278         Set<String> idSet = new HashSet<String>();
279         byte[] id;
280         while (sessionIds.hasMoreElements()) {
281             id = sessionIds.nextElement();
282             idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
283         }
284         return idSet;
285     }
286 
287     @Sharable
288     private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
289         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
290         private final boolean server;
291         private final boolean autoRead;
292 
293         ReadAndDiscardHandler(boolean server, boolean autoRead) {
294             this.server = server;
295             this.autoRead = autoRead;
296         }
297 
298         @Override
299         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
300             in.skipBytes(in.readableBytes());
301         }
302 
303         @Override
304         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
305             try {
306                 ctx.flush();
307             } finally {
308                 if (!autoRead) {
309                     ctx.read();
310                 }
311             }
312         }
313 
314         @Override
315         public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
316             if (logger.isWarnEnabled()) {
317                 logger.warn(
318                         "Unexpected exception from the " +
319                         (server? "server" : "client") + " side", cause);
320             }
321 
322             exception.compareAndSet(null, cause);
323             ctx.close();
324         }
325     }
326 
327     @SuppressJava6Requirement(reason = "Test only")
328     private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
329             implements ResumableX509ExtendedTrustManager {
330         @Override
331         public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
332             engine.getSession().putValue("key", "value");
333         }
334 
335         @SuppressJava6Requirement(reason = "Test only")
336         @Override
337         public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
338                 throws CertificateException {
339             engine.getHandshakeSession().putValue("key", "value");
340         }
341 
342         @Override
343         public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
344             throw new CertificateException("Unsupported operation");
345         }
346 
347         @Override
348         public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
349                 throws CertificateException {
350             throw new CertificateException("Unsupported operation");
351         }
352 
353         @Override
354         public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
355                 throws CertificateException {
356             throw new CertificateException("Unsupported operation");
357         }
358 
359         @Override
360         public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
361                 throws CertificateException {
362             throw new CertificateException("Unsupported operation");
363         }
364 
365         @Override
366         public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
367             throw new CertificateException("Unsupported operation");
368         }
369 
370         @Override
371         public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
372             throw new CertificateException("Unsupported operation");
373         }
374 
375         @Override
376         public X509Certificate[] getAcceptedIssuers() {
377             return new X509Certificate[0];
378         }
379     }
380 }