1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandler.Sharable;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.SimpleChannelInboundHandler;
30 import io.netty.handler.ssl.OpenSsl;
31 import io.netty.handler.ssl.OpenSslContext;
32 import io.netty.handler.ssl.SslContext;
33 import io.netty.handler.ssl.SslContextBuilder;
34 import io.netty.handler.ssl.SslHandler;
35 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.ssl.util.SelfSignedCertificate;
38 import io.netty.handler.stream.ChunkedWriteHandler;
39 import io.netty.testsuite.util.TestUtils;
40 import io.netty.util.concurrent.Future;
41 import io.netty.util.concurrent.GenericFutureListener;
42 import io.netty.util.internal.logging.InternalLogger;
43 import io.netty.util.internal.logging.InternalLoggerFactory;
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.TestInfo;
46 import org.junit.jupiter.api.Timeout;
47 import org.junit.jupiter.params.ParameterizedTest;
48 import org.junit.jupiter.params.provider.MethodSource;
49
50 import javax.net.ssl.SSLEngine;
51 import java.io.File;
52 import java.io.IOException;
53 import java.security.cert.CertificateException;
54 import java.util.ArrayList;
55 import java.util.Collection;
56 import java.util.List;
57 import java.util.Random;
58 import java.util.concurrent.CountDownLatch;
59 import java.util.concurrent.ExecutorService;
60 import java.util.concurrent.Executors;
61 import java.util.concurrent.TimeUnit;
62 import java.util.concurrent.atomic.AtomicInteger;
63 import java.util.concurrent.atomic.AtomicReference;
64
65 import static org.hamcrest.MatcherAssert.assertThat;
66 import static org.hamcrest.Matchers.anyOf;
67 import static org.hamcrest.Matchers.is;
68 import static org.hamcrest.Matchers.not;
69 import static org.hamcrest.Matchers.sameInstance;
70 import static org.junit.jupiter.api.Assertions.assertEquals;
71 import static org.junit.jupiter.api.Assertions.assertSame;
72
73 public class SocketSslEchoTest extends AbstractSocketTest {
74
75 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
76
77 private static final int FIRST_MESSAGE_SIZE = 16384;
78 private static final Random random = new Random();
79 private static final File CERT_FILE;
80 private static final File KEY_FILE;
81 static final byte[] data = new byte[1048576];
82
83 static {
84 random.nextBytes(data);
85
86 SelfSignedCertificate ssc;
87 try {
88 ssc = new SelfSignedCertificate();
89 } catch (CertificateException e) {
90 throw new Error(e);
91 }
92 CERT_FILE = ssc.certificate();
93 KEY_FILE = ssc.privateKey();
94 }
95
96 protected enum RenegotiationType {
97 NONE,
98 CLIENT_INITIATED,
99 SERVER_INITIATED,
100 }
101
102 protected static class Renegotiation {
103 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
104
105 final RenegotiationType type;
106 final String cipherSuite;
107
108 Renegotiation(RenegotiationType type, String cipherSuite) {
109 this.type = type;
110 this.cipherSuite = cipherSuite;
111 }
112
113 @Override
114 public String toString() {
115 if (type == RenegotiationType.NONE) {
116 return "NONE";
117 }
118
119 return type + "(" + cipherSuite + ')';
120 }
121 }
122
123 public static Collection<Object[]> data() throws Exception {
124 List<SslContext> serverContexts = new ArrayList<SslContext>();
125 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
126 .sslProvider(SslProvider.JDK)
127
128 .protocols("TLSv1.2")
129 .build());
130
131 List<SslContext> clientContexts = new ArrayList<SslContext>();
132 clientContexts.add(SslContextBuilder.forClient()
133 .sslProvider(SslProvider.JDK)
134 .trustManager(CERT_FILE)
135
136 .protocols("TLSv1.2")
137 .endpointIdentificationAlgorithm(null)
138 .build());
139
140 boolean hasOpenSsl = OpenSsl.isAvailable();
141 if (hasOpenSsl) {
142 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
143 .sslProvider(SslProvider.OPENSSL)
144
145 .protocols("TLSv1.2")
146 .build());
147 clientContexts.add(SslContextBuilder.forClient()
148 .sslProvider(SslProvider.OPENSSL)
149 .trustManager(CERT_FILE)
150
151 .protocols("TLSv1.2")
152 .endpointIdentificationAlgorithm(null)
153 .build());
154 } else {
155 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
156 }
157
158 List<Object[]> params = new ArrayList<Object[]>();
159 for (SslContext sc: serverContexts) {
160 for (SslContext cc: clientContexts) {
161 for (RenegotiationType rt: RenegotiationType.values()) {
162 if (rt != RenegotiationType.NONE &&
163 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
164
165 continue;
166 }
167
168 final Renegotiation r;
169 switch (rt) {
170 case NONE:
171 r = Renegotiation.NONE;
172 break;
173 case SERVER_INITIATED:
174 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
175 break;
176 case CLIENT_INITIATED:
177 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
178 break;
179 default:
180 throw new Error();
181 }
182
183 for (int i = 0; i < 32; i++) {
184 params.add(new Object[] {
185 sc, cc, r,
186 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
187 }
188 }
189 }
190 }
191
192 return params;
193 }
194
195 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
196 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
197 private final AtomicInteger clientSendCounter = new AtomicInteger();
198 private final AtomicInteger clientRecvCounter = new AtomicInteger();
199 private final AtomicInteger serverRecvCounter = new AtomicInteger();
200
201 private final AtomicInteger clientNegoCounter = new AtomicInteger();
202 private final AtomicInteger serverNegoCounter = new AtomicInteger();
203
204 private volatile Channel clientChannel;
205 private volatile Channel serverChannel;
206
207 private volatile SslHandler clientSslHandler;
208 private volatile SslHandler serverSslHandler;
209
210 private final EchoClientHandler clientHandler =
211 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
212
213 private final EchoServerHandler serverHandler =
214 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
215
216 private SslContext serverCtx;
217 private SslContext clientCtx;
218 private Renegotiation renegotiation;
219 private boolean serverUsesDelegatedTaskExecutor;
220 private boolean clientUsesDelegatedTaskExecutor;
221 private boolean autoRead;
222 private boolean useChunkedWriteHandler;
223 private boolean useCompositeByteBuf;
224
225 @AfterAll
226 public static void compressHeapDumps() throws Exception {
227 TestUtils.compressHeapDumps();
228 }
229
230 @ParameterizedTest(name =
231 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
232 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
233 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
234 @MethodSource("data")
235 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
236 public void testSslEcho(
237 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
238 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
239 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
240 TestInfo testInfo) throws Throwable {
241 this.serverCtx = serverCtx;
242 this.clientCtx = clientCtx;
243 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
244 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
245 this.renegotiation = renegotiation;
246 this.autoRead = autoRead;
247 this.useChunkedWriteHandler = useChunkedWriteHandler;
248 this.useCompositeByteBuf = useCompositeByteBuf;
249 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
250 @Override
251 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
252 testSslEcho(serverBootstrap, bootstrap);
253 }
254 });
255 }
256
257 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
258 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
259 reset();
260
261 sb.childOption(ChannelOption.AUTO_READ, autoRead);
262 cb.option(ChannelOption.AUTO_READ, autoRead);
263
264 sb.childHandler(new ChannelInitializer<Channel>() {
265 @Override
266 public void initChannel(Channel sch) {
267 serverChannel = sch;
268
269 if (serverUsesDelegatedTaskExecutor) {
270 SSLEngine sse = serverCtx.newEngine(sch.alloc());
271 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
272 } else {
273 serverSslHandler = serverCtx.newHandler(sch.alloc());
274 }
275 serverSslHandler.setHandshakeTimeoutMillis(0);
276
277 sch.pipeline().addLast("ssl", serverSslHandler);
278 if (useChunkedWriteHandler) {
279 sch.pipeline().addLast(new ChunkedWriteHandler());
280 }
281 sch.pipeline().addLast("serverHandler", serverHandler);
282 }
283 });
284
285 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
286 cb.handler(new ChannelInitializer<Channel>() {
287 @Override
288 public void initChannel(Channel sch) {
289 clientChannel = sch;
290
291 if (clientUsesDelegatedTaskExecutor) {
292 SSLEngine cse = clientCtx.newEngine(sch.alloc());
293 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
294 } else {
295 clientSslHandler = clientCtx.newHandler(sch.alloc());
296 }
297 clientSslHandler.setHandshakeTimeoutMillis(0);
298
299 sch.pipeline().addLast("ssl", clientSslHandler);
300 if (useChunkedWriteHandler) {
301 sch.pipeline().addLast(new ChunkedWriteHandler());
302 }
303 sch.pipeline().addLast("clientHandler", clientHandler);
304 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
305 @Override
306 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
307 if (evt instanceof SslHandshakeCompletionEvent) {
308 clientHandshakeEventLatch.countDown();
309 }
310 ctx.fireUserEventTriggered(evt);
311 }
312 });
313 }
314 });
315
316 final Channel sc = sb.bind().sync().channel();
317 cb.connect(sc.localAddress()).sync();
318
319 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
320
321
322 clientHandshakeFuture.sync();
323 clientHandshakeEventLatch.await();
324
325 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
326 clientSendCounter.set(FIRST_MESSAGE_SIZE);
327
328 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
329 Future<Channel> renegoFuture = null;
330 while (clientSendCounter.get() < data.length) {
331 int clientSendCounterVal = clientSendCounter.get();
332 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
333 ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
334 if (useCompositeByteBuf) {
335 buf = Unpooled.compositeBuffer().addComponent(true, buf);
336 }
337
338 ChannelFuture future = clientChannel.writeAndFlush(buf);
339 clientSendCounter.set(clientSendCounterVal += length);
340 future.sync();
341
342 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
343 needsRenegotiation = false;
344 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
345 renegoFuture = clientSslHandler.renegotiate();
346 logStats("CLIENT RENEGOTIATES");
347 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
348 }
349 }
350
351
352 while (clientRecvCounter.get() < data.length) {
353 if (serverException.get() != null) {
354 break;
355 }
356 if (clientException.get() != null) {
357 break;
358 }
359
360 Thread.sleep(50);
361 }
362
363 while (serverRecvCounter.get() < data.length) {
364 if (serverException.get() != null) {
365 break;
366 }
367 if (clientException.get() != null) {
368 break;
369 }
370
371 Thread.sleep(50);
372 }
373
374
375 if (renegoFuture != null) {
376 renegoFuture.sync();
377 }
378 if (serverHandler.renegoFuture != null) {
379 serverHandler.renegoFuture.sync();
380 }
381
382 serverChannel.close().awaitUninterruptibly();
383 clientChannel.close().awaitUninterruptibly();
384 sc.close().awaitUninterruptibly();
385 delegatedTaskExecutor.shutdown();
386
387 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
388 throw serverException.get();
389 }
390 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
391 throw clientException.get();
392 }
393 if (serverException.get() != null) {
394 throw serverException.get();
395 }
396 if (clientException.get() != null) {
397 throw clientException.get();
398 }
399
400
401 try {
402 switch (renegotiation.type) {
403 case SERVER_INITIATED:
404 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
405 assertThat(serverNegoCounter.get(), is(2));
406 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
407 break;
408 case CLIENT_INITIATED:
409 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
410 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
411 assertThat(clientNegoCounter.get(), is(2));
412 break;
413 case NONE:
414 assertThat(serverNegoCounter.get(), is(1));
415 assertThat(clientNegoCounter.get(), is(1));
416 }
417 } finally {
418 logStats("STATS");
419 }
420 }
421
422 private void reset() {
423 clientException.set(null);
424 serverException.set(null);
425
426 clientSendCounter.set(0);
427 clientRecvCounter.set(0);
428 serverRecvCounter.set(0);
429
430 clientNegoCounter.set(0);
431 serverNegoCounter.set(0);
432
433 clientChannel = null;
434 serverChannel = null;
435
436 clientSslHandler = null;
437 serverSslHandler = null;
438 }
439
440 void logStats(String message) {
441 logger.debug(
442 "{}:\n" +
443 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
444 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
445 message,
446 clientSendCounter, clientRecvCounter, clientNegoCounter,
447 clientSslHandler.engine().getSession().getCipherSuite(),
448 serverRecvCounter, serverNegoCounter,
449 serverSslHandler.engine().getSession().getCipherSuite());
450 }
451
452 @Sharable
453 private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
454
455 protected final AtomicInteger recvCounter;
456 protected final AtomicInteger negoCounter;
457 protected final AtomicReference<Throwable> exception;
458
459 EchoHandler(
460 AtomicInteger recvCounter, AtomicInteger negoCounter,
461 AtomicReference<Throwable> exception) {
462
463 this.recvCounter = recvCounter;
464 this.negoCounter = negoCounter;
465 this.exception = exception;
466 }
467
468 @Override
469 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
470
471
472 if (!autoRead) {
473 ctx.read();
474 }
475 ctx.fireChannelReadComplete();
476 }
477
478 @Override
479 public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
480 if (evt instanceof SslHandshakeCompletionEvent) {
481 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
482 if (handshakeEvt.cause() != null) {
483 logger.warn("Handshake failed:", handshakeEvt.cause());
484 }
485 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
486 negoCounter.incrementAndGet();
487 logStats("HANDSHAKEN");
488 }
489 ctx.fireUserEventTriggered(evt);
490 }
491
492 @Override
493 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
494 if (logger.isWarnEnabled()) {
495 logger.warn("Unexpected exception from the client side:", cause);
496 }
497
498 exception.compareAndSet(null, cause);
499 ctx.close();
500 }
501 }
502
503 private class EchoClientHandler extends EchoHandler {
504
505 EchoClientHandler(
506 AtomicInteger recvCounter, AtomicInteger negoCounter,
507 AtomicReference<Throwable> exception) {
508
509 super(recvCounter, negoCounter, exception);
510 }
511
512 @Override
513 public void handlerAdded(final ChannelHandlerContext ctx) {
514 if (!autoRead) {
515 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
516 new GenericFutureListener<Future<? super Channel>>() {
517 @Override
518 public void operationComplete(Future<? super Channel> future) {
519 ctx.read();
520 }
521 });
522 }
523 }
524
525 @Override
526 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
527 byte[] actual = new byte[in.readableBytes()];
528 in.readBytes(actual);
529
530 int lastIdx = recvCounter.get();
531 for (int i = 0; i < actual.length; i ++) {
532 assertEquals(data[i + lastIdx], actual[i]);
533 }
534
535 recvCounter.addAndGet(actual.length);
536 }
537 }
538
539 private class EchoServerHandler extends EchoHandler {
540 volatile Future<Channel> renegoFuture;
541
542 EchoServerHandler(
543 AtomicInteger recvCounter, AtomicInteger negoCounter,
544 AtomicReference<Throwable> exception) {
545
546 super(recvCounter, negoCounter, exception);
547 }
548
549 @Override
550 public final void channelRegistered(ChannelHandlerContext ctx) {
551 renegoFuture = null;
552 }
553
554 @Override
555 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
556 if (!autoRead) {
557 ctx.read();
558 }
559 ctx.fireChannelActive();
560 }
561
562 @Override
563 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
564 byte[] actual = new byte[in.readableBytes()];
565 in.readBytes(actual);
566
567 int lastIdx = recvCounter.get();
568 for (int i = 0; i < actual.length; i ++) {
569 assertEquals(data[i + lastIdx], actual[i]);
570 }
571
572 ByteBuf buf = Unpooled.wrappedBuffer(actual);
573 if (useCompositeByteBuf) {
574 buf = Unpooled.compositeBuffer().addComponent(true, buf);
575 }
576 ctx.writeAndFlush(buf);
577
578 recvCounter.addAndGet(actual.length);
579
580
581 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
582 recvCounter.get() > data.length / 2 && renegoFuture == null) {
583
584 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
585
586 Future<Channel> hf = sslHandler.handshakeFuture();
587 assertThat(hf.isDone(), is(true));
588
589 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
590 logStats("SERVER RENEGOTIATES");
591 renegoFuture = sslHandler.renegotiate();
592 assertThat(renegoFuture, is(not(sameInstance(hf))));
593 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
594 assertThat(renegoFuture.isDone(), is(false));
595 }
596 }
597 }
598 }