1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.api.Buffer;
21 import io.netty5.buffer.api.BufferAllocator;
22 import io.netty5.channel.Channel;
23 import io.netty5.channel.ChannelHandlerContext;
24 import io.netty5.channel.ChannelInitializer;
25 import io.netty5.channel.SimpleChannelInboundHandler;
26 import io.netty5.handler.codec.DecoderException;
27 import io.netty5.handler.ssl.OpenSsl;
28 import io.netty5.handler.ssl.SslContext;
29 import io.netty5.handler.ssl.SslContextBuilder;
30 import io.netty5.handler.ssl.SslHandler;
31 import io.netty5.handler.ssl.SslHandshakeCompletionEvent;
32 import io.netty5.handler.ssl.SslProvider;
33 import io.netty5.handler.ssl.util.SelfSignedCertificate;
34 import io.netty5.util.concurrent.Future;
35 import io.netty5.util.internal.logging.InternalLogger;
36 import io.netty5.util.internal.logging.InternalLoggerFactory;
37 import org.junit.jupiter.api.TestInfo;
38 import org.junit.jupiter.api.Timeout;
39 import org.junit.jupiter.api.condition.DisabledIf;
40 import org.junit.jupiter.params.ParameterizedTest;
41 import org.junit.jupiter.params.provider.MethodSource;
42
43 import javax.net.ssl.SSLHandshakeException;
44 import java.io.File;
45 import java.nio.channels.ClosedChannelException;
46 import java.security.cert.CertificateException;
47 import java.util.ArrayList;
48 import java.util.Collection;
49 import java.util.List;
50 import java.util.concurrent.Executor;
51 import java.util.concurrent.ExecutorService;
52 import java.util.concurrent.Executors;
53 import java.util.concurrent.TimeUnit;
54 import java.util.concurrent.atomic.AtomicReference;
55
56 import static org.junit.jupiter.api.Assertions.assertTrue;
57 import static org.junit.jupiter.api.Assertions.fail;
58 import static org.junit.jupiter.api.Assumptions.assumeFalse;
59
60 public class SocketSslClientRenegotiateTest extends AbstractSocketTest {
61 private static final InternalLogger logger = InternalLoggerFactory.getInstance(
62 SocketSslClientRenegotiateTest.class);
63 private static final File CERT_FILE;
64 private static final File KEY_FILE;
65
66 static {
67 SelfSignedCertificate ssc;
68 try {
69 ssc = new SelfSignedCertificate();
70 } catch (CertificateException e) {
71 throw new Error(e);
72 }
73 CERT_FILE = ssc.certificate();
74 KEY_FILE = ssc.privateKey();
75 }
76
77 private static boolean openSslNotAvailable() {
78 return !OpenSsl.isAvailable();
79 }
80
81 public static Collection<Object[]> data() throws Exception {
82 List<SslContext> serverContexts = new ArrayList<>();
83 List<SslContext> clientContexts = new ArrayList<>();
84 clientContexts.add(
85 SslContextBuilder.forClient()
86 .trustManager(CERT_FILE)
87 .sslProvider(SslProvider.JDK)
88 .build()
89 );
90
91 boolean hasOpenSsl = OpenSsl.isAvailable();
92 if (hasOpenSsl) {
93 serverContexts.add(
94 SslContextBuilder.forServer(CERT_FILE, KEY_FILE)
95 .sslProvider(SslProvider.OPENSSL)
96 .build()
97 );
98 } else {
99 logger.warn("OpenSSL is unavailable and thus will not be tested.", OpenSsl.unavailabilityCause());
100 }
101
102 List<Object[]> params = new ArrayList<>();
103 for (SslContext sc: serverContexts) {
104 for (SslContext cc: clientContexts) {
105 for (int i = 0; i < 32; i++) {
106 params.add(new Object[] { sc, cc, true});
107 params.add(new Object[] { sc, cc, false});
108 }
109 }
110 }
111
112 return params;
113 }
114
115 private final AtomicReference<Throwable> clientException = new AtomicReference<>();
116 private final AtomicReference<Throwable> serverException = new AtomicReference<>();
117
118 private volatile Channel clientChannel;
119 private volatile Channel serverChannel;
120
121 private volatile SslHandler clientSslHandler;
122 private volatile SslHandler serverSslHandler;
123
124 private final TestHandler clientHandler = new TestHandler(clientException);
125
126 private final TestHandler serverHandler = new TestHandler(serverException);
127
128 @DisabledIf("openSslNotAvailable")
129 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}, delegate = {2}")
130 @MethodSource("data")
131 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
132 public void testSslRenegotiationRejected(SslContext serverCtx, SslContext clientCtx, boolean delegate,
133 TestInfo testInfo) throws Throwable {
134
135 assumeFalse("BoringSSL".equals(OpenSsl.versionString()));
136 run(testInfo, (sb, cb) -> testSslRenegotiationRejected(sb, cb, serverCtx, clientCtx, delegate));
137 }
138
139 private static SslHandler newSslHandler(SslContext sslCtx, BufferAllocator allocator, Executor executor) {
140 if (executor == null) {
141 return sslCtx.newHandler(allocator);
142 } else {
143 return sslCtx.newHandler(allocator, executor);
144 }
145 }
146
147 public void testSslRenegotiationRejected(ServerBootstrap sb, Bootstrap cb, SslContext serverCtx,
148 SslContext clientCtx, boolean delegate) throws Throwable {
149 reset();
150
151 final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null;
152
153 try {
154 sb.childHandler(new ChannelInitializer<>() {
155 @Override
156 public void initChannel(Channel sch) throws Exception {
157 serverChannel = sch;
158 serverSslHandler = newSslHandler(serverCtx, sch.bufferAllocator(), executorService);
159
160 serverSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
161 sch.pipeline().addLast("ssl", serverSslHandler);
162 sch.pipeline().addLast("handler", serverHandler);
163 }
164 });
165
166 cb.handler(new ChannelInitializer<>() {
167 @Override
168 public void initChannel(Channel sch) throws Exception {
169 clientChannel = sch;
170 clientSslHandler = newSslHandler(clientCtx, sch.bufferAllocator(), executorService);
171
172 clientSslHandler.engine().setEnabledProtocols(new String[]{"TLSv1.2"});
173 sch.pipeline().addLast("ssl", clientSslHandler);
174 sch.pipeline().addLast("handler", clientHandler);
175 }
176 });
177
178 Channel sc = sb.bind().asStage().get();
179 cb.connect(sc.localAddress()).asStage().sync();
180
181 Future<Channel> clientHandshakeFuture = clientSslHandler.handshakeFuture();
182 clientHandshakeFuture.asStage().sync();
183
184 String renegotiation = clientSslHandler.engine().getEnabledCipherSuites()[0];
185
186 clientSslHandler.engine().setEnabledCipherSuites(new String[]{renegotiation});
187 clientSslHandler.renegotiate().asStage().await();
188 serverChannel.close().asStage().await();
189 clientChannel.close().asStage().await();
190 sc.close().asStage().await();
191 try {
192 if (serverException.get() != null) {
193 throw serverException.get();
194 }
195 fail();
196 } catch (DecoderException e) {
197 assertTrue(e.getCause() instanceof SSLHandshakeException);
198 }
199 if (clientException.get() != null) {
200 throw clientException.get();
201 }
202 } finally {
203 if (executorService != null) {
204 executorService.shutdown();
205 }
206 }
207 }
208
209 private void reset() {
210 clientException.set(null);
211 serverException.set(null);
212 clientHandler.handshakeCounter = 0;
213 serverHandler.handshakeCounter = 0;
214 clientChannel = null;
215 serverChannel = null;
216
217 clientSslHandler = null;
218 serverSslHandler = null;
219 }
220
221 private static final class TestHandler extends SimpleChannelInboundHandler<Buffer> {
222
223 private final AtomicReference<Throwable> exception;
224 private int handshakeCounter;
225
226 TestHandler(AtomicReference<Throwable> exception) {
227 this.exception = exception;
228 }
229
230 @Override
231 public boolean isSharable() {
232 return true;
233 }
234
235 @Override
236 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
237 ctx.flush();
238 }
239
240 @Override
241 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242 exception.compareAndSet(null, cause);
243 ctx.close();
244 }
245
246 @Override
247 public void channelInboundEvent(ChannelHandlerContext ctx, Object evt) throws Exception {
248 if (evt instanceof SslHandshakeCompletionEvent) {
249 SslHandshakeCompletionEvent handshakeEvt = (SslHandshakeCompletionEvent) evt;
250 if (handshakeCounter == 0) {
251 handshakeCounter++;
252 if (handshakeEvt.cause() != null) {
253 logger.warn("Handshake failed:", handshakeEvt.cause());
254 }
255 assertTrue(handshakeEvt.isSuccess());
256 } else {
257 if (ctx.channel().parent() == null) {
258 assertTrue(handshakeEvt.cause() instanceof ClosedChannelException);
259 }
260 }
261 }
262 }
263
264 @Override
265 public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception { }
266 }
267 }