1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufAllocator;
20 import io.netty.buffer.ByteBufUtil;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.handler.codec.base64.Base64;
23 import io.netty.handler.codec.base64.Base64Dialect;
24 import io.netty.util.NetUtil;
25 import io.netty.util.internal.EmptyArrays;
26 import io.netty.util.internal.StringUtil;
27 import io.netty.util.internal.logging.InternalLogger;
28 import io.netty.util.internal.logging.InternalLoggerFactory;
29
30 import java.nio.ByteBuffer;
31 import java.nio.ByteOrder;
32 import java.security.KeyManagementException;
33 import java.security.NoSuchAlgorithmException;
34 import java.security.NoSuchProviderException;
35 import java.security.Provider;
36 import java.util.Collections;
37 import java.util.LinkedHashSet;
38 import java.util.List;
39 import java.util.Set;
40
41 import javax.net.ssl.SSLContext;
42 import javax.net.ssl.SSLHandshakeException;
43 import javax.net.ssl.TrustManager;
44
45 import static java.util.Arrays.asList;
46
47
48
49
50 final class SslUtils {
51 private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslUtils.class);
52
53
54 static final Set<String> TLSV13_CIPHERS = Collections.unmodifiableSet(new LinkedHashSet<String>(
55 asList("TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256",
56 "TLS_AES_128_GCM_SHA256", "TLS_AES_128_CCM_8_SHA256",
57 "TLS_AES_128_CCM_SHA256")));
58
59
60
61
62 static final int GMSSL_PROTOCOL_VERSION = 0x101;
63
64 static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL";
65
66
67
68
69 static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20;
70
71
72
73
74 static final int SSL_CONTENT_TYPE_ALERT = 21;
75
76
77
78
79 static final int SSL_CONTENT_TYPE_HANDSHAKE = 22;
80
81
82
83
84 static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23;
85
86
87
88
89 static final int SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT = 24;
90
91
92
93
94 static final int SSL_RECORD_HEADER_LENGTH = 5;
95
96
97
98
99 static final int NOT_ENOUGH_DATA = -1;
100
101
102
103
104 static final int NOT_ENCRYPTED = -2;
105
106 static final String[] DEFAULT_CIPHER_SUITES;
107 static final String[] DEFAULT_TLSV13_CIPHER_SUITES;
108 static final String[] TLSV13_CIPHER_SUITES = { "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384" };
109
110 private static final boolean TLSV1_3_JDK_SUPPORTED;
111 private static final boolean TLSV1_3_JDK_DEFAULT_ENABLED;
112
113 static {
114 TLSV1_3_JDK_SUPPORTED = isTLSv13SupportedByJDK0(null);
115 TLSV1_3_JDK_DEFAULT_ENABLED = isTLSv13EnabledByJDK0(null);
116 if (TLSV1_3_JDK_SUPPORTED) {
117 DEFAULT_TLSV13_CIPHER_SUITES = TLSV13_CIPHER_SUITES;
118 } else {
119 DEFAULT_TLSV13_CIPHER_SUITES = EmptyArrays.EMPTY_STRINGS;
120 }
121
122 Set<String> defaultCiphers = new LinkedHashSet<String>();
123
124 defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384");
125 defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256");
126 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256");
127 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384");
128 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA");
129
130 defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA");
131
132 defaultCiphers.add("TLS_RSA_WITH_AES_128_GCM_SHA256");
133 defaultCiphers.add("TLS_RSA_WITH_AES_128_CBC_SHA");
134
135 defaultCiphers.add("TLS_RSA_WITH_AES_256_CBC_SHA");
136
137 Collections.addAll(defaultCiphers, DEFAULT_TLSV13_CIPHER_SUITES);
138
139 DEFAULT_CIPHER_SUITES = defaultCiphers.toArray(EmptyArrays.EMPTY_STRINGS);
140 }
141
142
143
144
145 static boolean isTLSv13SupportedByJDK(Provider provider) {
146 if (provider == null) {
147 return TLSV1_3_JDK_SUPPORTED;
148 }
149 return isTLSv13SupportedByJDK0(provider);
150 }
151
152 private static boolean isTLSv13SupportedByJDK0(Provider provider) {
153 try {
154 return arrayContains(newInitContext(provider)
155 .getSupportedSSLParameters().getProtocols(), SslProtocols.TLS_v1_3);
156 } catch (Throwable cause) {
157 logger.debug("Unable to detect if JDK SSLEngine with provider {} supports TLSv1.3, assuming no",
158 provider, cause);
159 return false;
160 }
161 }
162
163
164
165
166 static boolean isTLSv13EnabledByJDK(Provider provider) {
167 if (provider == null) {
168 return TLSV1_3_JDK_DEFAULT_ENABLED;
169 }
170 return isTLSv13EnabledByJDK0(provider);
171 }
172
173 private static boolean isTLSv13EnabledByJDK0(Provider provider) {
174 try {
175 return arrayContains(newInitContext(provider)
176 .getDefaultSSLParameters().getProtocols(), SslProtocols.TLS_v1_3);
177 } catch (Throwable cause) {
178 logger.debug("Unable to detect if JDK SSLEngine with provider {} enables TLSv1.3 by default," +
179 " assuming no", provider, cause);
180 return false;
181 }
182 }
183
184 private static SSLContext newInitContext(Provider provider)
185 throws NoSuchAlgorithmException, KeyManagementException {
186 final SSLContext context;
187 if (provider == null) {
188 context = SSLContext.getInstance("TLS");
189 } else {
190 context = SSLContext.getInstance("TLS", provider);
191 }
192 context.init(null, new TrustManager[0], null);
193 return context;
194 }
195
196 static SSLContext getSSLContext(String provider)
197 throws NoSuchAlgorithmException, KeyManagementException, NoSuchProviderException {
198 final SSLContext context;
199 if (StringUtil.isNullOrEmpty(provider)) {
200 context = SSLContext.getInstance(getTlsVersion());
201 } else {
202 context = SSLContext.getInstance(getTlsVersion(), provider);
203 }
204 context.init(null, new TrustManager[0], null);
205 return context;
206 }
207
208 private static String getTlsVersion() {
209 return TLSV1_3_JDK_SUPPORTED ? SslProtocols.TLS_v1_3 : SslProtocols.TLS_v1_2;
210 }
211
212 static boolean arrayContains(String[] array, String value) {
213 for (String v: array) {
214 if (value.equals(v)) {
215 return true;
216 }
217 }
218 return false;
219 }
220
221
222
223
224 static void addIfSupported(Set<String> supported, List<String> enabled, String... names) {
225 for (String n: names) {
226 if (supported.contains(n)) {
227 enabled.add(n);
228 }
229 }
230 }
231
232 static void useFallbackCiphersIfDefaultIsEmpty(List<String> defaultCiphers, Iterable<String> fallbackCiphers) {
233 if (defaultCiphers.isEmpty()) {
234 for (String cipher : fallbackCiphers) {
235 if (cipher.startsWith("SSL_") || cipher.contains("_RC4_")) {
236 continue;
237 }
238 defaultCiphers.add(cipher);
239 }
240 }
241 }
242
243 static void useFallbackCiphersIfDefaultIsEmpty(List<String> defaultCiphers, String... fallbackCiphers) {
244 useFallbackCiphersIfDefaultIsEmpty(defaultCiphers, asList(fallbackCiphers));
245 }
246
247
248
249
250 static SSLHandshakeException toSSLHandshakeException(Throwable e) {
251 if (e instanceof SSLHandshakeException) {
252 return (SSLHandshakeException) e;
253 }
254
255 return (SSLHandshakeException) new SSLHandshakeException(e.getMessage()).initCause(e);
256 }
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275 static int getEncryptedPacketLength(ByteBuf buffer, int offset) {
276 int packetLength = 0;
277
278
279 boolean tls;
280 switch (buffer.getUnsignedByte(offset)) {
281 case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
282 case SSL_CONTENT_TYPE_ALERT:
283 case SSL_CONTENT_TYPE_HANDSHAKE:
284 case SSL_CONTENT_TYPE_APPLICATION_DATA:
285 case SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT:
286 tls = true;
287 break;
288 default:
289
290 tls = false;
291 }
292
293 if (tls) {
294
295 int majorVersion = buffer.getUnsignedByte(offset + 1);
296 if (majorVersion == 3 || buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) {
297
298 packetLength = unsignedShortBE(buffer, offset + 3) + SSL_RECORD_HEADER_LENGTH;
299 if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
300
301 tls = false;
302 }
303 } else {
304
305 tls = false;
306 }
307 }
308
309 if (!tls) {
310
311 int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3;
312 int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1);
313 if (majorVersion == 2 || majorVersion == 3) {
314
315 packetLength = headerLength == 2 ?
316 (shortBE(buffer, offset) & 0x7FFF) + 2 : (shortBE(buffer, offset) & 0x3FFF) + 3;
317 if (packetLength <= headerLength) {
318 return NOT_ENOUGH_DATA;
319 }
320 } else {
321 return NOT_ENCRYPTED;
322 }
323 }
324 return packetLength;
325 }
326
327
328 @SuppressWarnings("deprecation")
329 private static int unsignedShortBE(ByteBuf buffer, int offset) {
330 int value = buffer.getUnsignedShort(offset);
331 if (buffer.order() == ByteOrder.LITTLE_ENDIAN) {
332 value = Integer.reverseBytes(value) >>> Short.SIZE;
333 }
334 return value;
335 }
336
337
338 @SuppressWarnings("deprecation")
339 private static short shortBE(ByteBuf buffer, int offset) {
340 short value = buffer.getShort(offset);
341 if (buffer.order() == ByteOrder.LITTLE_ENDIAN) {
342 value = Short.reverseBytes(value);
343 }
344 return value;
345 }
346
347 private static short unsignedByte(byte b) {
348 return (short) (b & 0xFF);
349 }
350
351
352 private static int unsignedShortBE(ByteBuffer buffer, int offset) {
353 return shortBE(buffer, offset) & 0xFFFF;
354 }
355
356
357 private static short shortBE(ByteBuffer buffer, int offset) {
358 return buffer.order() == ByteOrder.BIG_ENDIAN ?
359 buffer.getShort(offset) : ByteBufUtil.swapShort(buffer.getShort(offset));
360 }
361
362 static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) {
363 ByteBuffer buffer = buffers[offset];
364
365
366 if (buffer.remaining() >= SSL_RECORD_HEADER_LENGTH) {
367 return getEncryptedPacketLength(buffer);
368 }
369
370
371 ByteBuffer tmp = ByteBuffer.allocate(5);
372
373 do {
374 buffer = buffers[offset++].duplicate();
375 if (buffer.remaining() > tmp.remaining()) {
376 buffer.limit(buffer.position() + tmp.remaining());
377 }
378 tmp.put(buffer);
379 } while (tmp.hasRemaining());
380
381
382 tmp.flip();
383 return getEncryptedPacketLength(tmp);
384 }
385
386 private static int getEncryptedPacketLength(ByteBuffer buffer) {
387 int packetLength = 0;
388 int pos = buffer.position();
389
390 boolean tls;
391 switch (unsignedByte(buffer.get(pos))) {
392 case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
393 case SSL_CONTENT_TYPE_ALERT:
394 case SSL_CONTENT_TYPE_HANDSHAKE:
395 case SSL_CONTENT_TYPE_APPLICATION_DATA:
396 case SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT:
397 tls = true;
398 break;
399 default:
400
401 tls = false;
402 }
403
404 if (tls) {
405
406 int majorVersion = unsignedByte(buffer.get(pos + 1));
407 if (majorVersion == 3 || buffer.getShort(pos + 1) == GMSSL_PROTOCOL_VERSION) {
408
409 packetLength = unsignedShortBE(buffer, pos + 3) + SSL_RECORD_HEADER_LENGTH;
410 if (packetLength <= SSL_RECORD_HEADER_LENGTH) {
411
412 tls = false;
413 }
414 } else {
415
416 tls = false;
417 }
418 }
419
420 if (!tls) {
421
422 int headerLength = (unsignedByte(buffer.get(pos)) & 0x80) != 0 ? 2 : 3;
423 int majorVersion = unsignedByte(buffer.get(pos + headerLength + 1));
424 if (majorVersion == 2 || majorVersion == 3) {
425
426 packetLength = headerLength == 2 ?
427 (shortBE(buffer, pos) & 0x7FFF) + 2 : (shortBE(buffer, pos) & 0x3FFF) + 3;
428 if (packetLength <= headerLength) {
429 return NOT_ENOUGH_DATA;
430 }
431 } else {
432 return NOT_ENCRYPTED;
433 }
434 }
435 return packetLength;
436 }
437
438 static void handleHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean notify) {
439
440
441 ctx.flush();
442 if (notify) {
443 ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause));
444 }
445 ctx.close();
446 }
447
448
449
450
451 static void zeroout(ByteBuf buffer) {
452 if (!buffer.isReadOnly()) {
453 buffer.setZero(0, buffer.capacity());
454 }
455 }
456
457
458
459
460 static void zerooutAndRelease(ByteBuf buffer) {
461 zeroout(buffer);
462 buffer.release();
463 }
464
465
466
467
468
469
470 static ByteBuf toBase64(ByteBufAllocator allocator, ByteBuf src) {
471 ByteBuf dst = Base64.encode(src, src.readerIndex(),
472 src.readableBytes(), true, Base64Dialect.STANDARD, allocator);
473 src.readerIndex(src.writerIndex());
474 return dst;
475 }
476
477
478
479
480 static boolean isValidHostNameForSNI(String hostname) {
481
482 return hostname != null &&
483
484
485 hostname.indexOf('.') > 0 &&
486 !hostname.endsWith(".") && !hostname.startsWith("/") &&
487 !NetUtil.isValidIpV4Address(hostname) &&
488 !NetUtil.isValidIpV6Address(hostname);
489 }
490
491
492
493
494 static boolean isTLSv13Cipher(String cipher) {
495
496 return TLSV13_CIPHERS.contains(cipher);
497 }
498
499 private SslUtils() {
500 }
501 }