1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.PooledByteBufAllocator;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInitializer;
24 import io.netty.channel.ChannelOption;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.codec.LineBasedFrameDecoder;
28 import io.netty.handler.codec.string.StringDecoder;
29 import io.netty.handler.codec.string.StringEncoder;
30 import io.netty.handler.logging.LogLevel;
31 import io.netty.handler.logging.LoggingHandler;
32 import io.netty.handler.ssl.OpenSsl;
33 import io.netty.handler.ssl.SslContext;
34 import io.netty.handler.ssl.SslContextBuilder;
35 import io.netty.handler.ssl.SslHandler;
36 import io.netty.handler.ssl.SslProvider;
37 import io.netty.handler.ssl.util.SelfSignedCertificate;
38 import io.netty.util.concurrent.DefaultEventExecutorGroup;
39 import io.netty.util.concurrent.EventExecutorGroup;
40 import io.netty.util.concurrent.Future;
41 import io.netty.util.internal.logging.InternalLogger;
42 import io.netty.util.internal.logging.InternalLoggerFactory;
43
44 import org.junit.jupiter.api.AfterAll;
45 import org.junit.jupiter.api.BeforeAll;
46 import org.junit.jupiter.api.TestInfo;
47 import org.junit.jupiter.api.Timeout;
48 import org.junit.jupiter.params.ParameterizedTest;
49 import org.junit.jupiter.params.provider.MethodSource;
50
51 import javax.net.ssl.SSLEngine;
52 import java.io.File;
53 import java.io.IOException;
54 import java.security.cert.CertificateException;
55 import java.util.ArrayList;
56 import java.util.Collection;
57 import java.util.List;
58 import java.util.concurrent.TimeUnit;
59 import java.util.concurrent.atomic.AtomicReference;
60
61 import static org.junit.jupiter.api.Assertions.assertEquals;
62 import static org.junit.jupiter.api.Assertions.assertNotNull;
63 import static org.junit.jupiter.api.Assertions.assertTrue;
64
65 public class SocketStartTlsTest extends AbstractSocketTest {
66 private static final String PARAMETERIZED_NAME = "{index}: serverEngine = {0}, clientEngine = {1}";
67
68 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketStartTlsTest.class);
69
70 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
71 private static final File CERT_FILE;
72 private static final File KEY_FILE;
73 private static EventExecutorGroup executor;
74
75 static {
76 SelfSignedCertificate ssc;
77 try {
78 ssc = new SelfSignedCertificate();
79 } catch (CertificateException e) {
80 throw new Error(e);
81 }
82 CERT_FILE = ssc.certificate();
83 KEY_FILE = ssc.privateKey();
84 }
85
86 public static Collection<Object[]> data() throws Exception {
87 List<SslContext> serverContexts = new ArrayList<SslContext>();
88 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
89
90 List<SslContext> clientContexts = new ArrayList<SslContext>();
91 clientContexts.add(SslContextBuilder.forClient()
92 .sslProvider(SslProvider.JDK)
93 .trustManager(CERT_FILE)
94 .endpointIdentificationAlgorithm(null)
95 .build());
96
97 boolean hasOpenSsl = OpenSsl.isAvailable();
98 if (hasOpenSsl) {
99 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
100 .sslProvider(SslProvider.OPENSSL).build());
101 clientContexts.add(SslContextBuilder.forClient()
102 .sslProvider(SslProvider.OPENSSL)
103 .trustManager(CERT_FILE)
104 .endpointIdentificationAlgorithm(null)
105 .build());
106 } else {
107 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
108 }
109
110 List<Object[]> params = new ArrayList<Object[]>();
111 for (SslContext sc: serverContexts) {
112 for (SslContext cc: clientContexts) {
113 params.add(new Object[] { sc, cc });
114 }
115 }
116 return params;
117 }
118
119 @BeforeAll
120 public static void createExecutor() {
121 executor = new DefaultEventExecutorGroup(2);
122 }
123
124 @AfterAll
125 public static void shutdownExecutor() throws Exception {
126 executor.shutdownGracefully().sync();
127 }
128
129 @ParameterizedTest(name = PARAMETERIZED_NAME)
130 @MethodSource("data")
131 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
132 public void testStartTls(final SslContext serverCtx, final SslContext clientCtx, TestInfo testInfo)
133 throws Throwable {
134 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
135 @Override
136 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
137 testStartTls(sb, cb, serverCtx, clientCtx);
138 }
139 });
140 }
141
142 public void testStartTls(ServerBootstrap sb, Bootstrap cb,
143 SslContext serverCtx, SslContext clientCtx) throws Throwable {
144 testStartTls(sb, cb, serverCtx, clientCtx, true);
145 }
146
147 @ParameterizedTest(name = PARAMETERIZED_NAME)
148 @MethodSource("data")
149 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
150 public void testStartTlsNotAutoRead(final SslContext serverCtx, final SslContext clientCtx,
151 TestInfo testInfo) throws Throwable {
152 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
153 @Override
154 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
155 testStartTlsNotAutoRead(sb, cb, serverCtx, clientCtx);
156 }
157 });
158 }
159
160 public void testStartTlsNotAutoRead(ServerBootstrap sb, Bootstrap cb,
161 SslContext serverCtx, SslContext clientCtx) throws Throwable {
162 testStartTls(sb, cb, serverCtx, clientCtx, false);
163 }
164
165 private void testStartTls(ServerBootstrap sb, Bootstrap cb,
166 SslContext serverCtx, SslContext clientCtx, boolean autoRead) throws Throwable {
167 sb.childOption(ChannelOption.AUTO_READ, autoRead);
168 cb.option(ChannelOption.AUTO_READ, autoRead);
169
170 final EventExecutorGroup executor = SocketStartTlsTest.executor;
171 SSLEngine sse = serverCtx.newEngine(PooledByteBufAllocator.DEFAULT);
172 SSLEngine cse = clientCtx.newEngine(PooledByteBufAllocator.DEFAULT);
173
174 final StartTlsServerHandler sh = new StartTlsServerHandler(sse, autoRead);
175 final StartTlsClientHandler ch = new StartTlsClientHandler(cse, autoRead);
176
177 sb.childHandler(new ChannelInitializer<Channel>() {
178 @Override
179 public void initChannel(Channel sch) throws Exception {
180 ChannelPipeline p = sch.pipeline();
181 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
182 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
183 p.addLast(executor, sh);
184 }
185 });
186
187 cb.handler(new ChannelInitializer<Channel>() {
188 @Override
189 public void initChannel(Channel sch) throws Exception {
190 ChannelPipeline p = sch.pipeline();
191 p.addLast("logger", new LoggingHandler(LOG_LEVEL));
192 p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
193 p.addLast(executor, ch);
194 }
195 });
196
197 Channel sc = sb.bind().sync().channel();
198 Channel cc = cb.connect(sc.localAddress()).sync().channel();
199
200 while (cc.isActive()) {
201 if (sh.exception.get() != null) {
202 break;
203 }
204 if (ch.exception.get() != null) {
205 break;
206 }
207
208 Thread.sleep(50);
209 }
210
211 while (sh.channel.isActive()) {
212 if (sh.exception.get() != null) {
213 break;
214 }
215 if (ch.exception.get() != null) {
216 break;
217 }
218
219 Thread.sleep(50);
220 }
221
222 sh.channel.close().awaitUninterruptibly();
223 cc.close().awaitUninterruptibly();
224 sc.close().awaitUninterruptibly();
225
226 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
227 throw sh.exception.get();
228 }
229 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
230 throw ch.exception.get();
231 }
232 if (sh.exception.get() != null) {
233 throw sh.exception.get();
234 }
235 if (ch.exception.get() != null) {
236 throw ch.exception.get();
237 }
238 }
239
240 private static class StartTlsClientHandler extends SimpleChannelInboundHandler<String> {
241 private final SslHandler sslHandler;
242 private final boolean autoRead;
243 private Future<Channel> handshakeFuture;
244 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
245
246 StartTlsClientHandler(SSLEngine engine, boolean autoRead) {
247 engine.setUseClientMode(true);
248 sslHandler = new SslHandler(engine);
249 this.autoRead = autoRead;
250 }
251
252 @Override
253 public void channelActive(ChannelHandlerContext ctx)
254 throws Exception {
255 if (!autoRead) {
256 ctx.read();
257 }
258 ctx.writeAndFlush("StartTlsRequest\n");
259 }
260
261 @Override
262 public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
263 if ("StartTlsResponse".equals(msg)) {
264 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
265 handshakeFuture = sslHandler.handshakeFuture();
266 ctx.writeAndFlush("EncryptedRequest\n");
267 return;
268 }
269
270 assertEquals("EncryptedResponse", msg);
271 assertNotNull(handshakeFuture);
272 assertTrue(handshakeFuture.isSuccess());
273 ctx.close();
274 }
275
276 @Override
277 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
278 if (!autoRead) {
279 ctx.read();
280 }
281 }
282
283 @Override
284 public void exceptionCaught(ChannelHandlerContext ctx,
285 Throwable cause) throws Exception {
286 if (logger.isWarnEnabled()) {
287 logger.warn("Unexpected exception from the client side", cause);
288 }
289
290 exception.compareAndSet(null, cause);
291 ctx.close();
292 }
293 }
294
295 private static class StartTlsServerHandler extends SimpleChannelInboundHandler<String> {
296 private final SslHandler sslHandler;
297 private final boolean autoRead;
298 volatile Channel channel;
299 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
300
301 StartTlsServerHandler(SSLEngine engine, boolean autoRead) {
302 engine.setUseClientMode(false);
303 sslHandler = new SslHandler(engine, true);
304 this.autoRead = autoRead;
305 }
306
307 @Override
308 public void channelActive(ChannelHandlerContext ctx) throws Exception {
309 channel = ctx.channel();
310 if (!autoRead) {
311 ctx.read();
312 }
313 }
314
315 @Override
316 public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
317 if ("StartTlsRequest".equals(msg)) {
318 ctx.pipeline().addAfter("logger", "ssl", sslHandler);
319 ctx.writeAndFlush("StartTlsResponse\n");
320 return;
321 }
322
323 assertEquals("EncryptedRequest", msg);
324 ctx.writeAndFlush("EncryptedResponse\n");
325 }
326
327 @Override
328 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
329 if (!autoRead) {
330 ctx.read();
331 }
332 }
333
334 @Override
335 public void exceptionCaught(ChannelHandlerContext ctx,
336 Throwable cause) throws Exception {
337 if (logger.isWarnEnabled()) {
338 logger.warn("Unexpected exception from the server side", cause);
339 }
340
341 exception.compareAndSet(null, cause);
342 ctx.close();
343 }
344 }
345 }