1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandler.Sharable;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelInitializer;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.codec.DecoderException;
28 import io.netty.handler.ssl.JdkSslClientContext;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.OpenSslServerContext;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.pkitesting.CertificateBuilder;
35 import io.netty.pkitesting.X509Bundle;
36 import io.netty.util.concurrent.Future;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39 import org.junit.jupiter.api.TestInfo;
40 import org.junit.jupiter.api.Timeout;
41 import org.junit.jupiter.api.condition.DisabledIf;
42 import org.junit.jupiter.params.ParameterizedTest;
43 import org.junit.jupiter.params.provider.MethodSource;
44
45 import java.io.File;
46 import java.nio.channels.ClosedChannelException;
47 import java.util.ArrayList;
48 import java.util.Collection;
49 import java.util.List;
50 import java.util.concurrent.Executor;
51 import java.util.concurrent.ExecutorService;
52 import java.util.concurrent.Executors;
53 import java.util.concurrent.TimeUnit;
54 import java.util.concurrent.atomic.AtomicReference;
55
56 import javax.net.ssl.SSLHandshakeException;
57
58 import static org.junit.jupiter.api.Assertions.assertSame;
59 import static org.junit.jupiter.api.Assertions.assertTrue;
60 import static org.junit.jupiter.api.Assertions.fail;
61 import static org.junit.jupiter.api.Assumptions.assumeTrue;
62
63 public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
64 private static final InternalLogger logger = InternalLoggerFactory.getInstance(
65 SocketSslClientRenegotiateTest.class);
66 private static final File CERT_FILE;
67 private static final File KEY_FILE;
68
69 static {
70 try {
71 X509Bundle cert = new CertificateBuilder()
72 .subject("cn=localhost")
73 .setIsCertificateAuthority(true)
74 .buildSelfSigned();
75 CERT_FILE = cert.toTempCertChainPem();
76 KEY_FILE = cert.toTempPrivateKeyPem();
77 } catch (Exception e) {
78 throw new ExceptionInInitializerError(e);
79 }
80 }
81
82 private static boolean openSslNotAvailable() {
83 return !OpenSsl.isAvailable();
84 }
85
86 public static Collection<Object[]> data() throws Exception {
87 List<SslContext> serverContexts = new ArrayList<SslContext>();
88 List<SslContext> clientContexts = new ArrayList<SslContext>();
89 clientContexts.add(new JdkSslClientContext(CERT_FILE));
90
91 boolean hasOpenSsl = OpenSsl.isAvailable();
92 if (hasOpenSsl) {
93 OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
94 serverContexts.add(context);
95 } else {
96 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
97 }
98
99 List<Object[]> params = new ArrayList<Object[]>();
100 for (SslContext sc: serverContexts) {
101 for (SslContext cc: clientContexts) {
102 for (int i = 0; i < 32; i++) {
103 params.add(new Object[] { sc, cc, true});
104 params.add(new Object[] { sc, cc, false});
105 }
106 }
107 }
108
109 return params;
110 }
111
112 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
113 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
114
115 private volatile Channel clientChannel;
116 private volatile Channel serverChannel;
117
118 private volatile SslHandler clientSslHandler;
119 private volatile SslHandler serverSslHandler;
120
121 private final TestHandler clientHandler = new TestHandler(clientException);
122
123 private final TestHandler serverHandler = new TestHandler(serverException);
124
125 @DisabledIf("openSslNotAvailable")
126 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
127 @MethodSource("data")
128 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
129 public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx,
130 final boolean delegate, TestInfo testInfo) throws Throwable {
131 assumeTrue(OpenSsl.isRenegotiationSupported());
132 assumeTrue(OpenSsl.isAvailable());
133 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
134 @Override
135 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
136 testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate);
137 }
138 });
139 }
140
141 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
142 if (executor == null) {
143 return sslCtx.newHandler(allocator);
144 } else {
145 return sslCtx.newHandler(allocator, executor);
146 }
147 }
148
149 public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
150 final SslContext clientCtx, boolean delegate) throws Throwable {
151 reset();
152
153 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
154
155 try {
156 sb.childHandler(new ChannelInitializer<Channel>() {
157 @Override
158 @SuppressWarnings("deprecation")
159 public void initChannel(Channel sch) throws Exception {
160 serverChannel = sch;
161 serverSslHandler = newSslHandler(serverCtx, sch.alloc(), executorService);
162
163 serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
164 sch.pipeline().addLast("ssl", serverSslHandler);
165 sch.pipeline().addLast("handler", serverHandler);
166 }
167 });
168
169 cb.handler(new ChannelInitializer<Channel>() {
170 @Override
171 @SuppressWarnings("deprecation")
172 public void initChannel(Channel sch) throws Exception {
173 clientChannel = sch;
174 clientSslHandler = newSslHandler(clientCtx, sch.alloc(), executorService);
175
176 clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
177 sch.pipeline().addLast("ssl", clientSslHandler);
178 sch.pipeline().addLast("handler", clientHandler);
179 }
180 });
181
182 Channel sc = sb.bind().sync().channel();
183 cb.connect(sc.localAddress()).sync();
184
185 Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
186 clientHandshakeFuture.sync();
187
188 String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
189
190 clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
191 clientSslHandler.renegotiate().await();
192 serverChannel.close().awaitUninterruptibly();
193 clientChannel.close().awaitUninterruptibly();
194 sc.close().awaitUninterruptibly();
195 try {
196 if (serverException.get() != null) {
197 throw serverException.get();
198 }
199 fail();
200 } catch (DecoderException e) {
201 assertTrue(e.getCause() instanceof SSLHandshakeException);
202 }
203 if (clientException.get() != null) {
204 throw clientException.get();
205 }
206 } finally {
207 if (executorService != null) {
208 executorService.shutdown();
209 }
210 }
211 }
212
213 private void reset() {
214 clientException.set(null);
215 serverException.set(null);
216 clientHandler.handshakeCounter = 0;
217 serverHandler.handshakeCounter = 0;
218 clientChannel = null;
219 serverChannel = null;
220
221 clientSslHandler = null;
222 serverSslHandler = null;
223 }
224
225 @Sharable
226 private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
227
228 protected final AtomicReference<Throwable> exception;
229 private int handshakeCounter;
230
231 TestHandler(AtomicReference<Throwable> exception) {
232 this.exception = exception;
233 }
234
235 @Override
236 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
237 ctx.flush();
238 }
239
240 @Override
241 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242 exception.compareAndSet(null, cause);
243 ctx.close();
244 }
245
246 @Override
247 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
248 if (evt instanceof SslHandshakeCompletionEvent) {
249 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
250 if (handshakeCounter == 0) {
251 handshakeCounter++;
252 if (handshakeEvt.cause() != null) {
253 logger.warn("Handshake failed:", handshakeEvt.cause());
254 }
255 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
256 } else {
257 if (ctx.channel().parent() == null) {
258 assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
259 }
260 }
261 }
262 }
263
264 @Override
265 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { }
266 }
267 }