1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.Buffer;
21 import io.netty5.buffer.api.BufferAllocator;
22 import io.netty5.channel.Channel;
23 import io.netty5.channel.ChannelHandlerContext;
24 import io.netty5.channel.ChannelInitializer;
25 import io.netty5.channel.ChannelPipeline;
26 import io.netty5.channel.SimpleChannelInboundHandler;
27 import io.netty5.handler.logging.LogLevel;
28 import io.netty5.handler.logging.LoggingHandler;
29 import io.netty5.handler.ssl.OpenSsl;
30 import io.netty5.handler.ssl.SslContext;
31 import io.netty5.handler.ssl.SslContextBuilder;
32 import io.netty5.handler.ssl.SslHandler;
33 import io.netty5.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty5.handler.ssl.SslProvider;
35 import io.netty5.handler.ssl.util.SelfSignedCertificate;
36 import io.netty5.util.internal.PlatformDependent;
37 import io.netty5.util.internal.logging.InternalLogger;
38 import io.netty5.util.internal.logging.InternalLoggerFactory;
39 import org.junit.jupiter.api.TestInfo;
40 import org.junit.jupiter.api.Timeout;
41 import org.junit.jupiter.params.ParameterizedTest;
42 import org.junit.jupiter.params.provider.MethodSource;
43
44 import javax.net.ssl.SSLPeerUnverifiedException;
45 import javax.net.ssl.SSLSession;
46 import java.io.File;
47 import java.io.IOException;
48 import java.security.cert.CertificateException;
49 import java.util.ArrayList;
50 import java.util.Collection;
51 import java.util.List;
52 import java.util.concurrent.CountDownLatch;
53 import java.util.concurrent.Executor;
54 import java.util.concurrent.ExecutorService;
55 import java.util.concurrent.Executors;
56 import java.util.concurrent.TimeUnit;
57 import java.util.concurrent.atomic.AtomicReference;
58
59 import static org.junit.jupiter.api.Assertions.assertEquals;
60 import static org.junit.jupiter.api.Assertions.fail;
61
62 public class SocketSslGreetingTest extends AbstractSocketTest {
63
64 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
65
66 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
67 private static final File CERT_FILE;
68 private static final File KEY_FILE;
69
70 static {
71 SelfSignedCertificate ssc;
72 try {
73 ssc = new SelfSignedCertificate();
74 } catch (CertificateException e) {
75 throw new Error(e);
76 }
77 CERT_FILE = ssc.certificate();
78 KEY_FILE = ssc.privateKey();
79 }
80
81 public static Collection<Object[]> data() throws Exception {
82 List<SslContext> serverContexts = new ArrayList<>();
83 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
84
85 List<SslContext> clientContexts = new ArrayList<>();
86 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
87
88 boolean hasOpenSsl = OpenSsl.isAvailable();
89 if (hasOpenSsl) {
90 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
91 .sslProvider(SslProvider.OPENSSL).build());
92 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
93 .trustManager(CERT_FILE).build());
94 } else {
95 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
96 }
97
98 List<Object[]> params = new ArrayList<>();
99 for (SslContext sc: serverContexts) {
100 for (SslContext cc: clientContexts) {
101 params.add(new Object[] { sc, cc, true });
102 params.add(new Object[] { sc, cc, false });
103 }
104 }
105 return params;
106 }
107
108 private static SslHandler newSslHandler(SslContext sslCtx, BufferAllocator allocator, Executor executor) {
109 if (executor == null) {
110 return sslCtx.newHandler(allocator);
111 } else {
112 return sslCtx.newHandler(allocator, executor);
113 }
114 }
115
116
117 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
118 @MethodSource("data")
119 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
120 public void testSslGreeting(SslContext serverCtx, SslContext clientCtx, boolean delegate,
121 TestInfo testInfo) throws Throwable {
122 run(testInfo, (sb, cb) -> testSslGreeting(sb, cb, serverCtx, clientCtx, delegate));
123 }
124
125 public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, SslContext serverCtx,
126 SslContext clientCtx, boolean delegate) throws Throwable {
127 final ServerHandler sh = new ServerHandler();
128 final ClientHandler ch = new ClientHandler();
129
130 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
131 try {
132 sb.childHandler(new ChannelInitializer<>() {
133 @Override
134 public void initChannel(Channel sch) throws Exception {
135 ChannelPipeline p = sch.pipeline();
136 p.addLast(newSslHandler(serverCtx, sch.bufferAllocator(), executorService));
137 p.addLast(new LoggingHandler(LOG_LEVEL));
138 p.addLast(sh);
139 }
140 });
141
142 cb.handler(new ChannelInitializer<>() {
143 @Override
144 public void initChannel(Channel sch) throws Exception {
145 ChannelPipeline p = sch.pipeline();
146 p.addLast(newSslHandler(clientCtx, sch.bufferAllocator(), executorService));
147 p.addLast(new LoggingHandler(LOG_LEVEL));
148 p.addLast(ch);
149 }
150 });
151
152 Channel sc = sb.bind().asStage().get();
153 Channel cc = cb.connect(sc.localAddress()).asStage().get();
154
155 ch.latch.await();
156
157 sh.channel.close().asStage().await();
158 cc.close().asStage().await();
159 sc.close().asStage().await();
160
161 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
162 throw sh.exception.get();
163 }
164 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
165 throw ch.exception.get();
166 }
167 if (sh.exception.get() != null) {
168 throw sh.exception.get();
169 }
170 if (ch.exception.get() != null) {
171 throw ch.exception.get();
172 }
173 } finally {
174 if (executorService != null) {
175 executorService.shutdown();
176 }
177 }
178 }
179
180 private static class ClientHandler extends SimpleChannelInboundHandler<Buffer> {
181
182 final AtomicReference<Throwable> exception = new AtomicReference<>();
183 final CountDownLatch latch = new CountDownLatch(1);
184
185 @Override
186 public void messageReceived(ChannelHandlerContext ctx, Buffer buf) throws Exception {
187 assertEquals('a', buf.readByte());
188 assertEquals(0, buf.readableBytes());
189 latch.countDown();
190 ctx.close();
191 }
192
193 @Override
194 public void channelExceptionCaught(ChannelHandlerContext ctx,
195 Throwable cause) throws Exception {
196 if (logger.isWarnEnabled()) {
197 logger.warn("Unexpected exception from the client side", cause);
198 }
199
200 exception.compareAndSet(null, cause);
201 ctx.close();
202 }
203 }
204
205 private static class ServerHandler extends SimpleChannelInboundHandler<String> {
206 volatile Channel channel;
207 final AtomicReference<Throwable> exception = new AtomicReference<>();
208
209 @Override
210 protected void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
211
212 }
213
214 @Override
215 public void channelActive(ChannelHandlerContext ctx)
216 throws Exception {
217 channel = ctx.channel();
218 channel.writeAndFlush(ctx.bufferAllocator().copyOf(new byte[] {'a'}));
219 }
220
221 @Override
222 public void channelExceptionCaught(ChannelHandlerContext ctx,
223 Throwable cause) throws Exception {
224 if (logger.isWarnEnabled()) {
225 logger.warn("Unexpected exception from the server side", cause);
226 }
227
228 exception.compareAndSet(null, cause);
229 ctx.close();
230 }
231
232 @Override
233 public void channelInboundEvent(final ChannelHandlerContext ctx, final Object evt) throws Exception {
234 if (evt instanceof SslHandshakeCompletionEvent) {
235 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
236 if (event.isSuccess()) {
237 SSLSession session = event.session();
238 try {
239 session.getPeerCertificates();
240 fail();
241 } catch (SSLPeerUnverifiedException e) {
242
243 }
244 try {
245 session.getPeerCertificateChain();
246 fail();
247 } catch (SSLPeerUnverifiedException e) {
248
249 } catch (UnsupportedOperationException e) {
250
251
252 if (PlatformDependent.javaVersion() < 15) {
253 throw e;
254 }
255 }
256 try {
257 session.getPeerPrincipal();
258 fail();
259 } catch (SSLPeerUnverifiedException e) {
260
261 }
262 }
263 }
264 ctx.fireChannelInboundEvent(evt);
265 }
266 }
267 }