1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandler.Sharable;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.SimpleChannelInboundHandler;
30 import io.netty.handler.ssl.OpenSsl;
31 import io.netty.handler.ssl.OpenSslContext;
32 import io.netty.handler.ssl.SslContext;
33 import io.netty.handler.ssl.SslContextBuilder;
34 import io.netty.handler.ssl.SslHandler;
35 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.ssl.util.SelfSignedCertificate;
38 import io.netty.handler.stream.ChunkedWriteHandler;
39 import io.netty.testsuite.util.TestUtils;
40 import io.netty.util.concurrent.Future;
41 import io.netty.util.concurrent.GenericFutureListener;
42 import io.netty.util.internal.logging.InternalLogger;
43 import io.netty.util.internal.logging.InternalLoggerFactory;
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.TestInfo;
46 import org.junit.jupiter.api.Timeout;
47 import org.junit.jupiter.params.ParameterizedTest;
48 import org.junit.jupiter.params.provider.MethodSource;
49
50 import javax.net.ssl.SSLEngine;
51 import java.io.File;
52 import java.io.IOException;
53 import java.security.cert.CertificateException;
54 import java.util.ArrayList;
55 import java.util.Collection;
56 import java.util.List;
57 import java.util.Random;
58 import java.util.concurrent.CountDownLatch;
59 import java.util.concurrent.ExecutorService;
60 import java.util.concurrent.Executors;
61 import java.util.concurrent.TimeUnit;
62 import java.util.concurrent.atomic.AtomicInteger;
63 import java.util.concurrent.atomic.AtomicReference;
64
65 import static org.assertj.core.api.Assertions.assertThat;
66 import static org.junit.jupiter.api.Assertions.assertEquals;
67 import static org.junit.jupiter.api.Assertions.assertFalse;
68 import static org.junit.jupiter.api.Assertions.assertNotSame;
69 import static org.junit.jupiter.api.Assertions.assertSame;
70 import static org.junit.jupiter.api.Assertions.assertTrue;
71
72 public class SocketSslEchoTest extends AbstractSocketTest {
73
74 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
75
76 private static final int FIRST_MESSAGE_SIZE = 16384;
77 private static final Random random = new Random();
78 private static final File CERT_FILE;
79 private static final File KEY_FILE;
80 static final byte[] data = new byte[1048576];
81
82 static {
83 random.nextBytes(data);
84
85 SelfSignedCertificate ssc;
86 try {
87 ssc = new SelfSignedCertificate();
88 } catch (CertificateException e) {
89 throw new Error(e);
90 }
91 CERT_FILE = ssc.certificate();
92 KEY_FILE = ssc.privateKey();
93 }
94
95 protected enum RenegotiationType {
96 NONE,
97 CLIENT_INITIATED,
98 SERVER_INITIATED,
99 }
100
101 protected static class Renegotiation {
102 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
103
104 final RenegotiationType type;
105 final String cipherSuite;
106
107 Renegotiation(RenegotiationType type, String cipherSuite) {
108 this.type = type;
109 this.cipherSuite = cipherSuite;
110 }
111
112 @Override
113 public String toString() {
114 if (type == RenegotiationType.NONE) {
115 return "NONE";
116 }
117
118 return type + "(" + cipherSuite + ')';
119 }
120 }
121
122 public static Collection<Object[]> data() throws Exception {
123 List<SslContext> serverContexts = new ArrayList<SslContext>();
124 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
125 .sslProvider(SslProvider.JDK)
126
127 .protocols("TLSv1.2")
128 .build());
129
130 List<SslContext> clientContexts = new ArrayList<SslContext>();
131 clientContexts.add(SslContextBuilder.forClient()
132 .sslProvider(SslProvider.JDK)
133 .trustManager(CERT_FILE)
134
135 .protocols("TLSv1.2")
136 .build());
137
138 boolean hasOpenSsl = OpenSsl.isAvailable();
139 if (hasOpenSsl) {
140 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
141 .sslProvider(SslProvider.OPENSSL)
142
143 .protocols("TLSv1.2")
144 .build());
145 clientContexts.add(SslContextBuilder.forClient()
146 .sslProvider(SslProvider.OPENSSL)
147 .trustManager(CERT_FILE)
148
149 .protocols("TLSv1.2")
150 .build());
151 } else {
152 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
153 }
154
155 List<Object[]> params = new ArrayList<Object[]>();
156 for (SslContext sc: serverContexts) {
157 for (SslContext cc: clientContexts) {
158 for (RenegotiationType rt: RenegotiationType.values()) {
159 if (rt != RenegotiationType.NONE &&
160 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
161
162 continue;
163 }
164
165 final Renegotiation r;
166 switch (rt) {
167 case NONE:
168 r = Renegotiation.NONE;
169 break;
170 case SERVER_INITIATED:
171 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
172 break;
173 case CLIENT_INITIATED:
174 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
175 break;
176 default:
177 throw new Error();
178 }
179
180 for (int i = 0; i < 32; i++) {
181 params.add(new Object[] {
182 sc, cc, r,
183 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
184 }
185 }
186 }
187 }
188
189 return params;
190 }
191
192 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
193 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
194 private final AtomicInteger clientSendCounter = new AtomicInteger();
195 private final AtomicInteger clientRecvCounter = new AtomicInteger();
196 private final AtomicInteger serverRecvCounter = new AtomicInteger();
197
198 private final AtomicInteger clientNegoCounter = new AtomicInteger();
199 private final AtomicInteger serverNegoCounter = new AtomicInteger();
200
201 private volatile Channel clientChannel;
202 private volatile Channel serverChannel;
203
204 private volatile SslHandler clientSslHandler;
205 private volatile SslHandler serverSslHandler;
206
207 private final EchoClientHandler clientHandler =
208 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
209
210 private final EchoServerHandler serverHandler =
211 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
212
213 private SslContext serverCtx;
214 private SslContext clientCtx;
215 private Renegotiation renegotiation;
216 private boolean serverUsesDelegatedTaskExecutor;
217 private boolean clientUsesDelegatedTaskExecutor;
218 private boolean autoRead;
219 private boolean useChunkedWriteHandler;
220 private boolean useCompositeByteBuf;
221
222 @AfterAll
223 public static void compressHeapDumps() throws Exception {
224 TestUtils.compressHeapDumps();
225 }
226
227 @ParameterizedTest(name =
228 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
229 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
230 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
231 @MethodSource("data")
232 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
233 public void testSslEcho(
234 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
235 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
236 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
237 TestInfo testInfo) throws Throwable {
238 this.serverCtx = serverCtx;
239 this.clientCtx = clientCtx;
240 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
241 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
242 this.renegotiation = renegotiation;
243 this.autoRead = autoRead;
244 this.useChunkedWriteHandler = useChunkedWriteHandler;
245 this.useCompositeByteBuf = useCompositeByteBuf;
246 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
247 @Override
248 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
249 testSslEcho(serverBootstrap, bootstrap);
250 }
251 });
252 }
253
254 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
255 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
256 reset();
257
258 sb.childOption(ChannelOption.AUTO_READ, autoRead);
259 cb.option(ChannelOption.AUTO_READ, autoRead);
260
261 sb.childHandler(new ChannelInitializer<Channel>() {
262 @Override
263 public void initChannel(Channel sch) {
264 serverChannel = sch;
265
266 if (serverUsesDelegatedTaskExecutor) {
267 SSLEngine sse = serverCtx.newEngine(sch.alloc());
268 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
269 } else {
270 serverSslHandler = serverCtx.newHandler(sch.alloc());
271 }
272 serverSslHandler.setHandshakeTimeoutMillis(0);
273
274 sch.pipeline().addLast("ssl", serverSslHandler);
275 if (useChunkedWriteHandler) {
276 sch.pipeline().addLast(new ChunkedWriteHandler());
277 }
278 sch.pipeline().addLast("serverHandler", serverHandler);
279 }
280 });
281
282 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
283 cb.handler(new ChannelInitializer<Channel>() {
284 @Override
285 public void initChannel(Channel sch) {
286 clientChannel = sch;
287
288 if (clientUsesDelegatedTaskExecutor) {
289 SSLEngine cse = clientCtx.newEngine(sch.alloc());
290 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
291 } else {
292 clientSslHandler = clientCtx.newHandler(sch.alloc());
293 }
294 clientSslHandler.setHandshakeTimeoutMillis(0);
295
296 sch.pipeline().addLast("ssl", clientSslHandler);
297 if (useChunkedWriteHandler) {
298 sch.pipeline().addLast(new ChunkedWriteHandler());
299 }
300 sch.pipeline().addLast("clientHandler", clientHandler);
301 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
302 @Override
303 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
304 if (evt instanceof SslHandshakeCompletionEvent) {
305 clientHandshakeEventLatch.countDown();
306 }
307 ctx.fireUserEventTriggered(evt);
308 }
309 });
310 }
311 });
312
313 final Channel sc = sb.bind().sync().channel();
314 cb.connect(sc.localAddress()).sync();
315
316 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
317
318
319 clientHandshakeFuture.sync();
320 clientHandshakeEventLatch.await();
321
322 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
323 clientSendCounter.set(FIRST_MESSAGE_SIZE);
324
325 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
326 Future<Channel> renegoFuture = null;
327 while (clientSendCounter.get() < data.length) {
328 int clientSendCounterVal = clientSendCounter.get();
329 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
330 ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
331 if (useCompositeByteBuf) {
332 buf = Unpooled.compositeBuffer().addComponent(true, buf);
333 }
334
335 ChannelFuture future = clientChannel.writeAndFlush(buf);
336 clientSendCounter.set(clientSendCounterVal += length);
337 future.sync();
338
339 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
340 needsRenegotiation = false;
341 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
342 renegoFuture = clientSslHandler.renegotiate();
343 logStats("CLIENT RENEGOTIATES");
344 assertNotSame(renegoFuture, clientHandshakeFuture);
345 }
346 }
347
348
349 while (clientRecvCounter.get() < data.length) {
350 if (serverException.get() != null) {
351 break;
352 }
353 if (clientException.get() != null) {
354 break;
355 }
356
357 Thread.sleep(50);
358 }
359
360 while (serverRecvCounter.get() < data.length) {
361 if (serverException.get() != null) {
362 break;
363 }
364 if (clientException.get() != null) {
365 break;
366 }
367
368 Thread.sleep(50);
369 }
370
371
372 if (renegoFuture != null) {
373 renegoFuture.sync();
374 }
375 if (serverHandler.renegoFuture != null) {
376 serverHandler.renegoFuture.sync();
377 }
378
379 serverChannel.close().awaitUninterruptibly();
380 clientChannel.close().awaitUninterruptibly();
381 sc.close().awaitUninterruptibly();
382 delegatedTaskExecutor.shutdown();
383
384 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
385 throw serverException.get();
386 }
387 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
388 throw clientException.get();
389 }
390 if (serverException.get() != null) {
391 throw serverException.get();
392 }
393 if (clientException.get() != null) {
394 throw clientException.get();
395 }
396
397
398 try {
399 switch (renegotiation.type) {
400 case SERVER_INITIATED:
401 assertEquals(renegotiation.cipherSuite, serverSslHandler.engine().getSession().getCipherSuite());
402 assertEquals(2, serverNegoCounter.get());
403 assertThat(clientNegoCounter.get()).isIn(1, 2);
404 break;
405 case CLIENT_INITIATED:
406 assertThat(serverNegoCounter.get()).isIn(1, 2);
407 assertEquals(renegotiation.cipherSuite, clientSslHandler.engine().getSession().getCipherSuite());
408 assertEquals(2, clientNegoCounter.get());
409 break;
410 case NONE:
411 assertEquals(1, serverNegoCounter.get());
412 assertEquals(1, clientNegoCounter.get());
413 }
414 } finally {
415 logStats("STATS");
416 }
417 }
418
419 private void reset() {
420 clientException.set(null);
421 serverException.set(null);
422
423 clientSendCounter.set(0);
424 clientRecvCounter.set(0);
425 serverRecvCounter.set(0);
426
427 clientNegoCounter.set(0);
428 serverNegoCounter.set(0);
429
430 clientChannel = null;
431 serverChannel = null;
432
433 clientSslHandler = null;
434 serverSslHandler = null;
435 }
436
437 void logStats(String message) {
438 logger.debug(
439 "{}:\n" +
440 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
441 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
442 message,
443 clientSendCounter, clientRecvCounter, clientNegoCounter,
444 clientSslHandler.engine().getSession().getCipherSuite(),
445 serverRecvCounter, serverNegoCounter,
446 serverSslHandler.engine().getSession().getCipherSuite());
447 }
448
449 @Sharable
450 private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
451
452 protected final AtomicInteger recvCounter;
453 protected final AtomicInteger negoCounter;
454 protected final AtomicReference<Throwable> exception;
455
456 EchoHandler(
457 AtomicInteger recvCounter, AtomicInteger negoCounter,
458 AtomicReference<Throwable> exception) {
459
460 this.recvCounter = recvCounter;
461 this.negoCounter = negoCounter;
462 this.exception = exception;
463 }
464
465 @Override
466 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
467
468
469 if (!autoRead) {
470 ctx.read();
471 }
472 ctx.fireChannelReadComplete();
473 }
474
475 @Override
476 public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
477 if (evt instanceof SslHandshakeCompletionEvent) {
478 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
479 if (handshakeEvt.cause() != null) {
480 logger.warn("Handshake failed:", handshakeEvt.cause());
481 }
482 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
483 negoCounter.incrementAndGet();
484 logStats("HANDSHAKEN");
485 }
486 ctx.fireUserEventTriggered(evt);
487 }
488
489 @Override
490 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
491 if (logger.isWarnEnabled()) {
492 logger.warn("Unexpected exception from the client side:", cause);
493 }
494
495 exception.compareAndSet(null, cause);
496 ctx.close();
497 }
498 }
499
500 private class EchoClientHandler extends EchoHandler {
501
502 EchoClientHandler(
503 AtomicInteger recvCounter, AtomicInteger negoCounter,
504 AtomicReference<Throwable> exception) {
505
506 super(recvCounter, negoCounter, exception);
507 }
508
509 @Override
510 public void handlerAdded(final ChannelHandlerContext ctx) {
511 if (!autoRead) {
512 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
513 new GenericFutureListener<Future<? super Channel>>() {
514 @Override
515 public void operationComplete(Future<? super Channel> future) {
516 ctx.read();
517 }
518 });
519 }
520 }
521
522 @Override
523 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
524 byte[] actual = new byte[in.readableBytes()];
525 in.readBytes(actual);
526
527 int lastIdx = recvCounter.get();
528 for (int i = 0; i < actual.length; i ++) {
529 assertEquals(data[i + lastIdx], actual[i]);
530 }
531
532 recvCounter.addAndGet(actual.length);
533 }
534 }
535
536 private class EchoServerHandler extends EchoHandler {
537 volatile Future<Channel> renegoFuture;
538
539 EchoServerHandler(
540 AtomicInteger recvCounter, AtomicInteger negoCounter,
541 AtomicReference<Throwable> exception) {
542
543 super(recvCounter, negoCounter, exception);
544 }
545
546 @Override
547 public final void channelRegistered(ChannelHandlerContext ctx) {
548 renegoFuture = null;
549 }
550
551 @Override
552 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
553 if (!autoRead) {
554 ctx.read();
555 }
556 ctx.fireChannelActive();
557 }
558
559 @Override
560 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
561 byte[] actual = new byte[in.readableBytes()];
562 in.readBytes(actual);
563
564 int lastIdx = recvCounter.get();
565 for (int i = 0; i < actual.length; i ++) {
566 assertEquals(data[i + lastIdx], actual[i]);
567 }
568
569 ByteBuf buf = Unpooled.wrappedBuffer(actual);
570 if (useCompositeByteBuf) {
571 buf = Unpooled.compositeBuffer().addComponent(true, buf);
572 }
573 ctx.writeAndFlush(buf);
574
575 recvCounter.addAndGet(actual.length);
576
577
578 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
579 recvCounter.get() > data.length / 2 && renegoFuture == null) {
580
581 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
582
583 Future<Channel> hf = sslHandler.handshakeFuture();
584 assertTrue(hf.isDone());
585
586 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
587 logStats("SERVER RENEGOTIATES");
588 renegoFuture = sslHandler.renegotiate();
589 assertNotSame(renegoFuture, hf);
590 assertSame(renegoFuture, sslHandler.handshakeFuture());
591 assertFalse(renegoFuture.isDone());
592 }
593 }
594 }
595 }