1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.testsuite.transport.socket;
17
18 import io.netty5.bootstrap.Bootstrap;
19 import io.netty5.bootstrap.ServerBootstrap;
20 import io.netty5.buffer.BufferUtil;
21 import io.netty5.buffer.api.Buffer;
22 import io.netty5.buffer.api.DefaultBufferAllocators;
23 import io.netty5.channel.Channel;
24 import io.netty5.channel.ChannelHandlerContext;
25 import io.netty5.channel.ChannelInitializer;
26 import io.netty5.channel.SimpleChannelInboundHandler;
27 import io.netty5.channel.socket.SocketChannel;
28 import io.netty5.handler.ssl.SslContext;
29 import io.netty5.handler.ssl.SslContextBuilder;
30 import io.netty5.handler.ssl.SslHandler;
31 import io.netty5.handler.ssl.SslProvider;
32 import io.netty5.handler.ssl.util.SelfSignedCertificate;
33 import io.netty5.util.internal.logging.InternalLogger;
34 import io.netty5.util.internal.logging.InternalLoggerFactory;
35 import org.junit.jupiter.api.TestInfo;
36 import org.junit.jupiter.api.Timeout;
37 import org.junit.jupiter.params.ParameterizedTest;
38 import org.junit.jupiter.params.provider.MethodSource;
39
40 import javax.net.ssl.SSLEngine;
41 import javax.net.ssl.SSLSessionContext;
42 import java.io.File;
43 import java.io.IOException;
44 import java.net.InetSocketAddress;
45 import java.security.cert.CertificateException;
46 import java.util.Collection;
47 import java.util.Collections;
48 import java.util.Enumeration;
49 import java.util.HashSet;
50 import java.util.Set;
51 import java.util.concurrent.TimeUnit;
52 import java.util.concurrent.atomic.AtomicReference;
53
54 import static org.junit.jupiter.api.Assertions.assertEquals;
55
56 public class SocketSslSessionReuseTest extends AbstractSocketTest {
57
58 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class);
59
60 private static final File CERT_FILE;
61 private static final File KEY_FILE;
62
63 static {
64 SelfSignedCertificate ssc;
65 try {
66 ssc = new SelfSignedCertificate();
67 } catch (CertificateException e) {
68 throw new Error(e);
69 }
70 CERT_FILE = ssc.certificate();
71 KEY_FILE = ssc.privateKey();
72 }
73
74 public static Collection<Object[]> data() throws Exception {
75 return Collections.singletonList(new Object[] {
76 SslContextBuilder.forServer(CERT_FILE, KEY_FILE).sslProvider(SslProvider.JDK).build(),
77 SslContextBuilder.forClient().trustManager(CERT_FILE).sslProvider(SslProvider.JDK).build()
78 });
79 }
80
81 @ParameterizedTest(name = "{index}: serverEngine = {0}, clientEngine = {1}")
82 @MethodSource("data")
83 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
84 public void testSslSessionReuse(SslContext serverCtx, SslContext clientCtx, TestInfo testInfo) throws Throwable {
85 run(testInfo, (sb, cb) -> testSslSessionReuse(sb, cb, serverCtx, clientCtx));
86 }
87
88 public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb,
89 SslContext serverCtx, SslContext clientCtx) throws Throwable {
90 final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true);
91 final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true);
92 final String[] protocols = { "TLSv1", "TLSv1.1", "TLSv1.2" };
93
94 sb.childHandler(new ChannelInitializer<SocketChannel>() {
95 @Override
96 protected void initChannel(SocketChannel sch) throws Exception {
97 SSLEngine engine = serverCtx.newEngine(sch.bufferAllocator());
98 engine.setUseClientMode(false);
99 engine.setEnabledProtocols(protocols);
100
101 sch.pipeline().addLast(new SslHandler(engine));
102 sch.pipeline().addLast(sh);
103 }
104 });
105 final Channel sc = sb.bind().asStage().get();
106
107 cb.handler(new ChannelInitializer<SocketChannel>() {
108 @Override
109 protected void initChannel(SocketChannel sch) throws Exception {
110 InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress();
111 SSLEngine engine = clientCtx.newEngine(
112 sch.bufferAllocator(), serverAddr.getHostString(), serverAddr.getPort());
113 engine.setUseClientMode(true);
114 engine.setEnabledProtocols(protocols);
115
116 sch.pipeline().addLast(new SslHandler(engine));
117 sch.pipeline().addLast(ch);
118 }
119 });
120
121 try {
122 SSLSessionContext clientSessionCtx = clientCtx.sessionContext();
123 Buffer msg = DefaultBufferAllocators.preferredAllocator().copyOf(new byte[] { 0xa, 0xb, 0xc, 0xd });
124 Channel cc = cb.connect(sc.localAddress()).asStage().get();
125 cc.writeAndFlush(msg).asStage().sync();
126 cc.closeFuture().asStage().sync();
127 rethrowHandlerExceptions(sh, ch);
128 Set<String> sessions = sessionIdSet(clientSessionCtx.getIds());
129
130 msg = DefaultBufferAllocators.preferredAllocator().copyOf(new byte[] { 0xa, 0xb, 0xc, 0xd });
131 cc = cb.connect(sc.localAddress()).asStage().get();
132 cc.writeAndFlush(msg).asStage().sync();
133 cc.closeFuture().asStage().sync();
134 assertEquals(sessions, sessionIdSet(clientSessionCtx.getIds()), "Expected no new sessions");
135 rethrowHandlerExceptions(sh, ch);
136 } finally {
137 sc.close().asStage().await();
138 }
139 }
140
141 private static void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable {
142 if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
143 throw sh.exception.get();
144 }
145 if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
146 throw ch.exception.get();
147 }
148 if (sh.exception.get() != null) {
149 throw sh.exception.get();
150 }
151 if (ch.exception.get() != null) {
152 throw ch.exception.get();
153 }
154 }
155
156 private static Set<String> sessionIdSet(Enumeration<byte[]> sessionIds) {
157 Set<String> idSet = new HashSet<>();
158 byte[] id;
159 while (sessionIds.hasMoreElements()) {
160 id = sessionIds.nextElement();
161 idSet.add(BufferUtil.hexDump(id));
162 }
163 return idSet;
164 }
165
166 private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler<Buffer> {
167 final AtomicReference<Throwable> exception = new AtomicReference<>();
168 private final boolean server;
169 private final boolean autoRead;
170
171 ReadAndDiscardHandler(boolean server, boolean autoRead) {
172 this.server = server;
173 this.autoRead = autoRead;
174 }
175
176 @Override
177 public boolean isSharable() {
178 return true;
179 }
180
181 @Override
182 public void messageReceived(ChannelHandlerContext ctx, Buffer in) throws Exception {
183 byte[] actual = new byte[in.readableBytes()];
184 in.readBytes(actual, 0, actual.length);
185 ctx.close();
186 }
187
188 @Override
189 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
190 try {
191 ctx.flush();
192 } finally {
193 if (!autoRead) {
194 ctx.read();
195 }
196 }
197 }
198
199 @Override
200 public void channelExceptionCaught(ChannelHandlerContext ctx,
201 Throwable cause) throws Exception {
202 if (logger.isWarnEnabled()) {
203 logger.warn(
204 "Unexpected exception from the " +
205 (server? "server" : "client") + " side", cause);
206 }
207
208 exception.compareAndSet(null, cause);
209 ctx.close();
210 }
211 }
212 }