1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelHandler.Sharable;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.SimpleChannelInboundHandler;
29 import io.netty.channel.socket.SocketChannel;
30 import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslContextBuilder;
33 import io.netty.handler.ssl.SslHandler;
34 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35 import io.netty.handler.ssl.SslProvider;
36 import io.netty.pkitesting.CertificateBuilder;
37 import io.netty.pkitesting.X509Bundle;
38 import io.netty.util.internal.SuppressJava6Requirement;
39 import io.netty.util.internal.logging.InternalLogger;
40 import io.netty.util.internal.logging.InternalLoggerFactory;
41
42 import org.junit.jupiter.api.TestInfo;
43 import org.junit.jupiter.api.Timeout;
44 import org.junit.jupiter.params.ParameterizedTest;
45 import org.junit.jupiter.params.provider.MethodSource;
46
47 import javax.net.ssl.SSLEngine;
48 import javax.net.ssl.SSLSession;
49 import javax.net.ssl.SSLSessionContext;
50 import javax.net.ssl.TrustManager;
51 import javax.net.ssl.X509ExtendedTrustManager;
52
53 import java.io.File;
54 import java.io.IOException;
55 import java.net.InetSocketAddress;
56 import java.net.Socket;
57 import java.security.cert.CertificateException;
58 import java.security.cert.X509Certificate;
59 import java.util.Arrays;
60 import java.util.Collection;
61 import java.util.Collections;
62 import java.util.Enumeration;
63 import java.util.HashSet;
64 import java.util.Set;
65 import java.util.concurrent.BlockingQueue;
66 import java.util.concurrent.ExecutionException;
67 import java.util.concurrent.LinkedBlockingQueue;
68 import java.util.concurrent.TimeUnit;
69 import java.util.concurrent.atomic.AtomicReference;
70
71 import static org.junit.jupiter.api.Assertions.assertEquals;
72 import static org.junit.jupiter.api.Assertions.assertTrue;
73
74 public class SocketSslSessionReuseTest extends AbstractSocketTest {
75
76 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
77
78 private static final File CERT_FILE;
79 private static final File KEY_FILE;
80
81 static {
82 try {
83 X509Bundle cert = new CertificateBuilder()
84 .subject("cn=localhost")
85 .setIsCertificateAuthority(true)
86 .buildSelfSigned();
87 CERT_FILE = cert.toTempCertChainPem();
88 KEY_FILE = cert.toTempPrivateKeyPem();
89 } catch (Exception e) {
90 throw new ExceptionInInitializerError(e);
91 }
92 }
93
94 public static Collection<Object[]> jdkOnly() throws Exception {
95 return Collections.singleton(new Object[]{
96 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
97 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
98 .endpointIdentificationAlgorithm(null)
99 });
100 }
101
102 public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
103 return Arrays.asList(new Object[]{
104 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
105 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
106 .endpointIdentificationAlgorithm(null)
107 },
108 new Object[]{
109 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
110 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
111 .endpointIdentificationAlgorithm(null)
112 });
113 }
114
115 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
116 @MethodSource("jdkOnly")
117 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
118 public void testSslSessionReuse(
119 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
120 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
121 @Override
122 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
123 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
124 }
125 });
126 }
127
128 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
129 final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
130 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
131 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
132 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
133
134 sb.childHandler(new ChannelInitializer<SocketChannel>() {
135 @Override
136 protected void initChannel(SocketChannel sch) throws Exception {
137 SSLEngine engine = serverCtx.newEngine(sch.alloc());
138 engine.setUseClientMode(false);
139 engine.setEnabledProtocols(protocols);
140
141 sch.pipeline().addLast(new SslHandler(engine));
142 sch.pipeline().addLast(sh);
143 }
144 });
145 final Channel sc = sb.bind().sync().channel();
146
147 cb.handler(new ChannelInitializer<SocketChannel>() {
148 @Override
149 protected void initChannel(SocketChannel sch) throws Exception {
150 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
151 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
152 engine.setUseClientMode(true);
153 engine.setEnabledProtocols(protocols);
154
155 sch.pipeline().addLast(new SslHandler(engine));
156 sch.pipeline().addLast(ch);
157 }
158 });
159
160 try {
161 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
162 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
163 Channel cc = cb.connect(sc.localAddress()).sync().channel();
164 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
165 cc.closeFuture().sync();
166 rethrowHandlerExceptions(sh, ch);
167 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
168
169 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
170 cc = cb.connect(sc.localAddress()).sync().channel();
171 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
172 cc.closeFuture().sync();
173 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
174 rethrowHandlerExceptions(sh, ch);
175 } finally {
176 sc.close().awaitUninterruptibly();
177 }
178 }
179
180 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
181 @MethodSource("jdkAndOpenSSL")
182 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
183 public void testSslSessionTrustManagerResumption(
184 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
185 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
186 @Override
187 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
188 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
189 }
190 });
191 }
192
193 public void testSslSessionTrustManagerResumption(
194 ServerBootstrap sb, Bootstrap cb,
195 SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
196 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
197 serverCtxBldr.protocols(protocols);
198 clientCtxBldr.protocols(protocols);
199 TrustManager clientTrustManager = new SessionSettingTrustManager();
200 clientCtxBldr.trustManager(clientTrustManager);
201 final SslContext serverContext = serverCtxBldr.build();
202 final SslContext clientContext = clientCtxBldr.build();
203
204 final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
205 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
206 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
207 @Override
208 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
209 if (evt instanceof SslHandshakeCompletionEvent) {
210 SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
211 if (handshakeCompletionEvent.isSuccess()) {
212 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
213 assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
214 } else {
215 logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
216 }
217 }
218 super.userEventTriggered(ctx, evt);
219 }
220 };
221
222 sb.childHandler(new ChannelInitializer<SocketChannel>() {
223 @Override
224 protected void initChannel(SocketChannel sch) throws Exception {
225 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
226 sch.pipeline().addLast(sh);
227 }
228 });
229 final Channel sc = sb.bind().sync().channel();
230
231 cb.handler(new ChannelInitializer<SocketChannel>() {
232 @Override
233 protected void initChannel(SocketChannel sch) throws Exception {
234 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
235 SslHandler sslHandler = clientContext.newHandler(
236 sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
237
238 sch.pipeline().addLast(sslHandler);
239 sch.pipeline().addLast(ch);
240 }
241 });
242
243 try {
244 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
245 Channel cc = cb.connect(sc.localAddress()).sync().channel();
246 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
247 cc.closeFuture().sync();
248 rethrowHandlerExceptions(sh, ch);
249 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
250
251 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
252 cc = cb.connect(sc.localAddress()).sync().channel();
253 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
254 cc.closeFuture().sync();
255 rethrowHandlerExceptions(sh, ch);
256 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
257 } finally {
258 sc.close().awaitUninterruptibly();
259 }
260 }
261
262 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
263 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
264 throw new ExecutionException(sh.exception.get());
265 }
266 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
267 throw new ExecutionException(ch.exception.get());
268 }
269 if (sh.exception.get() != null) {
270 throw new ExecutionException(sh.exception.get());
271 }
272 if (ch.exception.get() != null) {
273 throw new ExecutionException(ch.exception.get());
274 }
275 }
276
277 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
278 Set<String> idSet = new HashSet<String>();
279 byte[] id;
280 while (sessionIds.hasMoreElements()) {
281 id = sessionIds.nextElement();
282 idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
283 }
284 return idSet;
285 }
286
287 @Sharable
288 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
289 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
290 private final boolean server;
291 private final boolean autoRead;
292
293 ReadAndDiscardHandler(boolean server, boolean autoRead) {
294 this.server = server;
295 this.autoRead = autoRead;
296 }
297
298 @Override
299 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
300 in.skipBytes(in.readableBytes());
301 }
302
303 @Override
304 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
305 try {
306 ctx.flush();
307 } finally {
308 if (!autoRead) {
309 ctx.read();
310 }
311 }
312 }
313
314 @Override
315 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
316 if (logger.isWarnEnabled()) {
317 logger.warn(
318 "Unexpected exception from the " +
319 (server? "server" : "client") + " side", cause);
320 }
321
322 exception.compareAndSet(null, cause);
323 ctx.close();
324 }
325 }
326
327 @SuppressJava6Requirement(reason = "Test only")
328 private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
329 implements ResumableX509ExtendedTrustManager {
330 @Override
331 public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
332 engine.getSession().putValue("key", "value");
333 }
334
335 @SuppressJava6Requirement(reason = "Test only")
336 @Override
337 public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
338 throws CertificateException {
339 engine.getHandshakeSession().putValue("key", "value");
340 }
341
342 @Override
343 public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
344 throw new CertificateException("Unsupported operation");
345 }
346
347 @Override
348 public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
349 throws CertificateException {
350 throw new CertificateException("Unsupported operation");
351 }
352
353 @Override
354 public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
355 throws CertificateException {
356 throw new CertificateException("Unsupported operation");
357 }
358
359 @Override
360 public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
361 throws CertificateException {
362 throw new CertificateException("Unsupported operation");
363 }
364
365 @Override
366 public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
367 throw new CertificateException("Unsupported operation");
368 }
369
370 @Override
371 public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
372 throw new CertificateException("Unsupported operation");
373 }
374
375 @Override
376 public X509Certificate[] getAcceptedIssuers() {
377 return new X509Certificate[0];
378 }
379 }
380 }