View Javadoc
1   /*
2    * Copyright 2015 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.buffer.api.Buffer;
21  import io.netty5.buffer.api.BufferAllocator;
22  import io.netty5.channel.Channel;
23  import io.netty5.channel.ChannelHandlerContext;
24  import io.netty5.channel.ChannelInitializer;
25  import io.netty5.channel.SimpleChannelInboundHandler;
26  import io.netty5.handler.codec.DecoderException;
27  import io.netty5.handler.ssl.OpenSsl;
28  import io.netty5.handler.ssl.SslContext;
29  import io.netty5.handler.ssl.SslContextBuilder;
30  import io.netty5.handler.ssl.SslHandler;
31  import io.netty5.handler.ssl.SslHandshakeCompletionEvent;
32  import io.netty5.handler.ssl.SslProvider;
33  import io.netty5.handler.ssl.util.SelfSignedCertificate;
34  import io.netty5.util.concurrent.Future;
35  import io.netty5.util.internal.logging.InternalLogger;
36  import io.netty5.util.internal.logging.InternalLoggerFactory;
37  import org.junit.jupiter.api.TestInfo;
38  import org.junit.jupiter.api.Timeout;
39  import org.junit.jupiter.api.condition.DisabledIf;
40  import org.junit.jupiter.params.ParameterizedTest;
41  import org.junit.jupiter.params.provider.MethodSource;
42  
43  import javax.net.ssl.SSLHandshakeException;
44  import java.io.File;
45  import java.nio.channels.ClosedChannelException;
46  import java.security.cert.CertificateException;
47  import java.util.ArrayList;
48  import java.util.Collection;
49  import java.util.List;
50  import java.util.concurrent.Executor;
51  import java.util.concurrent.ExecutorService;
52  import java.util.concurrent.Executors;
53  import java.util.concurrent.TimeUnit;
54  import java.util.concurrent.atomic.AtomicReference;
55  
56  import static org.junit.jupiter.api.Assertions.assertTrue;
57  import static org.junit.jupiter.api.Assertions.fail;
58  import static org.junit.jupiter.api.Assumptions.assumeFalse;
59  
60  public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
61      private static final InternalLogger logger = InternalLoggerFactory.getInstance(
62              SocketSslClientRenegotiateTest.class);
63      private static final File CERT_FILE;
64      private static final File KEY_FILE;
65  
66      static {
67          SelfSignedCertificate ssc;
68          try {
69              ssc = new SelfSignedCertificate();
70          } catch (CertificateException e) {
71              throw new Error(e);
72          }
73          CERT_FILE = ssc.certificate();
74          KEY_FILE = ssc.privateKey();
75      }
76  
77      private static boolean openSslNotAvailable() {
78          return !OpenSsl.isAvailable();
79      }
80  
81      public static Collection<Object[]> data() throws Exception {
82          List<SslContext> serverContexts = new ArrayList<>();
83          List<SslContext> clientContexts = new ArrayList<>();
84          clientContexts.add(
85            SslContextBuilder.forClient()
86              .trustManager(CERT_FILE)
87              .sslProvider(SslProvider.JDK)
88              .build()
89          );
90  
91          boolean hasOpenSsl = OpenSsl.isAvailable();
92          if (hasOpenSsl) {
93              serverContexts.add(
94                SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
95                  .sslProvider(SslProvider.OPENSSL)
96                  .build()
97              );
98          } else {
99              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
100         }
101 
102         List<Object[]> params = new ArrayList<>();
103         for (SslContext sc: serverContexts) {
104             for (SslContext cc: clientContexts) {
105                 for (int i = 0; i < 32; i++) {
106                     params.add(new Object[] { sc, cc, true});
107                     params.add(new Object[] { sc, cc, false});
108                 }
109             }
110         }
111 
112         return params;
113     }
114 
115     private final AtomicReference<Throwable> clientException = new AtomicReference<>();
116     private final AtomicReference<Throwable> serverException = new AtomicReference<>();
117 
118     private volatile Channel clientChannel;
119     private volatile Channel serverChannel;
120 
121     private volatile SslHandler clientSslHandler;
122     private volatile SslHandler serverSslHandler;
123 
124     private final TestHandler clientHandler = new TestHandler(clientException);
125 
126     private final TestHandler serverHandler = new TestHandler(serverException);
127 
128     @DisabledIf("openSslNotAvailable")
129     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
130     @MethodSource("data")
131     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
132     public void testSslRenegotiationRejected(SslContext serverCtx, SslContext clientCtx, boolean delegate,
133                                              TestInfo testInfo) throws Throwable {
134         // BoringSSL does not support renegotiation intentionally.
135         assumeFalse("BoringSSL".equals(OpenSsl.versionString()));
136         run(testInfo, (sb, cb) -> testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate));
137     }
138 
139     private static SslHandler newSslHandler(SslContext sslCtx, BufferAllocator allocator, Executor executor) {
140         if (executor == null) {
141             return sslCtx.newHandler(allocator);
142         } else {
143             return sslCtx.newHandler(allocator, executor);
144         }
145     }
146 
147     public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, SslContext serverCtx,
148                                              SslContext clientCtx, boolean delegate) throws Throwable {
149         reset();
150 
151         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
152 
153         try {
154             sb.childHandler(new ChannelInitializer<>() {
155                 @Override
156                 public void initChannel(Channel sch) throws Exception {
157                     serverChannel = sch;
158                     serverSslHandler = newSslHandler(serverCtx, sch.bufferAllocator(), executorService);
159                     // As we test renegotiation we should use a protocol that support it.
160                     serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
161                     sch.pipeline().addLast("ssl", serverSslHandler);
162                     sch.pipeline().addLast("handler", serverHandler);
163                 }
164             });
165 
166             cb.handler(new ChannelInitializer<>() {
167                 @Override
168                 public void initChannel(Channel sch) throws Exception {
169                     clientChannel = sch;
170                     clientSslHandler = newSslHandler(clientCtx, sch.bufferAllocator(), executorService);
171                     // As we test renegotiation we should use a protocol that support it.
172                     clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
173                     sch.pipeline().addLast("ssl", clientSslHandler);
174                     sch.pipeline().addLast("handler", clientHandler);
175                 }
176             });
177 
178             Channel sc = sb.bind().asStage().get();
179             cb.connect(sc.localAddress()).asStage().sync();
180 
181             Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
182             clientHandshakeFuture.asStage().sync();
183 
184             String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
185             // Use the first previous enabled ciphersuite and try to renegotiate.
186             clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
187             clientSslHandler.renegotiate().asStage().await();
188             serverChannel.close().asStage().await();
189             clientChannel.close().asStage().await();
190             sc.close().asStage().await();
191             try {
192                 if (serverException.get() != null) {
193                     throw serverException.get();
194                 }
195                 fail();
196             } catch (DecoderException e) {
197                 assertTrue(e.getCause() instanceof SSLHandshakeException);
198             }
199             if (clientException.get() != null) {
200                 throw clientException.get();
201             }
202         } finally {
203             if (executorService != null) {
204                 executorService.shutdown();
205             }
206         }
207     }
208 
209     private void reset() {
210         clientException.set(null);
211         serverException.set(null);
212         clientHandler.handshakeCounter = 0;
213         serverHandler.handshakeCounter = 0;
214         clientChannel = null;
215         serverChannel = null;
216 
217         clientSslHandler = null;
218         serverSslHandler = null;
219     }
220 
221     private static final class TestHandler extends SimpleChannelInboundHandler<Buffer> {
222 
223         private final AtomicReference<Throwable> exception;
224         private int handshakeCounter;
225 
226         TestHandler(AtomicReference<Throwable> exception) {
227             this.exception = exception;
228         }
229 
230         @Override
231         public boolean isSharable() {
232             return true;
233         }
234 
235         @Override
236         public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
237             ctx.flush();
238         }
239 
240         @Override
241         public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242             exception.compareAndSet(null, cause);
243             ctx.close();
244         }
245 
246         @Override
247         public void channelInboundEvent(ChannelHandlerContext ctx, Object evt) throws Exception {
248             if (evt instanceof SslHandshakeCompletionEvent) {
249                 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
250                 if (handshakeCounter == 0) {
251                     handshakeCounter++;
252                     if (handshakeEvt.cause() != null) {
253                         logger.warn("Handshake failed:", handshakeEvt.cause());
254                     }
255                     assertTrue(handshakeEvt.isSuccess());
256                 } else {
257                     if (ctx.channel().parent() == null) {
258                         assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
259                     }
260                 }
261             }
262         }
263 
264         @Override
265         public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception { }
266     }
267 }