View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.Unpooled;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelHandler.Sharable;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelInboundHandlerAdapter;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.ChannelOption;
29  import io.netty.channel.SimpleChannelInboundHandler;
30  import io.netty.handler.ssl.OpenSsl;
31  import io.netty.handler.ssl.OpenSslContext;
32  import io.netty.handler.ssl.SslContext;
33  import io.netty.handler.ssl.SslContextBuilder;
34  import io.netty.handler.ssl.SslHandler;
35  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.handler.ssl.util.SelfSignedCertificate;
38  import io.netty.handler.stream.ChunkedWriteHandler;
39  import io.netty.testsuite.util.TestUtils;
40  import io.netty.util.concurrent.Future;
41  import io.netty.util.concurrent.GenericFutureListener;
42  import io.netty.util.internal.logging.InternalLogger;
43  import io.netty.util.internal.logging.InternalLoggerFactory;
44  import org.junit.jupiter.api.AfterAll;
45  import org.junit.jupiter.api.TestInfo;
46  import org.junit.jupiter.api.Timeout;
47  import org.junit.jupiter.params.ParameterizedTest;
48  import org.junit.jupiter.params.provider.MethodSource;
49  
50  import javax.net.ssl.SSLEngine;
51  import java.io.File;
52  import java.io.IOException;
53  import java.security.cert.CertificateException;
54  import java.util.ArrayList;
55  import java.util.Collection;
56  import java.util.List;
57  import java.util.Random;
58  import java.util.concurrent.CountDownLatch;
59  import java.util.concurrent.ExecutorService;
60  import java.util.concurrent.Executors;
61  import java.util.concurrent.TimeUnit;
62  import java.util.concurrent.atomic.AtomicInteger;
63  import java.util.concurrent.atomic.AtomicReference;
64  
65  import static org.hamcrest.MatcherAssert.assertThat;
66  import static org.hamcrest.Matchers.anyOf;
67  import static org.hamcrest.Matchers.is;
68  import static org.hamcrest.Matchers.not;
69  import static org.hamcrest.Matchers.sameInstance;
70  import static org.junit.jupiter.api.Assertions.assertEquals;
71  import static org.junit.jupiter.api.Assertions.assertSame;
72  
73  public class SocketSslEchoTest extends AbstractSocketTest {
74  
75      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
76  
77      private static final int FIRST_MESSAGE_SIZE = 16384;
78      private static final Random random = new Random();
79      private static final File CERT_FILE;
80      private static final File KEY_FILE;
81      static final byte[] data = new byte[1048576];
82  
83      static {
84          random.nextBytes(data);
85  
86          SelfSignedCertificate ssc;
87          try {
88              ssc = new SelfSignedCertificate();
89          } catch (CertificateException e) {
90              throw new Error(e);
91          }
92          CERT_FILE = ssc.certificate();
93          KEY_FILE = ssc.privateKey();
94      }
95  
96      protected enum RenegotiationType {
97          NONE, // no renegotiation
98          CLIENT_INITIATED, // renegotiation from client
99          SERVER_INITIATED, // renegotiation from server
100     }
101 
102     protected static class Renegotiation {
103         static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
104 
105         final RenegotiationType type;
106         final String cipherSuite;
107 
108         Renegotiation(RenegotiationType type, String cipherSuite) {
109             this.type = type;
110             this.cipherSuite = cipherSuite;
111         }
112 
113         @Override
114         public String toString() {
115             if (type == RenegotiationType.NONE) {
116                 return "NONE";
117             }
118 
119             return type + "(" + cipherSuite + ')';
120         }
121     }
122 
123     public static Collection<Object[]> data() throws Exception {
124         List<SslContext> serverContexts = new ArrayList<SslContext>();
125         serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
126                                             .sslProvider(SslProvider.JDK)
127                                             // As we test renegotiation we should use a protocol that support it.
128                                             .protocols("TLSv1.2")
129                                             .build());
130 
131         List<SslContext> clientContexts = new ArrayList<SslContext>();
132         clientContexts.add(SslContextBuilder.forClient()
133                                             .sslProvider(SslProvider.JDK)
134                                             .trustManager(CERT_FILE)
135                                             // As we test renegotiation we should use a protocol that support it.
136                                             .protocols("TLSv1.2")
137                                             .build());
138 
139         boolean hasOpenSsl = OpenSsl.isAvailable();
140         if (hasOpenSsl) {
141             serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
142                                                 .sslProvider(SslProvider.OPENSSL)
143                                                 // As we test renegotiation we should use a protocol that support it.
144                                                 .protocols("TLSv1.2")
145                                                 .build());
146             clientContexts.add(SslContextBuilder.forClient()
147                                                 .sslProvider(SslProvider.OPENSSL)
148                                                 .trustManager(CERT_FILE)
149                                                 // As we test renegotiation we should use a protocol that support it.
150                                                 .protocols("TLSv1.2")
151                                                 .build());
152         } else {
153             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
154         }
155 
156         List<Object[]> params = new ArrayList<Object[]>();
157         for (SslContext sc: serverContexts) {
158             for (SslContext cc: clientContexts) {
159                 for (RenegotiationType rt: RenegotiationType.values()) {
160                     if (rt != RenegotiationType.NONE &&
161                         (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
162                         // TODO: OpenSslEngine does not support renegotiation yet.
163                         continue;
164                     }
165 
166                     final Renegotiation r;
167                     switch (rt) {
168                         case NONE:
169                             r = Renegotiation.NONE;
170                             break;
171                         case SERVER_INITIATED:
172                             r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
173                             break;
174                         case CLIENT_INITIATED:
175                             r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
176                             break;
177                         default:
178                             throw new Error();
179                     }
180 
181                     for (int i = 0; i < 32; i++) {
182                         params.add(new Object[] {
183                                 sc, cc, r,
184                                 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
185                     }
186                 }
187             }
188         }
189 
190         return params;
191     }
192 
193     private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
194     private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
195     private final AtomicInteger clientSendCounter = new AtomicInteger();
196     private final AtomicInteger clientRecvCounter = new AtomicInteger();
197     private final AtomicInteger serverRecvCounter = new AtomicInteger();
198 
199     private final AtomicInteger clientNegoCounter = new AtomicInteger();
200     private final AtomicInteger serverNegoCounter = new AtomicInteger();
201 
202     private volatile Channel clientChannel;
203     private volatile Channel serverChannel;
204 
205     private volatile SslHandler clientSslHandler;
206     private volatile SslHandler serverSslHandler;
207 
208     private final EchoClientHandler clientHandler =
209             new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
210 
211     private final EchoServerHandler serverHandler =
212             new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
213 
214     private SslContext serverCtx;
215     private SslContext clientCtx;
216     private Renegotiation renegotiation;
217     private boolean serverUsesDelegatedTaskExecutor;
218     private boolean clientUsesDelegatedTaskExecutor;
219     private boolean autoRead;
220     private boolean useChunkedWriteHandler;
221     private boolean useCompositeByteBuf;
222 
223     @AfterAll
224     public static void compressHeapDumps() throws Exception {
225         TestUtils.compressHeapDumps();
226     }
227 
228     @ParameterizedTest(name =
229             "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
230             "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
231             "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
232     @MethodSource("data")
233     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
234     public void testSslEcho(
235             SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
236             boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
237             boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
238             TestInfo testInfo) throws Throwable {
239         this.serverCtx = serverCtx;
240         this.clientCtx = clientCtx;
241         this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
242         this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
243         this.renegotiation = renegotiation;
244         this.autoRead = autoRead;
245         this.useChunkedWriteHandler = useChunkedWriteHandler;
246         this.useCompositeByteBuf = useCompositeByteBuf;
247         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
248             @Override
249             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
250                 testSslEcho(serverBootstrap, bootstrap);
251             }
252         });
253     }
254 
255     public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
256         final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
257         reset();
258 
259         sb.childOption(ChannelOption.AUTO_READ, autoRead);
260         cb.option(ChannelOption.AUTO_READ, autoRead);
261 
262         sb.childHandler(new ChannelInitializer<Channel>() {
263             @Override
264             public void initChannel(Channel sch) {
265                 serverChannel = sch;
266 
267                 if (serverUsesDelegatedTaskExecutor) {
268                     SSLEngine sse = serverCtx.newEngine(sch.alloc());
269                     serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
270                 } else {
271                     serverSslHandler = serverCtx.newHandler(sch.alloc());
272                 }
273                 serverSslHandler.setHandshakeTimeoutMillis(0);
274 
275                 sch.pipeline().addLast("ssl", serverSslHandler);
276                 if (useChunkedWriteHandler) {
277                     sch.pipeline().addLast(new ChunkedWriteHandler());
278                 }
279                 sch.pipeline().addLast("serverHandler", serverHandler);
280             }
281         });
282 
283         final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
284         cb.handler(new ChannelInitializer<Channel>() {
285             @Override
286             public void initChannel(Channel sch) {
287                 clientChannel = sch;
288 
289                 if (clientUsesDelegatedTaskExecutor) {
290                     SSLEngine cse = clientCtx.newEngine(sch.alloc());
291                     clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
292                 } else {
293                     clientSslHandler = clientCtx.newHandler(sch.alloc());
294                 }
295                 clientSslHandler.setHandshakeTimeoutMillis(0);
296 
297                 sch.pipeline().addLast("ssl", clientSslHandler);
298                 if (useChunkedWriteHandler) {
299                     sch.pipeline().addLast(new ChunkedWriteHandler());
300                 }
301                 sch.pipeline().addLast("clientHandler", clientHandler);
302                 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
303                     @Override
304                     public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
305                         if (evt instanceof SslHandshakeCompletionEvent) {
306                             clientHandshakeEventLatch.countDown();
307                         }
308                         ctx.fireUserEventTriggered(evt);
309                     }
310                 });
311             }
312         });
313 
314         final Channel sc = sb.bind().sync().channel();
315         cb.connect(sc.localAddress()).sync();
316 
317         final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
318 
319         // Wait for the handshake to complete before we flush anything. SslHandler should flush non-application data.
320         clientHandshakeFuture.sync();
321         clientHandshakeEventLatch.await();
322 
323         clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
324         clientSendCounter.set(FIRST_MESSAGE_SIZE);
325 
326         boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
327         Future<Channel> renegoFuture = null;
328         while (clientSendCounter.get() < data.length) {
329             int clientSendCounterVal = clientSendCounter.get();
330             int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
331             ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
332             if (useCompositeByteBuf) {
333                 buf = Unpooled.compositeBuffer().addComponent(true, buf);
334             }
335 
336             ChannelFuture future = clientChannel.writeAndFlush(buf);
337             clientSendCounter.set(clientSendCounterVal += length);
338             future.sync();
339 
340             if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
341                 needsRenegotiation = false;
342                 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
343                 renegoFuture = clientSslHandler.renegotiate();
344                 logStats("CLIENT RENEGOTIATES");
345                 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
346             }
347         }
348 
349         // Ensure all data has been exchanged.
350         while (clientRecvCounter.get() < data.length) {
351             if (serverException.get() != null) {
352                 break;
353             }
354             if (serverException.get() != null) {
355                 break;
356             }
357 
358             try {
359                 Thread.sleep(50);
360             } catch (InterruptedException e) {
361                 // Ignore.
362             }
363         }
364 
365         while (serverRecvCounter.get() < data.length) {
366             if (serverException.get() != null) {
367                 break;
368             }
369             if (clientException.get() != null) {
370                 break;
371             }
372 
373             try {
374                 Thread.sleep(50);
375             } catch (InterruptedException e) {
376                 // Ignore.
377             }
378         }
379 
380         // Wait until renegotiation is done.
381         if (renegoFuture != null) {
382             renegoFuture.sync();
383         }
384         if (serverHandler.renegoFuture != null) {
385             serverHandler.renegoFuture.sync();
386         }
387 
388         serverChannel.close().awaitUninterruptibly();
389         clientChannel.close().awaitUninterruptibly();
390         sc.close().awaitUninterruptibly();
391         delegatedTaskExecutor.shutdown();
392 
393         if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
394             throw serverException.get();
395         }
396         if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
397             throw clientException.get();
398         }
399         if (serverException.get() != null) {
400             throw serverException.get();
401         }
402         if (clientException.get() != null) {
403             throw clientException.get();
404         }
405 
406         // When renegotiation is done, at least the initiating side should be notified.
407         try {
408             switch (renegotiation.type) {
409             case SERVER_INITIATED:
410                 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
411                 assertThat(serverNegoCounter.get(), is(2));
412                 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
413                 break;
414             case CLIENT_INITIATED:
415                 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
416                 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
417                 assertThat(clientNegoCounter.get(), is(2));
418                 break;
419             case NONE:
420                 assertThat(serverNegoCounter.get(), is(1));
421                 assertThat(clientNegoCounter.get(), is(1));
422             }
423         } finally {
424             logStats("STATS");
425         }
426     }
427 
428     private void reset() {
429         clientException.set(null);
430         serverException.set(null);
431 
432         clientSendCounter.set(0);
433         clientRecvCounter.set(0);
434         serverRecvCounter.set(0);
435 
436         clientNegoCounter.set(0);
437         serverNegoCounter.set(0);
438 
439         clientChannel = null;
440         serverChannel = null;
441 
442         clientSslHandler = null;
443         serverSslHandler = null;
444     }
445 
446     void logStats(String message) {
447         logger.debug(
448                 "{}:\n" +
449                 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
450                 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
451                 message,
452                 clientSendCounter, clientRecvCounter, clientNegoCounter,
453                 clientSslHandler.engine().getSession().getCipherSuite(),
454                 serverRecvCounter, serverNegoCounter,
455                 serverSslHandler.engine().getSession().getCipherSuite());
456     }
457 
458     @Sharable
459     private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
460 
461         protected final AtomicInteger recvCounter;
462         protected final AtomicInteger negoCounter;
463         protected final AtomicReference<Throwable> exception;
464 
465         EchoHandler(
466                 AtomicInteger recvCounter, AtomicInteger negoCounter,
467                 AtomicReference<Throwable> exception) {
468 
469             this.recvCounter = recvCounter;
470             this.negoCounter = negoCounter;
471             this.exception = exception;
472         }
473 
474         @Override
475         public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
476             // We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
477             // non-application and previously flushed writes internally.
478             if (!autoRead) {
479                 ctx.read();
480             }
481             ctx.fireChannelReadComplete();
482         }
483 
484         @Override
485         public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
486             if (evt instanceof SslHandshakeCompletionEvent) {
487                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
488                 if (handshakeEvt.cause() != null) {
489                     logger.warn("Handshake failed:", handshakeEvt.cause());
490                 }
491                 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
492                 negoCounter.incrementAndGet();
493                 logStats("HANDSHAKEN");
494             }
495             ctx.fireUserEventTriggered(evt);
496         }
497 
498         @Override
499         public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
500             if (logger.isWarnEnabled()) {
501                 logger.warn("Unexpected exception from the client side:", cause);
502             }
503 
504             exception.compareAndSet(null, cause);
505             ctx.close();
506         }
507     }
508 
509     private class EchoClientHandler extends EchoHandler {
510 
511         EchoClientHandler(
512                 AtomicInteger recvCounter, AtomicInteger negoCounter,
513                 AtomicReference<Throwable> exception) {
514 
515             super(recvCounter, negoCounter, exception);
516         }
517 
518         @Override
519         public void handlerAdded(final ChannelHandlerContext ctx) {
520             if (!autoRead) {
521                 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
522                         new GenericFutureListener<Future<? super Channel>>() {
523                             @Override
524                             public void operationComplete(Future<? super Channel> future) {
525                                 ctx.read();
526                             }
527                         });
528             }
529         }
530 
531         @Override
532         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
533             byte[] actual = new byte[in.readableBytes()];
534             in.readBytes(actual);
535 
536             int lastIdx = recvCounter.get();
537             for (int i = 0; i < actual.length; i ++) {
538                 assertEquals(data[i + lastIdx], actual[i]);
539             }
540 
541             recvCounter.addAndGet(actual.length);
542         }
543     }
544 
545     private class EchoServerHandler extends EchoHandler {
546         volatile Future<Channel> renegoFuture;
547 
548         EchoServerHandler(
549                 AtomicInteger recvCounter, AtomicInteger negoCounter,
550                 AtomicReference<Throwable> exception) {
551 
552             super(recvCounter, negoCounter, exception);
553         }
554 
555         @Override
556         public final void channelRegistered(ChannelHandlerContext ctx) {
557             renegoFuture = null;
558         }
559 
560         @Override
561         public void channelActive(final ChannelHandlerContext ctx) throws Exception {
562             if (!autoRead) {
563                 ctx.read();
564             }
565             ctx.fireChannelActive();
566         }
567 
568         @Override
569         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
570             byte[] actual = new byte[in.readableBytes()];
571             in.readBytes(actual);
572 
573             int lastIdx = recvCounter.get();
574             for (int i = 0; i < actual.length; i ++) {
575                 assertEquals(data[i + lastIdx], actual[i]);
576             }
577 
578             ByteBuf buf = Unpooled.wrappedBuffer(actual);
579             if (useCompositeByteBuf) {
580                 buf = Unpooled.compositeBuffer().addComponent(true, buf);
581             }
582             ctx.writeAndFlush(buf);
583 
584             recvCounter.addAndGet(actual.length);
585 
586             // Perform server-initiated renegotiation if necessary.
587             if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
588                 recvCounter.get() > data.length / 2 && renegoFuture == null) {
589 
590                 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
591 
592                 Future<Channel> hf = sslHandler.handshakeFuture();
593                 assertThat(hf.isDone(), is(true));
594 
595                 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
596                 logStats("SERVER RENEGOTIATES");
597                 renegoFuture = sslHandler.renegotiate();
598                 assertThat(renegoFuture, is(not(sameInstance(hf))));
599                 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
600                 assertThat(renegoFuture.isDone(), is(false));
601             }
602         }
603     }
604 }