1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.channel.Channel;
21 import io.netty5.channel.ChannelHandlerContext;
22 import io.netty5.channel.ChannelInitializer;
23 import io.netty5.channel.ChannelOption;
24 import io.netty5.channel.ChannelPipeline;
25 import io.netty5.channel.SimpleChannelInboundHandler;
26 import io.netty5.handler.codec.LineBasedFrameDecoder;
27 import io.netty5.handler.codec.string.StringDecoder;
28 import io.netty5.handler.codec.string.StringEncoder;
29 import io.netty5.handler.logging.LogLevel;
30 import io.netty5.handler.logging.LoggingHandler;
31 import io.netty5.handler.ssl.OpenSsl;
32 import io.netty5.handler.ssl.SslContext;
33 import io.netty5.handler.ssl.SslContextBuilder;
34 import io.netty5.handler.ssl.SslHandler;
35 import io.netty5.handler.ssl.SslProvider;
36 import io.netty5.handler.ssl.util.SelfSignedCertificate;
37 import io.netty5.util.concurrent.Future;
38 import io.netty5.util.internal.logging.InternalLogger;
39 import io.netty5.util.internal.logging.InternalLoggerFactory;
40 import org.junit.jupiter.api.TestInfo;
41 import org.junit.jupiter.api.Timeout;
42 import org.junit.jupiter.params.ParameterizedTest;
43 import org.junit.jupiter.params.provider.MethodSource;
44
45 import javax.net.ssl.SSLEngine;
46 import java.io.File;
47 import java.io.IOException;
48 import java.security.cert.CertificateException;
49 import java.util.ArrayList;
50 import java.util.Collection;
51 import java.util.List;
52 import java.util.concurrent.TimeUnit;
53 import java.util.concurrent.atomic.AtomicReference;
54
55 import static io.netty5.buffer.api.DefaultBufferAllocators.offHeapAllocator;
56 import static org.junit.jupiter.api.Assertions.assertEquals;
57 import static org.junit.jupiter.api.Assertions.assertNotNull;
58 import static org.junit.jupiter.api.Assertions.assertTrue;
59
60 public class SocketStartTlsTest extends AbstractSocketTest {
61 private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
62
63 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
64
65 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
66 private static final File CERT_FILE;
67 private static final File KEY_FILE;
68
69 static {
70 SelfSignedCertificate ssc;
71 try {
72 ssc = new SelfSignedCertificate();
73 } catch (CertificateException e) {
74 throw new Error(e);
75 }
76 CERT_FILE = ssc.certificate();
77 KEY_FILE = ssc.privateKey();
78 }
79
80 public static Collection<Object[]> data() throws Exception {
81 List<SslContext> serverContexts = new ArrayList<>();
82 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
83
84 List<SslContext> clientContexts = new ArrayList<>();
85 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
86
87 boolean hasOpenSsl = OpenSsl.isAvailable();
88 if (hasOpenSsl) {
89 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
90 .sslProvider(SslProvider.OPENSSL).build());
91 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
92 .trustManager(CERT_FILE).build());
93 } else {
94 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
95 }
96
97 List<Object[]> params = new ArrayList<>();
98 for (SslContext sc: serverContexts) {
99 for (SslContext cc: clientContexts) {
100 params.add(new Object[] { sc, cc });
101 }
102 }
103 return params;
104 }
105
106 @ParameterizedTest(name = PARAMETERIZED_NAME)
107 @MethodSource("data")
108 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
109 public void testStartTls(SslContext serverCtx, SslContext clientCtx, TestInfo testInfo) throws Throwable {
110 run(testInfo, (sb, cb) -> testStartTls(sb, cb, serverCtx, clientCtx));
111 }
112
113 public void testStartTls(ServerBootstrap sb, Bootstrap cb,
114 SslContext serverCtx, SslContext clientCtx) throws Throwable {
115 testStartTls(sb, cb, serverCtx, clientCtx, true);
116 }
117
118 @ParameterizedTest(name = PARAMETERIZED_NAME)
119 @MethodSource("data")
120 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
121 public void testStartTlsNotAutoRead(SslContext serverCtx, SslContext clientCtx,
122 TestInfo testInfo) throws Throwable {
123 run(testInfo, (sb, cb) -> testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx));
124 }
125
126 public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
127 SslContext serverCtx, SslContext clientCtx) throws Throwable {
128 testStartTls(sb, cb, serverCtx, clientCtx, false);
129 }
130
131 private void testStartTls(ServerBootstrap sb, Bootstrap cb,
132 SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
133 sb.childOption(ChannelOption.AUTO_READ, autoRead);
134 cb.option(ChannelOption.AUTO_READ, autoRead);
135
136 SSLEngine sse = serverCtx.newEngine(offHeapAllocator());
137 SSLEngine cse = clientCtx.newEngine(offHeapAllocator());
138
139 final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
140 final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
141
142 sb.childHandler(new ChannelInitializer<>() {
143 @Override
144 public void initChannel(Channel sch) {
145 ChannelPipeline p = sch.pipeline();
146 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
147 p.addLast(new LineBasedFrameDecoder(64));
148 p.addLast(new StringDecoder(), new StringEncoder());
149 p.addLast(sh);
150 }
151 });
152
153 cb.handler(new ChannelInitializer<>() {
154 @Override
155 public void initChannel(Channel sch) {
156 ChannelPipeline p = sch.pipeline();
157 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
158 p.addLast(new LineBasedFrameDecoder(64));
159 p.addLast(new StringDecoder(), new StringEncoder());
160 p.addLast(ch);
161 }
162 });
163
164 Channel sc = sb.bind().asStage().get();
165 Channel cc = cb.connect(sc.localAddress()).asStage().get();
166
167 while (cc.isActive()) {
168 if (sh.exception.get() != null) {
169 break;
170 }
171 if (ch.exception.get() != null) {
172 break;
173 }
174
175 try {
176 Thread.sleep(50);
177 } catch (InterruptedException e) {
178
179 }
180 }
181
182 while (sh.channel.isActive()) {
183 if (sh.exception.get() != null) {
184 break;
185 }
186 if (ch.exception.get() != null) {
187 break;
188 }
189
190 try {
191 Thread.sleep(50);
192 } catch (InterruptedException e) {
193
194 }
195 }
196
197 sh.channel.close().asStage().await();
198 cc.close().asStage().await();
199 sc.close().asStage().await();
200
201 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
202 throw sh.exception.get();
203 }
204 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
205 throw ch.exception.get();
206 }
207 if (sh.exception.get() != null) {
208 throw sh.exception.get();
209 }
210 if (ch.exception.get() != null) {
211 throw ch.exception.get();
212 }
213 }
214
215 private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
216 private final SslHandler sslHandler;
217 private final boolean autoRead;
218 private Future<Channel> handshakeFuture;
219 final AtomicReference<Throwable> exception = new AtomicReference<>();
220
221 StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
222 engine.setUseClientMode(true);
223 sslHandler = new SslHandler(engine);
224 this.autoRead = autoRead;
225 }
226
227 @Override
228 public void channelActive(ChannelHandlerContext ctx)
229 throws Exception {
230 if (!autoRead) {
231 ctx.read();
232 }
233 ctx.writeAndFlush("StartTlsRequest\n");
234 }
235
236 @Override
237 public void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
238 if ("StartTlsResponse".equals(msg)) {
239 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
240 handshakeFuture = sslHandler.handshakeFuture();
241 ctx.writeAndFlush("EncryptedRequest\n");
242 return;
243 }
244
245 assertEquals("EncryptedResponse", msg);
246 assertNotNull(handshakeFuture);
247 assertTrue(handshakeFuture.isSuccess());
248 ctx.close();
249 }
250
251 @Override
252 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
253 if (!autoRead) {
254 ctx.read();
255 }
256 }
257
258 @Override
259 public void channelExceptionCaught(ChannelHandlerContext ctx,
260 Throwable cause) throws Exception {
261 if (logger.isWarnEnabled()) {
262 logger.warn("Unexpected exception from the client side", cause);
263 }
264
265 exception.compareAndSet(null, cause);
266 ctx.close();
267 }
268 }
269
270 private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
271 private final SslHandler sslHandler;
272 private final boolean autoRead;
273 volatile Channel channel;
274 final AtomicReference<Throwable> exception = new AtomicReference<>();
275
276 StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
277 engine.setUseClientMode(false);
278 sslHandler = new SslHandler(engine, true);
279 this.autoRead = autoRead;
280 }
281
282 @Override
283 public void channelActive(ChannelHandlerContext ctx) throws Exception {
284 channel = ctx.channel();
285 if (!autoRead) {
286 ctx.read();
287 }
288 }
289
290 @Override
291 public void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
292 if ("StartTlsRequest".equals(msg)) {
293 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
294 ctx.writeAndFlush("StartTlsResponse\n");
295 return;
296 }
297
298 assertEquals("EncryptedRequest", msg);
299 ctx.writeAndFlush("EncryptedResponse\n");
300 }
301
302 @Override
303 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
304 if (!autoRead) {
305 ctx.read();
306 }
307 }
308
309 @Override
310 public void channelExceptionCaught(ChannelHandlerContext ctx,
311 Throwable cause) throws Exception {
312 if (logger.isWarnEnabled()) {
313 logger.warn("Unexpected exception from the server side", cause);
314 }
315
316 exception.compareAndSet(null, cause);
317 ctx.close();
318 }
319 }
320 }