1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelFuture;
24 import io.netty.channel.ChannelHandler.Sharable;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.SimpleChannelInboundHandler;
30 import io.netty.handler.ssl.OpenSsl;
31 import io.netty.handler.ssl.OpenSslContext;
32 import io.netty.handler.ssl.SslContext;
33 import io.netty.handler.ssl.SslContextBuilder;
34 import io.netty.handler.ssl.SslHandler;
35 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.ssl.util.SelfSignedCertificate;
38 import io.netty.handler.stream.ChunkedWriteHandler;
39 import io.netty.testsuite.util.TestUtils;
40 import io.netty.util.concurrent.Future;
41 import io.netty.util.concurrent.GenericFutureListener;
42 import io.netty.util.internal.logging.InternalLogger;
43 import io.netty.util.internal.logging.InternalLoggerFactory;
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.TestInfo;
46 import org.junit.jupiter.api.Timeout;
47 import org.junit.jupiter.params.ParameterizedTest;
48 import org.junit.jupiter.params.provider.MethodSource;
49
50 import javax.net.ssl.SSLEngine;
51 import java.io.File;
52 import java.io.IOException;
53 import java.security.cert.CertificateException;
54 import java.util.ArrayList;
55 import java.util.Collection;
56 import java.util.List;
57 import java.util.Random;
58 import java.util.concurrent.CountDownLatch;
59 import java.util.concurrent.ExecutorService;
60 import java.util.concurrent.Executors;
61 import java.util.concurrent.TimeUnit;
62 import java.util.concurrent.atomic.AtomicInteger;
63 import java.util.concurrent.atomic.AtomicReference;
64
65 import static org.hamcrest.MatcherAssert.assertThat;
66 import static org.hamcrest.Matchers.anyOf;
67 import static org.hamcrest.Matchers.is;
68 import static org.hamcrest.Matchers.not;
69 import static org.hamcrest.Matchers.sameInstance;
70 import static org.junit.jupiter.api.Assertions.assertEquals;
71 import static org.junit.jupiter.api.Assertions.assertSame;
72
73 public class SocketSslEchoTest extends AbstractSocketTest {
74
75 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
76
77 private static final int FIRST_MESSAGE_SIZE = 16384;
78 private static final Random random = new Random();
79 private static final File CERT_FILE;
80 private static final File KEY_FILE;
81 static final byte[] data = new byte[1048576];
82
83 static {
84 random.nextBytes(data);
85
86 SelfSignedCertificate ssc;
87 try {
88 ssc = new SelfSignedCertificate();
89 } catch (CertificateException e) {
90 throw new Error(e);
91 }
92 CERT_FILE = ssc.certificate();
93 KEY_FILE = ssc.privateKey();
94 }
95
96 protected enum RenegotiationType {
97 NONE,
98 CLIENT_INITIATED,
99 SERVER_INITIATED,
100 }
101
102 protected static class Renegotiation {
103 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
104
105 final RenegotiationType type;
106 final String cipherSuite;
107
108 Renegotiation(RenegotiationType type, String cipherSuite) {
109 this.type = type;
110 this.cipherSuite = cipherSuite;
111 }
112
113 @Override
114 public String toString() {
115 if (type == RenegotiationType.NONE) {
116 return "NONE";
117 }
118
119 return type + "(" + cipherSuite + ')';
120 }
121 }
122
123 public static Collection<Object[]> data() throws Exception {
124 List<SslContext> serverContexts = new ArrayList<SslContext>();
125 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
126 .sslProvider(SslProvider.JDK)
127
128 .protocols("TLSv1.2")
129 .build());
130
131 List<SslContext> clientContexts = new ArrayList<SslContext>();
132 clientContexts.add(SslContextBuilder.forClient()
133 .sslProvider(SslProvider.JDK)
134 .trustManager(CERT_FILE)
135
136 .protocols("TLSv1.2")
137 .build());
138
139 boolean hasOpenSsl = OpenSsl.isAvailable();
140 if (hasOpenSsl) {
141 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
142 .sslProvider(SslProvider.OPENSSL)
143
144 .protocols("TLSv1.2")
145 .build());
146 clientContexts.add(SslContextBuilder.forClient()
147 .sslProvider(SslProvider.OPENSSL)
148 .trustManager(CERT_FILE)
149
150 .protocols("TLSv1.2")
151 .build());
152 } else {
153 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
154 }
155
156 List<Object[]> params = new ArrayList<Object[]>();
157 for (SslContext sc: serverContexts) {
158 for (SslContext cc: clientContexts) {
159 for (RenegotiationType rt: RenegotiationType.values()) {
160 if (rt != RenegotiationType.NONE &&
161 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
162
163 continue;
164 }
165
166 final Renegotiation r;
167 switch (rt) {
168 case NONE:
169 r = Renegotiation.NONE;
170 break;
171 case SERVER_INITIATED:
172 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
173 break;
174 case CLIENT_INITIATED:
175 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
176 break;
177 default:
178 throw new Error();
179 }
180
181 for (int i = 0; i < 32; i++) {
182 params.add(new Object[] {
183 sc, cc, r,
184 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
185 }
186 }
187 }
188 }
189
190 return params;
191 }
192
193 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
194 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
195 private final AtomicInteger clientSendCounter = new AtomicInteger();
196 private final AtomicInteger clientRecvCounter = new AtomicInteger();
197 private final AtomicInteger serverRecvCounter = new AtomicInteger();
198
199 private final AtomicInteger clientNegoCounter = new AtomicInteger();
200 private final AtomicInteger serverNegoCounter = new AtomicInteger();
201
202 private volatile Channel clientChannel;
203 private volatile Channel serverChannel;
204
205 private volatile SslHandler clientSslHandler;
206 private volatile SslHandler serverSslHandler;
207
208 private final EchoClientHandler clientHandler =
209 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
210
211 private final EchoServerHandler serverHandler =
212 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
213
214 private SslContext serverCtx;
215 private SslContext clientCtx;
216 private Renegotiation renegotiation;
217 private boolean serverUsesDelegatedTaskExecutor;
218 private boolean clientUsesDelegatedTaskExecutor;
219 private boolean autoRead;
220 private boolean useChunkedWriteHandler;
221 private boolean useCompositeByteBuf;
222
223 @AfterAll
224 public static void compressHeapDumps() throws Exception {
225 TestUtils.compressHeapDumps();
226 }
227
228 @ParameterizedTest(name =
229 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
230 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
231 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
232 @MethodSource("data")
233 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
234 public void testSslEcho(
235 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
236 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
237 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
238 TestInfo testInfo) throws Throwable {
239 this.serverCtx = serverCtx;
240 this.clientCtx = clientCtx;
241 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
242 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
243 this.renegotiation = renegotiation;
244 this.autoRead = autoRead;
245 this.useChunkedWriteHandler = useChunkedWriteHandler;
246 this.useCompositeByteBuf = useCompositeByteBuf;
247 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
248 @Override
249 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
250 testSslEcho(serverBootstrap, bootstrap);
251 }
252 });
253 }
254
255 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
256 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
257 reset();
258
259 sb.childOption(ChannelOption.AUTO_READ, autoRead);
260 cb.option(ChannelOption.AUTO_READ, autoRead);
261
262 sb.childHandler(new ChannelInitializer<Channel>() {
263 @Override
264 public void initChannel(Channel sch) {
265 serverChannel = sch;
266
267 if (serverUsesDelegatedTaskExecutor) {
268 SSLEngine sse = serverCtx.newEngine(sch.alloc());
269 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
270 } else {
271 serverSslHandler = serverCtx.newHandler(sch.alloc());
272 }
273 serverSslHandler.setHandshakeTimeoutMillis(0);
274
275 sch.pipeline().addLast("ssl", serverSslHandler);
276 if (useChunkedWriteHandler) {
277 sch.pipeline().addLast(new ChunkedWriteHandler());
278 }
279 sch.pipeline().addLast("serverHandler", serverHandler);
280 }
281 });
282
283 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
284 cb.handler(new ChannelInitializer<Channel>() {
285 @Override
286 public void initChannel(Channel sch) {
287 clientChannel = sch;
288
289 if (clientUsesDelegatedTaskExecutor) {
290 SSLEngine cse = clientCtx.newEngine(sch.alloc());
291 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
292 } else {
293 clientSslHandler = clientCtx.newHandler(sch.alloc());
294 }
295 clientSslHandler.setHandshakeTimeoutMillis(0);
296
297 sch.pipeline().addLast("ssl", clientSslHandler);
298 if (useChunkedWriteHandler) {
299 sch.pipeline().addLast(new ChunkedWriteHandler());
300 }
301 sch.pipeline().addLast("clientHandler", clientHandler);
302 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
303 @Override
304 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
305 if (evt instanceof SslHandshakeCompletionEvent) {
306 clientHandshakeEventLatch.countDown();
307 }
308 ctx.fireUserEventTriggered(evt);
309 }
310 });
311 }
312 });
313
314 final Channel sc = sb.bind().sync().channel();
315 cb.connect(sc.localAddress()).sync();
316
317 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
318
319
320 clientHandshakeFuture.sync();
321 clientHandshakeEventLatch.await();
322
323 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
324 clientSendCounter.set(FIRST_MESSAGE_SIZE);
325
326 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
327 Future<Channel> renegoFuture = null;
328 while (clientSendCounter.get() < data.length) {
329 int clientSendCounterVal = clientSendCounter.get();
330 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
331 ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
332 if (useCompositeByteBuf) {
333 buf = Unpooled.compositeBuffer().addComponent(true, buf);
334 }
335
336 ChannelFuture future = clientChannel.writeAndFlush(buf);
337 clientSendCounter.set(clientSendCounterVal += length);
338 future.sync();
339
340 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
341 needsRenegotiation = false;
342 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
343 renegoFuture = clientSslHandler.renegotiate();
344 logStats("CLIENT RENEGOTIATES");
345 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
346 }
347 }
348
349
350 while (clientRecvCounter.get() < data.length) {
351 if (serverException.get() != null) {
352 break;
353 }
354 if (serverException.get() != null) {
355 break;
356 }
357
358 try {
359 Thread.sleep(50);
360 } catch (InterruptedException e) {
361
362 }
363 }
364
365 while (serverRecvCounter.get() < data.length) {
366 if (serverException.get() != null) {
367 break;
368 }
369 if (clientException.get() != null) {
370 break;
371 }
372
373 try {
374 Thread.sleep(50);
375 } catch (InterruptedException e) {
376
377 }
378 }
379
380
381 if (renegoFuture != null) {
382 renegoFuture.sync();
383 }
384 if (serverHandler.renegoFuture != null) {
385 serverHandler.renegoFuture.sync();
386 }
387
388 serverChannel.close().awaitUninterruptibly();
389 clientChannel.close().awaitUninterruptibly();
390 sc.close().awaitUninterruptibly();
391 delegatedTaskExecutor.shutdown();
392
393 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
394 throw serverException.get();
395 }
396 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
397 throw clientException.get();
398 }
399 if (serverException.get() != null) {
400 throw serverException.get();
401 }
402 if (clientException.get() != null) {
403 throw clientException.get();
404 }
405
406
407 try {
408 switch (renegotiation.type) {
409 case SERVER_INITIATED:
410 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
411 assertThat(serverNegoCounter.get(), is(2));
412 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
413 break;
414 case CLIENT_INITIATED:
415 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
416 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
417 assertThat(clientNegoCounter.get(), is(2));
418 break;
419 case NONE:
420 assertThat(serverNegoCounter.get(), is(1));
421 assertThat(clientNegoCounter.get(), is(1));
422 }
423 } finally {
424 logStats("STATS");
425 }
426 }
427
428 private void reset() {
429 clientException.set(null);
430 serverException.set(null);
431
432 clientSendCounter.set(0);
433 clientRecvCounter.set(0);
434 serverRecvCounter.set(0);
435
436 clientNegoCounter.set(0);
437 serverNegoCounter.set(0);
438
439 clientChannel = null;
440 serverChannel = null;
441
442 clientSslHandler = null;
443 serverSslHandler = null;
444 }
445
446 void logStats(String message) {
447 logger.debug(
448 "{}:\n" +
449 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
450 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
451 message,
452 clientSendCounter, clientRecvCounter, clientNegoCounter,
453 clientSslHandler.engine().getSession().getCipherSuite(),
454 serverRecvCounter, serverNegoCounter,
455 serverSslHandler.engine().getSession().getCipherSuite());
456 }
457
458 @Sharable
459 private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
460
461 protected final AtomicInteger recvCounter;
462 protected final AtomicInteger negoCounter;
463 protected final AtomicReference<Throwable> exception;
464
465 EchoHandler(
466 AtomicInteger recvCounter, AtomicInteger negoCounter,
467 AtomicReference<Throwable> exception) {
468
469 this.recvCounter = recvCounter;
470 this.negoCounter = negoCounter;
471 this.exception = exception;
472 }
473
474 @Override
475 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
476
477
478 if (!autoRead) {
479 ctx.read();
480 }
481 ctx.fireChannelReadComplete();
482 }
483
484 @Override
485 public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
486 if (evt instanceof SslHandshakeCompletionEvent) {
487 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
488 if (handshakeEvt.cause() != null) {
489 logger.warn("Handshake failed:", handshakeEvt.cause());
490 }
491 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
492 negoCounter.incrementAndGet();
493 logStats("HANDSHAKEN");
494 }
495 ctx.fireUserEventTriggered(evt);
496 }
497
498 @Override
499 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
500 if (logger.isWarnEnabled()) {
501 logger.warn("Unexpected exception from the client side:", cause);
502 }
503
504 exception.compareAndSet(null, cause);
505 ctx.close();
506 }
507 }
508
509 private class EchoClientHandler extends EchoHandler {
510
511 EchoClientHandler(
512 AtomicInteger recvCounter, AtomicInteger negoCounter,
513 AtomicReference<Throwable> exception) {
514
515 super(recvCounter, negoCounter, exception);
516 }
517
518 @Override
519 public void handlerAdded(final ChannelHandlerContext ctx) {
520 if (!autoRead) {
521 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
522 new GenericFutureListener<Future<? super Channel>>() {
523 @Override
524 public void operationComplete(Future<? super Channel> future) {
525 ctx.read();
526 }
527 });
528 }
529 }
530
531 @Override
532 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
533 byte[] actual = new byte[in.readableBytes()];
534 in.readBytes(actual);
535
536 int lastIdx = recvCounter.get();
537 for (int i = 0; i < actual.length; i ++) {
538 assertEquals(data[i + lastIdx], actual[i]);
539 }
540
541 recvCounter.addAndGet(actual.length);
542 }
543 }
544
545 private class EchoServerHandler extends EchoHandler {
546 volatile Future<Channel> renegoFuture;
547
548 EchoServerHandler(
549 AtomicInteger recvCounter, AtomicInteger negoCounter,
550 AtomicReference<Throwable> exception) {
551
552 super(recvCounter, negoCounter, exception);
553 }
554
555 @Override
556 public final void channelRegistered(ChannelHandlerContext ctx) {
557 renegoFuture = null;
558 }
559
560 @Override
561 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
562 if (!autoRead) {
563 ctx.read();
564 }
565 ctx.fireChannelActive();
566 }
567
568 @Override
569 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
570 byte[] actual = new byte[in.readableBytes()];
571 in.readBytes(actual);
572
573 int lastIdx = recvCounter.get();
574 for (int i = 0; i < actual.length; i ++) {
575 assertEquals(data[i + lastIdx], actual[i]);
576 }
577
578 ByteBuf buf = Unpooled.wrappedBuffer(actual);
579 if (useCompositeByteBuf) {
580 buf = Unpooled.compositeBuffer().addComponent(true, buf);
581 }
582 ctx.writeAndFlush(buf);
583
584 recvCounter.addAndGet(actual.length);
585
586
587 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
588 recvCounter.get() > data.length / 2 && renegoFuture == null) {
589
590 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
591
592 Future<Channel> hf = sslHandler.handshakeFuture();
593 assertThat(hf.isDone(), is(true));
594
595 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
596 logStats("SERVER RENEGOTIATES");
597 renegoFuture = sslHandler.renegotiate();
598 assertThat(renegoFuture, is(not(sameInstance(hf))));
599 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
600 assertThat(renegoFuture.isDone(), is(false));
601 }
602 }
603 }
604 }