1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelOutboundHandler;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.ByteToMessageDecoder;
24 import io.netty.handler.codec.DecoderException;
25 import io.netty.handler.codec.TooLongFrameException;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.FutureListener;
28 import io.netty.util.internal.ObjectUtil;
29 import io.netty.util.internal.PlatformDependent;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.net.SocketAddress;
34 import java.util.List;
35
36
37
38
39 public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40
41
42
43
44
45 public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46
47
48
49 static final int DEFAULT_MAX_CLIENT_HELLO_LENGTH = 64 * 1024;
50
51 private static final InternalLogger logger =
52 InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
53
54 private final int maxClientHelloLength;
55 private boolean handshakeFailed;
56 private boolean suppressRead;
57 private boolean readPending;
58 private ByteBuf handshakeBuffer;
59
60 public SslClientHelloHandler() {
61 this(DEFAULT_MAX_CLIENT_HELLO_LENGTH);
62 }
63
64 protected SslClientHelloHandler(int maxClientHelloLength) {
65
66
67 this.maxClientHelloLength =
68 ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
69 }
70
71 @Override
72 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
73 if (!suppressRead && !handshakeFailed) {
74 try {
75 int readerIndex = in.readerIndex();
76 int readableBytes = in.readableBytes();
77 int handshakeLength = -1;
78
79
80 while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
81 final int contentType = in.getUnsignedByte(readerIndex);
82 switch (contentType) {
83 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
84
85 case SslUtils.SSL_CONTENT_TYPE_ALERT:
86 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
87
88
89 if (len == SslUtils.NOT_ENCRYPTED) {
90 handshakeFailed = true;
91 NotSslRecordException e = new NotSslRecordException(
92 "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
93 in.skipBytes(in.readableBytes());
94 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
95 SslUtils.handleHandshakeFailure(ctx, e, true);
96 throw e;
97 }
98 if (len == SslUtils.NOT_ENOUGH_DATA) {
99
100 return;
101 }
102
103 select(ctx, null);
104 return;
105 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
106 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
107
108 if (majorVersion == 3) {
109 int packetLength = in.getUnsignedShort(readerIndex + 3) +
110 SslUtils.SSL_RECORD_HEADER_LENGTH;
111
112 if (readableBytes < packetLength) {
113
114 return;
115 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
116 select(ctx, null);
117 return;
118 }
119
120 final int endOffset = readerIndex + packetLength;
121
122
123 if (handshakeLength == -1) {
124 if (readerIndex + 4 > endOffset) {
125
126 return;
127 }
128
129 final int handshakeType = in.getUnsignedByte(readerIndex +
130 SslUtils.SSL_RECORD_HEADER_LENGTH);
131
132
133
134 if (handshakeType != 1) {
135 select(ctx, null);
136 return;
137 }
138
139
140
141 handshakeLength = in.getUnsignedMedium(readerIndex +
142 SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
143
144 if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
145 TooLongFrameException e = new TooLongFrameException(
146 "ClientHello length exceeds " + maxClientHelloLength +
147 ": " + handshakeLength);
148 in.skipBytes(in.readableBytes());
149 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
150 SslUtils.handleHandshakeFailure(ctx, e, true);
151 throw e;
152 }
153
154 readerIndex += 4;
155 packetLength -= 4;
156
157 if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
158
159
160 readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
161 select(ctx, in.retainedSlice(readerIndex, handshakeLength));
162 return;
163 } else {
164 if (handshakeBuffer == null) {
165 handshakeBuffer = ctx.alloc().buffer(handshakeLength);
166 } else {
167
168 handshakeBuffer.clear();
169 }
170 }
171 }
172
173
174 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
175 packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
176 readerIndex += packetLength;
177 readableBytes -= packetLength;
178 if (handshakeLength <= handshakeBuffer.readableBytes()) {
179 ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
180 handshakeBuffer = null;
181
182 select(ctx, clientHello);
183 return;
184 }
185 break;
186 }
187
188 default:
189
190 select(ctx, null);
191 return;
192 }
193 }
194 } catch (NotSslRecordException e) {
195
196 throw e;
197 } catch (TooLongFrameException e) {
198
199 throw e;
200 } catch (Exception e) {
201
202 if (logger.isDebugEnabled()) {
203 logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
204 }
205 select(ctx, null);
206 }
207 }
208 }
209
210 private void releaseHandshakeBuffer() {
211 releaseIfNotNull(handshakeBuffer);
212 handshakeBuffer = null;
213 }
214
215 private static void releaseIfNotNull(ByteBuf buffer) {
216 if (buffer != null) {
217 buffer.release();
218 }
219 }
220
221 private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
222 final Future<T> future;
223 try {
224 future = lookup(ctx, clientHello);
225 if (future.isDone()) {
226 try {
227 onLookupComplete(ctx, future);
228 } catch (DecoderException err) {
229 ctx.fireExceptionCaught(err);
230 } catch (Exception cause) {
231 ctx.fireExceptionCaught(new DecoderException(cause));
232 } catch (Throwable cause) {
233 ctx.fireExceptionCaught(cause);
234 }
235 } else {
236 suppressRead = true;
237 final ByteBuf finalClientHello = clientHello;
238 future.addListener(new FutureListener<T>() {
239 @Override
240 public void operationComplete(Future<T> future) {
241 releaseIfNotNull(finalClientHello);
242 try {
243 suppressRead = false;
244 try {
245 onLookupComplete(ctx, future);
246 } catch (DecoderException err) {
247 ctx.fireExceptionCaught(err);
248 } catch (Exception cause) {
249 ctx.fireExceptionCaught(new DecoderException(cause));
250 } catch (Throwable cause) {
251 ctx.fireExceptionCaught(cause);
252 }
253 } finally {
254 if (readPending) {
255 readPending = false;
256 ctx.read();
257 }
258 }
259 }
260 });
261
262
263 clientHello = null;
264 }
265 } catch (Throwable cause) {
266 PlatformDependent.throwException(cause);
267 } finally {
268 releaseIfNotNull(clientHello);
269 }
270 }
271
272 @Override
273 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
274 releaseHandshakeBuffer();
275
276 super.handlerRemoved0(ctx);
277 }
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303 protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
304
305
306
307
308
309
310 protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
311
312 @Override
313 public void read(ChannelHandlerContext ctx) throws Exception {
314 if (suppressRead) {
315 readPending = true;
316 } else {
317 ctx.read();
318 }
319 }
320
321 @Override
322 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
323 ctx.bind(localAddress, promise);
324 }
325
326 @Override
327 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
328 ChannelPromise promise) throws Exception {
329 ctx.connect(remoteAddress, localAddress, promise);
330 }
331
332 @Override
333 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
334 ctx.disconnect(promise);
335 }
336
337 @Override
338 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
339 ctx.close(promise);
340 }
341
342 @Override
343 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
344 ctx.deregister(promise);
345 }
346
347 @Override
348 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
349 ctx.write(msg, promise);
350 }
351
352 @Override
353 public void flush(ChannelHandlerContext ctx) throws Exception {
354 ctx.flush();
355 }
356 }