1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufUtil;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelFutureListener;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelHandler.Sharable;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.SimpleChannelInboundHandler;
29 import io.netty.channel.socket.SocketChannel;
30 import io.netty.handler.ssl.ResumableX509ExtendedTrustManager;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslContextBuilder;
33 import io.netty.handler.ssl.SslHandler;
34 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
35 import io.netty.handler.ssl.SslProvider;
36 import io.netty.handler.ssl.util.SelfSignedCertificate;
37 import io.netty.util.internal.SuppressJava6Requirement;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40
41 import org.junit.jupiter.api.TestInfo;
42 import org.junit.jupiter.api.Timeout;
43 import org.junit.jupiter.params.ParameterizedTest;
44 import org.junit.jupiter.params.provider.MethodSource;
45
46 import javax.net.ssl.SSLEngine;
47 import javax.net.ssl.SSLSession;
48 import javax.net.ssl.SSLSessionContext;
49 import javax.net.ssl.TrustManager;
50 import javax.net.ssl.X509ExtendedTrustManager;
51
52 import java.io.File;
53 import java.io.IOException;
54 import java.net.InetSocketAddress;
55 import java.net.Socket;
56 import java.nio.channels.ClosedChannelException;
57 import java.security.cert.CertificateException;
58 import java.security.cert.X509Certificate;
59 import java.util.Arrays;
60 import java.util.Collection;
61 import java.util.Collections;
62 import java.util.Enumeration;
63 import java.util.HashSet;
64 import java.util.Set;
65 import java.util.concurrent.BlockingQueue;
66 import java.util.concurrent.ExecutionException;
67 import java.util.concurrent.LinkedBlockingQueue;
68 import java.util.concurrent.TimeUnit;
69 import java.util.concurrent.atomic.AtomicReference;
70
71 import static org.junit.jupiter.api.Assertions.assertEquals;
72 import static org.junit.jupiter.api.Assertions.assertTrue;
73
74 public class SocketSslSessionReuseTest extends AbstractSocketTest {
75
76 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
77
78 private static final File CERT_FILE;
79 private static final File KEY_FILE;
80
81 static {
82 SelfSignedCertificate ssc;
83 try {
84 ssc = new SelfSignedCertificate();
85 } catch (CertificateException e) {
86 throw new Error(e);
87 }
88 CERT_FILE = ssc.certificate();
89 KEY_FILE = ssc.privateKey();
90 }
91
92 public static Collection<Object[]> jdkOnly() throws Exception {
93 return Collections.singleton(new Object[]{
94 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
95 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
96 .endpointIdentificationAlgorithm(null)
97 });
98 }
99
100 public static Collection<Object[]> jdkAndOpenSSL() throws Exception {
101 return Arrays.asList(new Object[]{
102 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK),
103 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK)
104 .endpointIdentificationAlgorithm(null)
105 },
106 new Object[]{
107 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.OPENSSL),
108 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.OPENSSL)
109 .endpointIdentificationAlgorithm(null)
110 });
111 }
112
113 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
114 @MethodSource("jdkOnly")
115 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
116 public void testSslSessionReuse(
117 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
118 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
119 @Override
120 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
121 testSslSessionReuse(sb, cb, serverCtx.build(), clientCtx.build());
122 }
123 });
124 }
125
126 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
127 final SslContext serverCtx, final SslContext clientCtx) throws Throwable {
128 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
129 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
130 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
131
132 sb.childHandler(new ChannelInitializer<SocketChannel>() {
133 @Override
134 protected void initChannel(SocketChannel sch) throws Exception {
135 SSLEngine engine = serverCtx.newEngine(sch.alloc());
136 engine.setUseClientMode(false);
137 engine.setEnabledProtocols(protocols);
138
139 sch.pipeline().addLast(new SslHandler(engine));
140 sch.pipeline().addLast(sh);
141 }
142 });
143 final Channel sc = sb.bind().sync().channel();
144
145 cb.handler(new ChannelInitializer<SocketChannel>() {
146 @Override
147 protected void initChannel(SocketChannel sch) throws Exception {
148 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
149 SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
150 engine.setUseClientMode(true);
151 engine.setEnabledProtocols(protocols);
152
153 sch.pipeline().addLast(new SslHandler(engine));
154 sch.pipeline().addLast(ch);
155 }
156 });
157
158 try {
159 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
160 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
161 Channel cc = cb.connect(sc.localAddress()).sync().channel();
162 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
163 cc.closeFuture().sync();
164 rethrowHandlerExceptions(sh, ch);
165 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
166
167 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
168 cc = cb.connect(sc.localAddress()).sync().channel();
169 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
170 cc.closeFuture().sync();
171 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
172 rethrowHandlerExceptions(sh, ch);
173 } finally {
174 sc.close().awaitUninterruptibly();
175 }
176 }
177
178 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
179 @MethodSource("jdkAndOpenSSL")
180 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
181 public void testSslSessionTrustManagerResumption(
182 final SslContextBuilder serverCtx, final SslContextBuilder clientCtx, TestInfo testInfo) throws Throwable {
183 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
184 @Override
185 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
186 testSslSessionTrustManagerResumption(sb, cb, serverCtx, clientCtx);
187 }
188 });
189 }
190
191 public void testSslSessionTrustManagerResumption(
192 ServerBootstrap sb, Bootstrap cb,
193 SslContextBuilder serverCtxBldr, final SslContextBuilder clientCtxBldr) throws Throwable {
194 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
195 serverCtxBldr.protocols(protocols);
196 clientCtxBldr.protocols(protocols);
197 TrustManager clientTrustManager = new SessionSettingTrustManager();
198 clientCtxBldr.trustManager(clientTrustManager);
199 final SslContext serverContext = serverCtxBldr.build();
200 final SslContext clientContext = clientCtxBldr.build();
201
202 final BlockingQueue<String> sessionValue = new LinkedBlockingQueue<String>();
203 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
204 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true) {
205 @Override
206 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
207 if (evt instanceof SslHandshakeCompletionEvent) {
208 SslHandshakeCompletionEvent handshakeCompletionEvent = (SslHandshakeCompletionEvent) evt;
209 if (handshakeCompletionEvent.isSuccess()) {
210 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
211 assertTrue(sessionValue.offer(String.valueOf(session.getValue("key"))));
212 } else {
213 logger.error("SSL handshake failed", handshakeCompletionEvent.cause());
214 }
215 }
216 super.userEventTriggered(ctx, evt);
217 }
218 };
219
220 sb.childHandler(new ChannelInitializer<SocketChannel>() {
221 @Override
222 protected void initChannel(SocketChannel sch) throws Exception {
223 sch.pipeline().addLast(serverContext.newHandler(sch.alloc()));
224 sch.pipeline().addLast(sh);
225 }
226 });
227 final Channel sc = sb.bind().sync().channel();
228
229 cb.handler(new ChannelInitializer<SocketChannel>() {
230 @Override
231 protected void initChannel(SocketChannel sch) throws Exception {
232 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
233 SslHandler sslHandler = clientContext.newHandler(
234 sch.alloc(), serverAddr.getHostString(), serverAddr.getPort());
235
236 sch.pipeline().addLast(sslHandler);
237 sch.pipeline().addLast(ch);
238 }
239 });
240
241 try {
242 ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
243 Channel cc = cb.connect(sc.localAddress()).sync().channel();
244 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
245 cc.closeFuture().sync();
246 rethrowHandlerExceptions(sh, ch);
247 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
248
249 msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4);
250 cc = cb.connect(sc.localAddress()).sync().channel();
251 cc.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE).sync();
252 cc.closeFuture().sync();
253 rethrowHandlerExceptions(sh, ch);
254 assertEquals("value", sessionValue.poll(10, TimeUnit.SECONDS));
255 } finally {
256 sc.close().awaitUninterruptibly();
257 }
258 }
259
260 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
261 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
262 throw new ExecutionException(sh.exception.get());
263 }
264 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
265 throw new ExecutionException(ch.exception.get());
266 }
267 if (sh.exception.get() != null) {
268 throw new ExecutionException(sh.exception.get());
269 }
270 if (ch.exception.get() != null) {
271 throw new ExecutionException(ch.exception.get());
272 }
273 }
274
275 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
276 Set<String> idSet = new HashSet<String>();
277 byte[] id;
278 while (sessionIds.hasMoreElements()) {
279 id = sessionIds.nextElement();
280 idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id)));
281 }
282 return idSet;
283 }
284
285 @Sharable
286 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<ByteBuf> {
287 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
288 private final boolean server;
289 private final boolean autoRead;
290
291 ReadAndDiscardHandler(boolean server, boolean autoRead) {
292 this.server = server;
293 this.autoRead = autoRead;
294 }
295
296 @Override
297 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
298 in.skipBytes(in.readableBytes());
299 }
300
301 @Override
302 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
303 try {
304 ctx.flush();
305 } finally {
306 if (!autoRead) {
307 ctx.read();
308 }
309 }
310 }
311
312 @Override
313 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
314 if (logger.isWarnEnabled()) {
315 logger.warn(
316 "Unexpected exception from the " +
317 (server? "server" : "client") + " side", cause);
318 }
319
320 exception.compareAndSet(null, cause);
321 ctx.close();
322 }
323 }
324
325 @SuppressJava6Requirement(reason = "Test only")
326 private static final class SessionSettingTrustManager extends X509ExtendedTrustManager
327 implements ResumableX509ExtendedTrustManager {
328 @Override
329 public void resumeServerTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
330 engine.getSession().putValue("key", "value");
331 }
332
333 @SuppressJava6Requirement(reason = "Test only")
334 @Override
335 public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
336 throws CertificateException {
337 engine.getHandshakeSession().putValue("key", "value");
338 }
339
340 @Override
341 public void resumeClientTrusted(X509Certificate[] chain, SSLEngine engine) throws CertificateException {
342 throw new CertificateException("Unsupported operation");
343 }
344
345 @Override
346 public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine)
347 throws CertificateException {
348 throw new CertificateException("Unsupported operation");
349 }
350
351 @Override
352 public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket)
353 throws CertificateException {
354 throw new CertificateException("Unsupported operation");
355 }
356
357 @Override
358 public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
359 throws CertificateException {
360 throw new CertificateException("Unsupported operation");
361 }
362
363 @Override
364 public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
365 throw new CertificateException("Unsupported operation");
366 }
367
368 @Override
369 public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
370 throw new CertificateException("Unsupported operation");
371 }
372
373 @Override
374 public X509Certificate[] getAcceptedIssuers() {
375 return new X509Certificate[0];
376 }
377 }
378 }