1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.ssl;
17
18 import io.netty5.buffer.api.Buffer;
19 import io.netty5.buffer.api.BufferAllocator;
20 import io.netty5.channel.ChannelHandlerContext;
21 import io.netty5.handler.codec.base64.Base64;
22 import io.netty5.handler.codec.base64.Base64Dialect;
23 import io.netty5.util.NetUtil;
24 import io.netty5.util.internal.EmptyArrays;
25 import io.netty5.util.internal.StringUtil;
26 import io.netty5.util.internal.logging.InternalLogger;
27 import io.netty5.util.internal.logging.InternalLoggerFactory;
28
29 import javax.net.ssl.SSLContext;
30 import javax.net.ssl.SSLHandshakeException;
31 import javax.net.ssl.SSLSession;
32 import javax.net.ssl.TrustManager;
33 import java.nio.ByteBuffer;
34 import java.nio.ByteOrder;
35 import java.security.KeyManagementException;
36 import java.security.NoSuchAlgorithmException;
37 import java.security.NoSuchProviderException;
38 import java.security.Provider;
39 import java.util.Collections;
40 import java.util.LinkedHashSet;
41 import java.util.List;
42 import java.util.Set;
43
44 import static java.util.Arrays.asList;
45
46
47
48
49 final class SslUtils {
50 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslUtils.class);
51
52
53 static final Set<String> TLSV13_CIPHERS = Collections.unmodifiableSet(new LinkedHashSet<>(
54 asList("TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256",
55 "TLS_AES_128_GCM_SHA256", "TLS_AES_128_CCM_8_SHA256",
56 "TLS_AES_128_CCM_SHA256")));
57
58
59
60
61 static final int GMSSL_PROTOCOL_VERSION = 0x101;
62
63 static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL";
64
65
66
67
68 static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
69
70
71
72
73 static final int SSL_CONTENT_TYPE_ALERT = 21;
74
75
76
77
78 static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
79
80
81
82
83 static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
84
85
86
87
88 static final int SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT = 24;
89
90
91
92
93 static final int SSL_RECORD_HEADER_LENGTH = 5;
94
95
96
97
98 static final int NOT_ENOUGH_DATA = -1;
99
100
101
102
103 static final int NOT_ENCRYPTED = -2;
104
105 static final String[] DEFAULT_CIPHER_SUITES;
106 static final String[] DEFAULT_TLSV13_CIPHER_SUITES;
107 static final String[] TLSV13_CIPHER_SUITES = { "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384" };
108
109 private static final boolean TLSV1_3_JDK_SUPPORTED;
110 private static final boolean TLSV1_3_JDK_DEFAULT_ENABLED;
111 static final TrustManager[] EMPTY_TRUST_MANAGERS = new TrustManager[0];
112
113 static {
114 TLSV1_3_JDK_SUPPORTED = isTLSv13SupportedByJDK0(null);
115 TLSV1_3_JDK_DEFAULT_ENABLED = isTLSv13EnabledByJDK0(null);
116 if (TLSV1_3_JDK_SUPPORTED) {
117 DEFAULT_TLSV13_CIPHER_SUITES = TLSV13_CIPHER_SUITES;
118 } else {
119 DEFAULT_TLSV13_CIPHER_SUITES = EmptyArrays.EMPTY_STRINGS;
120 }
121
122 Set<String> defaultCiphers = new LinkedHashSet<>();
123
124 defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384");
125 defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256");
126 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256");
127 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384");
128 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA");
129
130 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA");
131
132 defaultCiphers.add("TLS_RSA_WITH_AES_128_GCM_SHA256");
133 defaultCiphers.add("TLS_RSA_WITH_AES_128_CBC_SHA");
134
135 defaultCiphers.add("TLS_RSA_WITH_AES_256_CBC_SHA");
136
137 Collections.addAll(defaultCiphers, DEFAULT_TLSV13_CIPHER_SUITES);
138
139 DEFAULT_CIPHER_SUITES = defaultCiphers.toArray(EmptyArrays.EMPTY_STRINGS);
140 }
141
142
143
144
145 static boolean isTLSv13SupportedByJDK(Provider provider) {
146 if (provider == null) {
147 return TLSV1_3_JDK_SUPPORTED;
148 }
149 return isTLSv13SupportedByJDK0(provider);
150 }
151
152 private static boolean isTLSv13SupportedByJDK0(Provider provider) {
153 try {
154 return arrayContains(newInitContext(provider)
155 .getSupportedSSLParameters().getProtocols(), SslProtocols.TLS_v1_3);
156 } catch (Throwable cause) {
157 logger.debug("Unable to detect if JDK SSLEngine with provider {} supports TLSv1.3, assuming no",
158 provider, cause);
159 return false;
160 }
161 }
162
163
164
165
166 static boolean isTLSv13EnabledByJDK(Provider provider) {
167 if (provider == null) {
168 return TLSV1_3_JDK_DEFAULT_ENABLED;
169 }
170 return isTLSv13EnabledByJDK0(provider);
171 }
172
173 private static boolean isTLSv13EnabledByJDK0(Provider provider) {
174 try {
175 return arrayContains(newInitContext(provider)
176 .getDefaultSSLParameters().getProtocols(), SslProtocols.TLS_v1_3);
177 } catch (Throwable cause) {
178 logger.debug("Unable to detect if JDK SSLEngine with provider {} enables TLSv1.3 by default," +
179 " assuming no", provider, cause);
180 return false;
181 }
182 }
183
184 private static SSLContext newInitContext(Provider provider)
185 throws NoSuchAlgorithmException, KeyManagementException {
186 final SSLContext context;
187 if (provider == null) {
188 context = SSLContext.getInstance("TLS");
189 } else {
190 context = SSLContext.getInstance("TLS", provider);
191 }
192 context.init(null, EMPTY_TRUST_MANAGERS, null);
193 return context;
194 }
195
196 static SSLContext getSSLContext(String provider)
197 throws NoSuchAlgorithmException, KeyManagementException, NoSuchProviderException {
198 final SSLContext context;
199 if (StringUtil.isNullOrEmpty(provider)) {
200 context = SSLContext.getInstance(getTlsVersion());
201 } else {
202 context = SSLContext.getInstance(getTlsVersion(), provider);
203 }
204 context.init(null, EMPTY_TRUST_MANAGERS, null);
205 return context;
206 }
207
208 private static String getTlsVersion() {
209 return TLSV1_3_JDK_SUPPORTED ? SslProtocols.TLS_v1_3 : SslProtocols.TLS_v1_2;
210 }
211
212 static boolean arrayContains(String[] array, String value) {
213 for (String v: array) {
214 if (value.equals(v)) {
215 return true;
216 }
217 }
218 return false;
219 }
220
221
222
223
224 static void addIfSupported(Set<String> supported, List<String> enabled, String... names) {
225 for (String n: names) {
226 if (supported.contains(n)) {
227 enabled.add(n);
228 }
229 }
230 }
231
232 static void useFallbackCiphersIfDefaultIsEmpty(List<String> defaultCiphers, Iterable<String> fallbackCiphers) {
233 if (defaultCiphers.isEmpty()) {
234 for (String cipher : fallbackCiphers) {
235 if (cipher.startsWith("SSL_") || cipher.contains("_RC4_")) {
236 continue;
237 }
238 defaultCiphers.add(cipher);
239 }
240 }
241 }
242
243 static void useFallbackCiphersIfDefaultIsEmpty(List<String> defaultCiphers, String... fallbackCiphers) {
244 useFallbackCiphersIfDefaultIsEmpty(defaultCiphers, asList(fallbackCiphers));
245 }
246
247
248
249
250 static SSLHandshakeException toSSLHandshakeException(Throwable e) {
251 if (e instanceof SSLHandshakeException) {
252 return (SSLHandshakeException) e;
253 }
254
255 return (SSLHandshakeException) new SSLHandshakeException(e.getMessage()).initCause(e);
256 }
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275 static int getEncryptedPacketLength(Buffer buffer, int offset) {
276 int packetLength = 0;
277
278
279 boolean tls;
280 switch (buffer.getUnsignedByte(offset)) {
281 case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
282 case SSL_CONTENT_TYPE_ALERT:
283 case SSL_CONTENT_TYPE_HANDSHAKE:
284 case SSL_CONTENT_TYPE_APPLICATION_DATA:
285 case SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT:
286 tls = true;
287 break;
288 default:
289
290 tls = false;
291 }
292
293 if (tls) {
294
295 int majorVersion = buffer.getUnsignedByte(offset + 1);
296 if (majorVersion == 3 || buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) {
297
298 packetLength = buffer.getUnsignedShort(offset + 3) + SSL_RECORD_HEADER_LENGTH;
299 if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
300
301 tls = false;
302 }
303 } else {
304
305 tls = false;
306 }
307 }
308
309 if (!tls) {
310
311 int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
312 int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
313 if (majorVersion == 2 || majorVersion == 3) {
314
315 packetLength = headerLength == 2 ?
316 (buffer.getShort(offset) & 0x7FFF) + 2 : (buffer.getShort(offset) & 0x3FFF) + 3;
317 if (packetLength <= headerLength) {
318 return NOT_ENOUGH_DATA;
319 }
320 } else {
321 return NOT_ENCRYPTED;
322 }
323 }
324 return packetLength;
325 }
326
327 private static short unsignedByte(byte b) {
328 return (short) (b & 0xFF);
329 }
330
331
332 private static int unsignedShortBE(ByteBuffer buffer, int offset) {
333 return shortBE(buffer, offset) & 0xFFFF;
334 }
335
336
337 private static short shortBE(ByteBuffer buffer, int offset) {
338 return buffer.order() == ByteOrder.BIG_ENDIAN ?
339 buffer.getShort(offset) : Short.reverseBytes(buffer.getShort(offset));
340 }
341
342 static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) {
343 ByteBuffer buffer = buffers[offset];
344
345
346 if (buffer.remaining() >= SSL_RECORD_HEADER_LENGTH) {
347 return getEncryptedPacketLength(buffer);
348 }
349
350
351 ByteBuffer tmp = ByteBuffer.allocate(5);
352
353 do {
354 buffer = buffers[offset++].duplicate();
355 if (buffer.remaining() > tmp.remaining()) {
356 buffer.limit(buffer.position() + tmp.remaining());
357 }
358 tmp.put(buffer);
359 } while (tmp.hasRemaining());
360
361
362 tmp.flip();
363 return getEncryptedPacketLength(tmp);
364 }
365
366 private static int getEncryptedPacketLength(ByteBuffer buffer) {
367 int packetLength = 0;
368 int pos = buffer.position();
369
370 boolean tls;
371 switch (unsignedByte(buffer.get(pos))) {
372 case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
373 case SSL_CONTENT_TYPE_ALERT:
374 case SSL_CONTENT_TYPE_HANDSHAKE:
375 case SSL_CONTENT_TYPE_APPLICATION_DATA:
376 case SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT:
377 tls = true;
378 break;
379 default:
380
381 tls = false;
382 }
383
384 if (tls) {
385
386 int majorVersion = unsignedByte(buffer.get(pos + 1));
387 if (majorVersion == 3 || buffer.getShort(pos + 1) == GMSSL_PROTOCOL_VERSION) {
388
389 packetLength = unsignedShortBE(buffer, pos + 3) + SSL_RECORD_HEADER_LENGTH;
390 if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
391
392 tls = false;
393 }
394 } else {
395
396 tls = false;
397 }
398 }
399
400 if (!tls) {
401
402 int headerLength = (unsignedByte(buffer.get(pos)) & 0x80) != 0 ? 2 : 3;
403 int majorVersion = unsignedByte(buffer.get(pos + headerLength + 1));
404 if (majorVersion == 2 || majorVersion == 3) {
405
406 packetLength = headerLength == 2 ?
407 (shortBE(buffer, pos) & 0x7FFF) + 2 : (shortBE(buffer, pos) & 0x3FFF) + 3;
408 if (packetLength <= headerLength) {
409 return NOT_ENOUGH_DATA;
410 }
411 } else {
412 return NOT_ENCRYPTED;
413 }
414 }
415 return packetLength;
416 }
417
418 static void handleHandshakeFailure(ChannelHandlerContext ctx, SSLSession session, String applicationProtocol,
419 Throwable cause, boolean notify) {
420
421
422 ctx.flush();
423 if (notify) {
424 ctx.fireChannelInboundEvent(new SslHandshakeCompletionEvent(session, applicationProtocol, cause));
425 }
426 }
427
428
429
430
431
432
433
434 static Buffer toBase64(BufferAllocator allocator, Buffer src) {
435 Buffer dst = Base64.encode(src, src.readerOffset(),
436 src.readableBytes(), true, Base64Dialect.STANDARD, allocator);
437 src.readerOffset(src.writerOffset());
438 return dst;
439 }
440
441
442
443
444 static boolean isValidHostNameForSNI(String hostname) {
445
446 return hostname != null &&
447
448
449 hostname.indexOf('.') > 0 &&
450 !hostname.endsWith(".") && !hostname.startsWith("/") &&
451 !NetUtil.isValidIpV4Address(hostname) &&
452 !NetUtil.isValidIpV6Address(hostname);
453 }
454
455
456
457
458 static boolean isTLSv13Cipher(String cipher) {
459
460 return TLSV13_CIPHERS.contains(cipher);
461 }
462
463 private SslUtils() {
464 }
465 }