1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInitializer;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.logging.LogLevel;
28 import io.netty.handler.logging.LoggingHandler;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.SslContext;
31 import io.netty.handler.ssl.SslContextBuilder;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.handler.ssl.SslProvider;
35 import io.netty.pkitesting.CertificateBuilder;
36 import io.netty.pkitesting.X509Bundle;
37 import io.netty.util.internal.PlatformDependent;
38 import io.netty.util.internal.logging.InternalLogger;
39 import io.netty.util.internal.logging.InternalLoggerFactory;
40 import org.junit.jupiter.api.TestInfo;
41 import org.junit.jupiter.api.Timeout;
42 import org.junit.jupiter.params.ParameterizedTest;
43 import org.junit.jupiter.params.provider.MethodSource;
44
45 import javax.net.ssl.SSLPeerUnverifiedException;
46 import javax.net.ssl.SSLSession;
47 import java.io.File;
48 import java.io.IOException;
49 import java.util.ArrayList;
50 import java.util.Collection;
51 import java.util.List;
52 import java.util.concurrent.CountDownLatch;
53 import java.util.concurrent.Executor;
54 import java.util.concurrent.ExecutorService;
55 import java.util.concurrent.Executors;
56 import java.util.concurrent.TimeUnit;
57 import java.util.concurrent.atomic.AtomicReference;
58
59 import static org.junit.jupiter.api.Assertions.assertEquals;
60 import static org.junit.jupiter.api.Assertions.assertFalse;
61 import static org.junit.jupiter.api.Assertions.fail;
62
63 public class SocketSslGreetingTest extends AbstractSocketTest {
64
65 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
66
67 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
68 private static final File CERT_FILE;
69 private static final File KEY_FILE;
70
71 static {
72 try {
73 X509Bundle cert = new CertificateBuilder()
74 .subject("cn=localhost")
75 .setIsCertificateAuthority(true)
76 .buildSelfSigned();
77 CERT_FILE = cert.toTempCertChainPem();
78 KEY_FILE = cert.toTempPrivateKeyPem();
79 } catch (Exception e) {
80 throw new ExceptionInInitializerError(e);
81 }
82 }
83
84 public static Collection<Object[]> data() throws Exception {
85 List<SslContext> serverContexts = new ArrayList<SslContext>();
86 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
87
88 List<SslContext> clientContexts = new ArrayList<SslContext>();
89 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
90 .endpointIdentificationAlgorithm(null).trustManager(CERT_FILE).build());
91
92 boolean hasOpenSsl = OpenSsl.isAvailable();
93 if (hasOpenSsl) {
94 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
95 .sslProvider(SslProvider.OPENSSL).build());
96 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
97 .endpointIdentificationAlgorithm(null)
98 .trustManager(CERT_FILE).build());
99 } else {
100 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
101 }
102
103 List<Object[]> params = new ArrayList<Object[]>();
104 for (SslContext sc: serverContexts) {
105 for (SslContext cc: clientContexts) {
106 params.add(new Object[] { sc, cc, true });
107 params.add(new Object[] { sc, cc, false });
108 }
109 }
110 return params;
111 }
112
113 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
114 if (executor == null) {
115 return sslCtx.newHandler(allocator);
116 } else {
117 return sslCtx.newHandler(allocator, executor);
118 }
119 }
120
121
122 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
123 @MethodSource("data")
124 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
125 public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
126 TestInfo testInfo) throws Throwable {
127 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
128 @Override
129 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
130 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
131 }
132 });
133 }
134
135 public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
136 final SslContext clientCtx, boolean delegate) throws Throwable {
137 final ServerHandler sh = new ServerHandler();
138 final ClientHandler ch = new ClientHandler();
139
140 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
141 try {
142 sb.childHandler(new ChannelInitializer<Channel>() {
143 @Override
144 public void initChannel(Channel sch) throws Exception {
145 ChannelPipeline p = sch.pipeline();
146 p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
147 p.addLast(new LoggingHandler(LOG_LEVEL));
148 p.addLast(sh);
149 }
150 });
151
152 cb.handler(new ChannelInitializer<Channel>() {
153 @Override
154 public void initChannel(Channel sch) throws Exception {
155 ChannelPipeline p = sch.pipeline();
156 p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
157 p.addLast(new LoggingHandler(LOG_LEVEL));
158 p.addLast(ch);
159 }
160 });
161
162 Channel sc = sb.bind().sync().channel();
163 Channel cc = cb.connect(sc.localAddress()).sync().channel();
164
165 ch.latch.await();
166
167 sh.channel.close().awaitUninterruptibly();
168 cc.close().awaitUninterruptibly();
169 sc.close().awaitUninterruptibly();
170
171 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
172 throw sh.exception.get();
173 }
174 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
175 throw ch.exception.get();
176 }
177 if (sh.exception.get() != null) {
178 throw sh.exception.get();
179 }
180 if (ch.exception.get() != null) {
181 throw ch.exception.get();
182 }
183 } finally {
184 if (executorService != null) {
185 executorService.shutdown();
186 }
187 }
188 }
189
190 private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
191
192 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
193 final CountDownLatch latch = new CountDownLatch(1);
194
195 @Override
196 public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
197 assertEquals('a', buf.readByte());
198 assertFalse(buf.isReadable());
199 latch.countDown();
200 ctx.close();
201 }
202
203 @Override
204 public void exceptionCaught(ChannelHandlerContext ctx,
205 Throwable cause) throws Exception {
206 if (logger.isWarnEnabled()) {
207 logger.warn("Unexpected exception from the client side", cause);
208 }
209
210 exception.compareAndSet(null, cause);
211 ctx.close();
212 }
213 }
214
215 private static class ServerHandler extends SimpleChannelInboundHandler<String> {
216 volatile Channel channel;
217 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
218
219 @Override
220 protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
221
222 }
223
224 @Override
225 public void channelActive(ChannelHandlerContext ctx)
226 throws Exception {
227 channel = ctx.channel();
228 channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
229 }
230
231 @Override
232 public void exceptionCaught(ChannelHandlerContext ctx,
233 Throwable cause) throws Exception {
234 if (logger.isWarnEnabled()) {
235 logger.warn("Unexpected exception from the server side", cause);
236 }
237
238 exception.compareAndSet(null, cause);
239 ctx.close();
240 }
241
242 @Override
243 public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
244 if (evt instanceof SslHandshakeCompletionEvent) {
245 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
246 if (event.isSuccess()) {
247 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
248 try {
249 session.getPeerCertificates();
250 fail();
251 } catch (SSLPeerUnverifiedException e) {
252
253 }
254 try {
255 session.getPeerCertificateChain();
256 fail();
257 } catch (SSLPeerUnverifiedException e) {
258
259 } catch (UnsupportedOperationException e) {
260
261
262 if (PlatformDependent.javaVersion() < 15) {
263 throw e;
264 }
265 }
266 try {
267 session.getPeerPrincipal();
268 fail();
269 } catch (SSLPeerUnverifiedException e) {
270
271 }
272 }
273 }
274 ctx.fireUserEventTriggered(evt);
275 }
276 }
277 }