View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.buffer.api.Buffer;
21  import io.netty5.buffer.api.BufferAllocator;
22  import io.netty5.buffer.api.DefaultBufferAllocators;
23  import io.netty5.channel.Channel;
24  import io.netty5.channel.ChannelHandler;
25  import io.netty5.channel.ChannelHandlerContext;
26  import io.netty5.channel.ChannelInitializer;
27  import io.netty5.channel.ChannelOption;
28  import io.netty5.channel.SimpleChannelInboundHandler;
29  import io.netty5.handler.ssl.OpenSsl;
30  import io.netty5.handler.ssl.OpenSslContext;
31  import io.netty5.handler.ssl.SslContext;
32  import io.netty5.handler.ssl.SslContextBuilder;
33  import io.netty5.handler.ssl.SslHandler;
34  import io.netty5.handler.ssl.SslHandshakeCompletionEvent;
35  import io.netty5.handler.ssl.SslProvider;
36  import io.netty5.handler.ssl.util.SelfSignedCertificate;
37  import io.netty5.handler.stream.ChunkedWriteHandler;
38  import io.netty5.testsuite.util.TestUtils;
39  import io.netty5.util.concurrent.Future;
40  import io.netty5.util.internal.logging.InternalLogger;
41  import io.netty5.util.internal.logging.InternalLoggerFactory;
42  import org.junit.jupiter.api.AfterAll;
43  import org.junit.jupiter.api.TestInfo;
44  import org.junit.jupiter.api.Timeout;
45  import org.junit.jupiter.params.ParameterizedTest;
46  import org.junit.jupiter.params.provider.MethodSource;
47  
48  import javax.net.ssl.SSLEngine;
49  import java.io.File;
50  import java.io.IOException;
51  import java.security.cert.CertificateException;
52  import java.util.ArrayList;
53  import java.util.Collection;
54  import java.util.List;
55  import java.util.Random;
56  import java.util.concurrent.CountDownLatch;
57  import java.util.concurrent.ExecutorService;
58  import java.util.concurrent.Executors;
59  import java.util.concurrent.TimeUnit;
60  import java.util.concurrent.atomic.AtomicInteger;
61  import java.util.concurrent.atomic.AtomicReference;
62  
63  import static org.hamcrest.MatcherAssert.assertThat;
64  import static org.hamcrest.Matchers.anyOf;
65  import static org.hamcrest.Matchers.is;
66  import static org.hamcrest.Matchers.not;
67  import static org.hamcrest.Matchers.sameInstance;
68  import static org.junit.jupiter.api.Assertions.assertEquals;
69  import static org.junit.jupiter.api.Assertions.assertTrue;
70  
71  public class SocketSslEchoTest extends AbstractSocketTest {
72  
73      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
74  
75      private static final int FIRST_MESSAGE_SIZE = 16384;
76      private static final Random random = new Random();
77      private static final File CERT_FILE;
78      private static final File KEY_FILE;
79      static final byte[] data = new byte[1048576];
80  
81      static {
82          random.nextBytes(data);
83  
84          SelfSignedCertificate ssc;
85          try {
86              ssc = new SelfSignedCertificate();
87          } catch (CertificateException e) {
88              throw new Error(e);
89          }
90          CERT_FILE = ssc.certificate();
91          KEY_FILE = ssc.privateKey();
92      }
93  
94      private static final BufferAllocator bufferAllocator = DefaultBufferAllocators.preferredAllocator();
95  
96      protected enum RenegotiationType {
97          NONE, // no renegotiation
98          CLIENT_INITIATED, // renegotiation from client
99          SERVER_INITIATED, // renegotiation from server
100     }
101 
102     protected static class Renegotiation {
103         static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
104 
105         final RenegotiationType type;
106         final String cipherSuite;
107 
108         Renegotiation(RenegotiationType type, String cipherSuite) {
109             this.type = type;
110             this.cipherSuite = cipherSuite;
111         }
112 
113         @Override
114         public String toString() {
115             if (type == RenegotiationType.NONE) {
116                 return "NONE";
117             }
118 
119             return type + "(" + cipherSuite + ')';
120         }
121     }
122 
123     public static Collection<Object[]> data() throws Exception {
124         List<SslContext> serverContexts = new ArrayList<>();
125         serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
126                                             .sslProvider(SslProvider.JDK)
127                                             // As we test renegotiation we should use a protocol that support it.
128                                             .protocols("TLSv1.2")
129                                             .build());
130 
131         List<SslContext> clientContexts = new ArrayList<>();
132         clientContexts.add(SslContextBuilder.forClient()
133                                             .sslProvider(SslProvider.JDK)
134                                             .trustManager(CERT_FILE)
135                                             // As we test renegotiation we should use a protocol that support it.
136                                             .protocols("TLSv1.2")
137                                             .build());
138 
139         boolean hasOpenSsl = OpenSsl.isAvailable();
140         if (hasOpenSsl) {
141             serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
142                                                 .sslProvider(SslProvider.OPENSSL)
143                                                 // As we test renegotiation we should use a protocol that support it.
144                                                 .protocols("TLSv1.2")
145                                                 .build());
146             clientContexts.add(SslContextBuilder.forClient()
147                                                 .sslProvider(SslProvider.OPENSSL)
148                                                 .trustManager(CERT_FILE)
149                                                 // As we test renegotiation we should use a protocol that support it.
150                                                 .protocols("TLSv1.2")
151                                                 .build());
152         } else {
153             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
154         }
155 
156         List<Object[]> params = new ArrayList<>();
157         for (SslContext sc: serverContexts) {
158             for (SslContext cc: clientContexts) {
159                 for (RenegotiationType rt: RenegotiationType.values()) {
160                     if (rt != RenegotiationType.NONE &&
161                         (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
162                         // TODO: OpenSslEngine does not support renegotiation yet.
163                         continue;
164                     }
165 
166                     final Renegotiation r;
167                     switch (rt) {
168                         case NONE:
169                             r = Renegotiation.NONE;
170                             break;
171                         case SERVER_INITIATED:
172                             r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
173                             break;
174                         case CLIENT_INITIATED:
175                             r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
176                             break;
177                         default:
178                             throw new Error();
179                     }
180 
181                     for (int i = 0; i < 32; i++) {
182                         params.add(new Object[] {
183                                 sc, cc, r,
184                                 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
185                     }
186                 }
187             }
188         }
189 
190         return params;
191     }
192 
193     private final AtomicReference<Throwable> clientException = new AtomicReference<>();
194     private final AtomicReference<Throwable> serverException = new AtomicReference<>();
195 
196     private final AtomicInteger clientSendCounter = new AtomicInteger();
197     private final AtomicInteger clientRecvCounter = new AtomicInteger();
198     private final AtomicInteger serverRecvCounter = new AtomicInteger();
199 
200     private final AtomicInteger clientNegoCounter = new AtomicInteger();
201     private final AtomicInteger serverNegoCounter = new AtomicInteger();
202 
203     private volatile Channel clientChannel;
204     private volatile Channel serverChannel;
205 
206     private volatile SslHandler clientSslHandler;
207     private volatile SslHandler serverSslHandler;
208 
209     private final EchoClientHandler clientHandler =
210             new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
211 
212     private final EchoServerHandler serverHandler =
213             new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
214 
215     private SslContext serverCtx;
216     private SslContext clientCtx;
217     private Renegotiation renegotiation;
218     private boolean serverUsesDelegatedTaskExecutor;
219     private boolean clientUsesDelegatedTaskExecutor;
220     private boolean autoRead;
221     private boolean useChunkedWriteHandler;
222     private boolean useCompositeBuffer;
223 
224     @AfterAll
225     public static void compressHeapDumps() throws Exception {
226         TestUtils.compressHeapDumps();
227     }
228 
229     @ParameterizedTest(name =
230             "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
231             "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
232             "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeBuffer = {7}")
233     @MethodSource("data")
234     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
235     public void testSslEcho(
236             SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
237             boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
238             boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeBuffer,
239             TestInfo testInfo) throws Throwable {
240         this.serverCtx = serverCtx;
241         this.clientCtx = clientCtx;
242         this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
243         this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
244         this.renegotiation = renegotiation;
245         this.autoRead = autoRead;
246         this.useChunkedWriteHandler = useChunkedWriteHandler;
247         this.useCompositeBuffer = useCompositeBuffer;
248         run(testInfo, this::testSslEcho);
249     }
250 
251     public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
252         final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
253         reset();
254 
255         sb.childOption(ChannelOption.AUTO_READ, autoRead);
256         cb.option(ChannelOption.AUTO_READ, autoRead);
257 
258         sb.childHandler(new ChannelInitializer<>() {
259             @Override
260             public void initChannel(Channel sch) {
261                 serverChannel = sch;
262 
263                 if (serverUsesDelegatedTaskExecutor) {
264                     SSLEngine sse = serverCtx.newEngine(sch.bufferAllocator());
265                     serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
266                 } else {
267                     serverSslHandler = serverCtx.newHandler(sch.bufferAllocator());
268                 }
269                 serverSslHandler.setHandshakeTimeoutMillis(0);
270 
271                 sch.pipeline().addLast("ssl", serverSslHandler);
272                 if (useChunkedWriteHandler) {
273                     sch.pipeline().addLast(new ChunkedWriteHandler());
274                 }
275                 sch.pipeline().addLast("serverHandler", serverHandler);
276             }
277         });
278 
279         final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
280         cb.handler(new ChannelInitializer<>() {
281             @Override
282             public void initChannel(Channel sch) {
283                 clientChannel = sch;
284 
285                 if (clientUsesDelegatedTaskExecutor) {
286                     SSLEngine cse = clientCtx.newEngine(sch.bufferAllocator());
287                     clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
288                 } else {
289                     clientSslHandler = clientCtx.newHandler(sch.bufferAllocator());
290                 }
291                 clientSslHandler.setHandshakeTimeoutMillis(0);
292 
293                 sch.pipeline().addLast("ssl", clientSslHandler);
294                 if (useChunkedWriteHandler) {
295                     sch.pipeline().addLast(new ChunkedWriteHandler());
296                 }
297                 sch.pipeline().addLast("clientHandler", clientHandler);
298                 sch.pipeline().addLast(new ChannelHandler() {
299                     @Override
300                     public void channelInboundEvent(ChannelHandlerContext ctx, Object evt) {
301                         if (evt instanceof SslHandshakeCompletionEvent) {
302                             clientHandshakeEventLatch.countDown();
303                         }
304                         ctx.fireChannelInboundEvent(evt);
305                     }
306                 });
307             }
308         });
309 
310         final Channel sc = sb.bind().asStage().get();
311         cb.connect(sc.localAddress()).asStage().sync();
312 
313         final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
314 
315         // Wait for the handshake to complete before we flush anything. SslHandler should flush non-application data.
316         clientHandshakeFuture.asStage().sync();
317         clientHandshakeEventLatch.await();
318         Buffer dataBuffer = bufferAllocator.copyOf(data);
319 
320         clientChannel.writeAndFlush(dataBuffer.readSplit(FIRST_MESSAGE_SIZE));
321         clientSendCounter.set(FIRST_MESSAGE_SIZE);
322 
323         boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
324         Future<Channel> renegoFuture = null;
325         while (clientSendCounter.get() < data.length) {
326             int clientSendCounterVal = clientSendCounter.get();
327             int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
328             Buffer buf = dataBuffer.readSplit(length);
329             if (useCompositeBuffer) {
330                 buf = bufferAllocator.compose(buf.send());
331             }
332 
333             Future<Void> future = clientChannel.writeAndFlush(buf);
334             clientSendCounter.set(clientSendCounterVal += length);
335             future.asStage().sync();
336 
337             if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
338                 needsRenegotiation = false;
339                 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
340                 renegoFuture = clientSslHandler.renegotiate();
341                 logStats("CLIENT RENEGOTIATES");
342                 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
343             }
344         }
345 
346         // Ensure all data has been exchanged.
347         while (clientRecvCounter.get() < data.length) {
348             if (serverException.get() != null) {
349                 break;
350             }
351             if (clientException.get() != null) {
352                 break;
353             }
354 
355             try {
356                 Thread.sleep(50);
357             } catch (InterruptedException e) {
358                 // Ignore.
359             }
360         }
361 
362         while (serverRecvCounter.get() < data.length) {
363             if (serverException.get() != null) {
364                 break;
365             }
366             if (clientException.get() != null) {
367                 break;
368             }
369 
370             try {
371                 Thread.sleep(50);
372             } catch (InterruptedException e) {
373                 // Ignore.
374             }
375         }
376 
377         // Wait until renegotiation is done.
378         if (renegoFuture != null) {
379             renegoFuture.asStage().sync();
380         }
381         if (serverHandler.renegoFuture != null) {
382             serverHandler.renegoFuture.asStage().sync();
383         }
384 
385         serverChannel.close().asStage().await();
386         clientChannel.close().asStage().await();
387         sc.close().asStage().await();
388         delegatedTaskExecutor.shutdown();
389 
390         if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
391             throw serverException.get();
392         }
393         if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
394             throw clientException.get();
395         }
396         if (serverException.get() != null) {
397             throw serverException.get();
398         }
399         if (clientException.get() != null) {
400             throw clientException.get();
401         }
402 
403         // When renegotiation is done, at least the initiating side should be notified.
404         try {
405             switch (renegotiation.type) {
406             case SERVER_INITIATED:
407                 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
408                 assertThat(serverNegoCounter.get(), is(2));
409                 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
410                 break;
411             case CLIENT_INITIATED:
412                 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
413                 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
414                 assertThat(clientNegoCounter.get(), is(2));
415                 break;
416             case NONE:
417                 assertThat(serverNegoCounter.get(), is(1));
418                 assertThat(clientNegoCounter.get(), is(1));
419             }
420         } finally {
421             logStats("STATS");
422         }
423     }
424 
425     private void reset() {
426         clientException.set(null);
427         serverException.set(null);
428 
429         clientSendCounter.set(0);
430         clientRecvCounter.set(0);
431         serverRecvCounter.set(0);
432 
433         clientNegoCounter.set(0);
434         serverNegoCounter.set(0);
435 
436         clientChannel = null;
437         serverChannel = null;
438 
439         clientSslHandler = null;
440         serverSslHandler = null;
441     }
442 
443     void logStats(String message) {
444         logger.debug(
445                 "{}:\n" +
446                 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
447                 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
448                 message,
449                 clientSendCounter, clientRecvCounter, clientNegoCounter,
450                 clientSslHandler.engine().getSession().getCipherSuite(),
451                 serverRecvCounter, serverNegoCounter,
452                 serverSslHandler.engine().getSession().getCipherSuite());
453     }
454 
455     private abstract class EchoHandler extends SimpleChannelInboundHandler<Buffer> {
456 
457         protected final AtomicInteger recvCounter;
458         protected final AtomicInteger negoCounter;
459         protected final AtomicReference<Throwable> exception;
460 
461         EchoHandler(
462                 AtomicInteger recvCounter, AtomicInteger negoCounter,
463                 AtomicReference<Throwable> exception) {
464 
465             this.recvCounter = recvCounter;
466             this.negoCounter = negoCounter;
467             this.exception = exception;
468         }
469 
470         @Override
471         public boolean isSharable() {
472             return true;
473         }
474 
475         @Override
476         public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
477             // We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
478             // non-application and previously flushed writes internally.
479             if (!autoRead) {
480                 ctx.read();
481             }
482             ctx.fireChannelReadComplete();
483         }
484 
485         @Override
486         public final void channelInboundEvent(ChannelHandlerContext ctx, Object evt) {
487             if (evt instanceof SslHandshakeCompletionEvent) {
488                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
489                 if (handshakeEvt.cause() != null) {
490                     logger.warn("Handshake failed:", handshakeEvt.cause());
491                 }
492                 assertTrue(handshakeEvt.isSuccess());
493                 negoCounter.incrementAndGet();
494                 logStats("HANDSHAKEN");
495             }
496             ctx.fireChannelInboundEvent(evt);
497         }
498 
499         @Override
500         public final void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
501             if (logger.isWarnEnabled()) {
502                 logger.warn("Unexpected exception from the client side:", cause);
503             }
504 
505             exception.compareAndSet(null, cause);
506             ctx.close();
507         }
508     }
509 
510     private class EchoClientHandler extends EchoHandler {
511 
512         EchoClientHandler(
513                 AtomicInteger recvCounter, AtomicInteger negoCounter,
514                 AtomicReference<Throwable> exception) {
515 
516             super(recvCounter, negoCounter, exception);
517         }
518 
519         @Override
520         public void handlerAdded(final ChannelHandlerContext ctx) {
521             if (!autoRead) {
522                 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(ctx, (c, f) -> c.read());
523             }
524         }
525 
526         @Override
527         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
528             byte[] actual = new byte[in.readableBytes()];
529             in.readBytes(actual, 0, actual.length);
530 
531             int lastIdx = recvCounter.get();
532             for (int i = 0; i < actual.length; i ++) {
533                 assertEquals(data[i + lastIdx], actual[i]);
534             }
535 
536             recvCounter.addAndGet(actual.length);
537         }
538     }
539 
540     private class EchoServerHandler extends EchoHandler {
541         volatile Future<Channel> renegoFuture;
542 
543         EchoServerHandler(
544                 AtomicInteger recvCounter, AtomicInteger negoCounter,
545                 AtomicReference<Throwable> exception) {
546 
547             super(recvCounter, negoCounter, exception);
548         }
549 
550         @Override
551         public final void channelRegistered(ChannelHandlerContext ctx) {
552             renegoFuture = null;
553         }
554 
555         @Override
556         public void channelActive(final ChannelHandlerContext ctx) throws Exception {
557             if (!autoRead) {
558                 ctx.read();
559             }
560             ctx.fireChannelActive();
561         }
562 
563         @Override
564         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
565             byte[] actual = new byte[in.readableBytes()];
566             in.readBytes(actual, 0, actual.length);
567 
568             int lastIdx = recvCounter.get();
569             for (int i = 0; i < actual.length; i ++) {
570                 assertEquals(data[i + lastIdx], actual[i]);
571             }
572 
573             Buffer buf = bufferAllocator.copyOf(actual);
574             if (useCompositeBuffer) {
575                 buf = bufferAllocator.compose(buf.send());
576             }
577             ctx.writeAndFlush(buf);
578 
579             recvCounter.addAndGet(actual.length);
580 
581             // Perform server-initiated renegotiation if necessary.
582             if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
583                 recvCounter.get() > data.length / 2 && renegoFuture == null) {
584 
585                 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
586 
587                 Future<Channel> hf = sslHandler.handshakeFuture();
588                 assertThat(hf.isDone(), is(true));
589 
590                 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
591                 logStats("SERVER RENEGOTIATES");
592                 renegoFuture = sslHandler.renegotiate();
593                 assertThat(renegoFuture, is(not(sameInstance(hf))));
594                 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
595                 assertThat(renegoFuture.isDone(), is(false));
596             }
597         }
598     }
599 }