View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.buffer.BufferUtil;
21  import io.netty5.buffer.api.Buffer;
22  import io.netty5.buffer.api.DefaultBufferAllocators;
23  import io.netty5.channel.Channel;
24  import io.netty5.channel.ChannelHandlerContext;
25  import io.netty5.channel.ChannelInitializer;
26  import io.netty5.channel.SimpleChannelInboundHandler;
27  import io.netty5.channel.socket.SocketChannel;
28  import io.netty5.handler.ssl.SslContext;
29  import io.netty5.handler.ssl.SslContextBuilder;
30  import io.netty5.handler.ssl.SslHandler;
31  import io.netty5.handler.ssl.SslProvider;
32  import io.netty5.handler.ssl.util.SelfSignedCertificate;
33  import io.netty5.util.internal.logging.InternalLogger;
34  import io.netty5.util.internal.logging.InternalLoggerFactory;
35  import org.junit.jupiter.api.TestInfo;
36  import org.junit.jupiter.api.Timeout;
37  import org.junit.jupiter.params.ParameterizedTest;
38  import org.junit.jupiter.params.provider.MethodSource;
39  
40  import javax.net.ssl.SSLEngine;
41  import javax.net.ssl.SSLSessionContext;
42  import java.io.File;
43  import java.io.IOException;
44  import java.net.InetSocketAddress;
45  import java.security.cert.CertificateException;
46  import java.util.Collection;
47  import java.util.Collections;
48  import java.util.Enumeration;
49  import java.util.HashSet;
50  import java.util.Set;
51  import java.util.concurrent.TimeUnit;
52  import java.util.concurrent.atomic.AtomicReference;
53  
54  import static org.junit.jupiter.api.Assertions.assertEquals;
55  
56  public class SocketSslSessionReuseTest extends AbstractSocketTest {
57  
58      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
59  
60      private static final File CERT_FILE;
61      private static final File KEY_FILE;
62  
63      static {
64          SelfSignedCertificate ssc;
65          try {
66              ssc = new SelfSignedCertificate();
67          } catch (CertificateException e) {
68              throw new Error(e);
69          }
70          CERT_FILE = ssc.certificate();
71          KEY_FILE = ssc.privateKey();
72      }
73  
74      public static Collection<Object[]> data() throws Exception {
75          return Collections.singletonList(new Object[] {
76            SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build(),
77            SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK).build()
78          });
79      }
80  
81      @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
82      @MethodSource("data")
83      @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
84      public void testSslSessionReuse(SslContext serverCtx, SslContext clientCtx, TestInfo testInfo) throws Throwable {
85          run(testInfo, (sb, cb) -> testSslSessionReuse(sb, cb, serverCtx, clientCtx));
86      }
87  
88      public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
89                                      SslContext serverCtx, SslContext clientCtx) throws Throwable {
90          final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
91          final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
92          final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
93  
94          sb.childHandler(new ChannelInitializer<SocketChannel>() {
95              @Override
96              protected void initChannel(SocketChannel sch) throws Exception {
97                  SSLEngine engine = serverCtx.newEngine(sch.bufferAllocator());
98                  engine.setUseClientMode(false);
99                  engine.setEnabledProtocols(protocols);
100 
101                 sch.pipeline().addLast(new SslHandler(engine));
102                 sch.pipeline().addLast(sh);
103             }
104         });
105         final Channel sc = sb.bind().asStage().get();
106 
107         cb.handler(new ChannelInitializer<SocketChannel>() {
108             @Override
109             protected void initChannel(SocketChannel sch) throws Exception {
110                 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
111                 SSLEngine engine = clientCtx.newEngine(
112                         sch.bufferAllocator(), serverAddr.getHostString(), serverAddr.getPort());
113                 engine.setUseClientMode(true);
114                 engine.setEnabledProtocols(protocols);
115 
116                 sch.pipeline().addLast(new SslHandler(engine));
117                 sch.pipeline().addLast(ch);
118             }
119         });
120 
121         try {
122             SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
123             Buffer msg = DefaultBufferAllocators.preferredAllocator().copyOf(new byte[] { 0xa, 0xb, 0xc, 0xd });
124             Channel cc = cb.connect(sc.localAddress()).asStage().get();
125             cc.writeAndFlush(msg).asStage().sync();
126             cc.closeFuture().asStage().sync();
127             rethrowHandlerExceptions(sh, ch);
128             Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
129 
130             msg = DefaultBufferAllocators.preferredAllocator().copyOf(new byte[] { 0xa, 0xb, 0xc, 0xd });
131             cc = cb.connect(sc.localAddress()).asStage().get();
132             cc.writeAndFlush(msg).asStage().sync();
133             cc.closeFuture().asStage().sync();
134             assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
135             rethrowHandlerExceptions(sh, ch);
136         } finally {
137             sc.close().asStage().await();
138         }
139     }
140 
141     private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
142         if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
143             throw sh.exception.get();
144         }
145         if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
146             throw ch.exception.get();
147         }
148         if (sh.exception.get() != null) {
149             throw sh.exception.get();
150         }
151         if (ch.exception.get() != null) {
152             throw ch.exception.get();
153         }
154     }
155 
156     private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
157         Set<String> idSet = new HashSet<>();
158         byte[] id;
159         while (sessionIds.hasMoreElements()) {
160             id = sessionIds.nextElement();
161             idSet.add(BufferUtil.hexDump(id));
162         }
163         return idSet;
164     }
165 
166     private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<Buffer> {
167         final AtomicReference<Throwable> exception = new AtomicReference<>();
168         private final boolean server;
169         private final boolean autoRead;
170 
171         ReadAndDiscardHandler(boolean server, boolean autoRead) {
172             this.server = server;
173             this.autoRead = autoRead;
174         }
175 
176         @Override
177         public boolean isSharable() {
178             return true;
179         }
180 
181         @Override
182         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
183             byte[] actual = new byte[in.readableBytes()];
184             in.readBytes(actual, 0, actual.length);
185             ctx.close();
186         }
187 
188         @Override
189         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
190             try {
191                 ctx.flush();
192             } finally {
193                 if (!autoRead) {
194                     ctx.read();
195                 }
196             }
197         }
198 
199         @Override
200         public void channelExceptionCaught(ChannelHandlerContext ctx,
201                                            Throwable cause) throws Exception {
202             if (logger.isWarnEnabled()) {
203                 logger.warn(
204                         "Unexpected exception from the " +
205                         (server? "server" : "client") + " side", cause);
206             }
207 
208             exception.compareAndSet(null, cause);
209             ctx.close();
210         }
211     }
212 }