1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelHandler.Sharable;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.SimpleChannelInboundHandler;
29 import io.netty.channel.socket.SocketChannel;
30 import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslContextBuilder;
33 import io.netty.handler.ssl.SslHandler;
34 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35 import io.netty.handler.ssl.SslProvider;
36 import io.netty.handler.ssl.util.SelfSignedCertificate;
37 import io.netty.util.internal.SuppressJava6Requirement;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40
41 import org.junit.jupiter.api.TestInfo;
42 import org.junit.jupiter.api.Timeout;
43 import org.junit.jupiter.params.ParameterizedTest;
44 import org.junit.jupiter.params.provider.MethodSource;
45
46 import javax.net.ssl.SSLEngine;
47 import javax.net.ssl.SSLSession;
48 import javax.net.ssl.SSLSessionContext;
49 import javax.net.ssl.TrustManager;
50 import javax.net.ssl.X509ExtendedTrustManager;
51
52 import java.io.File;
53 import java.io.IOException;
54 import java.net.InetSocketAddress;
55 import java.net.Socket;
56 import java.nio.channels.ClosedChannelException;
57 import java.security.cert.CertificateException;
58 import java.security.cert.X509Certificate;
59 import java.util.Arrays;
60 import java.util.Collection;
61 import java.util.Collections;
62 import java.util.Enumeration;
63 import java.util.HashSet;
64 import java.util.Set;
65 import java.util.concurrent.BlockingQueue;
66 import java.util.concurrent.ExecutionException;
67 import java.util.concurrent.LinkedBlockingQueue;
68 import java.util.concurrent.TimeUnit;
69 import java.util.concurrent.atomic.AtomicReference;
70
71 import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
72 import static org.junit.jupiter.api.Assertions.assertEquals;
73 import static org.junit.jupiter.api.Assertions.assertTrue;
74
75 public class SocketSslSessionReuseTest extends AbstractSocketTest {
76
77 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
78
79 private static final File CERT_FILE;
80 private static final File KEY_FILE;
81
82 static {
83 SelfSignedCertificate ssc;
84 try {
85 ssc = new SelfSignedCertificate();
86 } catch (CertificateException e) {
87 throw new Error(e);
88 }
89 CERT_FILE = ssc.certificate();
90 KEY_FILE = ssc.privateKey();
91 }
92
93 public static Collection<Object[]> jdkOnly() throws Exception {
94 return Collections.singleton(new Object[]{
95 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
96 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
97 });
98 }
99
100 public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
101 return Arrays.asList(new Object[]{
102 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
103 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
104 },
105 new Object[]{
106 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
107 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
108 });
109 }
110
111 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
112 @MethodSource("jdkOnly")
113 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
114 public void testSslSessionReuse(
115 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
116 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
117 @Override
118 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
119 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
120 }
121 });
122 }
123
124 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
125 final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
126 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
127 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
128 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
129
130 sb.childHandler(new ChannelInitializer<SocketChannel>() {
131 @Override
132 protected void initChannel(SocketChannel sch) throws Exception {
133 SSLEngine engine = serverCtx.newEngine(sch.alloc());
134 engine.setUseClientMode(false);
135 engine.setEnabledProtocols(protocols);
136
137 sch.pipeline().addLast(new SslHandler(engine));
138 sch.pipeline().addLast(sh);
139 }
140 });
141 final Channel sc = sb.bind().sync().channel();
142
143 cb.handler(new ChannelInitializer<SocketChannel>() {
144 @Override
145 protected void initChannel(SocketChannel sch) throws Exception {
146 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
147 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
148 engine.setUseClientMode(true);
149 engine.setEnabledProtocols(protocols);
150
151 sch.pipeline().addLast(new SslHandler(engine));
152 sch.pipeline().addLast(ch);
153 }
154 });
155
156 try {
157 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
158 Channel cc = cb.connect(sc.localAddress()).sync().channel();
159 ByteBuf msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
160
161 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
162 cc.closeFuture().sync();
163 rethrowHandlerExceptions(sh, ch);
164 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
165
166 cc = cb.connect(sc.localAddress()).sync().channel();
167 msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
168 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
169 cc.closeFuture().sync();
170 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
171 rethrowHandlerExceptions(sh, ch);
172 } finally {
173 sc.close().awaitUninterruptibly();
174 }
175 }
176
177 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
178 @MethodSource("jdkAndOpenSSL")
179 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
180 public void testSslSessionTrustManagerResumption(
181 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
182 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
183 @Override
184 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
185 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
186 }
187 });
188 }
189
190 public void testSslSessionTrustManagerResumption(
191 ServerBootstrap sb, Bootstrap cb,
192 SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
193 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
194 serverCtxBldr.protocols(protocols);
195 clientCtxBldr.protocols(protocols);
196 TrustManager clientTrustManager = new SessionSettingTrustManager();
197 clientCtxBldr.trustManager(clientTrustManager);
198 final SslContext serverContext = serverCtxBldr.build();
199 final SslContext clientContext = clientCtxBldr.build();
200
201 final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
202 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
203 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
204 @Override
205 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
206 if (evt instanceof SslHandshakeCompletionEvent) {
207 SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
208 if (handshakeCompletionEvent.isSuccess()) {
209 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
210 assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
211 } else {
212 logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
213 }
214 }
215 super.userEventTriggered(ctx, evt);
216 }
217 };
218
219 sb.childHandler(new ChannelInitializer<SocketChannel>() {
220 @Override
221 protected void initChannel(SocketChannel sch) throws Exception {
222 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
223 sch.pipeline().addLast(sh);
224 }
225 });
226 final Channel sc = sb.bind().sync().channel();
227
228 cb.handler(new ChannelInitializer<SocketChannel>() {
229 @Override
230 protected void initChannel(SocketChannel sch) throws Exception {
231 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
232 SslHandler sslHandler = clientContext.newHandler(
233 sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
234
235 sch.pipeline().addLast(sslHandler);
236 sch.pipeline().addLast(ch);
237 }
238 });
239
240 try {
241 Channel cc = cb.connect(sc.localAddress()).sync().channel();
242 ByteBuf msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
243 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
244 cc.closeFuture().sync();
245 rethrowHandlerExceptions(sh, ch);
246 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
247
248 cc = cb.connect(sc.localAddress()).sync().channel();
249 msg = randomBufferType(cc.alloc(), new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
250 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
251 cc.closeFuture().sync();
252 rethrowHandlerExceptions(sh, ch);
253 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
254 } finally {
255 sc.close().awaitUninterruptibly();
256 }
257 }
258
259 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
260 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
261 throw new ExecutionException(sh.exception.get());
262 }
263 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
264 throw new ExecutionException(ch.exception.get());
265 }
266 if (sh.exception.get() != null) {
267 throw new ExecutionException(sh.exception.get());
268 }
269 if (ch.exception.get() != null) {
270 throw new ExecutionException(ch.exception.get());
271 }
272 }
273
274 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
275 Set<String> idSet = new HashSet<String>();
276 byte[] id;
277 while (sessionIds.hasMoreElements()) {
278 id = sessionIds.nextElement();
279 idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
280 }
281 return idSet;
282 }
283
284 @Sharable
285 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
286 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
287 private final boolean server;
288 private final boolean autoRead;
289
290 ReadAndDiscardHandler(boolean server, boolean autoRead) {
291 this.server = server;
292 this.autoRead = autoRead;
293 }
294
295 @Override
296 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
297 in.skipBytes(in.readableBytes());
298 }
299
300 @Override
301 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
302 try {
303 ctx.flush();
304 } finally {
305 if (!autoRead) {
306 ctx.read();
307 }
308 }
309 }
310
311 @Override
312 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
313 if (logger.isWarnEnabled()) {
314 logger.warn(
315 "Unexpected exception from the " +
316 (server? "server" : "client") + " side", cause);
317 }
318
319 exception.compareAndSet(null, cause);
320 ctx.close();
321 }
322 }
323
324 @SuppressJava6Requirement(reason = "Test only")
325 private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
326 implements ResumableX509ExtendedTrustManager {
327 @Override
328 public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
329 engine.getSession().putValue("key", "value");
330 }
331
332 @SuppressJava6Requirement(reason = "Test only")
333 @Override
334 public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
335 throws CertificateException {
336 engine.getHandshakeSession().putValue("key", "value");
337 }
338
339 @Override
340 public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
341 throw new CertificateException("Unsupported operation");
342 }
343
344 @Override
345 public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
346 throws CertificateException {
347 throw new CertificateException("Unsupported operation");
348 }
349
350 @Override
351 public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
352 throws CertificateException {
353 throw new CertificateException("Unsupported operation");
354 }
355
356 @Override
357 public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
358 throws CertificateException {
359 throw new CertificateException("Unsupported operation");
360 }
361
362 @Override
363 public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
364 throw new CertificateException("Unsupported operation");
365 }
366
367 @Override
368 public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
369 throw new CertificateException("Unsupported operation");
370 }
371
372 @Override
373 public X509Certificate[] getAcceptedIssuers() {
374 return new X509Certificate[0];
375 }
376 }
377 }