1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelOutboundHandler;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.ByteToMessageDecoder;
24 import io.netty.handler.codec.DecoderException;
25 import io.netty.handler.codec.TooLongFrameException;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.FutureListener;
28 import io.netty.util.internal.ObjectUtil;
29 import io.netty.util.internal.PlatformDependent;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.net.SocketAddress;
34 import java.util.List;
35
36
37
38
39 public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40
41
42
43
44
45 public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46
47 private static final InternalLogger logger =
48 InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49
50 private final int maxClientHelloLength;
51 private boolean handshakeFailed;
52 private boolean suppressRead;
53 private boolean readPending;
54 private ByteBuf handshakeBuffer;
55
56 public SslClientHelloHandler() {
57 this(MAX_CLIENT_HELLO_LENGTH);
58 }
59
60 protected SslClientHelloHandler(int maxClientHelloLength) {
61
62
63 this.maxClientHelloLength =
64 ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65 }
66
67 @Override
68 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69 if (!suppressRead && !handshakeFailed) {
70 try {
71 int readerIndex = in.readerIndex();
72 int readableBytes = in.readableBytes();
73 int handshakeLength = -1;
74
75
76 while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77 final int contentType = in.getUnsignedByte(readerIndex);
78 switch (contentType) {
79 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80
81 case SslUtils.SSL_CONTENT_TYPE_ALERT:
82 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
83
84
85 if (len == SslUtils.NOT_ENCRYPTED) {
86 handshakeFailed = true;
87 NotSslRecordException e = new NotSslRecordException(
88 "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89 in.skipBytes(in.readableBytes());
90 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91 SslUtils.handleHandshakeFailure(ctx, e, true);
92 throw e;
93 }
94 if (len == SslUtils.NOT_ENOUGH_DATA) {
95
96 return;
97 }
98
99 select(ctx, null);
100 return;
101 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103
104 if (majorVersion == 3) {
105 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106 SslUtils.SSL_RECORD_HEADER_LENGTH;
107
108 if (readableBytes < packetLength) {
109
110 return;
111 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112 select(ctx, null);
113 return;
114 }
115
116 final int endOffset = readerIndex + packetLength;
117
118
119 if (handshakeLength == -1) {
120 if (readerIndex + 4 > endOffset) {
121
122 return;
123 }
124
125 final int handshakeType = in.getUnsignedByte(readerIndex +
126 SslUtils.SSL_RECORD_HEADER_LENGTH);
127
128
129
130 if (handshakeType != 1) {
131 select(ctx, null);
132 return;
133 }
134
135
136
137 handshakeLength = in.getUnsignedMedium(readerIndex +
138 SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139
140 if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141 TooLongFrameException e = new TooLongFrameException(
142 "ClientHello length exceeds " + maxClientHelloLength +
143 ": " + handshakeLength);
144 in.skipBytes(in.readableBytes());
145 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146 SslUtils.handleHandshakeFailure(ctx, e, true);
147 throw e;
148 }
149
150 readerIndex += 4;
151 packetLength -= 4;
152
153 if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154
155
156 readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157 select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158 return;
159 } else {
160 if (handshakeBuffer == null) {
161 handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162 } else {
163
164 handshakeBuffer.clear();
165 }
166 }
167 }
168
169
170 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171 packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172 readerIndex += packetLength;
173 readableBytes -= packetLength;
174 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175 ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176 handshakeBuffer = null;
177
178 select(ctx, clientHello);
179 return;
180 }
181 break;
182 }
183
184 default:
185
186 select(ctx, null);
187 return;
188 }
189 }
190 } catch (NotSslRecordException e) {
191
192 throw e;
193 } catch (TooLongFrameException e) {
194
195 throw e;
196 } catch (Exception e) {
197
198 if (logger.isDebugEnabled()) {
199 logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200 }
201 select(ctx, null);
202 }
203 }
204 }
205
206 private void releaseHandshakeBuffer() {
207 releaseIfNotNull(handshakeBuffer);
208 handshakeBuffer = null;
209 }
210
211 private static void releaseIfNotNull(ByteBuf buffer) {
212 if (buffer != null) {
213 buffer.release();
214 }
215 }
216
217 private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218 final Future<T> future;
219 try {
220 future = lookup(ctx, clientHello);
221 if (future.isDone()) {
222 try {
223 onLookupComplete(ctx, future);
224 } catch (DecoderException err) {
225 ctx.fireExceptionCaught(err);
226 } catch (Exception cause) {
227 ctx.fireExceptionCaught(new DecoderException(cause));
228 } catch (Throwable cause) {
229 ctx.fireExceptionCaught(cause);
230 }
231 } else {
232 suppressRead = true;
233 final ByteBuf finalClientHello = clientHello;
234 future.addListener(new FutureListener<T>() {
235 @Override
236 public void operationComplete(Future<T> future) {
237 releaseIfNotNull(finalClientHello);
238 try {
239 suppressRead = false;
240 try {
241 onLookupComplete(ctx, future);
242 } catch (DecoderException err) {
243 ctx.fireExceptionCaught(err);
244 } catch (Exception cause) {
245 ctx.fireExceptionCaught(new DecoderException(cause));
246 } catch (Throwable cause) {
247 ctx.fireExceptionCaught(cause);
248 }
249 } finally {
250 if (readPending) {
251 readPending = false;
252 ctx.read();
253 }
254 }
255 }
256 });
257
258
259 clientHello = null;
260 }
261 } catch (Throwable cause) {
262 PlatformDependent.throwException(cause);
263 } finally {
264 releaseIfNotNull(clientHello);
265 }
266 }
267
268 @Override
269 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
270 releaseHandshakeBuffer();
271
272 super.handlerRemoved0(ctx);
273 }
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299 protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
300
301
302
303
304
305
306 protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
307
308 @Override
309 public void read(ChannelHandlerContext ctx) throws Exception {
310 if (suppressRead) {
311 readPending = true;
312 } else {
313 ctx.read();
314 }
315 }
316
317 @Override
318 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
319 ctx.bind(localAddress, promise);
320 }
321
322 @Override
323 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
324 ChannelPromise promise) throws Exception {
325 ctx.connect(remoteAddress, localAddress, promise);
326 }
327
328 @Override
329 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
330 ctx.disconnect(promise);
331 }
332
333 @Override
334 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
335 ctx.close(promise);
336 }
337
338 @Override
339 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
340 ctx.deregister(promise);
341 }
342
343 @Override
344 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
345 ctx.write(msg, promise);
346 }
347
348 @Override
349 public void flush(ChannelHandlerContext ctx) throws Exception {
350 ctx.flush();
351 }
352 }