1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelHandler.Sharable;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.SimpleChannelInboundHandler;
29 import io.netty.channel.socket.SocketChannel;
30 import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslContextBuilder;
33 import io.netty.handler.ssl.SslHandler;
34 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35 import io.netty.handler.ssl.SslProvider;
36 import io.netty.pkitesting.CertificateBuilder;
37 import io.netty.pkitesting.X509Bundle;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40
41 import org.junit.jupiter.api.TestInfo;
42 import org.junit.jupiter.api.Timeout;
43 import org.junit.jupiter.params.ParameterizedTest;
44 import org.junit.jupiter.params.provider.MethodSource;
45
46 import javax.net.ssl.SSLEngine;
47 import javax.net.ssl.SSLSession;
48 import javax.net.ssl.SSLSessionContext;
49 import javax.net.ssl.TrustManager;
50 import javax.net.ssl.X509ExtendedTrustManager;
51
52 import java.io.File;
53 import java.io.IOException;
54 import java.net.InetSocketAddress;
55 import java.net.Socket;
56 import java.security.cert.CertificateException;
57 import java.security.cert.X509Certificate;
58 import java.util.Arrays;
59 import java.util.Collection;
60 import java.util.Collections;
61 import java.util.Enumeration;
62 import java.util.HashSet;
63 import java.util.Set;
64 import java.util.concurrent.BlockingQueue;
65 import java.util.concurrent.ExecutionException;
66 import java.util.concurrent.LinkedBlockingQueue;
67 import java.util.concurrent.TimeUnit;
68 import java.util.concurrent.atomic.AtomicReference;
69
70 import static org.junit.jupiter.api.Assertions.assertEquals;
71 import static org.junit.jupiter.api.Assertions.assertTrue;
72
73 public class SocketSslSessionReuseTest extends AbstractSocketTest {
74
75 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
76
77 private static final File CERT_FILE;
78 private static final File KEY_FILE;
79
80 static {
81 try {
82 X509Bundle cert = new CertificateBuilder()
83 .subject("cn=localhost")
84 .setIsCertificateAuthority(true)
85 .buildSelfSigned();
86 CERT_FILE = cert.toTempCertChainPem();
87 KEY_FILE = cert.toTempPrivateKeyPem();
88 } catch (Exception e) {
89 throw new ExceptionInInitializerError(e);
90 }
91 }
92
93 public static Collection<Object[]> jdkOnly() throws Exception {
94 return Collections.singleton(new Object[]{
95 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
96 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
97 .endpointIdentificationAlgorithm(null)
98 });
99 }
100
101 public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
102 return Arrays.asList(new Object[]{
103 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
104 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
105 .endpointIdentificationAlgorithm(null)
106 },
107 new Object[]{
108 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
109 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
110 .endpointIdentificationAlgorithm(null)
111 });
112 }
113
114 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
115 @MethodSource("jdkOnly")
116 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
117 public void testSslSessionReuse(
118 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
119 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
120 @Override
121 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
122 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
123 }
124 });
125 }
126
127 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
128 final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
129 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
130 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
131 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
132
133 sb.childHandler(new ChannelInitializer<SocketChannel>() {
134 @Override
135 protected void initChannel(SocketChannel sch) throws Exception {
136 SSLEngine engine = serverCtx.newEngine(sch.alloc());
137 engine.setUseClientMode(false);
138 engine.setEnabledProtocols(protocols);
139
140 sch.pipeline().addLast(new SslHandler(engine));
141 sch.pipeline().addLast(sh);
142 }
143 });
144 final Channel sc = sb.bind().sync().channel();
145
146 cb.handler(new ChannelInitializer<SocketChannel>() {
147 @Override
148 protected void initChannel(SocketChannel sch) throws Exception {
149 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
150 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
151 engine.setUseClientMode(true);
152 engine.setEnabledProtocols(protocols);
153
154 sch.pipeline().addLast(new SslHandler(engine));
155 sch.pipeline().addLast(ch);
156 }
157 });
158
159 try {
160 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
161 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
162 Channel cc = cb.connect(sc.localAddress()).sync().channel();
163 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
164 cc.closeFuture().sync();
165 rethrowHandlerExceptions(sh, ch);
166 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
167
168 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
169 cc = cb.connect(sc.localAddress()).sync().channel();
170 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
171 cc.closeFuture().sync();
172 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
173 rethrowHandlerExceptions(sh, ch);
174 } finally {
175 sc.close().awaitUninterruptibly();
176 }
177 }
178
179 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
180 @MethodSource("jdkAndOpenSSL")
181 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
182 public void testSslSessionTrustManagerResumption(
183 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
184 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
185 @Override
186 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
187 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
188 }
189 });
190 }
191
192 public void testSslSessionTrustManagerResumption(
193 ServerBootstrap sb, Bootstrap cb,
194 SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
195 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
196 serverCtxBldr.protocols(protocols);
197 clientCtxBldr.protocols(protocols);
198 TrustManager clientTrustManager = new SessionSettingTrustManager();
199 clientCtxBldr.trustManager(clientTrustManager);
200 final SslContext serverContext = serverCtxBldr.build();
201 final SslContext clientContext = clientCtxBldr.build();
202
203 final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
204 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
205 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
206 @Override
207 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
208 if (evt instanceof SslHandshakeCompletionEvent) {
209 SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
210 if (handshakeCompletionEvent.isSuccess()) {
211 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
212 assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
213 } else {
214 logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
215 }
216 }
217 super.userEventTriggered(ctx, evt);
218 }
219 };
220
221 sb.childHandler(new ChannelInitializer<SocketChannel>() {
222 @Override
223 protected void initChannel(SocketChannel sch) throws Exception {
224 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
225 sch.pipeline().addLast(sh);
226 }
227 });
228 final Channel sc = sb.bind().sync().channel();
229
230 cb.handler(new ChannelInitializer<SocketChannel>() {
231 @Override
232 protected void initChannel(SocketChannel sch) throws Exception {
233 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
234 SslHandler sslHandler = clientContext.newHandler(
235 sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
236
237 sch.pipeline().addLast(sslHandler);
238 sch.pipeline().addLast(ch);
239 }
240 });
241
242 try {
243 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
244 Channel cc = cb.connect(sc.localAddress()).sync().channel();
245 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
246 cc.closeFuture().sync();
247 rethrowHandlerExceptions(sh, ch);
248 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
249
250 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
251 cc = cb.connect(sc.localAddress()).sync().channel();
252 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
253 cc.closeFuture().sync();
254 rethrowHandlerExceptions(sh, ch);
255 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
256 } finally {
257 sc.close().awaitUninterruptibly();
258 }
259 }
260
261 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
262 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
263 throw new ExecutionException(sh.exception.get());
264 }
265 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
266 throw new ExecutionException(ch.exception.get());
267 }
268 if (sh.exception.get() != null) {
269 throw new ExecutionException(sh.exception.get());
270 }
271 if (ch.exception.get() != null) {
272 throw new ExecutionException(ch.exception.get());
273 }
274 }
275
276 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
277 Set<String> idSet = new HashSet<String>();
278 byte[] id;
279 while (sessionIds.hasMoreElements()) {
280 id = sessionIds.nextElement();
281 idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
282 }
283 return idSet;
284 }
285
286 @Sharable
287 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
288 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
289 private final boolean server;
290 private final boolean autoRead;
291
292 ReadAndDiscardHandler(boolean server, boolean autoRead) {
293 this.server = server;
294 this.autoRead = autoRead;
295 }
296
297 @Override
298 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
299 in.skipBytes(in.readableBytes());
300 }
301
302 @Override
303 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
304 try {
305 ctx.flush();
306 } finally {
307 if (!autoRead) {
308 ctx.read();
309 }
310 }
311 }
312
313 @Override
314 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
315 if (logger.isWarnEnabled()) {
316 logger.warn(
317 "Unexpected exception from the " +
318 (server? "server" : "client") + " side", cause);
319 }
320
321 exception.compareAndSet(null, cause);
322 ctx.close();
323 }
324 }
325
326 private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
327 implements ResumableX509ExtendedTrustManager {
328 @Override
329 public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
330 engine.getSession().putValue("key", "value");
331 }
332
333 @Override
334 public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
335 throws CertificateException {
336 engine.getHandshakeSession().putValue("key", "value");
337 }
338
339 @Override
340 public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
341 throw new CertificateException("Unsupported operation");
342 }
343
344 @Override
345 public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
346 throws CertificateException {
347 throw new CertificateException("Unsupported operation");
348 }
349
350 @Override
351 public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
352 throws CertificateException {
353 throw new CertificateException("Unsupported operation");
354 }
355
356 @Override
357 public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
358 throws CertificateException {
359 throw new CertificateException("Unsupported operation");
360 }
361
362 @Override
363 public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
364 throw new CertificateException("Unsupported operation");
365 }
366
367 @Override
368 public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
369 throw new CertificateException("Unsupported operation");
370 }
371
372 @Override
373 public X509Certificate[] getAcceptedIssuers() {
374 return new X509Certificate[0];
375 }
376 }
377 }