1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelHandler.Sharable;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.SimpleChannelInboundHandler;
29 import io.netty.channel.socket.SocketChannel;
30 import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslContextBuilder;
33 import io.netty.handler.ssl.SslHandler;
34 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35 import io.netty.handler.ssl.SslProvider;
36 import io.netty.handler.ssl.util.SelfSignedCertificate;
37 import io.netty.util.internal.SuppressJava6Requirement;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40
41 import org.junit.jupiter.api.TestInfo;
42 import org.junit.jupiter.api.Timeout;
43 import org.junit.jupiter.params.ParameterizedTest;
44 import org.junit.jupiter.params.provider.MethodSource;
45
46 import javax.net.ssl.SSLEngine;
47 import javax.net.ssl.SSLSession;
48 import javax.net.ssl.SSLSessionContext;
49 import javax.net.ssl.TrustManager;
50 import javax.net.ssl.X509ExtendedTrustManager;
51
52 import java.io.File;
53 import java.io.IOException;
54 import java.net.InetSocketAddress;
55 import java.net.Socket;
56 import java.nio.channels.ClosedChannelException;
57 import java.security.cert.CertificateException;
58 import java.security.cert.X509Certificate;
59 import java.util.Arrays;
60 import java.util.Collection;
61 import java.util.Collections;
62 import java.util.Enumeration;
63 import java.util.HashSet;
64 import java.util.Set;
65 import java.util.concurrent.BlockingQueue;
66 import java.util.concurrent.ExecutionException;
67 import java.util.concurrent.LinkedBlockingQueue;
68 import java.util.concurrent.TimeUnit;
69 import java.util.concurrent.atomic.AtomicReference;
70
71 import static org.junit.jupiter.api.Assertions.assertEquals;
72 import static org.junit.jupiter.api.Assertions.assertTrue;
73
74 public class SocketSslSessionReuseTest extends AbstractSocketTest {
75
76 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
77
78 private static final File CERT_FILE;
79 private static final File KEY_FILE;
80
81 static {
82 SelfSignedCertificate ssc;
83 try {
84 ssc = new SelfSignedCertificate();
85 } catch (CertificateException e) {
86 throw new Error(e);
87 }
88 CERT_FILE = ssc.certificate();
89 KEY_FILE = ssc.privateKey();
90 }
91
92 public static Collection<Object[]> jdkOnly() throws Exception {
93 return Collections.singleton(new Object[]{
94 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
95 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
96 });
97 }
98
99 public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
100 return Arrays.asList(new Object[]{
101 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
102 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
103 },
104 new Object[]{
105 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
106 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
107 });
108 }
109
110 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
111 @MethodSource("jdkOnly")
112 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
113 public void testSslSessionReuse(
114 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
115 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
116 @Override
117 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
118 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
119 }
120 });
121 }
122
123 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
124 final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
125 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
126 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
127 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
128
129 sb.childHandler(new ChannelInitializer<SocketChannel>() {
130 @Override
131 protected void initChannel(SocketChannel sch) throws Exception {
132 SSLEngine engine = serverCtx.newEngine(sch.alloc());
133 engine.setUseClientMode(false);
134 engine.setEnabledProtocols(protocols);
135
136 sch.pipeline().addLast(new SslHandler(engine));
137 sch.pipeline().addLast(sh);
138 }
139 });
140 final Channel sc = sb.bind().sync().channel();
141
142 cb.handler(new ChannelInitializer<SocketChannel>() {
143 @Override
144 protected void initChannel(SocketChannel sch) throws Exception {
145 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
146 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
147 engine.setUseClientMode(true);
148 engine.setEnabledProtocols(protocols);
149
150 sch.pipeline().addLast(new SslHandler(engine));
151 sch.pipeline().addLast(ch);
152 }
153 });
154
155 try {
156 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
157 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
158 Channel cc = cb.connect(sc.localAddress()).sync().channel();
159 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
160 cc.closeFuture().sync();
161 rethrowHandlerExceptions(sh, ch);
162 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
163
164 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
165 cc = cb.connect(sc.localAddress()).sync().channel();
166 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
167 cc.closeFuture().sync();
168 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
169 rethrowHandlerExceptions(sh, ch);
170 } finally {
171 sc.close().awaitUninterruptibly();
172 }
173 }
174
175 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
176 @MethodSource("jdkAndOpenSSL")
177 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
178 public void testSslSessionTrustManagerResumption(
179 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
180 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
181 @Override
182 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
183 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
184 }
185 });
186 }
187
188 public void testSslSessionTrustManagerResumption(
189 ServerBootstrap sb, Bootstrap cb,
190 SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
191 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
192 serverCtxBldr.protocols(protocols);
193 clientCtxBldr.protocols(protocols);
194 TrustManager clientTrustManager = new SessionSettingTrustManager();
195 clientCtxBldr.trustManager(clientTrustManager);
196 final SslContext serverContext = serverCtxBldr.build();
197 final SslContext clientContext = clientCtxBldr.build();
198
199 final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
200 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
201 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
202 @Override
203 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
204 if (evt instanceof SslHandshakeCompletionEvent) {
205 SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
206 if (handshakeCompletionEvent.isSuccess()) {
207 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
208 assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
209 } else {
210 logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
211 }
212 }
213 super.userEventTriggered(ctx, evt);
214 }
215 };
216
217 sb.childHandler(new ChannelInitializer<SocketChannel>() {
218 @Override
219 protected void initChannel(SocketChannel sch) throws Exception {
220 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
221 sch.pipeline().addLast(sh);
222 }
223 });
224 final Channel sc = sb.bind().sync().channel();
225
226 cb.handler(new ChannelInitializer<SocketChannel>() {
227 @Override
228 protected void initChannel(SocketChannel sch) throws Exception {
229 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
230 SslHandler sslHandler = clientContext.newHandler(
231 sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
232
233 sch.pipeline().addLast(sslHandler);
234 sch.pipeline().addLast(ch);
235 }
236 });
237
238 try {
239 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
240 Channel cc = cb.connect(sc.localAddress()).sync().channel();
241 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
242 cc.closeFuture().sync();
243 rethrowHandlerExceptions(sh, ch);
244 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
245
246 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
247 cc = cb.connect(sc.localAddress()).sync().channel();
248 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
249 cc.closeFuture().sync();
250 rethrowHandlerExceptions(sh, ch);
251 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
252 } finally {
253 sc.close().awaitUninterruptibly();
254 }
255 }
256
257 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
258 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
259 throw new ExecutionException(sh.exception.get());
260 }
261 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
262 throw new ExecutionException(ch.exception.get());
263 }
264 if (sh.exception.get() != null) {
265 throw new ExecutionException(sh.exception.get());
266 }
267 if (ch.exception.get() != null) {
268 throw new ExecutionException(ch.exception.get());
269 }
270 }
271
272 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
273 Set<String> idSet = new HashSet<String>();
274 byte[] id;
275 while (sessionIds.hasMoreElements()) {
276 id = sessionIds.nextElement();
277 idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
278 }
279 return idSet;
280 }
281
282 @Sharable
283 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
284 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
285 private final boolean server;
286 private final boolean autoRead;
287
288 ReadAndDiscardHandler(boolean server, boolean autoRead) {
289 this.server = server;
290 this.autoRead = autoRead;
291 }
292
293 @Override
294 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
295 in.skipBytes(in.readableBytes());
296 }
297
298 @Override
299 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
300 try {
301 ctx.flush();
302 } finally {
303 if (!autoRead) {
304 ctx.read();
305 }
306 }
307 }
308
309 @Override
310 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
311 if (logger.isWarnEnabled()) {
312 logger.warn(
313 "Unexpected exception from the " +
314 (server? "server" : "client") + " side", cause);
315 }
316
317 exception.compareAndSet(null, cause);
318 ctx.close();
319 }
320 }
321
322 @SuppressJava6Requirement(reason = "Test only")
323 private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
324 implements ResumableX509ExtendedTrustManager {
325 @Override
326 public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
327 engine.getSession().putValue("key", "value");
328 }
329
330 @SuppressJava6Requirement(reason = "Test only")
331 @Override
332 public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
333 throws CertificateException {
334 engine.getHandshakeSession().putValue("key", "value");
335 }
336
337 @Override
338 public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
339 throw new CertificateException("Unsupported operation");
340 }
341
342 @Override
343 public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
344 throws CertificateException {
345 throw new CertificateException("Unsupported operation");
346 }
347
348 @Override
349 public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
350 throws CertificateException {
351 throw new CertificateException("Unsupported operation");
352 }
353
354 @Override
355 public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
356 throws CertificateException {
357 throw new CertificateException("Unsupported operation");
358 }
359
360 @Override
361 public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
362 throw new CertificateException("Unsupported operation");
363 }
364
365 @Override
366 public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
367 throw new CertificateException("Unsupported operation");
368 }
369
370 @Override
371 public X509Certificate[] getAcceptedIssuers() {
372 return new X509Certificate[0];
373 }
374 }
375 }