1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.channel.Channel;
23 import io.netty.channel.ChannelHandler.Sharable;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelInitializer;
26 import io.netty.channel.SimpleChannelInboundHandler;
27 import io.netty.handler.codec.DecoderException;
28 import io.netty.handler.ssl.JdkSslClientContext;
29 import io.netty.handler.ssl.OpenSsl;
30 import io.netty.handler.ssl.OpenSslServerContext;
31 import io.netty.handler.ssl.SslContext;
32 import io.netty.handler.ssl.SslHandler;
33 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
34 import io.netty.pkitesting.CertificateBuilder;
35 import io.netty.pkitesting.X509Bundle;
36 import io.netty.util.concurrent.Future;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39 import org.junit.jupiter.api.TestInfo;
40 import org.junit.jupiter.api.Timeout;
41 import org.junit.jupiter.api.condition.DisabledIf;
42 import org.junit.jupiter.params.ParameterizedTest;
43 import org.junit.jupiter.params.provider.MethodSource;
44
45 import java.io.File;
46 import java.nio.channels.ClosedChannelException;
47 import java.util.ArrayList;
48 import java.util.Collection;
49 import java.util.List;
50 import java.util.concurrent.Executor;
51 import java.util.concurrent.ExecutorService;
52 import java.util.concurrent.Executors;
53 import java.util.concurrent.TimeUnit;
54 import java.util.concurrent.atomic.AtomicReference;
55
56 import javax.net.ssl.SSLHandshakeException;
57
58 import static org.junit.jupiter.api.Assertions.assertSame;
59 import static org.junit.jupiter.api.Assertions.assertTrue;
60 import static org.junit.jupiter.api.Assertions.fail;
61 import static org.junit.jupiter.api.Assumptions.assumeFalse;
62 import static org.junit.jupiter.api.Assumptions.assumeTrue;
63
64 public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
65 private static final InternalLogger logger = InternalLoggerFactory.getInstance(
66 SocketSslClientRenegotiateTest.class);
67 private static final File CERT_FILE;
68 private static final File KEY_FILE;
69
70 static {
71 try {
72 X509Bundle cert = new CertificateBuilder()
73 .subject("cn=localhost")
74 .setIsCertificateAuthority(true)
75 .buildSelfSigned();
76 CERT_FILE = cert.toTempCertChainPem();
77 KEY_FILE = cert.toTempPrivateKeyPem();
78 } catch (Exception e) {
79 throw new ExceptionInInitializerError(e);
80 }
81 }
82
83 private static boolean openSslNotAvailable() {
84 return !OpenSsl.isAvailable();
85 }
86
87 public static Collection<Object[]> data() throws Exception {
88 List<SslContext> serverContexts = new ArrayList<SslContext>();
89 List<SslContext> clientContexts = new ArrayList<SslContext>();
90 clientContexts.add(new JdkSslClientContext(CERT_FILE));
91
92 boolean hasOpenSsl = OpenSsl.isAvailable();
93 if (hasOpenSsl) {
94 OpenSslServerContext context = new OpenSslServerContext(CERT_FILE, KEY_FILE);
95 serverContexts.add(context);
96 } else {
97 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
98 }
99
100 List<Object[]> params = new ArrayList<Object[]>();
101 for (SslContext sc: serverContexts) {
102 for (SslContext cc: clientContexts) {
103 for (int i = 0; i < 32; i++) {
104 params.add(new Object[] { sc, cc, true});
105 params.add(new Object[] { sc, cc, false});
106 }
107 }
108 }
109
110 return params;
111 }
112
113 private final AtomicReference<Throwable> clientException = new AtomicReference<Throwable>();
114 private final AtomicReference<Throwable> serverException = new AtomicReference<Throwable>();
115
116 private volatile Channel clientChannel;
117 private volatile Channel serverChannel;
118
119 private volatile SslHandler clientSslHandler;
120 private volatile SslHandler serverSslHandler;
121
122 private final TestHandler clientHandler = new TestHandler(clientException);
123
124 private final TestHandler serverHandler = new TestHandler(serverException);
125
126 @DisabledIf("openSslNotAvailable")
127 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
128 @MethodSource("data")
129 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
130 public void testSslRenegotiationRejected(final SslContext serverCtx, final SslContext clientCtx,
131 final boolean delegate, TestInfo testInfo) throws Throwable {
132
133 assumeFalse("BoringSSL".equals(OpenSsl.versionString()));
134 assumeTrue(OpenSsl.isAvailable());
135 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
136 @Override
137 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
138 testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate);
139 }
140 });
141 }
142
143 private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) {
144 if (executor == null) {
145 return sslCtx.newHandler(allocator);
146 } else {
147 return sslCtx.newHandler(allocator, executor);
148 }
149 }
150
151 public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, final SslContext serverCtx,
152 final SslContext clientCtx, boolean delegate) throws Throwable {
153 reset();
154
155 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
156
157 try {
158 sb.childHandler(new ChannelInitializer<Channel>() {
159 @Override
160 @SuppressWarnings("deprecation")
161 public void initChannel(Channel sch) throws Exception {
162 serverChannel = sch;
163 serverSslHandler = newSslHandler(serverCtx, sch.alloc(), executorService);
164
165 serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
166 sch.pipeline().addLast("ssl", serverSslHandler);
167 sch.pipeline().addLast("handler", serverHandler);
168 }
169 });
170
171 cb.handler(new ChannelInitializer<Channel>() {
172 @Override
173 @SuppressWarnings("deprecation")
174 public void initChannel(Channel sch) throws Exception {
175 clientChannel = sch;
176 clientSslHandler = newSslHandler(clientCtx, sch.alloc(), executorService);
177
178 clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
179 sch.pipeline().addLast("ssl", clientSslHandler);
180 sch.pipeline().addLast("handler", clientHandler);
181 }
182 });
183
184 Channel sc = sb.bind().sync().channel();
185 cb.connect(sc.localAddress()).sync();
186
187 Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
188 clientHandshakeFuture.sync();
189
190 String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
191
192 clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
193 clientSslHandler.renegotiate().await();
194 serverChannel.close().awaitUninterruptibly();
195 clientChannel.close().awaitUninterruptibly();
196 sc.close().awaitUninterruptibly();
197 try {
198 if (serverException.get() != null) {
199 throw serverException.get();
200 }
201 fail();
202 } catch (DecoderException e) {
203 assertTrue(e.getCause() instanceof SSLHandshakeException);
204 }
205 if (clientException.get() != null) {
206 throw clientException.get();
207 }
208 } finally {
209 if (executorService != null) {
210 executorService.shutdown();
211 }
212 }
213 }
214
215 private void reset() {
216 clientException.set(null);
217 serverException.set(null);
218 clientHandler.handshakeCounter = 0;
219 serverHandler.handshakeCounter = 0;
220 clientChannel = null;
221 serverChannel = null;
222
223 clientSslHandler = null;
224 serverSslHandler = null;
225 }
226
227 @Sharable
228 private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
229
230 protected final AtomicReference<Throwable> exception;
231 private int handshakeCounter;
232
233 TestHandler(AtomicReference<Throwable> exception) {
234 this.exception = exception;
235 }
236
237 @Override
238 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
239 ctx.flush();
240 }
241
242 @Override
243 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
244 exception.compareAndSet(null, cause);
245 ctx.close();
246 }
247
248 @Override
249 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
250 if (evt instanceof SslHandshakeCompletionEvent) {
251 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
252 if (handshakeCounter == 0) {
253 handshakeCounter++;
254 if (handshakeEvt.cause() != null) {
255 logger.warn("Handshake failed:", handshakeEvt.cause());
256 }
257 assertSame(SslHandshakeCompletionEvent.SUCCESS, evt);
258 } else {
259 if (ctx.channel().parent() == null) {
260 assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
261 }
262 }
263 }
264 }
265
266 @Override
267 public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception { }
268 }
269 }