1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufUtil;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelOutboundHandler;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.ByteToMessageDecoder;
24 import io.netty.handler.codec.DecoderException;
25 import io.netty.handler.codec.TooLongFrameException;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.FutureListener;
28 import io.netty.util.internal.ObjectUtil;
29 import io.netty.util.internal.PlatformDependent;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.net.SocketAddress;
34 import java.util.List;
35
36
37
38
39 public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40
41
42
43
44
45 public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46
47 private static final InternalLogger logger =
48 InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49
50 private final int maxClientHelloLength;
51 private boolean handshakeFailed;
52 private boolean suppressRead;
53 private boolean readPending;
54 private ByteBuf handshakeBuffer;
55
56 public SslClientHelloHandler() {
57 this(MAX_CLIENT_HELLO_LENGTH);
58 }
59
60 protected SslClientHelloHandler(int maxClientHelloLength) {
61
62
63 this.maxClientHelloLength =
64 ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65 }
66
67 @Override
68 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69 if (!suppressRead && !handshakeFailed) {
70 try {
71 int readerIndex = in.readerIndex();
72 int readableBytes = in.readableBytes();
73 int handshakeLength = -1;
74
75
76 while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77 final int contentType = in.getUnsignedByte(readerIndex);
78 switch (contentType) {
79 case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80
81 case SslUtils.SSL_CONTENT_TYPE_ALERT:
82 final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
83
84
85 if (len == SslUtils.NOT_ENCRYPTED) {
86 handshakeFailed = true;
87 NotSslRecordException e = new NotSslRecordException(
88 "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89 in.skipBytes(in.readableBytes());
90 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91 SslUtils.handleHandshakeFailure(ctx, e, true);
92 throw e;
93 }
94 if (len == SslUtils.NOT_ENOUGH_DATA) {
95
96 return;
97 }
98
99 select(ctx, null);
100 return;
101 case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102 final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103
104 if (majorVersion == 3) {
105 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106 SslUtils.SSL_RECORD_HEADER_LENGTH;
107
108 if (readableBytes < packetLength) {
109
110 return;
111 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112 select(ctx, null);
113 return;
114 }
115
116 final int endOffset = readerIndex + packetLength;
117
118
119 if (handshakeLength == -1) {
120 if (readerIndex + 4 > endOffset) {
121
122 return;
123 }
124
125 final int handshakeType = in.getUnsignedByte(readerIndex +
126 SslUtils.SSL_RECORD_HEADER_LENGTH);
127
128
129
130 if (handshakeType != 1) {
131 select(ctx, null);
132 return;
133 }
134
135
136
137 handshakeLength = in.getUnsignedMedium(readerIndex +
138 SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139
140 if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141 TooLongFrameException e = new TooLongFrameException(
142 "ClientHello length exceeds " + maxClientHelloLength +
143 ": " + handshakeLength);
144 in.skipBytes(in.readableBytes());
145 ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146 SslUtils.handleHandshakeFailure(ctx, e, true);
147 throw e;
148 }
149
150 readerIndex += 4;
151 packetLength -= 4;
152
153 if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154
155
156 readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157 select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158 return;
159 } else {
160 if (handshakeBuffer == null) {
161 handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162 } else {
163
164 handshakeBuffer.clear();
165 }
166 }
167 }
168
169
170 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171 packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172 readerIndex += packetLength;
173 readableBytes -= packetLength;
174 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175 ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176 handshakeBuffer = null;
177
178 select(ctx, clientHello);
179 return;
180 }
181 break;
182 }
183
184 default:
185
186 select(ctx, null);
187 return;
188 }
189 }
190 } catch (NotSslRecordException e) {
191
192 throw e;
193 } catch (TooLongFrameException e) {
194
195 throw e;
196 } catch (Exception e) {
197
198 if (logger.isDebugEnabled()) {
199 logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200 }
201 select(ctx, null);
202 }
203 }
204 }
205
206 private void releaseHandshakeBuffer() {
207 releaseIfNotNull(handshakeBuffer);
208 handshakeBuffer = null;
209 }
210
211 private static void releaseIfNotNull(ByteBuf buffer) {
212 if (buffer != null) {
213 buffer.release();
214 }
215 }
216
217 private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218 final Future<T> future;
219 try {
220 future = lookup(ctx, clientHello);
221 if (future.isDone()) {
222 onLookupComplete(ctx, future);
223 } else {
224 suppressRead = true;
225 final ByteBuf finalClientHello = clientHello;
226 future.addListener((FutureListener<T>) future1 -> {
227 releaseIfNotNull(finalClientHello);
228 try {
229 suppressRead = false;
230 try {
231 onLookupComplete(ctx, future1);
232 } catch (DecoderException err) {
233 ctx.fireExceptionCaught(err);
234 } catch (Exception cause) {
235 ctx.fireExceptionCaught(new DecoderException(cause));
236 } catch (Throwable cause) {
237 ctx.fireExceptionCaught(cause);
238 }
239 } finally {
240 if (readPending) {
241 readPending = false;
242 ctx.read();
243 }
244 }
245 });
246
247
248 clientHello = null;
249 }
250 } catch (Throwable cause) {
251 PlatformDependent.throwException(cause);
252 } finally {
253 releaseIfNotNull(clientHello);
254 }
255 }
256
257 @Override
258 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
259 releaseHandshakeBuffer();
260
261 super.handlerRemoved0(ctx);
262 }
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288 protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
289
290
291
292
293
294
295 protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
296
297 @Override
298 public void read(ChannelHandlerContext ctx) throws Exception {
299 if (suppressRead) {
300 readPending = true;
301 } else {
302 ctx.read();
303 }
304 }
305
306 @Override
307 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
308 ctx.bind(localAddress, promise);
309 }
310
311 @Override
312 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
313 ChannelPromise promise) throws Exception {
314 ctx.connect(remoteAddress, localAddress, promise);
315 }
316
317 @Override
318 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
319 ctx.disconnect(promise);
320 }
321
322 @Override
323 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
324 ctx.close(promise);
325 }
326
327 @Override
328 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
329 ctx.deregister(promise);
330 }
331
332 @Override
333 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
334 ctx.write(msg, promise);
335 }
336
337 @Override
338 public void flush(ChannelHandlerContext ctx) throws Exception {
339 ctx.flush();
340 }
341 }