1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.PooledByteBufAllocator;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInitializer;
24 import io.netty.channel.ChannelOption;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.codec.LineBasedFrameDecoder;
28 import io.netty.handler.codec.string.StringDecoder;
29 import io.netty.handler.codec.string.StringEncoder;
30 import io.netty.handler.logging.LogLevel;
31 import io.netty.handler.logging.LoggingHandler;
32 import io.netty.handler.ssl.OpenSsl;
33 import io.netty.handler.ssl.SslContext;
34 import io.netty.handler.ssl.SslContextBuilder;
35 import io.netty.handler.ssl.SslHandler;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.pkitesting.CertificateBuilder;
38 import io.netty.pkitesting.X509Bundle;
39 import io.netty.util.concurrent.DefaultEventExecutorGroup;
40 import io.netty.util.concurrent.EventExecutorGroup;
41 import io.netty.util.concurrent.Future;
42 import io.netty.util.internal.logging.InternalLogger;
43 import io.netty.util.internal.logging.InternalLoggerFactory;
44
45 import org.junit.jupiter.api.AfterAll;
46 import org.junit.jupiter.api.BeforeAll;
47 import org.junit.jupiter.api.TestInfo;
48 import org.junit.jupiter.api.Timeout;
49 import org.junit.jupiter.params.ParameterizedTest;
50 import org.junit.jupiter.params.provider.MethodSource;
51
52 import javax.net.ssl.SSLEngine;
53 import java.io.File;
54 import java.io.IOException;
55 import java.util.ArrayList;
56 import java.util.Collection;
57 import java.util.List;
58 import java.util.concurrent.TimeUnit;
59 import java.util.concurrent.atomic.AtomicReference;
60
61 import static org.junit.jupiter.api.Assertions.assertEquals;
62 import static org.junit.jupiter.api.Assertions.assertNotNull;
63 import static org.junit.jupiter.api.Assertions.assertTrue;
64
65 public class SocketStartTlsTest extends AbstractSocketTest {
66 private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67
68 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69
70 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71 private static final File CERT_FILE;
72 private static final File KEY_FILE;
73 private static EventExecutorGroup executor;
74
75 static {
76 try {
77 X509Bundle cert = new CertificateBuilder()
78 .subject("cn=localhost")
79 .setIsCertificateAuthority(true)
80 .buildSelfSigned();
81 CERT_FILE = cert.toTempCertChainPem();
82 KEY_FILE = cert.toTempPrivateKeyPem();
83 } catch (Exception e) {
84 throw new ExceptionInInitializerError(e);
85 }
86 }
87
88 public static Collection<Object[]> data() throws Exception {
89 List<SslContext> serverContexts = new ArrayList<SslContext>();
90 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
91
92 List<SslContext> clientContexts = new ArrayList<SslContext>();
93 clientContexts.add(SslContextBuilder.forClient()
94 .sslProvider(SslProvider.JDK)
95 .trustManager(CERT_FILE)
96 .endpointIdentificationAlgorithm(null)
97 .build());
98
99 boolean hasOpenSsl = OpenSsl.isAvailable();
100 if (hasOpenSsl) {
101 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
102 .sslProvider(SslProvider.OPENSSL).build());
103 clientContexts.add(SslContextBuilder.forClient()
104 .sslProvider(SslProvider.OPENSSL)
105 .trustManager(CERT_FILE)
106 .endpointIdentificationAlgorithm(null)
107 .build());
108 } else {
109 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
110 }
111
112 List<Object[]> params = new ArrayList<Object[]>();
113 for (SslContext sc: serverContexts) {
114 for (SslContext cc: clientContexts) {
115 params.add(new Object[] { sc, cc });
116 }
117 }
118 return params;
119 }
120
121 @BeforeAll
122 public static void createExecutor() {
123 executor = new DefaultEventExecutorGroup(2);
124 }
125
126 @AfterAll
127 public static void shutdownExecutor() throws Exception {
128 executor.shutdownGracefully().sync();
129 }
130
131 @ParameterizedTest(name = PARAMETERIZED_NAME)
132 @MethodSource("data")
133 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
134 public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
135 throws Throwable {
136 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
137 @Override
138 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
139 testStartTls(sb, cb, serverCtx, clientCtx);
140 }
141 });
142 }
143
144 public void testStartTls(ServerBootstrap sb, Bootstrap cb,
145 SslContext serverCtx, SslContext clientCtx) throws Throwable {
146 testStartTls(sb, cb, serverCtx, clientCtx, true);
147 }
148
149 @ParameterizedTest(name = PARAMETERIZED_NAME)
150 @MethodSource("data")
151 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
152 public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
153 TestInfo testInfo) throws Throwable {
154 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
155 @Override
156 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
157 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
158 }
159 });
160 }
161
162 public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
163 SslContext serverCtx, SslContext clientCtx) throws Throwable {
164 testStartTls(sb, cb, serverCtx, clientCtx, false);
165 }
166
167 private void testStartTls(ServerBootstrap sb, Bootstrap cb,
168 SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
169 sb.childOption(ChannelOption.AUTO_READ, autoRead);
170 cb.option(ChannelOption.AUTO_READ, autoRead);
171
172 final EventExecutorGroup executor = SocketStartTlsTest.executor;
173 SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
174 SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
175
176 final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
177 final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
178
179 sb.childHandler(new ChannelInitializer<Channel>() {
180 @Override
181 public void initChannel(Channel sch) throws Exception {
182 ChannelPipeline p = sch.pipeline();
183 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
184 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
185 p.addLast(executor, sh);
186 }
187 });
188
189 cb.handler(new ChannelInitializer<Channel>() {
190 @Override
191 public void initChannel(Channel sch) throws Exception {
192 ChannelPipeline p = sch.pipeline();
193 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
194 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
195 p.addLast(executor, ch);
196 }
197 });
198
199 Channel sc = sb.bind().sync().channel();
200 Channel cc = cb.connect(sc.localAddress()).sync().channel();
201
202 while (cc.isActive()) {
203 if (sh.exception.get() != null) {
204 break;
205 }
206 if (ch.exception.get() != null) {
207 break;
208 }
209
210 Thread.sleep(50);
211 }
212
213 while (sh.channel.isActive()) {
214 if (sh.exception.get() != null) {
215 break;
216 }
217 if (ch.exception.get() != null) {
218 break;
219 }
220
221 Thread.sleep(50);
222 }
223
224 sh.channel.close().awaitUninterruptibly();
225 cc.close().awaitUninterruptibly();
226 sc.close().awaitUninterruptibly();
227
228 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
229 throw sh.exception.get();
230 }
231 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
232 throw ch.exception.get();
233 }
234 if (sh.exception.get() != null) {
235 throw sh.exception.get();
236 }
237 if (ch.exception.get() != null) {
238 throw ch.exception.get();
239 }
240 }
241
242 private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
243 private final SslHandler sslHandler;
244 private final boolean autoRead;
245 private Future<Channel> handshakeFuture;
246 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
247
248 StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
249 engine.setUseClientMode(true);
250 sslHandler = new SslHandler(engine);
251 this.autoRead = autoRead;
252 }
253
254 @Override
255 public void channelActive(ChannelHandlerContext ctx)
256 throws Exception {
257 if (!autoRead) {
258 ctx.read();
259 }
260 ctx.writeAndFlush("StartTlsRequest\n");
261 }
262
263 @Override
264 public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
265 if ("StartTlsResponse".equals(msg)) {
266 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
267 handshakeFuture = sslHandler.handshakeFuture();
268 ctx.writeAndFlush("EncryptedRequest\n");
269 return;
270 }
271
272 assertEquals("EncryptedResponse", msg);
273 assertNotNull(handshakeFuture);
274 assertTrue(handshakeFuture.isSuccess());
275 ctx.close();
276 }
277
278 @Override
279 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
280 if (!autoRead) {
281 ctx.read();
282 }
283 }
284
285 @Override
286 public void exceptionCaught(ChannelHandlerContext ctx,
287 Throwable cause) throws Exception {
288 if (logger.isWarnEnabled()) {
289 logger.warn("Unexpected exception from the client side", cause);
290 }
291
292 exception.compareAndSet(null, cause);
293 ctx.close();
294 }
295 }
296
297 private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
298 private final SslHandler sslHandler;
299 private final boolean autoRead;
300 volatile Channel channel;
301 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
302
303 StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
304 engine.setUseClientMode(false);
305 sslHandler = new SslHandler(engine, true);
306 this.autoRead = autoRead;
307 }
308
309 @Override
310 public void channelActive(ChannelHandlerContext ctx) throws Exception {
311 channel = ctx.channel();
312 if (!autoRead) {
313 ctx.read();
314 }
315 }
316
317 @Override
318 public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
319 if ("StartTlsRequest".equals(msg)) {
320 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
321 ctx.writeAndFlush("StartTlsResponse\n");
322 return;
323 }
324
325 assertEquals("EncryptedRequest", msg);
326 ctx.writeAndFlush("EncryptedResponse\n");
327 }
328
329 @Override
330 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
331 if (!autoRead) {
332 ctx.read();
333 }
334 }
335
336 @Override
337 public void exceptionCaught(ChannelHandlerContext ctx,
338 Throwable cause) throws Exception {
339 if (logger.isWarnEnabled()) {
340 logger.warn("Unexpected exception from the server side", cause);
341 }
342
343 exception.compareAndSet(null, cause);
344 ctx.close();
345 }
346 }
347 }