View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.Unpooled;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelHandler.Sharable;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelInboundHandlerAdapter;
27  import io.netty.channel.ChannelInitializer;
28  import io.netty.channel.ChannelOption;
29  import io.netty.channel.SimpleChannelInboundHandler;
30  import io.netty.handler.ssl.OpenSsl;
31  import io.netty.handler.ssl.OpenSslContext;
32  import io.netty.handler.ssl.SslContext;
33  import io.netty.handler.ssl.SslContextBuilder;
34  import io.netty.handler.ssl.SslHandler;
35  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
36  import io.netty.handler.ssl.SslProvider;
37  import io.netty.handler.stream.ChunkedWriteHandler;
38  import io.netty.pkitesting.CertificateBuilder;
39  import io.netty.pkitesting.X509Bundle;
40  import io.netty.testsuite.util.TestUtils;
41  import io.netty.util.concurrent.Future;
42  import io.netty.util.concurrent.GenericFutureListener;
43  import io.netty.util.internal.logging.InternalLogger;
44  import io.netty.util.internal.logging.InternalLoggerFactory;
45  import org.junit.jupiter.api.AfterAll;
46  import org.junit.jupiter.api.TestInfo;
47  import org.junit.jupiter.api.Timeout;
48  import org.junit.jupiter.params.ParameterizedTest;
49  import org.junit.jupiter.params.provider.MethodSource;
50  
51  import javax.net.ssl.SSLEngine;
52  import java.io.File;
53  import java.io.IOException;
54  import java.util.ArrayList;
55  import java.util.Collection;
56  import java.util.List;
57  import java.util.Random;
58  import java.util.concurrent.CountDownLatch;
59  import java.util.concurrent.ExecutorService;
60  import java.util.concurrent.Executors;
61  import java.util.concurrent.TimeUnit;
62  import java.util.concurrent.atomic.AtomicInteger;
63  import java.util.concurrent.atomic.AtomicReference;
64  
65  import static org.hamcrest.MatcherAssert.assertThat;
66  import static org.hamcrest.Matchers.anyOf;
67  import static org.hamcrest.Matchers.is;
68  import static org.hamcrest.Matchers.not;
69  import static org.hamcrest.Matchers.sameInstance;
70  import static org.junit.jupiter.api.Assertions.assertEquals;
71  import static org.junit.jupiter.api.Assertions.assertSame;
72  
73  public class SocketSslEchoTest extends AbstractSocketTest {
74  
75      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslEchoTest.class);
76  
77      private static final int FIRST_MESSAGE_SIZE = 16384;
78      private static final Random random = new Random();
79      private static final File CERT_FILE;
80      private static final File KEY_FILE;
81      static final byte[] data = new byte[1048576];
82  
83      static {
84          random.nextBytes(data);
85  
86          try {
87              X509Bundle cert = new CertificateBuilder()
88                      .rsa2048()
89                      .subject("cn=localhost")
90                      .setIsCertificateAuthority(true)
91                      .buildSelfSigned();
92              CERT_FILE = cert.toTempCertChainPem();
93              KEY_FILE = cert.toTempPrivateKeyPem();
94          } catch (Exception e) {
95              throw new ExceptionInInitializerError(e);
96          }
97      }
98  
99      protected enum RenegotiationType {
100         NONE, // no renegotiation
101         CLIENT_INITIATED, // renegotiation from client
102         SERVER_INITIATED, // renegotiation from server
103     }
104 
105     protected static class Renegotiation {
106         static final Renegotiation NONE = new Renegotiation(RenegotiationType.NONE, null);
107 
108         final RenegotiationType type;
109         final String cipherSuite;
110 
111         Renegotiation(RenegotiationType type, String cipherSuite) {
112             this.type = type;
113             this.cipherSuite = cipherSuite;
114         }
115 
116         @Override
117         public String toString() {
118             if (type == RenegotiationType.NONE) {
119                 return "NONE";
120             }
121 
122             return type + "(" + cipherSuite + ')';
123         }
124     }
125 
126     public static Collection<Object[]> data() throws Exception {
127         List<SslContext> serverContexts = new ArrayList<SslContext>();
128         serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
129                                             .sslProvider(SslProvider.JDK)
130                                             // As we test renegotiation we should use a protocol that support it.
131                                             .protocols("TLSv1.2")
132                                             .build());
133 
134         List<SslContext> clientContexts = new ArrayList<SslContext>();
135         clientContexts.add(SslContextBuilder.forClient()
136                                             .sslProvider(SslProvider.JDK)
137                                             .trustManager(CERT_FILE)
138                                             // As we test renegotiation we should use a protocol that support it.
139                                             .protocols("TLSv1.2")
140                                             .endpointIdentificationAlgorithm(null)
141                                             .build());
142 
143         boolean hasOpenSsl = OpenSsl.isAvailable();
144         if (hasOpenSsl) {
145             serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
146                                                 .sslProvider(SslProvider.OPENSSL)
147                                                 // As we test renegotiation we should use a protocol that support it.
148                                                 .protocols("TLSv1.2")
149                                                 .build());
150             clientContexts.add(SslContextBuilder.forClient()
151                                                 .sslProvider(SslProvider.OPENSSL)
152                                                 .trustManager(CERT_FILE)
153                                                 // As we test renegotiation we should use a protocol that support it.
154                                                 .protocols("TLSv1.2")
155                                                 .endpointIdentificationAlgorithm(null)
156                                                 .build());
157         } else {
158             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
159         }
160 
161         List<Object[]> params = new ArrayList<Object[]>();
162         for (SslContext sc: serverContexts) {
163             for (SslContext cc: clientContexts) {
164                 for (RenegotiationType rt: RenegotiationType.values()) {
165                     if (rt != RenegotiationType.NONE &&
166                         (sc instanceof OpenSslContext || cc instanceof OpenSslContext)) {
167                         // TODO: OpenSslEngine does not support renegotiation yet.
168                         continue;
169                     }
170 
171                     final Renegotiation r;
172                     switch (rt) {
173                         case NONE:
174                             r = Renegotiation.NONE;
175                             break;
176                         case SERVER_INITIATED:
177                             r = new Renegotiation(rt, sc.cipherSuites().get(sc.cipherSuites().size() - 1));
178                             break;
179                         case CLIENT_INITIATED:
180                             r = new Renegotiation(rt, cc.cipherSuites().get(cc.cipherSuites().size() - 1));
181                             break;
182                         default:
183                             throw new Error();
184                     }
185 
186                     for (int i = 0; i < 32; i++) {
187                         params.add(new Object[] {
188                                 sc, cc, r,
189                                 (i & 16) != 0, (i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0 });
190                     }
191                 }
192             }
193         }
194 
195         return params;
196     }
197 
198     private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
199     private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
200     private final AtomicInteger clientSendCounter = new AtomicInteger();
201     private final AtomicInteger clientRecvCounter = new AtomicInteger();
202     private final AtomicInteger serverRecvCounter = new AtomicInteger();
203 
204     private final AtomicInteger clientNegoCounter = new AtomicInteger();
205     private final AtomicInteger serverNegoCounter = new AtomicInteger();
206 
207     private volatile Channel clientChannel;
208     private volatile Channel serverChannel;
209 
210     private volatile SslHandler clientSslHandler;
211     private volatile SslHandler serverSslHandler;
212 
213     private final EchoClientHandler clientHandler =
214             new EchoClientHandler(clientRecvCounter, clientNegoCounter, clientException);
215 
216     private final EchoServerHandler serverHandler =
217             new EchoServerHandler(serverRecvCounter, serverNegoCounter, serverException);
218 
219     private SslContext serverCtx;
220     private SslContext clientCtx;
221     private Renegotiation renegotiation;
222     private boolean serverUsesDelegatedTaskExecutor;
223     private boolean clientUsesDelegatedTaskExecutor;
224     private boolean autoRead;
225     private boolean useChunkedWriteHandler;
226     private boolean useCompositeByteBuf;
227 
228     @AfterAll
229     public static void compressHeapDumps() throws Exception {
230         TestUtils.compressHeapDumps();
231     }
232 
233     @ParameterizedTest(name =
234             "{index}: serverEngine = {0}, clientEngine = {1}, renegotiation = {2}, " +
235             "serverUsesDelegatedTaskExecutor = {3}, clientUsesDelegatedTaskExecutor = {4}, " +
236             "autoRead = {5}, useChunkedWriteHandler = {6}, useCompositeByteBuf = {7}")
237     @MethodSource("data")
238     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
239     public void testSslEcho(
240             SslContext serverCtx, SslContext clientCtx, Renegotiation renegotiation,
241             boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
242             boolean autoRead, boolean useChunkedWriteHandler, boolean useCompositeByteBuf,
243             TestInfo testInfo) throws Throwable {
244         this.serverCtx = serverCtx;
245         this.clientCtx = clientCtx;
246         this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
247         this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
248         this.renegotiation = renegotiation;
249         this.autoRead = autoRead;
250         this.useChunkedWriteHandler = useChunkedWriteHandler;
251         this.useCompositeByteBuf = useCompositeByteBuf;
252         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
253             @Override
254             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
255                 testSslEcho(serverBootstrap, bootstrap);
256             }
257         });
258     }
259 
260     public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
261         final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
262         reset();
263 
264         sb.childOption(ChannelOption.AUTO_READ, autoRead);
265         cb.option(ChannelOption.AUTO_READ, autoRead);
266 
267         sb.childHandler(new ChannelInitializer<Channel>() {
268             @Override
269             public void initChannel(Channel sch) {
270                 serverChannel = sch;
271 
272                 if (serverUsesDelegatedTaskExecutor) {
273                     SSLEngine sse = serverCtx.newEngine(sch.alloc());
274                     serverSslHandler = new SslHandler(sse, delegatedTaskExecutor);
275                 } else {
276                     serverSslHandler = serverCtx.newHandler(sch.alloc());
277                 }
278                 serverSslHandler.setHandshakeTimeoutMillis(0);
279 
280                 sch.pipeline().addLast("ssl", serverSslHandler);
281                 if (useChunkedWriteHandler) {
282                     sch.pipeline().addLast(new ChunkedWriteHandler());
283                 }
284                 sch.pipeline().addLast("serverHandler", serverHandler);
285             }
286         });
287 
288         final CountDownLatch clientHandshakeEventLatch = new CountDownLatch(1);
289         cb.handler(new ChannelInitializer<Channel>() {
290             @Override
291             public void initChannel(Channel sch) {
292                 clientChannel = sch;
293 
294                 if (clientUsesDelegatedTaskExecutor) {
295                     SSLEngine cse = clientCtx.newEngine(sch.alloc());
296                     clientSslHandler = new SslHandler(cse, delegatedTaskExecutor);
297                 } else {
298                     clientSslHandler = clientCtx.newHandler(sch.alloc());
299                 }
300                 clientSslHandler.setHandshakeTimeoutMillis(0);
301 
302                 sch.pipeline().addLast("ssl", clientSslHandler);
303                 if (useChunkedWriteHandler) {
304                     sch.pipeline().addLast(new ChunkedWriteHandler());
305                 }
306                 sch.pipeline().addLast("clientHandler", clientHandler);
307                 sch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
308                     @Override
309                     public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
310                         if (evt instanceof SslHandshakeCompletionEvent) {
311                             clientHandshakeEventLatch.countDown();
312                         }
313                         ctx.fireUserEventTriggered(evt);
314                     }
315                 });
316             }
317         });
318 
319         final Channel sc = sb.bind().sync().channel();
320         cb.connect(sc.localAddress()).sync();
321 
322         final Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
323 
324         // Wait for the handshake to complete before we flush anything. SslHandler should flush non-application data.
325         clientHandshakeFuture.sync();
326         clientHandshakeEventLatch.await();
327 
328         clientChannel.writeAndFlush(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
329         clientSendCounter.set(FIRST_MESSAGE_SIZE);
330 
331         boolean needsRenegotiation = renegotiation.type == RenegotiationType.CLIENT_INITIATED;
332         Future<Channel> renegoFuture = null;
333         while (clientSendCounter.get() < data.length) {
334             int clientSendCounterVal = clientSendCounter.get();
335             int length = Math.min(random.nextInt(1024 * 64), data.length - clientSendCounterVal);
336             ByteBuf buf = Unpooled.wrappedBuffer(data, clientSendCounterVal, length);
337             if (useCompositeByteBuf) {
338                 buf = Unpooled.compositeBuffer().addComponent(true, buf);
339             }
340 
341             ChannelFuture future = clientChannel.writeAndFlush(buf);
342             clientSendCounter.set(clientSendCounterVal += length);
343             future.sync();
344 
345             if (needsRenegotiation && clientSendCounterVal >= data.length / 2) {
346                 needsRenegotiation = false;
347                 clientSslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
348                 renegoFuture = clientSslHandler.renegotiate();
349                 logStats("CLIENT RENEGOTIATES");
350                 assertThat(renegoFuture, is(not(sameInstance(clientHandshakeFuture))));
351             }
352         }
353 
354         // Ensure all data has been exchanged.
355         while (clientRecvCounter.get() < data.length) {
356             if (serverException.get() != null) {
357                 break;
358             }
359             if (clientException.get() != null) {
360                 break;
361             }
362 
363             Thread.sleep(50);
364         }
365 
366         while (serverRecvCounter.get() < data.length) {
367             if (serverException.get() != null) {
368                 break;
369             }
370             if (clientException.get() != null) {
371                 break;
372             }
373 
374             Thread.sleep(50);
375         }
376 
377         // Wait until renegotiation is done.
378         if (renegoFuture != null) {
379             renegoFuture.sync();
380         }
381         if (serverHandler.renegoFuture != null) {
382             serverHandler.renegoFuture.sync();
383         }
384 
385         serverChannel.close().awaitUninterruptibly();
386         clientChannel.close().awaitUninterruptibly();
387         sc.close().awaitUninterruptibly();
388         delegatedTaskExecutor.shutdown();
389 
390         if (serverException.get() != null && !(serverException.get() instanceof IOException)) {
391             throw serverException.get();
392         }
393         if (clientException.get() != null && !(clientException.get() instanceof IOException)) {
394             throw clientException.get();
395         }
396         if (serverException.get() != null) {
397             throw serverException.get();
398         }
399         if (clientException.get() != null) {
400             throw clientException.get();
401         }
402 
403         // When renegotiation is done, at least the initiating side should be notified.
404         try {
405             switch (renegotiation.type) {
406             case SERVER_INITIATED:
407                 assertThat(serverSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
408                 assertThat(serverNegoCounter.get(), is(2));
409                 assertThat(clientNegoCounter.get(), anyOf(is(1), is(2)));
410                 break;
411             case CLIENT_INITIATED:
412                 assertThat(serverNegoCounter.get(), anyOf(is(1), is(2)));
413                 assertThat(clientSslHandler.engine().getSession().getCipherSuite(), is(renegotiation.cipherSuite));
414                 assertThat(clientNegoCounter.get(), is(2));
415                 break;
416             case NONE:
417                 assertThat(serverNegoCounter.get(), is(1));
418                 assertThat(clientNegoCounter.get(), is(1));
419             }
420         } finally {
421             logStats("STATS");
422         }
423     }
424 
425     private void reset() {
426         clientException.set(null);
427         serverException.set(null);
428 
429         clientSendCounter.set(0);
430         clientRecvCounter.set(0);
431         serverRecvCounter.set(0);
432 
433         clientNegoCounter.set(0);
434         serverNegoCounter.set(0);
435 
436         clientChannel = null;
437         serverChannel = null;
438 
439         clientSslHandler = null;
440         serverSslHandler = null;
441     }
442 
443     void logStats(String message) {
444         logger.debug(
445                 "{}:\n" +
446                 "\tclient { sent: {}, rcvd: {}, nego: {}, cipher: {} },\n" +
447                 "\tserver { rcvd: {}, nego: {}, cipher: {} }",
448                 message,
449                 clientSendCounter, clientRecvCounter, clientNegoCounter,
450                 clientSslHandler.engine().getSession().getCipherSuite(),
451                 serverRecvCounter, serverNegoCounter,
452                 serverSslHandler.engine().getSession().getCipherSuite());
453     }
454 
455     @Sharable
456     private abstract class EchoHandler extends SimpleChannelInboundHandler<ByteBuf> {
457 
458         protected final AtomicInteger recvCounter;
459         protected final AtomicInteger negoCounter;
460         protected final AtomicReference<Throwable> exception;
461 
462         EchoHandler(
463                 AtomicInteger recvCounter, AtomicInteger negoCounter,
464                 AtomicReference<Throwable> exception) {
465 
466             this.recvCounter = recvCounter;
467             this.negoCounter = negoCounter;
468             this.exception = exception;
469         }
470 
471         @Override
472         public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
473             // We intentionally do not ctx.flush() here because we want to verify the SslHandler correctly flushing
474             // non-application and previously flushed writes internally.
475             if (!autoRead) {
476                 ctx.read();
477             }
478             ctx.fireChannelReadComplete();
479         }
480 
481         @Override
482         public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
483             if (evt instanceof SslHandshakeCompletionEvent) {
484                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
485                 if (handshakeEvt.cause() != null) {
486                     logger.warn("Handshake failed:", handshakeEvt.cause());
487                 }
488                 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
489                 negoCounter.incrementAndGet();
490                 logStats("HANDSHAKEN");
491             }
492             ctx.fireUserEventTriggered(evt);
493         }
494 
495         @Override
496         public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
497             if (logger.isWarnEnabled()) {
498                 logger.warn("Unexpected exception from the client side:", cause);
499             }
500 
501             exception.compareAndSet(null, cause);
502             ctx.close();
503         }
504     }
505 
506     private class EchoClientHandler extends EchoHandler {
507 
508         EchoClientHandler(
509                 AtomicInteger recvCounter, AtomicInteger negoCounter,
510                 AtomicReference<Throwable> exception) {
511 
512             super(recvCounter, negoCounter, exception);
513         }
514 
515         @Override
516         public void handlerAdded(final ChannelHandlerContext ctx) {
517             if (!autoRead) {
518                 ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
519                         new GenericFutureListener<Future<? super Channel>>() {
520                             @Override
521                             public void operationComplete(Future<? super Channel> future) {
522                                 ctx.read();
523                             }
524                         });
525             }
526         }
527 
528         @Override
529         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
530             byte[] actual = new byte[in.readableBytes()];
531             in.readBytes(actual);
532 
533             int lastIdx = recvCounter.get();
534             for (int i = 0; i < actual.length; i ++) {
535                 assertEquals(data[i + lastIdx], actual[i]);
536             }
537 
538             recvCounter.addAndGet(actual.length);
539         }
540     }
541 
542     private class EchoServerHandler extends EchoHandler {
543         volatile Future<Channel> renegoFuture;
544 
545         EchoServerHandler(
546                 AtomicInteger recvCounter, AtomicInteger negoCounter,
547                 AtomicReference<Throwable> exception) {
548 
549             super(recvCounter, negoCounter, exception);
550         }
551 
552         @Override
553         public final void channelRegistered(ChannelHandlerContext ctx) {
554             renegoFuture = null;
555         }
556 
557         @Override
558         public void channelActive(final ChannelHandlerContext ctx) throws Exception {
559             if (!autoRead) {
560                 ctx.read();
561             }
562             ctx.fireChannelActive();
563         }
564 
565         @Override
566         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
567             byte[] actual = new byte[in.readableBytes()];
568             in.readBytes(actual);
569 
570             int lastIdx = recvCounter.get();
571             for (int i = 0; i < actual.length; i ++) {
572                 assertEquals(data[i + lastIdx], actual[i]);
573             }
574 
575             ByteBuf buf = Unpooled.wrappedBuffer(actual);
576             if (useCompositeByteBuf) {
577                 buf = Unpooled.compositeBuffer().addComponent(true, buf);
578             }
579             ctx.writeAndFlush(buf);
580 
581             recvCounter.addAndGet(actual.length);
582 
583             // Perform server-initiated renegotiation if necessary.
584             if (renegotiation.type == RenegotiationType.SERVER_INITIATED &&
585                 recvCounter.get() > data.length / 2 && renegoFuture == null) {
586 
587                 SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
588 
589                 Future<Channel> hf = sslHandler.handshakeFuture();
590                 assertThat(hf.isDone(), is(true));
591 
592                 sslHandler.engine().setEnabledCipherSuites(new String[] { renegotiation.cipherSuite });
593                 logStats("SERVER RENEGOTIATES");
594                 renegoFuture = sslHandler.renegotiate();
595                 assertThat(renegoFuture, is(not(sameInstance(hf))));
596                 assertThat(renegoFuture, is(sameInstance(sslHandler.handshakeFuture())));
597                 assertThat(renegoFuture.isDone(), is(false));
598             }
599         }
600     }
601 }