View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufUtil;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelHandler.Sharable;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.SimpleChannelInboundHandler;
29  import io.netty.channel.socket.SocketChannel;
30  import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslContextBuilder;
33  import io.netty.handler.ssl.SslHandler;
34  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35  import io.netty.handler.ssl.SslProvider;
36  import io.netty.handler.ssl.util.SelfSignedCertificate;
37  import io.netty.util.internal.SuppressJava6Requirement;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  
41  import org.junit.jupiter.api.TestInfo;
42  import org.junit.jupiter.api.Timeout;
43  import org.junit.jupiter.params.ParameterizedTest;
44  import org.junit.jupiter.params.provider.MethodSource;
45  
46  import javax.net.ssl.SSLEngine;
47  import javax.net.ssl.SSLSession;
48  import javax.net.ssl.SSLSessionContext;
49  import javax.net.ssl.TrustManager;
50  import javax.net.ssl.X509ExtendedTrustManager;
51  
52  import java.io.File;
53  import java.io.IOException;
54  import java.net.InetSocketAddress;
55  import java.net.Socket;
56  import java.nio.channels.ClosedChannelException;
57  import java.security.cert.CertificateException;
58  import java.security.cert.X509Certificate;
59  import java.util.Arrays;
60  import java.util.Collection;
61  import java.util.Collections;
62  import java.util.Enumeration;
63  import java.util.HashSet;
64  import java.util.Set;
65  import java.util.concurrent.BlockingQueue;
66  import java.util.concurrent.ExecutionException;
67  import java.util.concurrent.LinkedBlockingQueue;
68  import java.util.concurrent.TimeUnit;
69  import java.util.concurrent.atomic.AtomicReference;
70  
71  import static org.junit.jupiter.api.Assertions.assertEquals;
72  import static org.junit.jupiter.api.Assertions.assertTrue;
73  
74  public class SocketSslSessionReuseTest extends AbstractSocketTest {
75  
76      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
77  
78      private static final File CERT_FILE;
79      private static final File KEY_FILE;
80  
81      static {
82          SelfSignedCertificate ssc;
83          try {
84              ssc = new SelfSignedCertificate();
85          } catch (CertificateException e) {
86              throw new Error(e);
87          }
88          CERT_FILE = ssc.certificate();
89          KEY_FILE = ssc.privateKey();
90      }
91  
92      public static Collection<Object[]> jdkOnly() throws Exception {
93          return Collections.singleton(new Object[]{
94                  SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
95                  SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
96                          .endpointIdentificationAlgorithm(null)
97          });
98      }
99  
100     public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
101         return Arrays.asList(new Object[]{
102                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
103                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
104                                 .endpointIdentificationAlgorithm(null)
105                 },
106                 new Object[]{
107                         SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
108                         SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
109                                 .endpointIdentificationAlgorithm(null)
110                 });
111     }
112 
113     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
114     @MethodSource("jdkOnly")
115     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
116     public void testSslSessionReuse(
117             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
118         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
119             @Override
120             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
121                 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
122             }
123         });
124     }
125 
126     public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
127                                     final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
128         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
129         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
130         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
131 
132         sb.childHandler(new ChannelInitializer<SocketChannel>() {
133             @Override
134             protected void initChannel(SocketChannel sch) throws Exception {
135                 SSLEngine engine = serverCtx.newEngine(sch.alloc());
136                 engine.setUseClientMode(false);
137                 engine.setEnabledProtocols(protocols);
138 
139                 sch.pipeline().addLast(new SslHandler(engine));
140                 sch.pipeline().addLast(sh);
141             }
142         });
143         final Channel sc = sb.bind().sync().channel();
144 
145         cb.handler(new ChannelInitializer<SocketChannel>() {
146             @Override
147             protected void initChannel(SocketChannel sch) throws Exception {
148                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
149                 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
150                 engine.setUseClientMode(true);
151                 engine.setEnabledProtocols(protocols);
152 
153                 sch.pipeline().addLast(new SslHandler(engine));
154                 sch.pipeline().addLast(ch);
155             }
156         });
157 
158         try {
159             SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
160             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
161             Channel cc = cb.connect(sc.localAddress()).sync().channel();
162             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
163             cc.closeFuture().sync();
164             rethrowHandlerExceptions(sh, ch);
165             Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
166 
167             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
168             cc = cb.connect(sc.localAddress()).sync().channel();
169             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
170             cc.closeFuture().sync();
171             assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
172             rethrowHandlerExceptions(sh, ch);
173         } finally {
174             sc.close().awaitUninterruptibly();
175         }
176     }
177 
178     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
179     @MethodSource("jdkAndOpenSSL")
180     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
181     public void testSslSessionTrustManagerResumption(
182             final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
183         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
184             @Override
185             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
186                 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
187             }
188         });
189     }
190 
191     public void testSslSessionTrustManagerResumption(
192             ServerBootstrap sb, Bootstrap cb,
193             SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
194         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
195         serverCtxBldr.protocols(protocols);
196         clientCtxBldr.protocols(protocols);
197         TrustManager clientTrustManager = new SessionSettingTrustManager();
198         clientCtxBldr.trustManager(clientTrustManager);
199         final SslContext serverContext = serverCtxBldr.build();
200         final SslContext clientContext = clientCtxBldr.build();
201 
202         final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
203         final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
204         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
205             @Override
206             public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
207                 if (evt instanceof SslHandshakeCompletionEvent) {
208                     SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
209                     if (handshakeCompletionEvent.isSuccess()) {
210                         SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
211                         assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
212                     } else {
213                         logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
214                     }
215                 }
216                 super.userEventTriggered(ctx, evt);
217             }
218         };
219 
220         sb.childHandler(new ChannelInitializer<SocketChannel>() {
221             @Override
222             protected void initChannel(SocketChannel sch) throws Exception {
223                 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
224                 sch.pipeline().addLast(sh);
225             }
226         });
227         final Channel sc = sb.bind().sync().channel();
228 
229         cb.handler(new ChannelInitializer<SocketChannel>() {
230             @Override
231             protected void initChannel(SocketChannel sch) throws Exception {
232                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
233                 SslHandler sslHandler = clientContext.newHandler(
234                         sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
235 
236                 sch.pipeline().addLast(sslHandler);
237                 sch.pipeline().addLast(ch);
238             }
239         });
240 
241         try {
242             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
243             Channel cc = cb.connect(sc.localAddress()).sync().channel();
244             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
245             cc.closeFuture().sync();
246             rethrowHandlerExceptions(sh, ch);
247             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
248 
249             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
250             cc = cb.connect(sc.localAddress()).sync().channel();
251             cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
252             cc.closeFuture().sync();
253             rethrowHandlerExceptions(sh, ch);
254             assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
255         } finally {
256             sc.close().awaitUninterruptibly();
257         }
258     }
259 
260     private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
261         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
262             throw new ExecutionException(sh.exception.get());
263         }
264         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
265             throw new ExecutionException(ch.exception.get());
266         }
267         if (sh.exception.get() != null) {
268             throw new ExecutionException(sh.exception.get());
269         }
270         if (ch.exception.get() != null) {
271             throw new ExecutionException(ch.exception.get());
272         }
273     }
274 
275     private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
276         Set<String> idSet = new HashSet<String>();
277         byte[] id;
278         while (sessionIds.hasMoreElements()) {
279             id = sessionIds.nextElement();
280             idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
281         }
282         return idSet;
283     }
284 
285     @Sharable
286     private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
287         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
288         private final boolean server;
289         private final boolean autoRead;
290 
291         ReadAndDiscardHandler(boolean server, boolean autoRead) {
292             this.server = server;
293             this.autoRead = autoRead;
294         }
295 
296         @Override
297         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
298             in.skipBytes(in.readableBytes());
299         }
300 
301         @Override
302         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
303             try {
304                 ctx.flush();
305             } finally {
306                 if (!autoRead) {
307                     ctx.read();
308                 }
309             }
310         }
311 
312         @Override
313         public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
314             if (logger.isWarnEnabled()) {
315                 logger.warn(
316                         "Unexpected exception from the " +
317                         (server? "server" : "client") + " side", cause);
318             }
319 
320             exception.compareAndSet(null, cause);
321             ctx.close();
322         }
323     }
324 
325     @SuppressJava6Requirement(reason = "Test only")
326     private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
327             implements ResumableX509ExtendedTrustManager {
328         @Override
329         public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
330             engine.getSession().putValue("key", "value");
331         }
332 
333         @SuppressJava6Requirement(reason = "Test only")
334         @Override
335         public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
336                 throws CertificateException {
337             engine.getHandshakeSession().putValue("key", "value");
338         }
339 
340         @Override
341         public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
342             throw new CertificateException("Unsupported operation");
343         }
344 
345         @Override
346         public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
347                 throws CertificateException {
348             throw new CertificateException("Unsupported operation");
349         }
350 
351         @Override
352         public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
353                 throws CertificateException {
354             throw new CertificateException("Unsupported operation");
355         }
356 
357         @Override
358         public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
359                 throws CertificateException {
360             throw new CertificateException("Unsupported operation");
361         }
362 
363         @Override
364         public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
365             throw new CertificateException("Unsupported operation");
366         }
367 
368         @Override
369         public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
370             throw new CertificateException("Unsupported operation");
371         }
372 
373         @Override
374         public X509Certificate[] getAcceptedIssuers() {
375             return new X509Certificate[0];
376         }
377     }
378 }