View Javadoc
1   /*
2   * Copyright 2014 The Netty Project
3   *
4   * The Netty Project licenses this file to you under the Apache License,
5   * version 2.0 (the "License"); you may not use this file except in compliance
6   * with the License. You may obtain a copy of the License at:
7   *
8   *   https://www.apache.org/licenses/LICENSE-2.0
9   *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13  * License for the specific language governing permissions and limitations
14  * under the License.
15  */
16  package io.netty5.testsuite.transport.socket;
17  
18  import io.netty5.bootstrap.Bootstrap;
19  import io.netty5.bootstrap.ServerBootstrap;
20  import io.netty5.buffer.api.Buffer;
21  import io.netty5.buffer.api.BufferAllocator;
22  import io.netty5.channel.Channel;
23  import io.netty5.channel.ChannelHandlerContext;
24  import io.netty5.channel.ChannelInitializer;
25  import io.netty5.channel.ChannelPipeline;
26  import io.netty5.channel.SimpleChannelInboundHandler;
27  import io.netty5.handler.logging.LogLevel;
28  import io.netty5.handler.logging.LoggingHandler;
29  import io.netty5.handler.ssl.OpenSsl;
30  import io.netty5.handler.ssl.SslContext;
31  import io.netty5.handler.ssl.SslContextBuilder;
32  import io.netty5.handler.ssl.SslHandler;
33  import io.netty5.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty5.handler.ssl.SslProvider;
35  import io.netty5.handler.ssl.util.SelfSignedCertificate;
36  import io.netty5.util.internal.PlatformDependent;
37  import io.netty5.util.internal.logging.InternalLogger;
38  import io.netty5.util.internal.logging.InternalLoggerFactory;
39  import org.junit.jupiter.api.TestInfo;
40  import org.junit.jupiter.api.Timeout;
41  import org.junit.jupiter.params.ParameterizedTest;
42  import org.junit.jupiter.params.provider.MethodSource;
43  
44  import javax.net.ssl.SSLPeerUnverifiedException;
45  import javax.net.ssl.SSLSession;
46  import java.io.File;
47  import java.io.IOException;
48  import java.security.cert.CertificateException;
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.List;
52  import java.util.concurrent.CountDownLatch;
53  import java.util.concurrent.Executor;
54  import java.util.concurrent.ExecutorService;
55  import java.util.concurrent.Executors;
56  import java.util.concurrent.TimeUnit;
57  import java.util.concurrent.atomic.AtomicReference;
58  
59  import static org.junit.jupiter.api.Assertions.assertEquals;
60  import static org.junit.jupiter.api.Assertions.fail;
61  
62  public class SocketSslGreetingTest extends AbstractSocketTest {
63  
64      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
65  
66      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
67      private static final File CERT_FILE;
68      private static final File KEY_FILE;
69  
70      static {
71          SelfSignedCertificate ssc;
72          try {
73              ssc = new SelfSignedCertificate();
74          } catch (CertificateException e) {
75              throw new Error(e);
76          }
77          CERT_FILE = ssc.certificate();
78          KEY_FILE = ssc.privateKey();
79      }
80  
81      public static Collection<Object[]> data() throws Exception {
82          List<SslContext> serverContexts = new ArrayList<>();
83          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
84  
85          List<SslContext> clientContexts = new ArrayList<>();
86          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
87  
88          boolean hasOpenSsl = OpenSsl.isAvailable();
89          if (hasOpenSsl) {
90              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
91                                                  .sslProvider(SslProvider.OPENSSL).build());
92              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
93                                                  .trustManager(CERT_FILE).build());
94          } else {
95              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
96          }
97  
98          List<Object[]> params = new ArrayList<>();
99          for (SslContext sc: serverContexts) {
100             for (SslContext cc: clientContexts) {
101                 params.add(new Object[] { sc, cc, true });
102                 params.add(new Object[] { sc, cc, false });
103             }
104         }
105         return params;
106     }
107 
108     private static SslHandler newSslHandler(SslContext sslCtx, BufferAllocator allocator, Executor executor) {
109         if (executor == null) {
110             return sslCtx.newHandler(allocator);
111         } else {
112             return sslCtx.newHandler(allocator, executor);
113         }
114     }
115 
116     // Test for https://github.com/netty/netty/pull/2437
117     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
118     @MethodSource("data")
119     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
120     public void testSslGreeting(SslContext serverCtx, SslContext clientCtx, boolean delegate,
121                                 TestInfo testInfo) throws Throwable {
122         run(testInfo, (sb, cb) -> testSslGreeting(sb, cb, serverCtx, clientCtx, delegate));
123     }
124 
125     public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, SslContext serverCtx,
126                                 SslContext clientCtx, boolean delegate) throws Throwable {
127         final ServerHandler sh = new ServerHandler();
128         final ClientHandler ch = new ClientHandler();
129 
130         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
131         try {
132             sb.childHandler(new ChannelInitializer<>() {
133                 @Override
134                 public void initChannel(Channel sch) throws Exception {
135                     ChannelPipeline p = sch.pipeline();
136                     p.addLast(newSslHandler(serverCtx, sch.bufferAllocator(), executorService));
137                     p.addLast(new LoggingHandler(LOG_LEVEL));
138                     p.addLast(sh);
139                 }
140             });
141 
142             cb.handler(new ChannelInitializer<>() {
143                 @Override
144                 public void initChannel(Channel sch) throws Exception {
145                     ChannelPipeline p = sch.pipeline();
146                     p.addLast(newSslHandler(clientCtx, sch.bufferAllocator(), executorService));
147                     p.addLast(new LoggingHandler(LOG_LEVEL));
148                     p.addLast(ch);
149                 }
150             });
151 
152             Channel sc = sb.bind().asStage().get();
153             Channel cc = cb.connect(sc.localAddress()).asStage().get();
154 
155             ch.latch.await();
156 
157             sh.channel.close().asStage().await();
158             cc.close().asStage().await();
159             sc.close().asStage().await();
160 
161             if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
162                 throw sh.exception.get();
163             }
164             if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
165                 throw ch.exception.get();
166             }
167             if (sh.exception.get() != null) {
168                 throw sh.exception.get();
169             }
170             if (ch.exception.get() != null) {
171                 throw ch.exception.get();
172             }
173         } finally {
174             if (executorService != null) {
175                 executorService.shutdown();
176             }
177         }
178     }
179 
180     private static class ClientHandler extends SimpleChannelInboundHandler<Buffer> {
181 
182         final AtomicReference<Throwable> exception = new AtomicReference<>();
183         final CountDownLatch latch = new CountDownLatch(1);
184 
185         @Override
186         public void messageReceived(ChannelHandlerContext ctx, Buffer buf) throws Exception {
187             assertEquals('a', buf.readByte());
188             assertEquals(0, buf.readableBytes());
189             latch.countDown();
190             ctx.close();
191         }
192 
193         @Override
194         public void channelExceptionCaught(ChannelHandlerContext ctx,
195                                            Throwable cause) throws Exception {
196             if (logger.isWarnEnabled()) {
197                 logger.warn("Unexpected exception from the client side", cause);
198             }
199 
200             exception.compareAndSet(null, cause);
201             ctx.close();
202         }
203     }
204 
205     private static class ServerHandler extends SimpleChannelInboundHandler<String> {
206         volatile Channel channel;
207         final AtomicReference<Throwable> exception = new AtomicReference<>();
208 
209         @Override
210         protected void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
211             // discard
212         }
213 
214         @Override
215         public void channelActive(ChannelHandlerContext ctx)
216                 throws Exception {
217             channel = ctx.channel();
218             channel.writeAndFlush(ctx.bufferAllocator().copyOf(new byte[] {'a'}));
219         }
220 
221         @Override
222         public void channelExceptionCaught(ChannelHandlerContext ctx,
223                                            Throwable cause) throws Exception {
224             if (logger.isWarnEnabled()) {
225                 logger.warn("Unexpected exception from the server side", cause);
226             }
227 
228             exception.compareAndSet(null, cause);
229             ctx.close();
230         }
231 
232         @Override
233         public void channelInboundEvent(final ChannelHandlerContext ctx, final Object evt) throws Exception {
234             if (evt instanceof SslHandshakeCompletionEvent) {
235                 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
236                 if (event.isSuccess()) {
237                     SSLSession session = event.session();
238                     try {
239                         session.getPeerCertificates();
240                         fail();
241                     } catch (SSLPeerUnverifiedException e) {
242                         // expected
243                     }
244                     try {
245                         session.getPeerCertificateChain();
246                         fail();
247                     } catch (SSLPeerUnverifiedException e) {
248                         // expected
249                     } catch (UnsupportedOperationException e) {
250                         // Starting from Java15 this method throws UnsupportedOperationException as it was
251                         // deprecated before and getPeerCertificates() should be used
252                         if (PlatformDependent.javaVersion() < 15) {
253                             throw e;
254                         }
255                     }
256                     try {
257                         session.getPeerPrincipal();
258                         fail();
259                     } catch (SSLPeerUnverifiedException e) {
260                         // expected
261                     }
262                 }
263             }
264             ctx.fireChannelInboundEvent(evt);
265         }
266     }
267 }