View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.Unpooled;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelHandler.Sharable;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelInboundHandlerAdapter;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.ChannelOption;
29  import io.netty.channel.SimpleChannelInboundHandler;
30  import io.netty.handler.ssl.OpenSsl;
31  import io.netty.handler.ssl.OpenSslContext;
32  import io.netty.handler.ssl.SslContext;
33  import io.netty.handler.ssl.SslContextBuilder;
34  import io.netty.handler.ssl.SslHandler;
35  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.handler.ssl.util.SelfSignedCertificate;
38  import io.netty.handler.stream.ChunkedWriteHandler;
39  import io.netty.testsuite.util.TestUtils;
40  import io.netty.util.concurrent.Future;
41  import io.netty.util.concurrent.GenericFutureListener;
42  import io.netty.util.internal.logging.InternalLogger;
43  import io.netty.util.internal.logging.InternalLoggerFactory;
44  import org.junit.jupiter.api.AfterAll;
45  import org.junit.jupiter.api.TestInfo;
46  import org.junit.jupiter.api.Timeout;
47  import org.junit.jupiter.params.ParameterizedTest;
48  import org.junit.jupiter.params.provider.MethodSource;
49  
50  import javax.net.ssl.SSLEngine;
51  import java.io.File;
52  import java.io.IOException;
53  import java.security.cert.CertificateException;
54  import java.util.ArrayList;
55  import java.util.Collection;
56  import java.util.List;
57  import java.util.Random;
58  import java.util.concurrent.CountDownLatch;
59  import java.util.concurrent.ExecutorService;
60  import java.util.concurrent.Executors;
61  import java.util.concurrent.TimeUnit;
62  import java.util.concurrent.atomic.AtomicInteger;
63  import java.util.concurrent.atomic.AtomicReference;
64  
65  import static org.assertj.core.api.Assertions.assertThat;
66  import static org.junit.jupiter.api.Assertions.assertEquals;
67  import static org.junit.jupiter.api.Assertions.assertFalse;
68  import static org.junit.jupiter.api.Assertions.assertNotSame;
69  import static org.junit.jupiter.api.Assertions.assertSame;
70  import static org.junit.jupiter.api.Assertions.assertTrue;
71  
72  public class SocketSslEchoTest extends AbstractSocketTest {
73  
74      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
75  
76      private static final int FIRST_MESSAGE_SIZE = 16384;
77      private static final Random random = new Random();
78      private static final File CERT_FILE;
79      private static final File KEY_FILE;
80      static final byte[] data = new byte[1048576];
81  
82      static {
83          random.nextBytes(data);
84  
85          SelfSignedCertificate ssc;
86          try {
87              ssc = new SelfSignedCertificate();
88          } catch (CertificateException e) {
89              throw new Error(e);
90          }
91          CERT_FILE = ssc.certificate();
92          KEY_FILE = ssc.privateKey();
93      }
94  
95      protected enum RenegotiationType {
96          NONE, // no renegotiation
97          CLIENT_INITIATED, // renegotiation from client
98          SERVER_INITIATED, // renegotiation from server
99      }
100 
101     protected static class Renegotiation {
102         static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
103 
104         final RenegotiationType type;
105         final String cipherSuite;
106 
107         Renegotiation(RenegotiationType type, String cipherSuite) {
108             this.type = type;
109             this.cipherSuite = cipherSuite;
110         }
111 
112         @Override
113         public String toString() {
114             if (type == RenegotiationType.NONE) {
115                 return "NONE";
116             }
117 
118             return type + "(" + cipherSuite + ')';
119         }
120     }
121 
122     public static Collection<Object[]> data() throws Exception {
123         List<SslContext> serverContexts = new ArrayList<SslContext>();
124         serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
125                                             .sslProvider(SslProvider.JDK)
126                                             // As we test renegotiation we should use a protocol that support it.
127                                             .protocols("TLSv1.2")
128                                             .build());
129 
130         List<SslContext> clientContexts = new ArrayList<SslContext>();
131         clientContexts.add(SslContextBuilder.forClient()
132                                             .sslProvider(SslProvider.JDK)
133                                             .trustManager(CERT_FILE)
134                                             // As we test renegotiation we should use a protocol that support it.
135                                             .protocols("TLSv1.2")
136                                             .build());
137 
138         boolean hasOpenSsl = OpenSsl.isAvailable();
139         if (hasOpenSsl) {
140             serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
141                                                 .sslProvider(SslProvider.OPENSSL)
142                                                 // As we test renegotiation we should use a protocol that support it.
143                                                 .protocols("TLSv1.2")
144                                                 .build());
145             clientContexts.add(SslContextBuilder.forClient()
146                                                 .sslProvider(SslProvider.OPENSSL)
147                                                 .trustManager(CERT_FILE)
148                                                 // As we test renegotiation we should use a protocol that support it.
149                                                 .protocols("TLSv1.2")
150                                                 .build());
151         } else {
152             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
153         }
154 
155         List<Object[]> params = new ArrayList<Object[]>();
156         for (SslContext sc: serverContexts) {
157             for (SslContext cc: clientContexts) {
158                 for (RenegotiationType rt: RenegotiationType.values()) {
159                     if (rt != RenegotiationType.NONE &&
160                         (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
161                         // TODO: OpenSslEngine does not support renegotiation yet.
162                         continue;
163                     }
164 
165                     final Renegotiation r;
166                     switch (rt) {
167                         case NONE:
168                             r = Renegotiation.NONE;
169                             break;
170                         case SERVER_INITIATED:
171                             r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
172                             break;
173                         case CLIENT_INITIATED:
174                             r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
175                             break;
176                         default:
177                             throw new Error();
178                     }
179 
180                     for (int i = 0; i < 32; i++) {
181                         params.add(new Object[] {
182                                 sc, cc, r,
183                                 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
184                     }
185                 }
186             }
187         }
188 
189         return params;
190     }
191 
192     private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
193     private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
194     private final AtomicInteger clientSendCounter = new AtomicInteger();
195     private final AtomicInteger clientRecvCounter = new AtomicInteger();
196     private final AtomicInteger serverRecvCounter = new AtomicInteger();
197 
198     private final AtomicInteger clientNegoCounter = new AtomicInteger();
199     private final AtomicInteger serverNegoCounter = new AtomicInteger();
200 
201     private volatile Channel clientChannel;
202     private volatile Channel serverChannel;
203 
204     private volatile SslHandler clientSslHandler;
205     private volatile SslHandler serverSslHandler;
206 
207     private final EchoClientHandler clientHandler =
208             new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
209 
210     private final EchoServerHandler serverHandler =
211             new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
212 
213     private SslContext serverCtx;
214     private SslContext clientCtx;
215     private Renegotiation renegotiation;
216     private boolean serverUsesDelegatedTaskExecutor;
217     private boolean clientUsesDelegatedTaskExecutor;
218     private boolean autoRead;
219     private boolean useChunkedWriteHandler;
220     private boolean useCompositeByteBuf;
221 
222     @AfterAll
223     public static void compressHeapDumps() throws Exception {
224         TestUtils.compressHeapDumps();
225     }
226 
227     @ParameterizedTest(name =
228             "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
229             "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
230             "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
231     @MethodSource("data")
232     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
233     public void testSslEcho(
234             SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
235             boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
236             boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
237             TestInfo testInfo) throws Throwable {
238         this.serverCtx = serverCtx;
239         this.clientCtx = clientCtx;
240         this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
241         this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
242         this.renegotiation = renegotiation;
243         this.autoRead = autoRead;
244         this.useChunkedWriteHandler = useChunkedWriteHandler;
245         this.useCompositeByteBuf = useCompositeByteBuf;
246         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
247             @Override
248             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
249                 testSslEcho(serverBootstrap, bootstrap);
250             }
251         });
252     }
253 
254     public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
255         final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
256         reset();
257 
258         sb.childOption(ChannelOption.AUTO_READ, autoRead);
259         cb.option(ChannelOption.AUTO_READ, autoRead);
260 
261         sb.childHandler(new ChannelInitializer<Channel>() {
262             @Override
263             public void initChannel(Channel sch) {
264                 serverChannel = sch;
265 
266                 if (serverUsesDelegatedTaskExecutor) {
267                     SSLEngine sse = serverCtx.newEngine(sch.alloc());
268                     serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
269                 } else {
270                     serverSslHandler = serverCtx.newHandler(sch.alloc());
271                 }
272                 serverSslHandler.setHandshakeTimeoutMillis(0);
273 
274                 sch.pipeline().addLast("ssl", serverSslHandler);
275                 if (useChunkedWriteHandler) {
276                     sch.pipeline().addLast(new ChunkedWriteHandler());
277                 }
278                 sch.pipeline().addLast("serverHandler", serverHandler);
279             }
280         });
281 
282         final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
283         cb.handler(new ChannelInitializer<Channel>() {
284             @Override
285             public void initChannel(Channel sch) {
286                 clientChannel = sch;
287 
288                 if (clientUsesDelegatedTaskExecutor) {
289                     SSLEngine cse = clientCtx.newEngine(sch.alloc());
290                     clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
291                 } else {
292                     clientSslHandler = clientCtx.newHandler(sch.alloc());
293                 }
294                 clientSslHandler.setHandshakeTimeoutMillis(0);
295 
296                 sch.pipeline().addLast("ssl", clientSslHandler);
297                 if (useChunkedWriteHandler) {
298                     sch.pipeline().addLast(new ChunkedWriteHandler());
299                 }
300                 sch.pipeline().addLast("clientHandler", clientHandler);
301                 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
302                     @Override
303                     public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
304                         if (evt instanceof SslHandshakeCompletionEvent) {
305                             clientHandshakeEventLatch.countDown();
306                         }
307                         ctx.fireUserEventTriggered(evt);
308                     }
309                 });
310             }
311         });
312 
313         final Channel sc = sb.bind().sync().channel();
314         cb.connect(sc.localAddress()).sync();
315 
316         final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
317 
318         // Wait for the handshake to complete before we flush anything. SslHandler should flush non-application data.
319         clientHandshakeFuture.sync();
320         clientHandshakeEventLatch.await();
321 
322         clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
323         clientSendCounter.set(FIRST_MESSAGE_SIZE);
324 
325         boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
326         Future<Channel> renegoFuture = null;
327         while (clientSendCounter.get() < data.length) {
328             int clientSendCounterVal = clientSendCounter.get();
329             int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
330             ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
331             if (useCompositeByteBuf) {
332                 buf = Unpooled.compositeBuffer().addComponent(true, buf);
333             }
334 
335             ChannelFuture future = clientChannel.writeAndFlush(buf);
336             clientSendCounter.set(clientSendCounterVal += length);
337             future.sync();
338 
339             if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
340                 needsRenegotiation = false;
341                 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
342                 renegoFuture = clientSslHandler.renegotiate();
343                 logStats("CLIENT RENEGOTIATES");
344                 assertNotSame(renegoFuture, clientHandshakeFuture);
345             }
346         }
347 
348         // Ensure all data has been exchanged.
349         while (clientRecvCounter.get() < data.length) {
350             if (serverException.get() != null) {
351                 break;
352             }
353             if (clientException.get() != null) {
354                 break;
355             }
356 
357             Thread.sleep(50);
358         }
359 
360         while (serverRecvCounter.get() < data.length) {
361             if (serverException.get() != null) {
362                 break;
363             }
364             if (clientException.get() != null) {
365                 break;
366             }
367 
368             Thread.sleep(50);
369         }
370 
371         // Wait until renegotiation is done.
372         if (renegoFuture != null) {
373             renegoFuture.sync();
374         }
375         if (serverHandler.renegoFuture != null) {
376             serverHandler.renegoFuture.sync();
377         }
378 
379         serverChannel.close().awaitUninterruptibly();
380         clientChannel.close().awaitUninterruptibly();
381         sc.close().awaitUninterruptibly();
382         delegatedTaskExecutor.shutdown();
383 
384         if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
385             throw serverException.get();
386         }
387         if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
388             throw clientException.get();
389         }
390         if (serverException.get() != null) {
391             throw serverException.get();
392         }
393         if (clientException.get() != null) {
394             throw clientException.get();
395         }
396 
397         // When renegotiation is done, at least the initiating side should be notified.
398         try {
399             switch (renegotiation.type) {
400             case SERVER_INITIATED:
401                 assertEquals(renegotiation.cipherSuite, serverSslHandler.engine().getSession().getCipherSuite());
402                 assertEquals(2, serverNegoCounter.get());
403                 assertThat(clientNegoCounter.get()).isIn(1, 2);
404                 break;
405             case CLIENT_INITIATED:
406                 assertThat(serverNegoCounter.get()).isIn(1, 2);
407                 assertEquals(renegotiation.cipherSuite, clientSslHandler.engine().getSession().getCipherSuite());
408                 assertEquals(2, clientNegoCounter.get());
409                 break;
410             case NONE:
411                 assertEquals(1, serverNegoCounter.get());
412                 assertEquals(1, clientNegoCounter.get());
413             }
414         } finally {
415             logStats("STATS");
416         }
417     }
418 
419     private void reset() {
420         clientException.set(null);
421         serverException.set(null);
422 
423         clientSendCounter.set(0);
424         clientRecvCounter.set(0);
425         serverRecvCounter.set(0);
426 
427         clientNegoCounter.set(0);
428         serverNegoCounter.set(0);
429 
430         clientChannel = null;
431         serverChannel = null;
432 
433         clientSslHandler = null;
434         serverSslHandler = null;
435     }
436 
437     void logStats(String message) {
438         logger.debug(
439                 "{}:\n" +
440                 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
441                 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
442                 message,
443                 clientSendCounter, clientRecvCounter, clientNegoCounter,
444                 clientSslHandler.engine().getSession().getCipherSuite(),
445                 serverRecvCounter, serverNegoCounter,
446                 serverSslHandler.engine().getSession().getCipherSuite());
447     }
448 
449     @Sharable
450     private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
451 
452         protected final AtomicInteger recvCounter;
453         protected final AtomicInteger negoCounter;
454         protected final AtomicReference<Throwable> exception;
455 
456         EchoHandler(
457                 AtomicInteger recvCounter, AtomicInteger negoCounter,
458                 AtomicReference<Throwable> exception) {
459 
460             this.recvCounter = recvCounter;
461             this.negoCounter = negoCounter;
462             this.exception = exception;
463         }
464 
465         @Override
466         public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
467             // We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
468             // non-application and previously flushed writes internally.
469             if (!autoRead) {
470                 ctx.read();
471             }
472             ctx.fireChannelReadComplete();
473         }
474 
475         @Override
476         public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
477             if (evt instanceof SslHandshakeCompletionEvent) {
478                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
479                 if (handshakeEvt.cause() != null) {
480                     logger.warn("Handshake failed:", handshakeEvt.cause());
481                 }
482                 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
483                 negoCounter.incrementAndGet();
484                 logStats("HANDSHAKEN");
485             }
486             ctx.fireUserEventTriggered(evt);
487         }
488 
489         @Override
490         public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
491             if (logger.isWarnEnabled()) {
492                 logger.warn("Unexpected exception from the client side:", cause);
493             }
494 
495             exception.compareAndSet(null, cause);
496             ctx.close();
497         }
498     }
499 
500     private class EchoClientHandler extends EchoHandler {
501 
502         EchoClientHandler(
503                 AtomicInteger recvCounter, AtomicInteger negoCounter,
504                 AtomicReference<Throwable> exception) {
505 
506             super(recvCounter, negoCounter, exception);
507         }
508 
509         @Override
510         public void handlerAdded(final ChannelHandlerContext ctx) {
511             if (!autoRead) {
512                 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
513                         new GenericFutureListener<Future<? super Channel>>() {
514                             @Override
515                             public void operationComplete(Future<? super Channel> future) {
516                                 ctx.read();
517                             }
518                         });
519             }
520         }
521 
522         @Override
523         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
524             byte[] actual = new byte[in.readableBytes()];
525             in.readBytes(actual);
526 
527             int lastIdx = recvCounter.get();
528             for (int i = 0; i < actual.length; i ++) {
529                 assertEquals(data[i + lastIdx], actual[i]);
530             }
531 
532             recvCounter.addAndGet(actual.length);
533         }
534     }
535 
536     private class EchoServerHandler extends EchoHandler {
537         volatile Future<Channel> renegoFuture;
538 
539         EchoServerHandler(
540                 AtomicInteger recvCounter, AtomicInteger negoCounter,
541                 AtomicReference<Throwable> exception) {
542 
543             super(recvCounter, negoCounter, exception);
544         }
545 
546         @Override
547         public final void channelRegistered(ChannelHandlerContext ctx) {
548             renegoFuture = null;
549         }
550 
551         @Override
552         public void channelActive(final ChannelHandlerContext ctx) throws Exception {
553             if (!autoRead) {
554                 ctx.read();
555             }
556             ctx.fireChannelActive();
557         }
558 
559         @Override
560         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
561             byte[] actual = new byte[in.readableBytes()];
562             in.readBytes(actual);
563 
564             int lastIdx = recvCounter.get();
565             for (int i = 0; i < actual.length; i ++) {
566                 assertEquals(data[i + lastIdx], actual[i]);
567             }
568 
569             ByteBuf buf = Unpooled.wrappedBuffer(actual);
570             if (useCompositeByteBuf) {
571                 buf = Unpooled.compositeBuffer().addComponent(true, buf);
572             }
573             ctx.writeAndFlush(buf);
574 
575             recvCounter.addAndGet(actual.length);
576 
577             // Perform server-initiated renegotiation if necessary.
578             if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
579                 recvCounter.get() > data.length / 2 && renegoFuture == null) {
580 
581                 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
582 
583                 Future<Channel> hf = sslHandler.handshakeFuture();
584                 assertTrue(hf.isDone());
585 
586                 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
587                 logStats("SERVER RENEGOTIATES");
588                 renegoFuture = sslHandler.renegotiate();
589                 assertNotSame(renegoFuture, hf);
590                 assertSame(renegoFuture, sslHandler.handshakeFuture());
591                 assertFalse(renegoFuture.isDone());
592             }
593         }
594     }
595 }