1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.handler.codec.ByteToMessageDecoder;
22 import io.netty.util.CharsetUtil;
23 import io.netty.util.DomainNameMapping;
24 import io.netty.util.ReferenceCountUtil;
25 import io.netty.util.internal.PlatformDependent;
26 import io.netty.util.internal.logging.InternalLogger;
27 import io.netty.util.internal.logging.InternalLoggerFactory;
28
29 import java.net.IDN;
30 import java.util.List;
31 import java.util.Locale;
32
33
34
35
36
37
38
39
40 public class SniHandler extends ByteToMessageDecoder {
41
42
43 private static final int MAX_SSL_RECORDS = 4;
44
45 private static final InternalLogger logger =
46 InternalLoggerFactory.getInstance(SniHandler.class);
47
48 private static final Selection EMPTY_SELECTION = new Selection(null, null);
49
50 private final DomainNameMapping<SslContext> mapping;
51
52 private boolean handshakeFailed;
53
54 private volatile Selection selection = EMPTY_SELECTION;
55
56
57
58
59
60
61
62 @SuppressWarnings("unchecked")
63 public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
64 if (mapping == null) {
65 throw new NullPointerException("mapping");
66 }
67
68 this.mapping = (DomainNameMapping<SslContext>) mapping;
69 }
70
71
72
73
74 public String hostname() {
75 return selection.hostname;
76 }
77
78
79
80
81 public SslContext sslContext() {
82 return selection.context;
83 }
84
85 @Override
86 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
87 if (!handshakeFailed) {
88 final int writerIndex = in.writerIndex();
89 try {
90 loop:
91 for (int i = 0; i < MAX_SSL_RECORDS; i++) {
92 final int readerIndex = in.readerIndex();
93 final int readableBytes = writerIndex - readerIndex;
94 if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) {
95
96 return;
97 }
98
99 final int command = in.getUnsignedByte(readerIndex);
100
101
102 switch (command) {
103 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
104 case SslUtils.SSL_CONTENT_TYPE_ALERT:
105 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
106
107
108 if (len == SslUtils.NOT_ENCRYPTED) {
109 handshakeFailed = true;
110 NotSslRecordException e = new NotSslRecordException(
111 "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
112 in.skipBytes(in.readableBytes());
113
114 SslUtils.notifyHandshakeFailure(ctx, e, true);
115 throw e;
116 }
117 if (len == SslUtils.NOT_ENOUGH_DATA ||
118 writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
119
120 return;
121 }
122
123 in.skipBytes(len);
124 continue;
125 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
126 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
127
128
129 if (majorVersion == 3) {
130 final int packetLength = in.getUnsignedShort(readerIndex + 3) +
131 SslUtils.SSL_RECORD_HEADER_LENGTH;
132
133 if (readableBytes < packetLength) {
134
135 return;
136 }
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158 final int endOffset = readerIndex + packetLength;
159 int offset = readerIndex + 43;
160
161 if (endOffset - offset < 6) {
162 break loop;
163 }
164
165 final int sessionIdLength = in.getUnsignedByte(offset);
166 offset += sessionIdLength + 1;
167
168 final int cipherSuitesLength = in.getUnsignedShort(offset);
169 offset += cipherSuitesLength + 2;
170
171 final int compressionMethodLength = in.getUnsignedByte(offset);
172 offset += compressionMethodLength + 1;
173
174 final int extensionsLength = in.getUnsignedShort(offset);
175 offset += 2;
176 final int extensionsLimit = offset + extensionsLength;
177
178 if (extensionsLimit > endOffset) {
179
180 break loop;
181 }
182
183 for (;;) {
184 if (extensionsLimit - offset < 4) {
185 break loop;
186 }
187
188 final int extensionType = in.getUnsignedShort(offset);
189 offset += 2;
190
191 final int extensionLength = in.getUnsignedShort(offset);
192 offset += 2;
193
194 if (extensionsLimit - offset < extensionLength) {
195 break loop;
196 }
197
198
199
200 if (extensionType == 0) {
201 offset += 2;
202 if (extensionsLimit - offset < 3) {
203 break loop;
204 }
205
206 final int serverNameType = in.getUnsignedByte(offset);
207 offset++;
208
209 if (serverNameType == 0) {
210 final int serverNameLength = in.getUnsignedShort(offset);
211 offset += 2;
212
213 if (extensionsLimit - offset < serverNameLength) {
214 break loop;
215 }
216
217 final String hostname = in.toString(offset, serverNameLength,
218 CharsetUtil.UTF_8);
219
220 try {
221 select(ctx, IDN.toASCII(hostname,
222 IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
223 } catch (Throwable t) {
224 PlatformDependent.throwException(t);
225 }
226 return;
227 } else {
228
229 break loop;
230 }
231 }
232
233 offset += extensionLength;
234 }
235 }
236
237 default:
238
239 break loop;
240 }
241 }
242 } catch (NotSslRecordException e) {
243
244 throw e;
245 } catch (Exception e) {
246
247 if (logger.isDebugEnabled()) {
248 logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
249 }
250 }
251
252 select(ctx, null);
253 }
254 }
255
256 private void select(ChannelHandlerContext ctx, String hostname) {
257 SslHandler sslHandler = null;
258 SslContext selectedContext = mapping.map(hostname);
259 selection = new Selection(selectedContext, hostname);
260 try {
261 sslHandler = selection.context.newHandler(ctx.alloc());
262 ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
263 } catch (Throwable cause) {
264 selection = EMPTY_SELECTION;
265
266
267
268 if (sslHandler != null) {
269 ReferenceCountUtil.safeRelease(sslHandler.engine());
270 }
271 PlatformDependent.throwException(cause);
272 }
273 }
274
275 private static final class Selection {
276 final SslContext context;
277 final String hostname;
278
279 Selection(SslContext context, String hostname) {
280 this.context = context;
281 this.hostname = hostname;
282 }
283 }
284 }