View Javadoc
1   /*
2   * Copyright 2014 The Netty Project
3   *
4   * The Netty Project licenses this file to you under the Apache License,
5   * version 2.0 (the "License"); you may not use this file except in compliance
6   * with the License. You may obtain a copy of the License at:
7   *
8   *   https://www.apache.org/licenses/LICENSE-2.0
9   *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13  * License for the specific language governing permissions and limitations
14  * under the License.
15  */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInitializer;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.logging.LogLevel;
28  import io.netty.handler.logging.LoggingHandler;
29  import io.netty.handler.ssl.OpenSsl;
30  import io.netty.handler.ssl.SslContext;
31  import io.netty.handler.ssl.SslContextBuilder;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty.handler.ssl.SslProvider;
35  import io.netty.pkitesting.CertificateBuilder;
36  import io.netty.pkitesting.X509Bundle;
37  import io.netty.util.internal.PlatformDependent;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  import org.junit.jupiter.api.TestInfo;
41  import org.junit.jupiter.api.Timeout;
42  import org.junit.jupiter.params.ParameterizedTest;
43  import org.junit.jupiter.params.provider.MethodSource;
44  
45  import javax.net.ssl.SSLPeerUnverifiedException;
46  import javax.net.ssl.SSLSession;
47  import java.io.File;
48  import java.io.IOException;
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.List;
52  import java.util.concurrent.CountDownLatch;
53  import java.util.concurrent.Executor;
54  import java.util.concurrent.ExecutorService;
55  import java.util.concurrent.Executors;
56  import java.util.concurrent.TimeUnit;
57  import java.util.concurrent.atomic.AtomicReference;
58  
59  import static org.junit.jupiter.api.Assertions.assertEquals;
60  import static org.junit.jupiter.api.Assertions.assertFalse;
61  import static org.junit.jupiter.api.Assertions.fail;
62  
63  public class SocketSslGreetingTest extends AbstractSocketTest {
64  
65      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
66  
67      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
68      private static final File CERT_FILE;
69      private static final File KEY_FILE;
70  
71      static {
72          try {
73              X509Bundle cert = new CertificateBuilder()
74                      .subject("cn=localhost")
75                      .setIsCertificateAuthority(true)
76                      .buildSelfSigned();
77              CERT_FILE = cert.toTempCertChainPem();
78              KEY_FILE = cert.toTempPrivateKeyPem();
79          } catch (Exception e) {
80              throw new ExceptionInInitializerError(e);
81          }
82      }
83  
84      public static Collection<Object[]> data() throws Exception {
85          List<SslContext> serverContexts = new ArrayList<SslContext>();
86          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
87  
88          List<SslContext> clientContexts = new ArrayList<SslContext>();
89          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
90                  .endpointIdentificationAlgorithm(null).trustManager(CERT_FILE).build());
91  
92          boolean hasOpenSsl = OpenSsl.isAvailable();
93          if (hasOpenSsl) {
94              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
95                                                  .sslProvider(SslProvider.OPENSSL).build());
96              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
97                                                  .endpointIdentificationAlgorithm(null)
98                                                  .trustManager(CERT_FILE).build());
99          } else {
100             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
101         }
102 
103         List<Object[]> params = new ArrayList<Object[]>();
104         for (SslContext sc: serverContexts) {
105             for (SslContext cc: clientContexts) {
106                 params.add(new Object[] { sc, cc, true });
107                 params.add(new Object[] { sc, cc, false });
108             }
109         }
110         return params;
111     }
112 
113     private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
114         if (executor == null) {
115             return sslCtx.newHandler(allocator);
116         } else {
117             return sslCtx.newHandler(allocator, executor);
118         }
119     }
120 
121     // Test for https://github.com/netty/netty/pull/2437
122     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
123     @MethodSource("data")
124     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
125     public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
126                                 TestInfo testInfo) throws Throwable {
127         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
128             @Override
129             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
130                 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
131             }
132         });
133     }
134 
135     public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
136                                 final SslContext clientCtx, boolean delegate) throws Throwable {
137         final ServerHandler sh = new ServerHandler();
138         final ClientHandler ch = new ClientHandler();
139 
140         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
141         try {
142             sb.childHandler(new ChannelInitializer<Channel>() {
143                 @Override
144                 public void initChannel(Channel sch) throws Exception {
145                     ChannelPipeline p = sch.pipeline();
146                     p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
147                     p.addLast(new LoggingHandler(LOG_LEVEL));
148                     p.addLast(sh);
149                 }
150             });
151 
152             cb.handler(new ChannelInitializer<Channel>() {
153                 @Override
154                 public void initChannel(Channel sch) throws Exception {
155                     ChannelPipeline p = sch.pipeline();
156                     p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
157                     p.addLast(new LoggingHandler(LOG_LEVEL));
158                     p.addLast(ch);
159                 }
160             });
161 
162             Channel sc = sb.bind().sync().channel();
163             Channel cc = cb.connect(sc.localAddress()).sync().channel();
164 
165             ch.latch.await();
166 
167             sh.channel.close().awaitUninterruptibly();
168             cc.close().awaitUninterruptibly();
169             sc.close().awaitUninterruptibly();
170 
171             if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
172                 throw sh.exception.get();
173             }
174             if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
175                 throw ch.exception.get();
176             }
177             if (sh.exception.get() != null) {
178                 throw sh.exception.get();
179             }
180             if (ch.exception.get() != null) {
181                 throw ch.exception.get();
182             }
183         } finally {
184             if (executorService != null) {
185                 executorService.shutdown();
186             }
187         }
188     }
189 
190     private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
191 
192         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
193         final CountDownLatch latch = new CountDownLatch(1);
194 
195         @Override
196         public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
197             assertEquals('a', buf.readByte());
198             assertFalse(buf.isReadable());
199             latch.countDown();
200             ctx.close();
201         }
202 
203         @Override
204         public void exceptionCaught(ChannelHandlerContext ctx,
205                                     Throwable cause) throws Exception {
206             if (logger.isWarnEnabled()) {
207                 logger.warn("Unexpected exception from the client side", cause);
208             }
209 
210             exception.compareAndSet(null, cause);
211             ctx.close();
212         }
213     }
214 
215     private static class ServerHandler extends SimpleChannelInboundHandler<String> {
216         volatile Channel channel;
217         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
218 
219         @Override
220         protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
221             // discard
222         }
223 
224         @Override
225         public void channelActive(ChannelHandlerContext ctx)
226                 throws Exception {
227             channel = ctx.channel();
228             channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
229         }
230 
231         @Override
232         public void exceptionCaught(ChannelHandlerContext ctx,
233                                     Throwable cause) throws Exception {
234             if (logger.isWarnEnabled()) {
235                 logger.warn("Unexpected exception from the server side", cause);
236             }
237 
238             exception.compareAndSet(null, cause);
239             ctx.close();
240         }
241 
242         @Override
243         public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
244             if (evt instanceof SslHandshakeCompletionEvent) {
245                 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
246                 if (event.isSuccess()) {
247                     SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
248                     try {
249                         session.getPeerCertificates();
250                         fail();
251                     } catch (SSLPeerUnverifiedException e) {
252                         // expected
253                     }
254                     try {
255                         session.getPeerCertificateChain();
256                         fail();
257                     } catch (SSLPeerUnverifiedException e) {
258                         // expected
259                     } catch (UnsupportedOperationException e) {
260                         // Starting from Java15 this method throws UnsupportedOperationException as it was
261                         // deprecated before and getPeerCertificates() should be used
262                         if (PlatformDependent.javaVersion() < 15) {
263                             throw e;
264                         }
265                     }
266                     try {
267                         session.getPeerPrincipal();
268                         fail();
269                     } catch (SSLPeerUnverifiedException e) {
270                         // expected
271                     }
272                 }
273             }
274             ctx.fireUserEventTriggered(evt);
275         }
276     }
277 }