View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufUtil;
22  import io.netty.buffer.Unpooled;
23  import io.netty.channel.Channel;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelHandler.Sharable;
26  import io.netty.channel.ChannelInitializer;
27  import io.netty.channel.SimpleChannelInboundHandler;
28  import io.netty.channel.socket.SocketChannel;
29  import io.netty.handler.ssl.JdkSslClientContext;
30  import io.netty.handler.ssl.JdkSslServerContext;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.util.SelfSignedCertificate;
34  import io.netty.util.internal.logging.InternalLogger;
35  import io.netty.util.internal.logging.InternalLoggerFactory;
36  
37  import org.junit.jupiter.api.TestInfo;
38  import org.junit.jupiter.api.Timeout;
39  import org.junit.jupiter.params.ParameterizedTest;
40  import org.junit.jupiter.params.provider.MethodSource;
41  
42  import javax.net.ssl.SSLEngine;
43  import javax.net.ssl.SSLSessionContext;
44  
45  import java.io.File;
46  import java.io.IOException;
47  import java.net.InetSocketAddress;
48  import java.security.cert.CertificateException;
49  import java.util.Collection;
50  import java.util.Collections;
51  import java.util.Enumeration;
52  import java.util.HashSet;
53  import java.util.Set;
54  import java.util.concurrent.TimeUnit;
55  import java.util.concurrent.atomic.AtomicReference;
56  
57  import static org.junit.jupiter.api.Assertions.assertEquals;
58  
59  public class SocketSslSessionReuseTest extends AbstractSocketTest {
60  
61      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
62  
63      private static final File CERT_FILE;
64      private static final File KEY_FILE;
65  
66      static {
67          SelfSignedCertificate ssc;
68          try {
69              ssc = new SelfSignedCertificate();
70          } catch (CertificateException e) {
71              throw new Error(e);
72          }
73          CERT_FILE = ssc.certificate();
74          KEY_FILE = ssc.privateKey();
75      }
76  
77      public static Collection<Object[]> data() throws Exception {
78          return Collections.singletonList(new Object[] {
79              new JdkSslServerContext(CERT_FILE, KEY_FILE),
80              new JdkSslClientContext(CERT_FILE)
81          });
82      }
83  
84      @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
85      @MethodSource("data")
86      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
87      public void testSslSessionReuse(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
88              throws Throwable {
89          run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
90              @Override
91              public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
92                  testSslSessionReuse(sb, cb, serverCtx, clientCtx);
93              }
94          });
95      }
96  
97      public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
98                                      final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
99          final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
100         final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
101         final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
102 
103         sb.childHandler(new ChannelInitializer<SocketChannel>() {
104             @Override
105             protected void initChannel(SocketChannel sch) throws Exception {
106                 SSLEngine engine = serverCtx.newEngine(sch.alloc());
107                 engine.setUseClientMode(false);
108                 engine.setEnabledProtocols(protocols);
109 
110                 sch.pipeline().addLast(new SslHandler(engine));
111                 sch.pipeline().addLast(sh);
112             }
113         });
114         final Channel sc = sb.bind().sync().channel();
115 
116         cb.handler(new ChannelInitializer<SocketChannel>() {
117             @Override
118             protected void initChannel(SocketChannel sch) throws Exception {
119                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
120                 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
121                 engine.setUseClientMode(true);
122                 engine.setEnabledProtocols(protocols);
123 
124                 sch.pipeline().addLast(new SslHandler(engine));
125                 sch.pipeline().addLast(ch);
126             }
127         });
128 
129         try {
130             SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
131             ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
132             Channel cc = cb.connect(sc.localAddress()).sync().channel();
133             cc.writeAndFlush(msg).sync();
134             cc.closeFuture().sync();
135             rethrowHandlerExceptions(sh, ch);
136             Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
137 
138             msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
139             cc = cb.connect(sc.localAddress()).sync().channel();
140             cc.writeAndFlush(msg).sync();
141             cc.closeFuture().sync();
142             assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
143             rethrowHandlerExceptions(sh, ch);
144         } finally {
145             sc.close().awaitUninterruptibly();
146         }
147     }
148 
149     private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
150         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
151             throw sh.exception.get();
152         }
153         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
154             throw ch.exception.get();
155         }
156         if (sh.exception.get() != null) {
157             throw sh.exception.get();
158         }
159         if (ch.exception.get() != null) {
160             throw ch.exception.get();
161         }
162     }
163 
164     private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
165         Set<String> idSet = new HashSet<String>();
166         byte[] id;
167         while (sessionIds.hasMoreElements()) {
168             id = sessionIds.nextElement();
169             idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
170         }
171         return idSet;
172     }
173 
174     @Sharable
175     private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
176         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
177         private final boolean server;
178         private final boolean autoRead;
179 
180         ReadAndDiscardHandler(boolean server, boolean autoRead) {
181             this.server = server;
182             this.autoRead = autoRead;
183         }
184 
185         @Override
186         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
187             byte[] actual = new byte[in.readableBytes()];
188             in.readBytes(actual);
189             ctx.close();
190         }
191 
192         @Override
193         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
194             try {
195                 ctx.flush();
196             } finally {
197                 if (!autoRead) {
198                     ctx.read();
199                 }
200             }
201         }
202 
203         @Override
204         public void exceptionCaught(ChannelHandlerContext ctx,
205                 Throwable cause) throws Exception {
206             if (logger.isWarnEnabled()) {
207                 logger.warn(
208                         "Unexpected exception from the " +
209                         (server? "server" : "client") + " side", cause);
210             }
211 
212             exception.compareAndSet(null, cause);
213             ctx.close();
214         }
215     }
216 }