View Javadoc
1   /*
2   * Copyright 2014 The Netty Project
3   *
4   * The Netty Project licenses this file to you under the Apache License,
5   * version 2.0 (the "License"); you may not use this file except in compliance
6   * with the License. You may obtain a copy of the License at:
7   *
8   *   https://www.apache.org/licenses/LICENSE-2.0
9   *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13  * License for the specific language governing permissions and limitations
14  * under the License.
15  */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInitializer;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.logging.LogLevel;
28  import io.netty.handler.logging.LoggingHandler;
29  import io.netty.handler.ssl.OpenSsl;
30  import io.netty.handler.ssl.SslContext;
31  import io.netty.handler.ssl.SslContextBuilder;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty.handler.ssl.SslProvider;
35  import io.netty.handler.ssl.util.SelfSignedCertificate;
36  import io.netty.util.internal.PlatformDependent;
37  import io.netty.util.internal.logging.InternalLogger;
38  import io.netty.util.internal.logging.InternalLoggerFactory;
39  import org.junit.jupiter.api.TestInfo;
40  import org.junit.jupiter.api.Timeout;
41  import org.junit.jupiter.params.ParameterizedTest;
42  import org.junit.jupiter.params.provider.MethodSource;
43  
44  import javax.net.ssl.SSLPeerUnverifiedException;
45  import javax.net.ssl.SSLSession;
46  import java.io.File;
47  import java.io.IOException;
48  import java.security.cert.CertificateException;
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.List;
52  import java.util.concurrent.CountDownLatch;
53  import java.util.concurrent.Executor;
54  import java.util.concurrent.ExecutorService;
55  import java.util.concurrent.Executors;
56  import java.util.concurrent.TimeUnit;
57  import java.util.concurrent.atomic.AtomicReference;
58  
59  import static org.junit.jupiter.api.Assertions.assertEquals;
60  import static org.junit.jupiter.api.Assertions.assertFalse;
61  import static org.junit.jupiter.api.Assertions.fail;
62  
63  public class SocketSslGreetingTest extends AbstractSocketTest {
64  
65      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
66  
67      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
68      private static final File CERT_FILE;
69      private static final File KEY_FILE;
70  
71      static {
72          SelfSignedCertificate ssc;
73          try {
74              ssc = new SelfSignedCertificate();
75          } catch (CertificateException e) {
76              throw new Error(e);
77          }
78          CERT_FILE = ssc.certificate();
79          KEY_FILE = ssc.privateKey();
80      }
81  
82      public static Collection<Object[]> data() throws Exception {
83          List<SslContext> serverContexts = new ArrayList<SslContext>();
84          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
85  
86          List<SslContext> clientContexts = new ArrayList<SslContext>();
87          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
88                  .endpointIdentificationAlgorithm(null).trustManager(CERT_FILE).build());
89  
90          boolean hasOpenSsl = OpenSsl.isAvailable();
91          if (hasOpenSsl) {
92              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
93                                                  .sslProvider(SslProvider.OPENSSL).build());
94              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
95                                                  .endpointIdentificationAlgorithm(null)
96                                                  .trustManager(CERT_FILE).build());
97          } else {
98              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
99          }
100 
101         List<Object[]> params = new ArrayList<Object[]>();
102         for (SslContext sc: serverContexts) {
103             for (SslContext cc: clientContexts) {
104                 params.add(new Object[] { sc, cc, true });
105                 params.add(new Object[] { sc, cc, false });
106             }
107         }
108         return params;
109     }
110 
111     private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
112         if (executor == null) {
113             return sslCtx.newHandler(allocator);
114         } else {
115             return sslCtx.newHandler(allocator, executor);
116         }
117     }
118 
119     // Test for https://github.com/netty/netty/pull/2437
120     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
121     @MethodSource("data")
122     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
123     public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
124                                 TestInfo testInfo) throws Throwable {
125         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
126             @Override
127             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
128                 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
129             }
130         });
131     }
132 
133     public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
134                                 final SslContext clientCtx, boolean delegate) throws Throwable {
135         final ServerHandler sh = new ServerHandler();
136         final ClientHandler ch = new ClientHandler();
137 
138         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
139         try {
140             sb.childHandler(new ChannelInitializer<Channel>() {
141                 @Override
142                 public void initChannel(Channel sch) throws Exception {
143                     ChannelPipeline p = sch.pipeline();
144                     p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
145                     p.addLast(new LoggingHandler(LOG_LEVEL));
146                     p.addLast(sh);
147                 }
148             });
149 
150             cb.handler(new ChannelInitializer<Channel>() {
151                 @Override
152                 public void initChannel(Channel sch) throws Exception {
153                     ChannelPipeline p = sch.pipeline();
154                     p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
155                     p.addLast(new LoggingHandler(LOG_LEVEL));
156                     p.addLast(ch);
157                 }
158             });
159 
160             Channel sc = sb.bind().sync().channel();
161             Channel cc = cb.connect(sc.localAddress()).sync().channel();
162 
163             ch.latch.await();
164 
165             sh.channel.close().awaitUninterruptibly();
166             cc.close().awaitUninterruptibly();
167             sc.close().awaitUninterruptibly();
168 
169             if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
170                 throw sh.exception.get();
171             }
172             if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
173                 throw ch.exception.get();
174             }
175             if (sh.exception.get() != null) {
176                 throw sh.exception.get();
177             }
178             if (ch.exception.get() != null) {
179                 throw ch.exception.get();
180             }
181         } finally {
182             if (executorService != null) {
183                 executorService.shutdown();
184             }
185         }
186     }
187 
188     private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
189 
190         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
191         final CountDownLatch latch = new CountDownLatch(1);
192 
193         @Override
194         public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
195             assertEquals('a', buf.readByte());
196             assertFalse(buf.isReadable());
197             latch.countDown();
198             ctx.close();
199         }
200 
201         @Override
202         public void exceptionCaught(ChannelHandlerContext ctx,
203                                     Throwable cause) throws Exception {
204             if (logger.isWarnEnabled()) {
205                 logger.warn("Unexpected exception from the client side", cause);
206             }
207 
208             exception.compareAndSet(null, cause);
209             ctx.close();
210         }
211     }
212 
213     private static class ServerHandler extends SimpleChannelInboundHandler<String> {
214         volatile Channel channel;
215         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
216 
217         @Override
218         protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
219             // discard
220         }
221 
222         @Override
223         public void channelActive(ChannelHandlerContext ctx)
224                 throws Exception {
225             channel = ctx.channel();
226             channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
227         }
228 
229         @Override
230         public void exceptionCaught(ChannelHandlerContext ctx,
231                                     Throwable cause) throws Exception {
232             if (logger.isWarnEnabled()) {
233                 logger.warn("Unexpected exception from the server side", cause);
234             }
235 
236             exception.compareAndSet(null, cause);
237             ctx.close();
238         }
239 
240         @Override
241         public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
242             if (evt instanceof SslHandshakeCompletionEvent) {
243                 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
244                 if (event.isSuccess()) {
245                     SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
246                     try {
247                         session.getPeerCertificates();
248                         fail();
249                     } catch (SSLPeerUnverifiedException e) {
250                         // expected
251                     }
252                     try {
253                         session.getPeerCertificateChain();
254                         fail();
255                     } catch (SSLPeerUnverifiedException e) {
256                         // expected
257                     } catch (UnsupportedOperationException e) {
258                         // Starting from Java15 this method throws UnsupportedOperationException as it was
259                         // deprecated before and getPeerCertificates() should be used
260                         if (PlatformDependent.javaVersion() < 15) {
261                             throw e;
262                         }
263                     }
264                     try {
265                         session.getPeerPrincipal();
266                         fail();
267                     } catch (SSLPeerUnverifiedException e) {
268                         // expected
269                     }
270                 }
271             }
272             ctx.fireUserEventTriggered(evt);
273         }
274     }
275 }