View Javadoc
1   /*
2   * Copyright 2014 The Netty Project
3   *
4   * The Netty Project licenses this file to you under the Apache License,
5   * version 2.0 (the "License"); you may not use this file except in compliance
6   * with the License. You may obtain a copy of the License at:
7   *
8   *   https://www.apache.org/licenses/LICENSE-2.0
9   *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13  * License for the specific language governing permissions and limitations
14  * under the License.
15  */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInitializer;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.logging.LogLevel;
28  import io.netty.handler.logging.LoggingHandler;
29  import io.netty.handler.ssl.OpenSsl;
30  import io.netty.handler.ssl.SslContext;
31  import io.netty.handler.ssl.SslContextBuilder;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty.handler.ssl.SslProvider;
35  import io.netty.pkitesting.CertificateBuilder;
36  import io.netty.pkitesting.X509Bundle;
37  import io.netty.util.internal.PlatformDependent;
38  import io.netty.util.internal.logging.InternalLogger;
39  import io.netty.util.internal.logging.InternalLoggerFactory;
40  import org.junit.jupiter.api.TestInfo;
41  import org.junit.jupiter.api.Timeout;
42  import org.junit.jupiter.params.ParameterizedTest;
43  import org.junit.jupiter.params.provider.MethodSource;
44  
45  import javax.net.ssl.SSLPeerUnverifiedException;
46  import javax.net.ssl.SSLSession;
47  import java.io.File;
48  import java.io.IOException;
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.List;
52  import java.util.concurrent.CountDownLatch;
53  import java.util.concurrent.Executor;
54  import java.util.concurrent.ExecutorService;
55  import java.util.concurrent.Executors;
56  import java.util.concurrent.TimeUnit;
57  import java.util.concurrent.atomic.AtomicReference;
58  
59  import static org.junit.jupiter.api.Assertions.assertEquals;
60  import static org.junit.jupiter.api.Assertions.assertFalse;
61  import static org.junit.jupiter.api.Assertions.assertTrue;
62  import static org.junit.jupiter.api.Assertions.fail;
63  
64  public class SocketSslGreetingTest extends AbstractSocketTest {
65  
66      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
67  
68      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
69      private static final File CERT_FILE;
70      private static final File KEY_FILE;
71  
72      static {
73          try {
74              X509Bundle cert = new CertificateBuilder()
75                      .subject("cn=localhost")
76                      .setIsCertificateAuthority(true)
77                      .buildSelfSigned();
78              CERT_FILE = cert.toTempCertChainPem();
79              KEY_FILE = cert.toTempPrivateKeyPem();
80          } catch (Exception e) {
81              throw new ExceptionInInitializerError(e);
82          }
83      }
84  
85      public static Collection<Object[]> data() throws Exception {
86          List<SslContext> serverContexts = new ArrayList<SslContext>();
87          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
88  
89          List<SslContext> clientContexts = new ArrayList<SslContext>();
90          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
91                  .endpointIdentificationAlgorithm(null).trustManager(CERT_FILE).build());
92  
93          boolean hasOpenSsl = OpenSsl.isAvailable();
94          if (hasOpenSsl) {
95              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
96                                                  .sslProvider(SslProvider.OPENSSL).build());
97              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
98                                                  .endpointIdentificationAlgorithm(null)
99                                                  .trustManager(CERT_FILE).build());
100         } else {
101             logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
102         }
103 
104         List<Object[]> params = new ArrayList<Object[]>();
105         for (SslContext sc: serverContexts) {
106             for (SslContext cc: clientContexts) {
107                 params.add(new Object[] { sc, cc, true });
108                 params.add(new Object[] { sc, cc, false });
109             }
110         }
111         return params;
112     }
113 
114     private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
115         if (executor == null) {
116             return sslCtx.newHandler(allocator);
117         } else {
118             return sslCtx.newHandler(allocator, executor);
119         }
120     }
121 
122     // Test for https://github.com/netty/netty/pull/2437
123     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
124     @MethodSource("data")
125     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
126     public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
127                                 TestInfo testInfo) throws Throwable {
128         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
129             @Override
130             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
131                 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
132             }
133         });
134     }
135 
136     public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
137                                 final SslContext clientCtx, boolean delegate) throws Throwable {
138         final ServerHandler sh = new ServerHandler();
139         final ClientHandler ch = new ClientHandler();
140 
141         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
142         try {
143             sb.childHandler(new ChannelInitializer<Channel>() {
144                 @Override
145                 public void initChannel(Channel sch) throws Exception {
146                     ChannelPipeline p = sch.pipeline();
147                     p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
148                     p.addLast(new LoggingHandler(LOG_LEVEL));
149                     p.addLast(sh);
150                 }
151             });
152 
153             cb.handler(new ChannelInitializer<Channel>() {
154                 @Override
155                 public void initChannel(Channel sch) throws Exception {
156                     ChannelPipeline p = sch.pipeline();
157                     p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
158                     p.addLast(new LoggingHandler(LOG_LEVEL));
159                     p.addLast(ch);
160                 }
161             });
162 
163             Channel sc = sb.bind().sync().channel();
164             Channel cc = cb.connect(sc.localAddress()).sync().channel();
165 
166             ch.latch.await();
167 
168             sh.channel.close().awaitUninterruptibly();
169             cc.close().awaitUninterruptibly();
170             sc.close().awaitUninterruptibly();
171 
172             if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
173                 throw sh.exception.get();
174             }
175             if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
176                 throw ch.exception.get();
177             }
178             if (sh.exception.get() != null) {
179                 throw sh.exception.get();
180             }
181             if (ch.exception.get() != null) {
182                 throw ch.exception.get();
183             }
184         } finally {
185             if (executorService != null) {
186                 executorService.shutdown();
187                 assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS));
188             }
189         }
190     }
191 
192     private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
193 
194         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
195         final CountDownLatch latch = new CountDownLatch(1);
196 
197         @Override
198         public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
199             assertEquals('a', buf.readByte());
200             assertFalse(buf.isReadable());
201             latch.countDown();
202             ctx.close();
203         }
204 
205         @Override
206         public void exceptionCaught(ChannelHandlerContext ctx,
207                                     Throwable cause) throws Exception {
208             if (logger.isWarnEnabled()) {
209                 logger.warn("Unexpected exception from the client side", cause);
210             }
211 
212             exception.compareAndSet(null, cause);
213             ctx.close();
214         }
215     }
216 
217     private static class ServerHandler extends SimpleChannelInboundHandler<String> {
218         volatile Channel channel;
219         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
220 
221         @Override
222         protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
223             // discard
224         }
225 
226         @Override
227         public void channelActive(ChannelHandlerContext ctx)
228                 throws Exception {
229             channel = ctx.channel();
230             channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
231         }
232 
233         @Override
234         public void exceptionCaught(ChannelHandlerContext ctx,
235                                     Throwable cause) throws Exception {
236             if (logger.isWarnEnabled()) {
237                 logger.warn("Unexpected exception from the server side", cause);
238             }
239 
240             exception.compareAndSet(null, cause);
241             ctx.close();
242         }
243 
244         @Override
245         public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
246             if (evt instanceof SslHandshakeCompletionEvent) {
247                 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
248                 if (event.isSuccess()) {
249                     SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
250                     try {
251                         session.getPeerCertificates();
252                         fail();
253                     } catch (SSLPeerUnverifiedException e) {
254                         // expected
255                     }
256                     try {
257                         session.getPeerCertificateChain();
258                         fail();
259                     } catch (SSLPeerUnverifiedException e) {
260                         // expected
261                     } catch (UnsupportedOperationException e) {
262                         // Starting from Java15 this method throws UnsupportedOperationException as it was
263                         // deprecated before and getPeerCertificates() should be used
264                         if (PlatformDependent.javaVersion() < 15) {
265                             throw e;
266                         }
267                     }
268                     try {
269                         session.getPeerPrincipal();
270                         fail();
271                     } catch (SSLPeerUnverifiedException e) {
272                         // expected
273                     }
274                 }
275             }
276             ctx.fireUserEventTriggered(evt);
277         }
278     }
279 }