View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelHandler.Sharable;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelInitializer;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.codec.DecoderException;
28  import io.netty.handler.ssl.JdkSslClientContext;
29  import io.netty.handler.ssl.OpenSsl;
30  import io.netty.handler.ssl.OpenSslServerContext;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty.pkitesting.CertificateBuilder;
35  import io.netty.pkitesting.X509Bundle;
36  import io.netty.util.concurrent.Future;
37  import io.netty.util.internal.logging.InternalLogger;
38  import io.netty.util.internal.logging.InternalLoggerFactory;
39  import org.junit.jupiter.api.TestInfo;
40  import org.junit.jupiter.api.Timeout;
41  import org.junit.jupiter.api.condition.DisabledIf;
42  import org.junit.jupiter.params.ParameterizedTest;
43  import org.junit.jupiter.params.provider.MethodSource;
44  
45  import java.io.File;
46  import java.nio.channels.ClosedChannelException;
47  import java.util.ArrayList;
48  import java.util.Collection;
49  import java.util.List;
50  import java.util.concurrent.Executor;
51  import java.util.concurrent.ExecutorService;
52  import java.util.concurrent.Executors;
53  import java.util.concurrent.TimeUnit;
54  import java.util.concurrent.atomic.AtomicReference;
55  
56  import javax.net.ssl.SSLHandshakeException;
57  
58  import static org.junit.jupiter.api.Assertions.assertSame;
59  import static org.junit.jupiter.api.Assertions.assertTrue;
60  import static org.junit.jupiter.api.Assertions.fail;
61  import static org.junit.jupiter.api.Assumptions.assumeFalse;
62  import static org.junit.jupiter.api.Assumptions.assumeTrue;
63  
64  public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
65      private static final InternalLogger logger = InternalLoggerFactory.getInstance(
66              SocketSslClientRenegotiateTest.class);
67      private static final File CERT_FILE;
68      private static final File KEY_FILE;
69  
70      static {
71          try {
72              X509Bundle cert = new CertificateBuilder()
73                      .subject("cn=localhost")
74                      .setIsCertificateAuthority(true)
75                      .buildSelfSigned();
76              CERT_FILE = cert.toTempCertChainPem();
77              KEY_FILE = cert.toTempPrivateKeyPem();
78          } catch (Exception e) {
79              throw new ExceptionInInitializerError(e);
80          }
81      }
82  
83      private static boolean openSslNotAvailable() {
84          return !OpenSsl.isAvailable();
85      }
86  
87      public static Collection<Object[]> data() throws Exception {
88          List<SslContext> serverContexts = new ArrayList<SslContext>();
89          List<SslContext> clientContexts = new ArrayList<SslContext>();
90          clientContexts.add(new JdkSslClientContext(CERT_FILE));
91  
92          boolean hasOpenSsl = OpenSsl.isAvailable();
93          if (hasOpenSsl) {
94              OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
95              serverContexts.add(context);
96          } else {
97              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
98          }
99  
100         List<Object[]> params = new ArrayList<Object[]>();
101         for (SslContext sc: serverContexts) {
102             for (SslContext cc: clientContexts) {
103                 for (int i = 0; i < 32; i++) {
104                     params.add(new Object[] { sc, cc, true});
105                     params.add(new Object[] { sc, cc, false});
106                 }
107             }
108         }
109 
110         return params;
111     }
112 
113     private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
114     private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
115 
116     private volatile Channel clientChannel;
117     private volatile Channel serverChannel;
118 
119     private volatile SslHandler clientSslHandler;
120     private volatile SslHandler serverSslHandler;
121 
122     private final TestHandler clientHandler = new TestHandler(clientException);
123 
124     private final TestHandler serverHandler = new TestHandler(serverException);
125 
126     @DisabledIf("openSslNotAvailable")
127     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
128     @MethodSource("data")
129     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
130     public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx,
131                                              final boolean delegate, TestInfo testInfo) throws Throwable {
132         // BoringSSL does not support renegotiation intentionally.
133         assumeFalse("BoringSSL".equals(OpenSsl.versionString()));
134         assumeTrue(OpenSsl.isAvailable());
135         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
136             @Override
137             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
138                 testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate);
139             }
140         });
141     }
142 
143     private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
144         if (executor == null) {
145             return sslCtx.newHandler(allocator);
146         } else {
147             return sslCtx.newHandler(allocator, executor);
148         }
149     }
150 
151     public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
152                                              final SslContext clientCtx, boolean delegate) throws Throwable {
153         reset();
154 
155         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
156 
157         try {
158             sb.childHandler(new ChannelInitializer<Channel>() {
159                 @Override
160                 @SuppressWarnings("deprecation")
161                 public void initChannel(Channel sch) throws Exception {
162                     serverChannel = sch;
163                     serverSslHandler = newSslHandler(serverCtx, sch.alloc(), executorService);
164                     // As we test renegotiation we should use a protocol that support it.
165                     serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
166                     sch.pipeline().addLast("ssl", serverSslHandler);
167                     sch.pipeline().addLast("handler", serverHandler);
168                 }
169             });
170 
171             cb.handler(new ChannelInitializer<Channel>() {
172                 @Override
173                 @SuppressWarnings("deprecation")
174                 public void initChannel(Channel sch) throws Exception {
175                     clientChannel = sch;
176                     clientSslHandler = newSslHandler(clientCtx, sch.alloc(), executorService);
177                     // As we test renegotiation we should use a protocol that support it.
178                     clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
179                     sch.pipeline().addLast("ssl", clientSslHandler);
180                     sch.pipeline().addLast("handler", clientHandler);
181                 }
182             });
183 
184             Channel sc = sb.bind().sync().channel();
185             cb.connect(sc.localAddress()).sync();
186 
187             Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
188             clientHandshakeFuture.sync();
189 
190             String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
191             // Use the first previous enabled ciphersuite and try to renegotiate.
192             clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
193             clientSslHandler.renegotiate().await();
194             serverChannel.close().awaitUninterruptibly();
195             clientChannel.close().awaitUninterruptibly();
196             sc.close().awaitUninterruptibly();
197             try {
198                 if (serverException.get() != null) {
199                     throw serverException.get();
200                 }
201                 fail();
202             } catch (DecoderException e) {
203                 assertTrue(e.getCause() instanceof SSLHandshakeException);
204             }
205             if (clientException.get() != null) {
206                 throw clientException.get();
207             }
208         } finally {
209             if (executorService != null) {
210                 executorService.shutdown();
211             }
212         }
213     }
214 
215     private void reset() {
216         clientException.set(null);
217         serverException.set(null);
218         clientHandler.handshakeCounter = 0;
219         serverHandler.handshakeCounter = 0;
220         clientChannel = null;
221         serverChannel = null;
222 
223         clientSslHandler = null;
224         serverSslHandler = null;
225     }
226 
227     @Sharable
228     private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
229 
230         protected final AtomicReference<Throwable> exception;
231         private int handshakeCounter;
232 
233         TestHandler(AtomicReference<Throwable> exception) {
234             this.exception = exception;
235         }
236 
237         @Override
238         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
239             ctx.flush();
240         }
241 
242         @Override
243         public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
244             exception.compareAndSet(null, cause);
245             ctx.close();
246         }
247 
248         @Override
249         public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
250             if (evt instanceof SslHandshakeCompletionEvent) {
251                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
252                 if (handshakeCounter == 0) {
253                     handshakeCounter++;
254                     if (handshakeEvt.cause() != null) {
255                         logger.warn("Handshake failed:", handshakeEvt.cause());
256                     }
257                     assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
258                 } else {
259                     if (ctx.channel().parent() == null) {
260                         assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
261                     }
262                 }
263             }
264         }
265 
266         @Override
267         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { }
268     }
269 }