1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.Buffer;
21 import io.netty5.buffer.api.BufferAllocator;
22 import io.netty5.buffer.api.DefaultBufferAllocators;
23 import io.netty5.channel.Channel;
24 import io.netty5.channel.ChannelHandler;
25 import io.netty5.channel.ChannelHandlerContext;
26 import io.netty5.channel.ChannelInitializer;
27 import io.netty5.channel.ChannelOption;
28 import io.netty5.channel.SimpleChannelInboundHandler;
29 import io.netty5.handler.ssl.OpenSsl;
30 import io.netty5.handler.ssl.OpenSslContext;
31 import io.netty5.handler.ssl.SslContext;
32 import io.netty5.handler.ssl.SslContextBuilder;
33 import io.netty5.handler.ssl.SslHandler;
34 import io.netty5.handler.ssl.SslHandshakeCompletionEvent;
35 import io.netty5.handler.ssl.SslProvider;
36 import io.netty5.handler.ssl.util.SelfSignedCertificate;
37 import io.netty5.handler.stream.ChunkedWriteHandler;
38 import io.netty5.testsuite.util.TestUtils;
39 import io.netty5.util.concurrent.Future;
40 import io.netty5.util.internal.logging.InternalLogger;
41 import io.netty5.util.internal.logging.InternalLoggerFactory;
42 import org.junit.jupiter.api.AfterAll;
43 import org.junit.jupiter.api.TestInfo;
44 import org.junit.jupiter.api.Timeout;
45 import org.junit.jupiter.params.ParameterizedTest;
46 import org.junit.jupiter.params.provider.MethodSource;
47
48 import javax.net.ssl.SSLEngine;
49 import java.io.File;
50 import java.io.IOException;
51 import java.security.cert.CertificateException;
52 import java.util.ArrayList;
53 import java.util.Collection;
54 import java.util.List;
55 import java.util.Random;
56 import java.util.concurrent.CountDownLatch;
57 import java.util.concurrent.ExecutorService;
58 import java.util.concurrent.Executors;
59 import java.util.concurrent.TimeUnit;
60 import java.util.concurrent.atomic.AtomicInteger;
61 import java.util.concurrent.atomic.AtomicReference;
62
63 import static org.hamcrest.MatcherAssert.assertThat;
64 import static org.hamcrest.Matchers.anyOf;
65 import static org.hamcrest.Matchers.is;
66 import static org.hamcrest.Matchers.not;
67 import static org.hamcrest.Matchers.sameInstance;
68 import static org.junit.jupiter.api.Assertions.assertEquals;
69 import static org.junit.jupiter.api.Assertions.assertTrue;
70
71 public class SocketSslEchoTest extends AbstractSocketTest {
72
73 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
74
75 private static final int FIRST_MESSAGE_SIZE = 16384;
76 private static final Random random = new Random();
77 private static final File CERT_FILE;
78 private static final File KEY_FILE;
79 static final byte[] data = new byte[1048576];
80
81 static {
82 random.nextBytes(data);
83
84 SelfSignedCertificate ssc;
85 try {
86 ssc = new SelfSignedCertificate();
87 } catch (CertificateException e) {
88 throw new Error(e);
89 }
90 CERT_FILE = ssc.certificate();
91 KEY_FILE = ssc.privateKey();
92 }
93
94 private static final BufferAllocator bufferAllocator = DefaultBufferAllocators.preferredAllocator();
95
96 protected enum RenegotiationType {
97 NONE,
98 CLIENT_INITIATED,
99 SERVER_INITIATED,
100 }
101
102 protected static class Renegotiation {
103 static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
104
105 final RenegotiationType type;
106 final String cipherSuite;
107
108 Renegotiation(RenegotiationType type, String cipherSuite) {
109 this.type = type;
110 this.cipherSuite = cipherSuite;
111 }
112
113 @Override
114 public String toString() {
115 if (type == RenegotiationType.NONE) {
116 return "NONE";
117 }
118
119 return type + "(" + cipherSuite + ')';
120 }
121 }
122
123 public static Collection<Object[]> data() throws Exception {
124 List<SslContext> serverContexts = new ArrayList<>();
125 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
126 .sslProvider(SslProvider.JDK)
127
128 .protocols("TLSv1.2")
129 .build());
130
131 List<SslContext> clientContexts = new ArrayList<>();
132 clientContexts.add(SslContextBuilder.forClient()
133 .sslProvider(SslProvider.JDK)
134 .trustManager(CERT_FILE)
135
136 .protocols("TLSv1.2")
137 .build());
138
139 boolean hasOpenSsl = OpenSsl.isAvailable();
140 if (hasOpenSsl) {
141 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
142 .sslProvider(SslProvider.OPENSSL)
143
144 .protocols("TLSv1.2")
145 .build());
146 clientContexts.add(SslContextBuilder.forClient()
147 .sslProvider(SslProvider.OPENSSL)
148 .trustManager(CERT_FILE)
149
150 .protocols("TLSv1.2")
151 .build());
152 } else {
153 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
154 }
155
156 List<Object[]> params = new ArrayList<>();
157 for (SslContext sc: serverContexts) {
158 for (SslContext cc: clientContexts) {
159 for (RenegotiationType rt: RenegotiationType.values()) {
160 if (rt != RenegotiationType.NONE &&
161 (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
162
163 continue;
164 }
165
166 final Renegotiation r;
167 switch (rt) {
168 case NONE:
169 r = Renegotiation.NONE;
170 break;
171 case SERVER_INITIATED:
172 r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
173 break;
174 case CLIENT_INITIATED:
175 r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
176 break;
177 default:
178 throw new Error();
179 }
180
181 for (int i = 0; i < 32; i++) {
182 params.add(new Object[] {
183 sc, cc, r,
184 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
185 }
186 }
187 }
188 }
189
190 return params;
191 }
192
193 private final AtomicReference<Throwable> clientException = new AtomicReference<>();
194 private final AtomicReference<Throwable> serverException = new AtomicReference<>();
195
196 private final AtomicInteger clientSendCounter = new AtomicInteger();
197 private final AtomicInteger clientRecvCounter = new AtomicInteger();
198 private final AtomicInteger serverRecvCounter = new AtomicInteger();
199
200 private final AtomicInteger clientNegoCounter = new AtomicInteger();
201 private final AtomicInteger serverNegoCounter = new AtomicInteger();
202
203 private volatile Channel clientChannel;
204 private volatile Channel serverChannel;
205
206 private volatile SslHandler clientSslHandler;
207 private volatile SslHandler serverSslHandler;
208
209 private final EchoClientHandler clientHandler =
210 new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
211
212 private final EchoServerHandler serverHandler =
213 new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
214
215 private SslContext serverCtx;
216 private SslContext clientCtx;
217 private Renegotiation renegotiation;
218 private boolean serverUsesDelegatedTaskExecutor;
219 private boolean clientUsesDelegatedTaskExecutor;
220 private boolean autoRead;
221 private boolean useChunkedWriteHandler;
222 private boolean useCompositeBuffer;
223
224 @AfterAll
225 public static void compressHeapDumps() throws Exception {
226 TestUtils.compressHeapDumps();
227 }
228
229 @ParameterizedTest(name =
230 "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
231 "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
232 "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeBuffer = {7}")
233 @MethodSource("data")
234 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
235 public void testSslEcho(
236 SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
237 boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
238 boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeBuffer,
239 TestInfo testInfo) throws Throwable {
240 this.serverCtx = serverCtx;
241 this.clientCtx = clientCtx;
242 this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
243 this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
244 this.renegotiation = renegotiation;
245 this.autoRead = autoRead;
246 this.useChunkedWriteHandler = useChunkedWriteHandler;
247 this.useCompositeBuffer = useCompositeBuffer;
248 run(testInfo, this::testSslEcho);
249 }
250
251 public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
252 final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
253 reset();
254
255 sb.childOption(ChannelOption.AUTO_READ, autoRead);
256 cb.option(ChannelOption.AUTO_READ, autoRead);
257
258 sb.childHandler(new ChannelInitializer<>() {
259 @Override
260 public void initChannel(Channel sch) {
261 serverChannel = sch;
262
263 if (serverUsesDelegatedTaskExecutor) {
264 SSLEngine sse = serverCtx.newEngine(sch.bufferAllocator());
265 serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
266 } else {
267 serverSslHandler = serverCtx.newHandler(sch.bufferAllocator());
268 }
269 serverSslHandler.setHandshakeTimeoutMillis(0);
270
271 sch.pipeline().addLast("ssl", serverSslHandler);
272 if (useChunkedWriteHandler) {
273 sch.pipeline().addLast(new ChunkedWriteHandler());
274 }
275 sch.pipeline().addLast("serverHandler", serverHandler);
276 }
277 });
278
279 final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
280 cb.handler(new ChannelInitializer<>() {
281 @Override
282 public void initChannel(Channel sch) {
283 clientChannel = sch;
284
285 if (clientUsesDelegatedTaskExecutor) {
286 SSLEngine cse = clientCtx.newEngine(sch.bufferAllocator());
287 clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
288 } else {
289 clientSslHandler = clientCtx.newHandler(sch.bufferAllocator());
290 }
291 clientSslHandler.setHandshakeTimeoutMillis(0);
292
293 sch.pipeline().addLast("ssl", clientSslHandler);
294 if (useChunkedWriteHandler) {
295 sch.pipeline().addLast(new ChunkedWriteHandler());
296 }
297 sch.pipeline().addLast("clientHandler", clientHandler);
298 sch.pipeline().addLast(new ChannelHandler() {
299 @Override
300 public void channelInboundEvent(ChannelHandlerContext ctx, Object evt) {
301 if (evt instanceof SslHandshakeCompletionEvent) {
302 clientHandshakeEventLatch.countDown();
303 }
304 ctx.fireChannelInboundEvent(evt);
305 }
306 });
307 }
308 });
309
310 final Channel sc = sb.bind().asStage().get();
311 cb.connect(sc.localAddress()).asStage().sync();
312
313 final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
314
315
316 clientHandshakeFuture.asStage().sync();
317 clientHandshakeEventLatch.await();
318 Buffer dataBuffer = bufferAllocator.copyOf(data);
319
320 clientChannel.writeAndFlush(dataBuffer.readSplit(FIRST_MESSAGE_SIZE));
321 clientSendCounter.set(FIRST_MESSAGE_SIZE);
322
323 boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
324 Future<Channel> renegoFuture = null;
325 while (clientSendCounter.get() < data.length) {
326 int clientSendCounterVal = clientSendCounter.get();
327 int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
328 Buffer buf = dataBuffer.readSplit(length);
329 if (useCompositeBuffer) {
330 buf = bufferAllocator.compose(buf.send());
331 }
332
333 Future<Void> future = clientChannel.writeAndFlush(buf);
334 clientSendCounter.set(clientSendCounterVal += length);
335 future.asStage().sync();
336
337 if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
338 needsRenegotiation = false;
339 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
340 renegoFuture = clientSslHandler.renegotiate();
341 logStats("CLIENT RENEGOTIATES");
342 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
343 }
344 }
345
346
347 while (clientRecvCounter.get() < data.length) {
348 if (serverException.get() != null) {
349 break;
350 }
351 if (clientException.get() != null) {
352 break;
353 }
354
355 try {
356 Thread.sleep(50);
357 } catch (InterruptedException e) {
358
359 }
360 }
361
362 while (serverRecvCounter.get() < data.length) {
363 if (serverException.get() != null) {
364 break;
365 }
366 if (clientException.get() != null) {
367 break;
368 }
369
370 try {
371 Thread.sleep(50);
372 } catch (InterruptedException e) {
373
374 }
375 }
376
377
378 if (renegoFuture != null) {
379 renegoFuture.asStage().sync();
380 }
381 if (serverHandler.renegoFuture != null) {
382 serverHandler.renegoFuture.asStage().sync();
383 }
384
385 serverChannel.close().asStage().await();
386 clientChannel.close().asStage().await();
387 sc.close().asStage().await();
388 delegatedTaskExecutor.shutdown();
389
390 if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
391 throw serverException.get();
392 }
393 if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
394 throw clientException.get();
395 }
396 if (serverException.get() != null) {
397 throw serverException.get();
398 }
399 if (clientException.get() != null) {
400 throw clientException.get();
401 }
402
403
404 try {
405 switch (renegotiation.type) {
406 case SERVER_INITIATED:
407 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
408 assertThat(serverNegoCounter.get(), is(2));
409 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
410 break;
411 case CLIENT_INITIATED:
412 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
413 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
414 assertThat(clientNegoCounter.get(), is(2));
415 break;
416 case NONE:
417 assertThat(serverNegoCounter.get(), is(1));
418 assertThat(clientNegoCounter.get(), is(1));
419 }
420 } finally {
421 logStats("STATS");
422 }
423 }
424
425 private void reset() {
426 clientException.set(null);
427 serverException.set(null);
428
429 clientSendCounter.set(0);
430 clientRecvCounter.set(0);
431 serverRecvCounter.set(0);
432
433 clientNegoCounter.set(0);
434 serverNegoCounter.set(0);
435
436 clientChannel = null;
437 serverChannel = null;
438
439 clientSslHandler = null;
440 serverSslHandler = null;
441 }
442
443 void logStats(String message) {
444 logger.debug(
445 "{}:\n" +
446 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
447 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
448 message,
449 clientSendCounter, clientRecvCounter, clientNegoCounter,
450 clientSslHandler.engine().getSession().getCipherSuite(),
451 serverRecvCounter, serverNegoCounter,
452 serverSslHandler.engine().getSession().getCipherSuite());
453 }
454
455 private abstract class EchoHandler extends SimpleChannelInboundHandler<Buffer> {
456
457 protected final AtomicInteger recvCounter;
458 protected final AtomicInteger negoCounter;
459 protected final AtomicReference<Throwable> exception;
460
461 EchoHandler(
462 AtomicInteger recvCounter, AtomicInteger negoCounter,
463 AtomicReference<Throwable> exception) {
464
465 this.recvCounter = recvCounter;
466 this.negoCounter = negoCounter;
467 this.exception = exception;
468 }
469
470 @Override
471 public boolean isSharable() {
472 return true;
473 }
474
475 @Override
476 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
477
478
479 if (!autoRead) {
480 ctx.read();
481 }
482 ctx.fireChannelReadComplete();
483 }
484
485 @Override
486 public final void channelInboundEvent(ChannelHandlerContext ctx, Object evt) {
487 if (evt instanceof SslHandshakeCompletionEvent) {
488 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
489 if (handshakeEvt.cause() != null) {
490 logger.warn("Handshake failed:", handshakeEvt.cause());
491 }
492 assertTrue(handshakeEvt.isSuccess());
493 negoCounter.incrementAndGet();
494 logStats("HANDSHAKEN");
495 }
496 ctx.fireChannelInboundEvent(evt);
497 }
498
499 @Override
500 public final void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
501 if (logger.isWarnEnabled()) {
502 logger.warn("Unexpected exception from the client side:", cause);
503 }
504
505 exception.compareAndSet(null, cause);
506 ctx.close();
507 }
508 }
509
510 private class EchoClientHandler extends EchoHandler {
511
512 EchoClientHandler(
513 AtomicInteger recvCounter, AtomicInteger negoCounter,
514 AtomicReference<Throwable> exception) {
515
516 super(recvCounter, negoCounter, exception);
517 }
518
519 @Override
520 public void handlerAdded(final ChannelHandlerContext ctx) {
521 if (!autoRead) {
522 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(ctx, (c, f) -> c.read());
523 }
524 }
525
526 @Override
527 public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
528 byte[] actual = new byte[in.readableBytes()];
529 in.readBytes(actual, 0, actual.length);
530
531 int lastIdx = recvCounter.get();
532 for (int i = 0; i < actual.length; i ++) {
533 assertEquals(data[i + lastIdx], actual[i]);
534 }
535
536 recvCounter.addAndGet(actual.length);
537 }
538 }
539
540 private class EchoServerHandler extends EchoHandler {
541 volatile Future<Channel> renegoFuture;
542
543 EchoServerHandler(
544 AtomicInteger recvCounter, AtomicInteger negoCounter,
545 AtomicReference<Throwable> exception) {
546
547 super(recvCounter, negoCounter, exception);
548 }
549
550 @Override
551 public final void channelRegistered(ChannelHandlerContext ctx) {
552 renegoFuture = null;
553 }
554
555 @Override
556 public void channelActive(final ChannelHandlerContext ctx) throws Exception {
557 if (!autoRead) {
558 ctx.read();
559 }
560 ctx.fireChannelActive();
561 }
562
563 @Override
564 public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
565 byte[] actual = new byte[in.readableBytes()];
566 in.readBytes(actual, 0, actual.length);
567
568 int lastIdx = recvCounter.get();
569 for (int i = 0; i < actual.length; i ++) {
570 assertEquals(data[i + lastIdx], actual[i]);
571 }
572
573 Buffer buf = bufferAllocator.copyOf(actual);
574 if (useCompositeBuffer) {
575 buf = bufferAllocator.compose(buf.send());
576 }
577 ctx.writeAndFlush(buf);
578
579 recvCounter.addAndGet(actual.length);
580
581
582 if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
583 recvCounter.get() > data.length / 2 && renegoFuture == null) {
584
585 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
586
587 Future<Channel> hf = sslHandler.handshakeFuture();
588 assertThat(hf.isDone(), is(true));
589
590 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
591 logStats("SERVER RENEGOTIATES");
592 renegoFuture = sslHandler.renegotiate();
593 assertThat(renegoFuture, is(not(sameInstance(hf))));
594 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
595 assertThat(renegoFuture.isDone(), is(false));
596 }
597 }
598 }
599 }