View Javadoc
1   /*
2   * Copyright 2014 The Netty Project
3   *
4   * The Netty Project licenses this file to you under the Apache License,
5   * version 2.0 (the "License"); you may not use this file except in compliance
6   * with the License. You may obtain a copy of the License at:
7   *
8   *   https://www.apache.org/licenses/LICENSE-2.0
9   *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13  * License for the specific language governing permissions and limitations
14  * under the License.
15  */
16  package io.netty.testsuite.transport.socket;
17  
18  import io.netty.bootstrap.Bootstrap;
19  import io.netty.bootstrap.ServerBootstrap;
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.ByteBufAllocator;
22  import io.netty.channel.Channel;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelInitializer;
25  import io.netty.channel.ChannelPipeline;
26  import io.netty.channel.SimpleChannelInboundHandler;
27  import io.netty.handler.logging.LogLevel;
28  import io.netty.handler.logging.LoggingHandler;
29  import io.netty.handler.ssl.OpenSsl;
30  import io.netty.handler.ssl.SslContext;
31  import io.netty.handler.ssl.SslContextBuilder;
32  import io.netty.handler.ssl.SslHandler;
33  import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34  import io.netty.handler.ssl.SslProvider;
35  import io.netty.handler.ssl.util.SelfSignedCertificate;
36  import io.netty.util.internal.PlatformDependent;
37  import io.netty.util.internal.logging.InternalLogger;
38  import io.netty.util.internal.logging.InternalLoggerFactory;
39  import org.junit.jupiter.api.TestInfo;
40  import org.junit.jupiter.api.Timeout;
41  import org.junit.jupiter.params.ParameterizedTest;
42  import org.junit.jupiter.params.provider.MethodSource;
43  
44  import javax.net.ssl.SSLPeerUnverifiedException;
45  import javax.net.ssl.SSLSession;
46  import java.io.File;
47  import java.io.IOException;
48  import java.security.cert.CertificateException;
49  import java.util.ArrayList;
50  import java.util.Collection;
51  import java.util.List;
52  import java.util.concurrent.CountDownLatch;
53  import java.util.concurrent.Executor;
54  import java.util.concurrent.ExecutorService;
55  import java.util.concurrent.Executors;
56  import java.util.concurrent.TimeUnit;
57  import java.util.concurrent.atomic.AtomicReference;
58  
59  import static org.junit.jupiter.api.Assertions.assertEquals;
60  import static org.junit.jupiter.api.Assertions.assertFalse;
61  import static org.junit.jupiter.api.Assertions.fail;
62  
63  public class SocketSslGreetingTest extends AbstractSocketTest {
64  
65      private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
66  
67      private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
68      private static final File CERT_FILE;
69      private static final File KEY_FILE;
70  
71      static {
72          SelfSignedCertificate ssc;
73          try {
74              ssc = new SelfSignedCertificate();
75          } catch (CertificateException e) {
76              throw new Error(e);
77          }
78          CERT_FILE = ssc.certificate();
79          KEY_FILE = ssc.privateKey();
80      }
81  
82      public static Collection<Object[]> data() throws Exception {
83          List<SslContext> serverContexts = new ArrayList<SslContext>();
84          serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
85  
86          List<SslContext> clientContexts = new ArrayList<SslContext>();
87          clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
88  
89          boolean hasOpenSsl = OpenSsl.isAvailable();
90          if (hasOpenSsl) {
91              serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
92                                                  .sslProvider(SslProvider.OPENSSL).build());
93              clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
94                                                  .trustManager(CERT_FILE).build());
95          } else {
96              logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
97          }
98  
99          List<Object[]> params = new ArrayList<Object[]>();
100         for (SslContext sc: serverContexts) {
101             for (SslContext cc: clientContexts) {
102                 params.add(new Object[] { sc, cc, true });
103                 params.add(new Object[] { sc, cc, false });
104             }
105         }
106         return params;
107     }
108 
109     private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
110         if (executor == null) {
111             return sslCtx.newHandler(allocator);
112         } else {
113             return sslCtx.newHandler(allocator, executor);
114         }
115     }
116 
117     // Test for https://github.com/netty/netty/pull/2437
118     @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
119     @MethodSource("data")
120     @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
121     public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
122                                 TestInfo testInfo) throws Throwable {
123         run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
124             @Override
125             public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
126                 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
127             }
128         });
129     }
130 
131     public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
132                                 final SslContext clientCtx, boolean delegate) throws Throwable {
133         final ServerHandler sh = new ServerHandler();
134         final ClientHandler ch = new ClientHandler();
135 
136         final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
137         try {
138             sb.childHandler(new ChannelInitializer<Channel>() {
139                 @Override
140                 public void initChannel(Channel sch) throws Exception {
141                     ChannelPipeline p = sch.pipeline();
142                     p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
143                     p.addLast(new LoggingHandler(LOG_LEVEL));
144                     p.addLast(sh);
145                 }
146             });
147 
148             cb.handler(new ChannelInitializer<Channel>() {
149                 @Override
150                 public void initChannel(Channel sch) throws Exception {
151                     ChannelPipeline p = sch.pipeline();
152                     p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
153                     p.addLast(new LoggingHandler(LOG_LEVEL));
154                     p.addLast(ch);
155                 }
156             });
157 
158             Channel sc = sb.bind().sync().channel();
159             Channel cc = cb.connect(sc.localAddress()).sync().channel();
160 
161             ch.latch.await();
162 
163             sh.channel.close().awaitUninterruptibly();
164             cc.close().awaitUninterruptibly();
165             sc.close().awaitUninterruptibly();
166 
167             if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
168                 throw sh.exception.get();
169             }
170             if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
171                 throw ch.exception.get();
172             }
173             if (sh.exception.get() != null) {
174                 throw sh.exception.get();
175             }
176             if (ch.exception.get() != null) {
177                 throw ch.exception.get();
178             }
179         } finally {
180             if (executorService != null) {
181                 executorService.shutdown();
182             }
183         }
184     }
185 
186     private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
187 
188         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
189         final CountDownLatch latch = new CountDownLatch(1);
190 
191         @Override
192         public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
193             assertEquals('a', buf.readByte());
194             assertFalse(buf.isReadable());
195             latch.countDown();
196             ctx.close();
197         }
198 
199         @Override
200         public void exceptionCaught(ChannelHandlerContext ctx,
201                                     Throwable cause) throws Exception {
202             if (logger.isWarnEnabled()) {
203                 logger.warn("Unexpected exception from the client side", cause);
204             }
205 
206             exception.compareAndSet(null, cause);
207             ctx.close();
208         }
209     }
210 
211     private static class ServerHandler extends SimpleChannelInboundHandler<String> {
212         volatile Channel channel;
213         final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
214 
215         @Override
216         protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
217             // discard
218         }
219 
220         @Override
221         public void channelActive(ChannelHandlerContext ctx)
222                 throws Exception {
223             channel = ctx.channel();
224             channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
225         }
226 
227         @Override
228         public void exceptionCaught(ChannelHandlerContext ctx,
229                                     Throwable cause) throws Exception {
230             if (logger.isWarnEnabled()) {
231                 logger.warn("Unexpected exception from the server side", cause);
232             }
233 
234             exception.compareAndSet(null, cause);
235             ctx.close();
236         }
237 
238         @Override
239         public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
240             if (evt instanceof SslHandshakeCompletionEvent) {
241                 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
242                 if (event.isSuccess()) {
243                     SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
244                     try {
245                         session.getPeerCertificates();
246                         fail();
247                     } catch (SSLPeerUnverifiedException e) {
248                         // expected
249                     }
250                     try {
251                         session.getPeerCertificateChain();
252                         fail();
253                     } catch (SSLPeerUnverifiedException e) {
254                         // expected
255                     } catch (UnsupportedOperationException e) {
256                         // Starting from Java15 this method throws UnsupportedOperationException as it was
257                         // deprecated before and getPeerCertificates() should be used
258                         if (PlatformDependent.javaVersion() < 15) {
259                             throw e;
260                         }
261                     }
262                     try {
263                         session.getPeerPrincipal();
264                         fail();
265                     } catch (SSLPeerUnverifiedException e) {
266                         // expected
267                     }
268                 }
269             }
270             ctx.fireUserEventTriggered(evt);
271         }
272     }
273 }