1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandler.Sharable;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.SimpleChannelInboundHandler;
30 import io.netty.handler.ssl.OpenSsl;
31 import io.netty.handler.ssl.OpenSslContext;
32 import io.netty.handler.ssl.SslContext;
33 import io.netty.handler.ssl.SslContextBuilder;
34 import io.netty.handler.ssl.SslHandler;
35 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.stream.ChunkedWriteHandler;
38 import io.netty.pkitesting.CertificateBuilder;
39 import io.netty.pkitesting.X509Bundle;
40 import io.netty.testsuite.util.TestUtils;
41 import io.netty.util.concurrent.Future;
42 import io.netty.util.concurrent.GenericFutureListener;
43 import io.netty.util.internal.logging.InternalLogger;
44 import io.netty.util.internal.logging.InternalLoggerFactory;
45 import org.junit.jupiter.api.AfterAll;
46 import org.junit.jupiter.api.TestInfo;
47 import org.junit.jupiter.api.Timeout;
48 import org.junit.jupiter.params.ParameterizedTest;
49 import org.junit.jupiter.params.provider.MethodSource;
50
51 import javax.net.ssl.SSLEngine;
52 import java.io.File;
53 import java.io.IOException;
54 import java.util.ArrayList;
55 import java.util.Collection;
56 import java.util.List;
57 import java.util.Random;
58 import java.util.concurrent.CountDownLatch;
59 import java.util.concurrent.ExecutorService;
60 import java.util.concurrent.Executors;
61 import java.util.concurrent.TimeUnit;
62 import java.util.concurrent.atomic.AtomicInteger;
63 import java.util.concurrent.atomic.AtomicReference;
64
65 import static org.hamcrest.MatcherAssert.assertThat;
66 import static org.hamcrest.Matchers.anyOf;
67 import static org.hamcrest.Matchers.is;
68 import static org.hamcrest.Matchers.not;
69 import static org.hamcrest.Matchers.sameInstance;
70 import static org.junit.jupiter.api.Assertions.assertEquals;
71 import static org.junit.jupiter.api.Assertions.assertSame;
72
73 public class SocketSslEchoTest extends AbstractSocketTest {
74
75 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
76
77 private static final int FIRST_MESSAGE_SIZE = 16384;
78 private static final Random random = new Random();
79 private static final File CERT_FILE;
80 private static final File KEY_FILE;
81 static final byte[] data = new byte[1048576];
82
83 static {
84 random.nextBytes(data);
85
86 try {
87 X509Bundle cert = new CertificateBuilder()
88 .rsa2048()
89 .subject("cn=localhost")
90 .setIsCertificateAuthority(true)
91 .buildSelfSigned();
92 CERT_FILE = cert.toTempCertChainPem();
93 KEY_FILE = cert.toTempPrivateKeyPem();
94 } catch (Exception e) {
95 throw new ExceptionInInitializerError(e);
96 }
97 }
98
99 protected enum RenegotiationType {
100 NONE,
101 CLIENT_INITIATED,
102 SERVER_INITIATED,
103 }
104
105 protected static class Renegotiation {
106 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
107
108 final RenegotiationType type;
109 final String cipherSuite;
110
111 Renegotiation(RenegotiationType type, String cipherSuite) {
112 this.type = type;
113 this.cipherSuite = cipherSuite;
114 }
115
116 @Override
117 public String toString() {
118 if (type == RenegotiationType.NONE) {
119 return "NONE";
120 }
121
122 return type + "(" + cipherSuite + ')';
123 }
124 }
125
126 public static Collection<Object[]> data() throws Exception {
127 List<SslContext> serverContexts = new ArrayList<SslContext>();
128 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
129 .sslProvider(SslProvider.JDK)
130
131 .protocols("TLSv1.2")
132 .build());
133
134 List<SslContext> clientContexts = new ArrayList<SslContext>();
135 clientContexts.add(SslContextBuilder.forClient()
136 .sslProvider(SslProvider.JDK)
137 .trustManager(CERT_FILE)
138
139 .protocols("TLSv1.2")
140 .endpointIdentificationAlgorithm(null)
141 .build());
142
143 boolean hasOpenSsl = OpenSsl.isAvailable();
144 if (hasOpenSsl) {
145 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
146 .sslProvider(SslProvider.OPENSSL)
147
148 .protocols("TLSv1.2")
149 .build());
150 clientContexts.add(SslContextBuilder.forClient()
151 .sslProvider(SslProvider.OPENSSL)
152 .trustManager(CERT_FILE)
153
154 .protocols("TLSv1.2")
155 .endpointIdentificationAlgorithm(null)
156 .build());
157 } else {
158 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
159 }
160
161 List<Object[]> params = new ArrayList<Object[]>();
162 for (SslContext sc: serverContexts) {
163 for (SslContext cc: clientContexts) {
164 for (RenegotiationType rt: RenegotiationType.values()) {
165 if (rt != RenegotiationType.NONE &&
166 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
167
168 continue;
169 }
170
171 final Renegotiation r;
172 switch (rt) {
173 case NONE:
174 r = Renegotiation.NONE;
175 break;
176 case SERVER_INITIATED:
177 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
178 break;
179 case CLIENT_INITIATED:
180 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
181 break;
182 default:
183 throw new Error();
184 }
185
186 for (int i = 0; i < 32; i++) {
187 params.add(new Object[] {
188 sc, cc, r,
189 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
190 }
191 }
192 }
193 }
194
195 return params;
196 }
197
198 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
199 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
200 private final AtomicInteger clientSendCounter = new AtomicInteger();
201 private final AtomicInteger clientRecvCounter = new AtomicInteger();
202 private final AtomicInteger serverRecvCounter = new AtomicInteger();
203
204 private final AtomicInteger clientNegoCounter = new AtomicInteger();
205 private final AtomicInteger serverNegoCounter = new AtomicInteger();
206
207 private volatile Channel clientChannel;
208 private volatile Channel serverChannel;
209
210 private volatile SslHandler clientSslHandler;
211 private volatile SslHandler serverSslHandler;
212
213 private final EchoClientHandler clientHandler =
214 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
215
216 private final EchoServerHandler serverHandler =
217 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
218
219 private SslContext serverCtx;
220 private SslContext clientCtx;
221 private Renegotiation renegotiation;
222 private boolean serverUsesDelegatedTaskExecutor;
223 private boolean clientUsesDelegatedTaskExecutor;
224 private boolean autoRead;
225 private boolean useChunkedWriteHandler;
226 private boolean useCompositeByteBuf;
227
228 @AfterAll
229 public static void compressHeapDumps() throws Exception {
230 TestUtils.compressHeapDumps();
231 }
232
233 @ParameterizedTest(name =
234 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
235 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
236 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
237 @MethodSource("data")
238 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
239 public void testSslEcho(
240 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
241 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
242 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
243 TestInfo testInfo) throws Throwable {
244 this.serverCtx = serverCtx;
245 this.clientCtx = clientCtx;
246 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
247 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
248 this.renegotiation = renegotiation;
249 this.autoRead = autoRead;
250 this.useChunkedWriteHandler = useChunkedWriteHandler;
251 this.useCompositeByteBuf = useCompositeByteBuf;
252 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
253 @Override
254 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
255 testSslEcho(serverBootstrap, bootstrap);
256 }
257 });
258 }
259
260 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
261 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
262 reset();
263
264 sb.childOption(ChannelOption.AUTO_READ, autoRead);
265 cb.option(ChannelOption.AUTO_READ, autoRead);
266
267 sb.childHandler(new ChannelInitializer<Channel>() {
268 @Override
269 public void initChannel(Channel sch) {
270 serverChannel = sch;
271
272 if (serverUsesDelegatedTaskExecutor) {
273 SSLEngine sse = serverCtx.newEngine(sch.alloc());
274 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
275 } else {
276 serverSslHandler = serverCtx.newHandler(sch.alloc());
277 }
278 serverSslHandler.setHandshakeTimeoutMillis(0);
279
280 sch.pipeline().addLast("ssl", serverSslHandler);
281 if (useChunkedWriteHandler) {
282 sch.pipeline().addLast(new ChunkedWriteHandler());
283 }
284 sch.pipeline().addLast("serverHandler", serverHandler);
285 }
286 });
287
288 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
289 cb.handler(new ChannelInitializer<Channel>() {
290 @Override
291 public void initChannel(Channel sch) {
292 clientChannel = sch;
293
294 if (clientUsesDelegatedTaskExecutor) {
295 SSLEngine cse = clientCtx.newEngine(sch.alloc());
296 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
297 } else {
298 clientSslHandler = clientCtx.newHandler(sch.alloc());
299 }
300 clientSslHandler.setHandshakeTimeoutMillis(0);
301
302 sch.pipeline().addLast("ssl", clientSslHandler);
303 if (useChunkedWriteHandler) {
304 sch.pipeline().addLast(new ChunkedWriteHandler());
305 }
306 sch.pipeline().addLast("clientHandler", clientHandler);
307 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
308 @Override
309 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
310 if (evt instanceof SslHandshakeCompletionEvent) {
311 clientHandshakeEventLatch.countDown();
312 }
313 ctx.fireUserEventTriggered(evt);
314 }
315 });
316 }
317 });
318
319 final Channel sc = sb.bind().sync().channel();
320 cb.connect(sc.localAddress()).sync();
321
322 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
323
324
325 clientHandshakeFuture.sync();
326 clientHandshakeEventLatch.await();
327
328 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
329 clientSendCounter.set(FIRST_MESSAGE_SIZE);
330
331 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
332 Future<Channel> renegoFuture = null;
333 while (clientSendCounter.get() < data.length) {
334 int clientSendCounterVal = clientSendCounter.get();
335 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
336 ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
337 if (useCompositeByteBuf) {
338 buf = Unpooled.compositeBuffer().addComponent(true, buf);
339 }
340
341 ChannelFuture future = clientChannel.writeAndFlush(buf);
342 clientSendCounter.set(clientSendCounterVal += length);
343 future.sync();
344
345 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
346 needsRenegotiation = false;
347 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
348 renegoFuture = clientSslHandler.renegotiate();
349 logStats("CLIENT RENEGOTIATES");
350 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
351 }
352 }
353
354
355 while (clientRecvCounter.get() < data.length) {
356 if (serverException.get() != null) {
357 break;
358 }
359 if (clientException.get() != null) {
360 break;
361 }
362
363 Thread.sleep(50);
364 }
365
366 while (serverRecvCounter.get() < data.length) {
367 if (serverException.get() != null) {
368 break;
369 }
370 if (clientException.get() != null) {
371 break;
372 }
373
374 Thread.sleep(50);
375 }
376
377
378 if (renegoFuture != null) {
379 renegoFuture.sync();
380 }
381 if (serverHandler.renegoFuture != null) {
382 serverHandler.renegoFuture.sync();
383 }
384
385 serverChannel.close().awaitUninterruptibly();
386 clientChannel.close().awaitUninterruptibly();
387 sc.close().awaitUninterruptibly();
388 delegatedTaskExecutor.shutdown();
389
390 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
391 throw serverException.get();
392 }
393 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
394 throw clientException.get();
395 }
396 if (serverException.get() != null) {
397 throw serverException.get();
398 }
399 if (clientException.get() != null) {
400 throw clientException.get();
401 }
402
403
404 try {
405 switch (renegotiation.type) {
406 case SERVER_INITIATED:
407 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
408 assertThat(serverNegoCounter.get(), is(2));
409 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
410 break;
411 case CLIENT_INITIATED:
412 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
413 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
414 assertThat(clientNegoCounter.get(), is(2));
415 break;
416 case NONE:
417 assertThat(serverNegoCounter.get(), is(1));
418 assertThat(clientNegoCounter.get(), is(1));
419 }
420 } finally {
421 logStats("STATS");
422 }
423 }
424
425 private void reset() {
426 clientException.set(null);
427 serverException.set(null);
428
429 clientSendCounter.set(0);
430 clientRecvCounter.set(0);
431 serverRecvCounter.set(0);
432
433 clientNegoCounter.set(0);
434 serverNegoCounter.set(0);
435
436 clientChannel = null;
437 serverChannel = null;
438
439 clientSslHandler = null;
440 serverSslHandler = null;
441 }
442
443 void logStats(String message) {
444 logger.debug(
445 "{}:\n" +
446 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
447 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
448 message,
449 clientSendCounter, clientRecvCounter, clientNegoCounter,
450 clientSslHandler.engine().getSession().getCipherSuite(),
451 serverRecvCounter, serverNegoCounter,
452 serverSslHandler.engine().getSession().getCipherSuite());
453 }
454
455 @Sharable
456 private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
457
458 protected final AtomicInteger recvCounter;
459 protected final AtomicInteger negoCounter;
460 protected final AtomicReference<Throwable> exception;
461
462 EchoHandler(
463 AtomicInteger recvCounter, AtomicInteger negoCounter,
464 AtomicReference<Throwable> exception) {
465
466 this.recvCounter = recvCounter;
467 this.negoCounter = negoCounter;
468 this.exception = exception;
469 }
470
471 @Override
472 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
473
474
475 if (!autoRead) {
476 ctx.read();
477 }
478 ctx.fireChannelReadComplete();
479 }
480
481 @Override
482 public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
483 if (evt instanceof SslHandshakeCompletionEvent) {
484 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
485 if (handshakeEvt.cause() != null) {
486 logger.warn("Handshake failed:", handshakeEvt.cause());
487 }
488 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
489 negoCounter.incrementAndGet();
490 logStats("HANDSHAKEN");
491 }
492 ctx.fireUserEventTriggered(evt);
493 }
494
495 @Override
496 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
497 if (logger.isWarnEnabled()) {
498 logger.warn("Unexpected exception from the client side:", cause);
499 }
500
501 exception.compareAndSet(null, cause);
502 ctx.close();
503 }
504 }
505
506 private class EchoClientHandler extends EchoHandler {
507
508 EchoClientHandler(
509 AtomicInteger recvCounter, AtomicInteger negoCounter,
510 AtomicReference<Throwable> exception) {
511
512 super(recvCounter, negoCounter, exception);
513 }
514
515 @Override
516 public void handlerAdded(final ChannelHandlerContext ctx) {
517 if (!autoRead) {
518 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
519 new GenericFutureListener<Future<? super Channel>>() {
520 @Override
521 public void operationComplete(Future<? super Channel> future) {
522 ctx.read();
523 }
524 });
525 }
526 }
527
528 @Override
529 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
530 byte[] actual = new byte[in.readableBytes()];
531 in.readBytes(actual);
532
533 int lastIdx = recvCounter.get();
534 for (int i = 0; i < actual.length; i ++) {
535 assertEquals(data[i + lastIdx], actual[i]);
536 }
537
538 recvCounter.addAndGet(actual.length);
539 }
540 }
541
542 private class EchoServerHandler extends EchoHandler {
543 volatile Future<Channel> renegoFuture;
544
545 EchoServerHandler(
546 AtomicInteger recvCounter, AtomicInteger negoCounter,
547 AtomicReference<Throwable> exception) {
548
549 super(recvCounter, negoCounter, exception);
550 }
551
552 @Override
553 public final void channelRegistered(ChannelHandlerContext ctx) {
554 renegoFuture = null;
555 }
556
557 @Override
558 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
559 if (!autoRead) {
560 ctx.read();
561 }
562 ctx.fireChannelActive();
563 }
564
565 @Override
566 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
567 byte[] actual = new byte[in.readableBytes()];
568 in.readBytes(actual);
569
570 int lastIdx = recvCounter.get();
571 for (int i = 0; i < actual.length; i ++) {
572 assertEquals(data[i + lastIdx], actual[i]);
573 }
574
575 ByteBuf buf = Unpooled.wrappedBuffer(actual);
576 if (useCompositeByteBuf) {
577 buf = Unpooled.compositeBuffer().addComponent(true, buf);
578 }
579 ctx.writeAndFlush(buf);
580
581 recvCounter.addAndGet(actual.length);
582
583
584 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
585 recvCounter.get() > data.length / 2 && renegoFuture == null) {
586
587 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
588
589 Future<Channel> hf = sslHandler.handshakeFuture();
590 assertThat(hf.isDone(), is(true));
591
592 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
593 logStats("SERVER RENEGOTIATES");
594 renegoFuture = sslHandler.renegotiate();
595 assertThat(renegoFuture, is(not(sameInstance(hf))));
596 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
597 assertThat(renegoFuture.isDone(), is(false));
598 }
599 }
600 }
601 }