View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelHandler.Sharable;
24  import io.netty.channel.ChannelHandlerContext;
25  import io.netty.channel.ChannelInitializer;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.codec.DecoderException;
28  import io.netty.handler.ssl.JdkSslClientContext;
29  import io.netty.handler.ssl.OpenSsl;
30  import io.netty.handler.ssl.OpenSslServerContext;
31  import io.netty.handler.ssl.SslContext;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty.handler.ssl.util.SelfSignedCertificate;
35  import io.netty.util.concurrent.Future;
36  import io.netty.util.internal.logging.InternalLogger;
37  import io.netty.util.internal.logging.InternalLoggerFactory;
38  import org.junit.jupiter.api.TestInfo;
39  import org.junit.jupiter.api.Timeout;
40  import org.junit.jupiter.api.condition.DisabledIf;
41  import org.junit.jupiter.params.ParameterizedTest;
42  import org.junit.jupiter.params.provider.MethodSource;
43  
44  import java.io.File;
45  import java.nio.channels.ClosedChannelException;
46  import java.security.cert.CertificateException;
47  import java.util.ArrayList;
48  import java.util.Collection;
49  import java.util.List;
50  import java.util.concurrent.Executor;
51  import java.util.concurrent.ExecutorService;
52  import java.util.concurrent.Executors;
53  import java.util.concurrent.TimeUnit;
54  import java.util.concurrent.atomic.AtomicReference;
55  
56  import javax.net.ssl.SSLHandshakeException;
57  
58  import static org.junit.jupiter.api.Assertions.assertSame;
59  import static org.junit.jupiter.api.Assertions.assertTrue;
60  import static org.junit.jupiter.api.Assertions.fail;
61  import static org.junit.jupiter.api.Assumptions.assumeFalse;
62  import static org.junit.jupiter.api.Assumptions.assumeTrue;
63  
64  public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
65      private static final InternalLogger logger = InternalLoggerFactory.getInstance(
66              SocketSslClientRenegotiateTest.class);
67      private static final File CERT_FILE;
68      private static final File KEY_FILE;
69  
70      static {
71          SelfSignedCertificate ssc;
72          try {
73              ssc = new SelfSignedCertificate();
74          } catch (CertificateException e) {
75              throw new Error(e);
76          }
77          CERT_FILE = ssc.certificate();
78          KEY_FILE = ssc.privateKey();
79      }
80  
81      private static boolean openSslNotAvailable() {
82          return !OpenSsl.isAvailable();
83      }
84  
85      public static Collection<Object[]> data() throws Exception {
86          List<SslContext> serverContexts = new ArrayList<SslContext>();
87          List<SslContext> clientContexts = new ArrayList<SslContext>();
88          clientContexts.add(new JdkSslClientContext(CERT_FILE));
89  
90          boolean hasOpenSsl = OpenSsl.isAvailable();
91          if (hasOpenSsl) {
92              OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
93              serverContexts.add(context);
94          } else {
95              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
96          }
97  
98          List<Object[]> params = new ArrayList<Object[]>();
99          for (SslContext sc: serverContexts) {
100             for (SslContext cc: clientContexts) {
101                 for (int i = 0; i < 32; i++) {
102                     params.add(new Object[] { sc, cc, true});
103                     params.add(new Object[] { sc, cc, false});
104                 }
105             }
106         }
107 
108         return params;
109     }
110 
111     private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
112     private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
113 
114     private volatile Channel clientChannel;
115     private volatile Channel serverChannel;
116 
117     private volatile SslHandler clientSslHandler;
118     private volatile SslHandler serverSslHandler;
119 
120     private final TestHandler clientHandler = new TestHandler(clientException);
121 
122     private final TestHandler serverHandler = new TestHandler(serverException);
123 
124     @DisabledIf("openSslNotAvailable")
125     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
126     @MethodSource("data")
127     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
128     public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx,
129                                              final boolean delegate, TestInfo testInfo) throws Throwable {
130         // BoringSSL does not support renegotiation intentionally.
131         assumeFalse("BoringSSL".equals(OpenSsl.versionString()));
132         assumeTrue(OpenSsl.isAvailable());
133         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
134             @Override
135             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
136                 testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate);
137             }
138         });
139     }
140 
141     private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
142         if (executor == null) {
143             return sslCtx.newHandler(allocator);
144         } else {
145             return sslCtx.newHandler(allocator, executor);
146         }
147     }
148 
149     public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
150                                              final SslContext clientCtx, boolean delegate) throws Throwable {
151         reset();
152 
153         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
154 
155         try {
156             sb.childHandler(new ChannelInitializer<Channel>() {
157                 @Override
158                 @SuppressWarnings("deprecation")
159                 public void initChannel(Channel sch) throws Exception {
160                     serverChannel = sch;
161                     serverSslHandler = newSslHandler(serverCtx, sch.alloc(), executorService);
162                     // As we test renegotiation we should use a protocol that support it.
163                     serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
164                     sch.pipeline().addLast("ssl", serverSslHandler);
165                     sch.pipeline().addLast("handler", serverHandler);
166                 }
167             });
168 
169             cb.handler(new ChannelInitializer<Channel>() {
170                 @Override
171                 @SuppressWarnings("deprecation")
172                 public void initChannel(Channel sch) throws Exception {
173                     clientChannel = sch;
174                     clientSslHandler = newSslHandler(clientCtx, sch.alloc(), executorService);
175                     // As we test renegotiation we should use a protocol that support it.
176                     clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
177                     sch.pipeline().addLast("ssl", clientSslHandler);
178                     sch.pipeline().addLast("handler", clientHandler);
179                 }
180             });
181 
182             Channel sc = sb.bind().sync().channel();
183             cb.connect(sc.localAddress()).sync();
184 
185             Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
186             clientHandshakeFuture.sync();
187 
188             String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
189             // Use the first previous enabled ciphersuite and try to renegotiate.
190             clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
191             clientSslHandler.renegotiate().await();
192             serverChannel.close().awaitUninterruptibly();
193             clientChannel.close().awaitUninterruptibly();
194             sc.close().awaitUninterruptibly();
195             try {
196                 if (serverException.get() != null) {
197                     throw serverException.get();
198                 }
199                 fail();
200             } catch (DecoderException e) {
201                 assertTrue(e.getCause() instanceof SSLHandshakeException);
202             }
203             if (clientException.get() != null) {
204                 throw clientException.get();
205             }
206         } finally {
207             if (executorService != null) {
208                 executorService.shutdown();
209             }
210         }
211     }
212 
213     private void reset() {
214         clientException.set(null);
215         serverException.set(null);
216         clientHandler.handshakeCounter = 0;
217         serverHandler.handshakeCounter = 0;
218         clientChannel = null;
219         serverChannel = null;
220 
221         clientSslHandler = null;
222         serverSslHandler = null;
223     }
224 
225     @Sharable
226     private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
227 
228         protected final AtomicReference<Throwable> exception;
229         private int handshakeCounter;
230 
231         TestHandler(AtomicReference<Throwable> exception) {
232             this.exception = exception;
233         }
234 
235         @Override
236         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
237             ctx.flush();
238         }
239 
240         @Override
241         public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242             exception.compareAndSet(null, cause);
243             ctx.close();
244         }
245 
246         @Override
247         public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
248             if (evt instanceof SslHandshakeCompletionEvent) {
249                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
250                 if (handshakeCounter == 0) {
251                     handshakeCounter++;
252                     if (handshakeEvt.cause() != null) {
253                         logger.warn("Handshake failed:", handshakeEvt.cause());
254                     }
255                     assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
256                 } else {
257                     if (ctx.channel().parent() == null) {
258                         assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
259                     }
260                 }
261             }
262         }
263 
264         @Override
265         public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { }
266     }
267 }