1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInitializer;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.logging.LogLevel;
28 import io.netty.handler.logging.LoggingHandler;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.SslContext;
31 import io.netty.handler.ssl.SslContextBuilder;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.handler.ssl.SslProvider;
35 import io.netty.pkitesting.CertificateBuilder;
36 import io.netty.pkitesting.X509Bundle;
37 import io.netty.util.internal.PlatformDependent;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40 import org.junit.jupiter.api.TestInfo;
41 import org.junit.jupiter.api.Timeout;
42 import org.junit.jupiter.params.ParameterizedTest;
43 import org.junit.jupiter.params.provider.MethodSource;
44
45 import javax.net.ssl.SSLPeerUnverifiedException;
46 import javax.net.ssl.SSLSession;
47 import java.io.File;
48 import java.io.IOException;
49 import java.util.ArrayList;
50 import java.util.Collection;
51 import java.util.List;
52 import java.util.concurrent.CountDownLatch;
53 import java.util.concurrent.Executor;
54 import java.util.concurrent.ExecutorService;
55 import java.util.concurrent.Executors;
56 import java.util.concurrent.TimeUnit;
57 import java.util.concurrent.atomic.AtomicReference;
58
59 import static org.junit.jupiter.api.Assertions.assertEquals;
60 import static org.junit.jupiter.api.Assertions.assertFalse;
61 import static org.junit.jupiter.api.Assertions.assertTrue;
62 import static org.junit.jupiter.api.Assertions.fail;
63
64 public class SocketSslGreetingTest extends AbstractSocketTest {
65
66 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
67
68 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
69 private static final File CERT_FILE;
70 private static final File KEY_FILE;
71
72 static {
73 try {
74 X509Bundle cert = new CertificateBuilder()
75 .subject("cn=localhost")
76 .setIsCertificateAuthority(true)
77 .buildSelfSigned();
78 CERT_FILE = cert.toTempCertChainPem();
79 KEY_FILE = cert.toTempPrivateKeyPem();
80 } catch (Exception e) {
81 throw new ExceptionInInitializerError(e);
82 }
83 }
84
85 public static Collection<Object[]> data() throws Exception {
86 List<SslContext> serverContexts = new ArrayList<SslContext>();
87 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
88
89 List<SslContext> clientContexts = new ArrayList<SslContext>();
90 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
91 .endpointIdentificationAlgorithm(null).trustManager(CERT_FILE).build());
92
93 boolean hasOpenSsl = OpenSsl.isAvailable();
94 if (hasOpenSsl) {
95 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
96 .sslProvider(SslProvider.OPENSSL).build());
97 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
98 .endpointIdentificationAlgorithm(null)
99 .trustManager(CERT_FILE).build());
100 } else {
101 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
102 }
103
104 List<Object[]> params = new ArrayList<Object[]>();
105 for (SslContext sc: serverContexts) {
106 for (SslContext cc: clientContexts) {
107 params.add(new Object[] { sc, cc, true });
108 params.add(new Object[] { sc, cc, false });
109 }
110 }
111 return params;
112 }
113
114 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
115 if (executor == null) {
116 return sslCtx.newHandler(allocator);
117 } else {
118 return sslCtx.newHandler(allocator, executor);
119 }
120 }
121
122
123 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
124 @MethodSource("data")
125 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
126 public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
127 TestInfo testInfo) throws Throwable {
128 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
129 @Override
130 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
131 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
132 }
133 });
134 }
135
136 public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
137 final SslContext clientCtx, boolean delegate) throws Throwable {
138 final ServerHandler sh = new ServerHandler();
139 final ClientHandler ch = new ClientHandler();
140
141 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
142 try {
143 sb.childHandler(new ChannelInitializer<Channel>() {
144 @Override
145 public void initChannel(Channel sch) throws Exception {
146 ChannelPipeline p = sch.pipeline();
147 p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
148 p.addLast(new LoggingHandler(LOG_LEVEL));
149 p.addLast(sh);
150 }
151 });
152
153 cb.handler(new ChannelInitializer<Channel>() {
154 @Override
155 public void initChannel(Channel sch) throws Exception {
156 ChannelPipeline p = sch.pipeline();
157 p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
158 p.addLast(new LoggingHandler(LOG_LEVEL));
159 p.addLast(ch);
160 }
161 });
162
163 Channel sc = sb.bind().sync().channel();
164 Channel cc = cb.connect(sc.localAddress()).sync().channel();
165
166 ch.latch.await();
167
168 sh.channel.close().awaitUninterruptibly();
169 cc.close().awaitUninterruptibly();
170 sc.close().awaitUninterruptibly();
171
172 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
173 throw sh.exception.get();
174 }
175 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
176 throw ch.exception.get();
177 }
178 if (sh.exception.get() != null) {
179 throw sh.exception.get();
180 }
181 if (ch.exception.get() != null) {
182 throw ch.exception.get();
183 }
184 } finally {
185 if (executorService != null) {
186 executorService.shutdown();
187 assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS));
188 }
189 }
190 }
191
192 private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
193
194 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
195 final CountDownLatch latch = new CountDownLatch(1);
196
197 @Override
198 public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
199 assertEquals('a', buf.readByte());
200 assertFalse(buf.isReadable());
201 latch.countDown();
202 ctx.close();
203 }
204
205 @Override
206 public void exceptionCaught(ChannelHandlerContext ctx,
207 Throwable cause) throws Exception {
208 if (logger.isWarnEnabled()) {
209 logger.warn("Unexpected exception from the client side", cause);
210 }
211
212 exception.compareAndSet(null, cause);
213 ctx.close();
214 }
215 }
216
217 private static class ServerHandler extends SimpleChannelInboundHandler<String> {
218 volatile Channel channel;
219 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
220
221 @Override
222 protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
223
224 }
225
226 @Override
227 public void channelActive(ChannelHandlerContext ctx)
228 throws Exception {
229 channel = ctx.channel();
230 channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
231 }
232
233 @Override
234 public void exceptionCaught(ChannelHandlerContext ctx,
235 Throwable cause) throws Exception {
236 if (logger.isWarnEnabled()) {
237 logger.warn("Unexpected exception from the server side", cause);
238 }
239
240 exception.compareAndSet(null, cause);
241 ctx.close();
242 }
243
244 @Override
245 public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
246 if (evt instanceof SslHandshakeCompletionEvent) {
247 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
248 if (event.isSuccess()) {
249 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
250 try {
251 session.getPeerCertificates();
252 fail();
253 } catch (SSLPeerUnverifiedException e) {
254
255 }
256 try {
257 session.getPeerCertificateChain();
258 fail();
259 } catch (SSLPeerUnverifiedException e) {
260
261 } catch (UnsupportedOperationException e) {
262
263
264 if (PlatformDependent.javaVersion() < 15) {
265 throw e;
266 }
267 }
268 try {
269 session.getPeerPrincipal();
270 fail();
271 } catch (SSLPeerUnverifiedException e) {
272
273 }
274 }
275 }
276 ctx.fireUserEventTriggered(evt);
277 }
278 }
279 }