1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelOutboundHandler;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.ByteToMessageDecoder;
24 import io.netty.handler.codec.DecoderException;
25 import io.netty.handler.codec.TooLongFrameException;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.FutureListener;
28 import io.netty.util.internal.ObjectUtil;
29 import io.netty.util.internal.PlatformDependent;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.net.SocketAddress;
34 import java.util.List;
35
36
37
38
39 public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40
41
42
43
44
45 public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46
47 private static final InternalLogger logger =
48 InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49
50 private final int maxClientHelloLength;
51 private boolean handshakeFailed;
52 private boolean suppressRead;
53 private boolean readPending;
54 private ByteBuf handshakeBuffer;
55
56 public SslClientHelloHandler() {
57 this(MAX_CLIENT_HELLO_LENGTH);
58 }
59
60 protected SslClientHelloHandler(int maxClientHelloLength) {
61
62
63 this.maxClientHelloLength =
64 ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65 }
66
67 @Override
68 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69 if (!suppressRead && !handshakeFailed) {
70 try {
71 int readerIndex = in.readerIndex();
72 int readableBytes = in.readableBytes();
73 int handshakeLength = -1;
74
75
76 while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77 final int contentType = in.getUnsignedByte(readerIndex);
78 switch (contentType) {
79 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80
81 case SslUtils.SSL_CONTENT_TYPE_ALERT:
82 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
83
84
85 if (len == SslUtils.NOT_ENCRYPTED) {
86 handshakeFailed = true;
87 NotSslRecordException e = new NotSslRecordException(
88 "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89 in.skipBytes(in.readableBytes());
90 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91 SslUtils.handleHandshakeFailure(ctx, e, true);
92 throw e;
93 }
94 if (len == SslUtils.NOT_ENOUGH_DATA) {
95
96 return;
97 }
98
99 select(ctx, null);
100 return;
101 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103
104 if (majorVersion == 3) {
105 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106 SslUtils.SSL_RECORD_HEADER_LENGTH;
107
108 if (readableBytes < packetLength) {
109
110 return;
111 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112 select(ctx, null);
113 return;
114 }
115
116 final int endOffset = readerIndex + packetLength;
117
118
119 if (handshakeLength == -1) {
120 if (readerIndex + 4 > endOffset) {
121
122 return;
123 }
124
125 final int handshakeType = in.getUnsignedByte(readerIndex +
126 SslUtils.SSL_RECORD_HEADER_LENGTH);
127
128
129
130 if (handshakeType != 1) {
131 select(ctx, null);
132 return;
133 }
134
135
136
137 handshakeLength = in.getUnsignedMedium(readerIndex +
138 SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139
140 if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141 TooLongFrameException e = new TooLongFrameException(
142 "ClientHello length exceeds " + maxClientHelloLength +
143 ": " + handshakeLength);
144 in.skipBytes(in.readableBytes());
145 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146 SslUtils.handleHandshakeFailure(ctx, e, true);
147 throw e;
148 }
149
150 readerIndex += 4;
151 packetLength -= 4;
152
153 if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154
155
156 readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157 select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158 return;
159 } else {
160 if (handshakeBuffer == null) {
161 handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162 } else {
163
164 handshakeBuffer.clear();
165 }
166 }
167 }
168
169
170 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171 packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172 readerIndex += packetLength;
173 readableBytes -= packetLength;
174 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175 ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176 handshakeBuffer = null;
177
178 select(ctx, clientHello);
179 return;
180 }
181 break;
182 }
183
184 default:
185
186 select(ctx, null);
187 return;
188 }
189 }
190 } catch (NotSslRecordException e) {
191
192 throw e;
193 } catch (TooLongFrameException e) {
194
195 throw e;
196 } catch (Exception e) {
197
198 if (logger.isDebugEnabled()) {
199 logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200 }
201 select(ctx, null);
202 }
203 }
204 }
205
206 private void releaseHandshakeBuffer() {
207 releaseIfNotNull(handshakeBuffer);
208 handshakeBuffer = null;
209 }
210
211 private static void releaseIfNotNull(ByteBuf buffer) {
212 if (buffer != null) {
213 buffer.release();
214 }
215 }
216
217 private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218 final Future<T> future;
219 try {
220 future = lookup(ctx, clientHello);
221 if (future.isDone()) {
222 onLookupComplete(ctx, future);
223 } else {
224 suppressRead = true;
225 final ByteBuf finalClientHello = clientHello;
226 future.addListener(new FutureListener<T>() {
227 @Override
228 public void operationComplete(Future<T> future) {
229 releaseIfNotNull(finalClientHello);
230 try {
231 suppressRead = false;
232 try {
233 onLookupComplete(ctx, future);
234 } catch (DecoderException err) {
235 ctx.fireExceptionCaught(err);
236 } catch (Exception cause) {
237 ctx.fireExceptionCaught(new DecoderException(cause));
238 } catch (Throwable cause) {
239 ctx.fireExceptionCaught(cause);
240 }
241 } finally {
242 if (readPending) {
243 readPending = false;
244 ctx.read();
245 }
246 }
247 }
248 });
249
250
251 clientHello = null;
252 }
253 } catch (Throwable cause) {
254 PlatformDependent.throwException(cause);
255 } finally {
256 releaseIfNotNull(clientHello);
257 }
258 }
259
260 @Override
261 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
262 releaseHandshakeBuffer();
263
264 super.handlerRemoved0(ctx);
265 }
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291 protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
292
293
294
295
296
297
298 protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
299
300 @Override
301 public void read(ChannelHandlerContext ctx) throws Exception {
302 if (suppressRead) {
303 readPending = true;
304 } else {
305 ctx.read();
306 }
307 }
308
309 @Override
310 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
311 ctx.bind(localAddress, promise);
312 }
313
314 @Override
315 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
316 ChannelPromise promise) throws Exception {
317 ctx.connect(remoteAddress, localAddress, promise);
318 }
319
320 @Override
321 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
322 ctx.disconnect(promise);
323 }
324
325 @Override
326 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
327 ctx.close(promise);
328 }
329
330 @Override
331 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
332 ctx.deregister(promise);
333 }
334
335 @Override
336 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
337 ctx.write(msg, promise);
338 }
339
340 @Override
341 public void flush(ChannelHandlerContext ctx) throws Exception {
342 ctx.flush();
343 }
344 }