1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandler.Sharable;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.SimpleChannelInboundHandler;
30 import io.netty.handler.ssl.OpenSsl;
31 import io.netty.handler.ssl.OpenSslContext;
32 import io.netty.handler.ssl.SslContext;
33 import io.netty.handler.ssl.SslContextBuilder;
34 import io.netty.handler.ssl.SslHandler;
35 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.stream.ChunkedWriteHandler;
38 import io.netty.pkitesting.CertificateBuilder;
39 import io.netty.pkitesting.X509Bundle;
40 import io.netty.testsuite.util.TestUtils;
41 import io.netty.util.concurrent.Future;
42 import io.netty.util.internal.logging.InternalLogger;
43 import io.netty.util.internal.logging.InternalLoggerFactory;
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.TestInfo;
46 import org.junit.jupiter.api.Timeout;
47 import org.junit.jupiter.params.ParameterizedTest;
48 import org.junit.jupiter.params.provider.MethodSource;
49
50 import javax.net.ssl.SSLEngine;
51 import java.io.File;
52 import java.io.IOException;
53 import java.util.ArrayList;
54 import java.util.Collection;
55 import java.util.List;
56 import java.util.Random;
57 import java.util.concurrent.CountDownLatch;
58 import java.util.concurrent.ExecutorService;
59 import java.util.concurrent.Executors;
60 import java.util.concurrent.TimeUnit;
61 import java.util.concurrent.atomic.AtomicInteger;
62 import java.util.concurrent.atomic.AtomicReference;
63
64 import static io.netty.testsuite.transport.TestsuitePermutation.randomBufferType;
65 import static org.assertj.core.api.Assertions.assertThat;
66 import static org.junit.jupiter.api.Assertions.assertEquals;
67 import static org.junit.jupiter.api.Assertions.assertFalse;
68 import static org.junit.jupiter.api.Assertions.assertNotSame;
69 import static org.junit.jupiter.api.Assertions.assertSame;
70 import static org.junit.jupiter.api.Assertions.assertTrue;
71
72 public class SocketSslEchoTest extends AbstractSocketTest {
73
74 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
75
76 private static final int FIRST_MESSAGE_SIZE = 16384;
77 private static final Random random = new Random();
78 private static final File CERT_FILE;
79 private static final File KEY_FILE;
80 static final byte[] data = new byte[1048576];
81
82 static {
83 random.nextBytes(data);
84
85 try {
86 X509Bundle cert = new CertificateBuilder()
87 .rsa2048()
88 .subject("cn=localhost")
89 .setIsCertificateAuthority(true)
90 .buildSelfSigned();
91 CERT_FILE = cert.toTempCertChainPem();
92 KEY_FILE = cert.toTempPrivateKeyPem();
93 } catch (Exception e) {
94 throw new ExceptionInInitializerError(e);
95 }
96 }
97
98 protected enum RenegotiationType {
99 NONE,
100 CLIENT_INITIATED,
101 SERVER_INITIATED,
102 }
103
104 protected static class Renegotiation {
105 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
106
107 final RenegotiationType type;
108 final String cipherSuite;
109
110 Renegotiation(RenegotiationType type, String cipherSuite) {
111 this.type = type;
112 this.cipherSuite = cipherSuite;
113 }
114
115 @Override
116 public String toString() {
117 if (type == RenegotiationType.NONE) {
118 return "NONE";
119 }
120
121 return type + "(" + cipherSuite + ')';
122 }
123 }
124
125 public static Collection<Object[]> data() throws Exception {
126 List<SslContext> serverContexts = new ArrayList<SslContext>();
127 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
128 .sslProvider(SslProvider.JDK)
129
130 .protocols("TLSv1.2")
131 .build());
132
133 List<SslContext> clientContexts = new ArrayList<SslContext>();
134 clientContexts.add(SslContextBuilder.forClient()
135 .sslProvider(SslProvider.JDK)
136 .trustManager(CERT_FILE)
137
138 .protocols("TLSv1.2")
139 .endpointIdentificationAlgorithm(null)
140 .build());
141
142 boolean hasOpenSsl = OpenSsl.isAvailable();
143 if (hasOpenSsl) {
144 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
145 .sslProvider(SslProvider.OPENSSL)
146
147 .protocols("TLSv1.2")
148 .build());
149 clientContexts.add(SslContextBuilder.forClient()
150 .sslProvider(SslProvider.OPENSSL)
151 .trustManager(CERT_FILE)
152
153 .protocols("TLSv1.2")
154 .endpointIdentificationAlgorithm(null)
155 .build());
156 } else {
157 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
158 }
159
160 List<Object[]> params = new ArrayList<Object[]>();
161 for (SslContext sc: serverContexts) {
162 for (SslContext cc: clientContexts) {
163 for (RenegotiationType rt: RenegotiationType.values()) {
164 if (rt != RenegotiationType.NONE &&
165 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
166
167 continue;
168 }
169
170 final Renegotiation r;
171 switch (rt) {
172 case NONE:
173 r = Renegotiation.NONE;
174 break;
175 case SERVER_INITIATED:
176 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
177 break;
178 case CLIENT_INITIATED:
179 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
180 break;
181 default:
182 throw new Error("Unexpected renegotiation type: " + rt);
183 }
184
185 for (int i = 0; i < 32; i++) {
186 params.add(new Object[] {
187 sc, cc, r,
188 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
189 }
190 }
191 }
192 }
193
194 return params;
195 }
196
197 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
198 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
199 private final AtomicInteger clientSendCounter = new AtomicInteger();
200 private final AtomicInteger clientRecvCounter = new AtomicInteger();
201 private final AtomicInteger serverRecvCounter = new AtomicInteger();
202
203 private final AtomicInteger clientNegoCounter = new AtomicInteger();
204 private final AtomicInteger serverNegoCounter = new AtomicInteger();
205
206 private volatile Channel clientChannel;
207 private volatile Channel serverChannel;
208
209 private volatile SslHandler clientSslHandler;
210 private volatile SslHandler serverSslHandler;
211
212 private final EchoClientHandler clientHandler =
213 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
214
215 private final EchoServerHandler serverHandler =
216 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
217
218 private SslContext serverCtx;
219 private SslContext clientCtx;
220 private Renegotiation renegotiation;
221 private boolean serverUsesDelegatedTaskExecutor;
222 private boolean clientUsesDelegatedTaskExecutor;
223 private boolean autoRead;
224 private boolean useChunkedWriteHandler;
225 private boolean useCompositeByteBuf;
226
227 @AfterAll
228 public static void compressHeapDumps() throws Exception {
229 TestUtils.compressHeapDumps();
230 }
231
232 @ParameterizedTest(name =
233 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
234 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
235 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
236 @MethodSource("data")
237 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
238 public void testSslEcho(
239 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
240 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
241 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
242 TestInfo testInfo) throws Throwable {
243 this.serverCtx = serverCtx;
244 this.clientCtx = clientCtx;
245 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
246 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
247 this.renegotiation = renegotiation;
248 this.autoRead = autoRead;
249 this.useChunkedWriteHandler = useChunkedWriteHandler;
250 this.useCompositeByteBuf = useCompositeByteBuf;
251 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
252 @Override
253 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
254 testSslEcho(serverBootstrap, bootstrap);
255 }
256 });
257 }
258
259 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
260 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
261 reset();
262
263 sb.childOption(ChannelOption.AUTO_READ, autoRead);
264 cb.option(ChannelOption.AUTO_READ, autoRead);
265
266 sb.childHandler(new ChannelInitializer<Channel>() {
267 @Override
268 public void initChannel(Channel sch) {
269 serverChannel = sch;
270
271 if (serverUsesDelegatedTaskExecutor) {
272 SSLEngine sse = serverCtx.newEngine(sch.alloc());
273 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
274 } else {
275 serverSslHandler = serverCtx.newHandler(sch.alloc());
276 }
277 serverSslHandler.setHandshakeTimeoutMillis(0);
278
279 sch.pipeline().addLast("ssl", serverSslHandler);
280 if (useChunkedWriteHandler) {
281 sch.pipeline().addLast(new ChunkedWriteHandler());
282 }
283 sch.pipeline().addLast("serverHandler", serverHandler);
284 }
285 });
286
287 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
288 cb.handler(new ChannelInitializer<Channel>() {
289 @Override
290 public void initChannel(Channel sch) {
291 clientChannel = sch;
292
293 if (clientUsesDelegatedTaskExecutor) {
294 SSLEngine cse = clientCtx.newEngine(sch.alloc());
295 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
296 } else {
297 clientSslHandler = clientCtx.newHandler(sch.alloc());
298 }
299 clientSslHandler.setHandshakeTimeoutMillis(0);
300
301 sch.pipeline().addLast("ssl", clientSslHandler);
302 if (useChunkedWriteHandler) {
303 sch.pipeline().addLast(new ChunkedWriteHandler());
304 }
305 sch.pipeline().addLast("clientHandler", clientHandler);
306 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
307 @Override
308 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
309 if (evt instanceof SslHandshakeCompletionEvent) {
310 clientHandshakeEventLatch.countDown();
311 }
312 ctx.fireUserEventTriggered(evt);
313 }
314 });
315 }
316 });
317
318 final Channel sc = sb.bind().sync().channel();
319 cb.connect(sc.localAddress()).sync();
320
321 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
322
323
324 clientHandshakeFuture.sync();
325 clientHandshakeEventLatch.await();
326
327 clientChannel.writeAndFlush(randomBufferType(clientChannel.alloc(), data, 0, FIRST_MESSAGE_SIZE));
328 clientSendCounter.set(FIRST_MESSAGE_SIZE);
329
330 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
331 Future<Channel> renegoFuture = null;
332 while (clientSendCounter.get() < data.length) {
333 int clientSendCounterVal = clientSendCounter.get();
334 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
335 ByteBuf buf = randomBufferType(clientChannel.alloc(), data, clientSendCounterVal, length);
336 if (useCompositeByteBuf) {
337 buf = Unpooled.compositeBuffer().addComponent(true, buf);
338 }
339
340 ChannelFuture future = clientChannel.writeAndFlush(buf);
341 clientSendCounter.set(clientSendCounterVal += length);
342 future.sync();
343
344 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
345 needsRenegotiation = false;
346 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
347 renegoFuture = clientSslHandler.renegotiate();
348 logStats("CLIENT RENEGOTIATES");
349 assertNotSame(renegoFuture, clientHandshakeFuture);
350 }
351 }
352
353
354 while (clientRecvCounter.get() < data.length) {
355 if (serverException.get() != null) {
356 break;
357 }
358 if (clientException.get() != null) {
359 break;
360 }
361
362 Thread.sleep(50);
363 }
364
365 while (serverRecvCounter.get() < data.length) {
366 if (serverException.get() != null) {
367 break;
368 }
369 if (clientException.get() != null) {
370 break;
371 }
372
373 Thread.sleep(50);
374 }
375
376
377 if (renegoFuture != null) {
378 renegoFuture.sync();
379 }
380 if (serverHandler.renegoFuture != null) {
381 serverHandler.renegoFuture.sync();
382 }
383
384 serverChannel.close().awaitUninterruptibly();
385 clientChannel.close().awaitUninterruptibly();
386 sc.close().awaitUninterruptibly();
387 delegatedTaskExecutor.shutdown();
388 assertTrue(delegatedTaskExecutor.awaitTermination(5, TimeUnit.SECONDS));
389
390 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
391 throw serverException.get();
392 }
393 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
394 throw clientException.get();
395 }
396 if (serverException.get() != null) {
397 throw serverException.get();
398 }
399 if (clientException.get() != null) {
400 throw clientException.get();
401 }
402
403
404 try {
405 switch (renegotiation.type) {
406 case SERVER_INITIATED:
407 assertEquals(renegotiation.cipherSuite, serverSslHandler.engine().getSession().getCipherSuite());
408 assertEquals(2, serverNegoCounter.get());
409 assertThat(clientNegoCounter.get()).isIn(1, 2);
410 break;
411 case CLIENT_INITIATED:
412 assertThat(serverNegoCounter.get()).isIn(1, 2);
413 assertEquals(renegotiation.cipherSuite, clientSslHandler.engine().getSession().getCipherSuite());
414 assertEquals(2, clientNegoCounter.get());
415 break;
416 case NONE:
417 assertEquals(1, serverNegoCounter.get());
418 assertEquals(1, clientNegoCounter.get());
419 }
420 } finally {
421 logStats("STATS");
422 }
423 }
424
425 private void reset() {
426 clientException.set(null);
427 serverException.set(null);
428
429 clientSendCounter.set(0);
430 clientRecvCounter.set(0);
431 serverRecvCounter.set(0);
432
433 clientNegoCounter.set(0);
434 serverNegoCounter.set(0);
435
436 clientChannel = null;
437 serverChannel = null;
438
439 clientSslHandler = null;
440 serverSslHandler = null;
441 }
442
443 void logStats(String message) {
444 logger.debug(
445 "{}:\n" +
446 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
447 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
448 message,
449 clientSendCounter, clientRecvCounter, clientNegoCounter,
450 clientSslHandler.engine().getSession().getCipherSuite(),
451 serverRecvCounter, serverNegoCounter,
452 serverSslHandler.engine().getSession().getCipherSuite());
453 }
454
455 @Sharable
456 private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
457
458 protected final AtomicInteger recvCounter;
459 protected final AtomicInteger negoCounter;
460 protected final AtomicReference<Throwable> exception;
461
462 EchoHandler(
463 AtomicInteger recvCounter, AtomicInteger negoCounter,
464 AtomicReference<Throwable> exception) {
465
466 this.recvCounter = recvCounter;
467 this.negoCounter = negoCounter;
468 this.exception = exception;
469 }
470
471 @Override
472 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
473
474
475 if (!autoRead) {
476 ctx.read();
477 }
478 ctx.fireChannelReadComplete();
479 }
480
481 @Override
482 public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
483 if (evt instanceof SslHandshakeCompletionEvent) {
484 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
485 if (handshakeEvt.cause() != null) {
486 logger.warn("Handshake failed:", handshakeEvt.cause());
487 }
488 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
489 negoCounter.incrementAndGet();
490 logStats("HANDSHAKEN");
491 }
492 ctx.fireUserEventTriggered(evt);
493 }
494
495 @Override
496 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
497 if (logger.isWarnEnabled()) {
498 logger.warn("Unexpected exception from the client side:", cause);
499 }
500
501 exception.compareAndSet(null, cause);
502 ctx.close();
503 }
504 }
505
506 private class EchoClientHandler extends EchoHandler {
507
508 EchoClientHandler(
509 AtomicInteger recvCounter, AtomicInteger negoCounter,
510 AtomicReference<Throwable> exception) {
511
512 super(recvCounter, negoCounter, exception);
513 }
514
515 @Override
516 public void handlerAdded(final ChannelHandlerContext ctx) {
517 if (!autoRead) {
518 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(future -> ctx.read());
519 }
520 }
521
522 @Override
523 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
524 byte[] actual = new byte[in.readableBytes()];
525 in.readBytes(actual);
526
527 int lastIdx = recvCounter.get();
528 for (int i = 0; i < actual.length; i ++) {
529 assertEquals(data[i + lastIdx], actual[i]);
530 }
531
532 recvCounter.addAndGet(actual.length);
533 }
534 }
535
536 private class EchoServerHandler extends EchoHandler {
537 volatile Future<Channel> renegoFuture;
538
539 EchoServerHandler(
540 AtomicInteger recvCounter, AtomicInteger negoCounter,
541 AtomicReference<Throwable> exception) {
542
543 super(recvCounter, negoCounter, exception);
544 }
545
546 @Override
547 public final void channelRegistered(ChannelHandlerContext ctx) {
548 renegoFuture = null;
549 }
550
551 @Override
552 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
553 if (!autoRead) {
554 ctx.read();
555 }
556 ctx.fireChannelActive();
557 }
558
559 @Override
560 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
561 byte[] actual = new byte[in.readableBytes()];
562 in.readBytes(actual);
563
564 int lastIdx = recvCounter.get();
565 for (int i = 0; i < actual.length; i ++) {
566 assertEquals(data[i + lastIdx], actual[i]);
567 }
568
569 ByteBuf buf = randomBufferType(ctx.alloc(), actual, 0, actual.length);
570 if (useCompositeByteBuf) {
571 buf = Unpooled.compositeBuffer().addComponent(true, buf);
572 }
573 ctx.writeAndFlush(buf);
574
575 recvCounter.addAndGet(actual.length);
576
577
578 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
579 recvCounter.get() > data.length / 2 && renegoFuture == null) {
580
581 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
582
583 Future<Channel> hf = sslHandler.handshakeFuture();
584 assertTrue(hf.isDone());
585
586 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
587 logStats("SERVER RENEGOTIATES");
588 renegoFuture = sslHandler.renegotiate();
589 assertNotSame(renegoFuture, hf);
590 assertSame(renegoFuture, sslHandler.handshakeFuture());
591 assertFalse(renegoFuture.isDone());
592 }
593 }
594 }
595 }