1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInitializer;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.logging.LogLevel;
28 import io.netty.handler.logging.LoggingHandler;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.SslContext;
31 import io.netty.handler.ssl.SslContextBuilder;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.handler.ssl.SslProvider;
35 import io.netty.handler.ssl.util.SelfSignedCertificate;
36 import io.netty.util.internal.PlatformDependent;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39 import org.junit.jupiter.api.TestInfo;
40 import org.junit.jupiter.api.Timeout;
41 import org.junit.jupiter.params.ParameterizedTest;
42 import org.junit.jupiter.params.provider.MethodSource;
43
44 import javax.net.ssl.SSLPeerUnverifiedException;
45 import javax.net.ssl.SSLSession;
46 import java.io.File;
47 import java.io.IOException;
48 import java.security.cert.CertificateException;
49 import java.util.ArrayList;
50 import java.util.Collection;
51 import java.util.List;
52 import java.util.concurrent.CountDownLatch;
53 import java.util.concurrent.Executor;
54 import java.util.concurrent.ExecutorService;
55 import java.util.concurrent.Executors;
56 import java.util.concurrent.TimeUnit;
57 import java.util.concurrent.atomic.AtomicReference;
58
59 import static org.junit.jupiter.api.Assertions.assertEquals;
60 import static org.junit.jupiter.api.Assertions.assertFalse;
61 import static org.junit.jupiter.api.Assertions.assertTrue;
62 import static org.junit.jupiter.api.Assertions.fail;
63
64 public class SocketSslGreetingTest extends AbstractSocketTest {
65
66 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslGreetingTest.class);
67
68 private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
69 private static final File CERT_FILE;
70 private static final File KEY_FILE;
71
72 static {
73 SelfSignedCertificate ssc;
74 try {
75 ssc = new SelfSignedCertificate();
76 } catch (CertificateException e) {
77 throw new Error(e);
78 }
79 CERT_FILE = ssc.certificate();
80 KEY_FILE = ssc.privateKey();
81 }
82
83 public static Collection<Object[]> data() throws Exception {
84 List<SslContext> serverContexts = new ArrayList<SslContext>();
85 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build());
86
87 List<SslContext> clientContexts = new ArrayList<SslContext>();
88 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.JDK).trustManager(CERT_FILE).build());
89
90 boolean hasOpenSsl = OpenSsl.isAvailable();
91 if (hasOpenSsl) {
92 serverContexts.add(SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
93 .sslProvider(SslProvider.OPENSSL).build());
94 clientContexts.add(SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL)
95 .trustManager(CERT_FILE).build());
96 } else {
97 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
98 }
99
100 List<Object[]> params = new ArrayList<Object[]>();
101 for (SslContext sc: serverContexts) {
102 for (SslContext cc: clientContexts) {
103 params.add(new Object[] { sc, cc, true });
104 params.add(new Object[] { sc, cc, false });
105 }
106 }
107 return params;
108 }
109
110 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
111 if (executor == null) {
112 return sslCtx.newHandler(allocator);
113 } else {
114 return sslCtx.newHandler(allocator, executor);
115 }
116 }
117
118
119 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
120 @MethodSource("data")
121 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
122 public void testSslGreeting(final SslContext serverCtx, final SslContext clientCtx, final boolean delegate,
123 TestInfo testInfo) throws Throwable {
124 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
125 @Override
126 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
127 testSslGreeting(sb, cb, serverCtx, clientCtx, delegate);
128 }
129 });
130 }
131
132 public void testSslGreeting(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
133 final SslContext clientCtx, boolean delegate) throws Throwable {
134 final ServerHandler sh = new ServerHandler();
135 final ClientHandler ch = new ClientHandler();
136
137 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
138 try {
139 sb.childHandler(new ChannelInitializer<Channel>() {
140 @Override
141 public void initChannel(Channel sch) throws Exception {
142 ChannelPipeline p = sch.pipeline();
143 p.addLast(newSslHandler(serverCtx, sch.alloc(), executorService));
144 p.addLast(new LoggingHandler(LOG_LEVEL));
145 p.addLast(sh);
146 }
147 });
148
149 cb.handler(new ChannelInitializer<Channel>() {
150 @Override
151 public void initChannel(Channel sch) throws Exception {
152 ChannelPipeline p = sch.pipeline();
153 p.addLast(newSslHandler(clientCtx, sch.alloc(), executorService));
154 p.addLast(new LoggingHandler(LOG_LEVEL));
155 p.addLast(ch);
156 }
157 });
158
159 Channel sc = sb.bind().sync().channel();
160 Channel cc = cb.connect(sc.localAddress()).sync().channel();
161
162 ch.latch.await();
163
164 sh.channel.close().awaitUninterruptibly();
165 cc.close().awaitUninterruptibly();
166 sc.close().awaitUninterruptibly();
167
168 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
169 throw sh.exception.get();
170 }
171 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
172 throw ch.exception.get();
173 }
174 if (sh.exception.get() != null) {
175 throw sh.exception.get();
176 }
177 if (ch.exception.get() != null) {
178 throw ch.exception.get();
179 }
180 } finally {
181 if (executorService != null) {
182 executorService.shutdown();
183 assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS));
184 }
185 }
186 }
187
188 private static class ClientHandler extends SimpleChannelInboundHandler<ByteBuf> {
189
190 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
191 final CountDownLatch latch = new CountDownLatch(1);
192
193 @Override
194 public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception {
195 assertEquals('a', buf.readByte());
196 assertFalse(buf.isReadable());
197 latch.countDown();
198 ctx.close();
199 }
200
201 @Override
202 public void exceptionCaught(ChannelHandlerContext ctx,
203 Throwable cause) throws Exception {
204 if (logger.isWarnEnabled()) {
205 logger.warn("Unexpected exception from the client side", cause);
206 }
207
208 exception.compareAndSet(null, cause);
209 ctx.close();
210 }
211 }
212
213 private static class ServerHandler extends SimpleChannelInboundHandler<String> {
214 volatile Channel channel;
215 final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
216
217 @Override
218 protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
219
220 }
221
222 @Override
223 public void channelActive(ChannelHandlerContext ctx)
224 throws Exception {
225 channel = ctx.channel();
226 channel.writeAndFlush(ctx.alloc().buffer().writeByte('a'));
227 }
228
229 @Override
230 public void exceptionCaught(ChannelHandlerContext ctx,
231 Throwable cause) throws Exception {
232 if (logger.isWarnEnabled()) {
233 logger.warn("Unexpected exception from the server side", cause);
234 }
235
236 exception.compareAndSet(null, cause);
237 ctx.close();
238 }
239
240 @Override
241 public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
242 if (evt instanceof SslHandshakeCompletionEvent) {
243 final SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt;
244 if (event.isSuccess()) {
245 SSLSession session = ctx.pipeline().get(SslHandler.class).engine().getSession();
246 try {
247 session.getPeerCertificates();
248 fail();
249 } catch (SSLPeerUnverifiedException e) {
250
251 }
252 try {
253 session.getPeerCertificateChain();
254 fail();
255 } catch (SSLPeerUnverifiedException e) {
256
257 } catch (UnsupportedOperationException e) {
258
259
260 if (PlatformDependent.javaVersion() < 15) {
261 throw e;
262 }
263 }
264 try {
265 session.getPeerPrincipal();
266 fail();
267 } catch (SSLPeerUnverifiedException e) {
268
269 }
270 }
271 }
272 ctx.fireUserEventTriggered(evt);
273 }
274 }
275 }