View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufUtil;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelOutboundHandler;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.ByteToMessageDecoder;
24  import io.netty.handler.codec.DecoderException;
25  import io.netty.handler.codec.TooLongFrameException;
26  import io.netty.util.concurrent.Future;
27  import io.netty.util.concurrent.FutureListener;
28  import io.netty.util.internal.ObjectUtil;
29  import io.netty.util.internal.PlatformDependent;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.net.SocketAddress;
34  import java.util.List;
35  
36  /**
37   * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
38   */
39  public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40  
41      /**
42       * The maximum length of client hello message as defined by
43       * <a href="https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1">RFC5246</a>.
44       */
45      public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46  
47      // Let's use a default limit of 64kb which should be big enough for almost everything in practice but still
48      // small enough to not allocate to much memory.
49      static final int DEFAULT_MAX_CLIENT_HELLO_LENGTH = 64 * 1024;
50  
51      private static final InternalLogger logger =
52              InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
53  
54      private final int maxClientHelloLength;
55      private boolean handshakeFailed;
56      private boolean suppressRead;
57      private boolean readPending;
58      private ByteBuf handshakeBuffer;
59  
60      public SslClientHelloHandler() {
61          this(DEFAULT_MAX_CLIENT_HELLO_LENGTH);
62      }
63  
64      protected SslClientHelloHandler(int maxClientHelloLength) {
65          // 16MB is the maximum as per RFC:
66          // See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1
67          this.maxClientHelloLength =
68                  ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
69      }
70  
71      @Override
72      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
73          if (!suppressRead && !handshakeFailed) {
74              try {
75                  int readerIndex = in.readerIndex();
76                  int readableBytes = in.readableBytes();
77                  int handshakeLength = -1;
78  
79                  // Check if we have enough data to determine the record type and length.
80                  while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
81                      final int contentType = in.getUnsignedByte(readerIndex);
82                      switch (contentType) {
83                          case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
84                              // fall-through
85                          case SslUtils.SSL_CONTENT_TYPE_ALERT:
86                              final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
87  
88                              // Not an SSL/TLS packet
89                              if (len == SslUtils.NOT_ENCRYPTED) {
90                                  handshakeFailed = true;
91                                  NotSslRecordException e = new NotSslRecordException(
92                                          "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
93                                  in.skipBytes(in.readableBytes());
94                                  ctx.fireUserEventTriggered(new SniCompletionEvent(e));
95                                  SslUtils.handleHandshakeFailure(ctx, e, true);
96                                  throw e;
97                              }
98                              if (len == SslUtils.NOT_ENOUGH_DATA) {
99                                  // Not enough data
100                                 return;
101                             }
102                             // No ClientHello
103                             select(ctx, null);
104                             return;
105                         case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
106                             final int majorVersion = in.getUnsignedByte(readerIndex + 1);
107                             // SSLv3 or TLS
108                             if (majorVersion == 3) {
109                                 int packetLength = in.getUnsignedShort(readerIndex + 3) +
110                                         SslUtils.SSL_RECORD_HEADER_LENGTH;
111 
112                                 if (readableBytes < packetLength) {
113                                     // client hello incomplete; try again to decode once more data is ready.
114                                     return;
115                                 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
116                                     select(ctx, null);
117                                     return;
118                                 }
119 
120                                 final int endOffset = readerIndex + packetLength;
121 
122                                 // Let's check if we already parsed the handshake length or not.
123                                 if (handshakeLength == -1) {
124                                     if (readerIndex + 4 > endOffset) {
125                                         // Need more data to read HandshakeType and handshakeLength (4 bytes)
126                                         return;
127                                     }
128 
129                                     final int handshakeType = in.getUnsignedByte(readerIndex +
130                                             SslUtils.SSL_RECORD_HEADER_LENGTH);
131 
132                                     // Check if this is a clientHello(1)
133                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
134                                     if (handshakeType != 1) {
135                                         select(ctx, null);
136                                         return;
137                                     }
138 
139                                     // Read the length of the handshake as it may arrive in fragments
140                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
141                                     handshakeLength = in.getUnsignedMedium(readerIndex +
142                                             SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
143 
144                                     if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
145                                         TooLongFrameException e = new TooLongFrameException(
146                                                 "ClientHello length exceeds " + maxClientHelloLength +
147                                                         ": " + handshakeLength);
148                                         in.skipBytes(in.readableBytes());
149                                         ctx.fireUserEventTriggered(new SniCompletionEvent(e));
150                                         SslUtils.handleHandshakeFailure(ctx, e, true);
151                                         throw e;
152                                     }
153                                     // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
154                                     readerIndex += 4;
155                                     packetLength -= 4;
156 
157                                     if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
158                                         // We have everything we need in one packet.
159                                         // Skip the record header
160                                         readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
161                                         select(ctx, in.retainedSlice(readerIndex, handshakeLength));
162                                         return;
163                                     } else {
164                                         if (handshakeBuffer == null) {
165                                             handshakeBuffer = ctx.alloc().buffer(handshakeLength);
166                                         } else {
167                                             // Clear the buffer so we can aggregate into it again.
168                                             handshakeBuffer.clear();
169                                         }
170                                     }
171                                 }
172 
173                                 // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
174                                 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
175                                         packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
176                                 readerIndex += packetLength;
177                                 readableBytes -= packetLength;
178                                 if (handshakeLength <= handshakeBuffer.readableBytes()) {
179                                     ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
180                                     handshakeBuffer = null;
181 
182                                     select(ctx, clientHello);
183                                     return;
184                                 }
185                                 break;
186                             }
187                             // fall-through
188                         default:
189                             // not tls, ssl or application data
190                             select(ctx, null);
191                             return;
192                     }
193                 }
194             } catch (NotSslRecordException e) {
195                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
196                 throw e;
197             } catch (TooLongFrameException e) {
198                 // Just rethrow as in this case we also closed the channel
199                 throw e;
200             } catch (Exception e) {
201                 // unexpected encoding, ignore sni and use default
202                 if (logger.isDebugEnabled()) {
203                     logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
204                 }
205                 select(ctx, null);
206             }
207         }
208     }
209 
210     private void releaseHandshakeBuffer() {
211         releaseIfNotNull(handshakeBuffer);
212         handshakeBuffer = null;
213     }
214 
215     private static void releaseIfNotNull(ByteBuf buffer) {
216         if (buffer != null) {
217             buffer.release();
218         }
219     }
220 
221     private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
222         final Future<T> future;
223         try {
224             future = lookup(ctx, clientHello);
225             if (future.isDone()) {
226                 try {
227                     onLookupComplete(ctx, future);
228                 } catch (DecoderException err) {
229                     ctx.fireExceptionCaught(err);
230                 } catch (Exception cause) {
231                     ctx.fireExceptionCaught(new DecoderException(cause));
232                 } catch (Throwable cause) {
233                     ctx.fireExceptionCaught(cause);
234                 }
235             } else {
236                 suppressRead = true;
237                 final ByteBuf finalClientHello = clientHello;
238                 future.addListener(new FutureListener<T>() {
239                     @Override
240                     public void operationComplete(Future<T> future) {
241                         releaseIfNotNull(finalClientHello);
242                         try {
243                             suppressRead = false;
244                             try {
245                                 onLookupComplete(ctx, future);
246                             } catch (DecoderException err) {
247                                 ctx.fireExceptionCaught(err);
248                             } catch (Exception cause) {
249                                 ctx.fireExceptionCaught(new DecoderException(cause));
250                             } catch (Throwable cause) {
251                                 ctx.fireExceptionCaught(cause);
252                             }
253                         } finally {
254                             if (readPending) {
255                                 readPending = false;
256                                 ctx.read();
257                             }
258                         }
259                     }
260                 });
261 
262                 // Ownership was transferred to the FutureListener.
263                 clientHello = null;
264             }
265         } catch (Throwable cause) {
266             PlatformDependent.throwException(cause);
267         } finally {
268             releaseIfNotNull(clientHello);
269         }
270     }
271 
272     @Override
273     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
274         releaseHandshakeBuffer();
275 
276         super.handlerRemoved0(ctx);
277     }
278 
279     /**
280      * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
281      * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
282      *
283      * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
284      *
285      * <pre>
286      * struct {
287      *    ProtocolVersion client_version;
288      *    Random random;
289      *    SessionID session_id;
290      *    CipherSuite cipher_suites<2..2^16-2>;
291      *    CompressionMethod compression_methods<1..2^8-1>;
292      *    select (extensions_present) {
293      *        case false:
294      *            struct {};
295      *        case true:
296      *            Extension extensions<0..2^16-1>;
297      *    };
298      * } ClientHello;
299      * </pre>
300      *
301      * @see #onLookupComplete(ChannelHandlerContext, Future)
302      */
303     protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
304 
305     /**
306      * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}.
307      *
308      * @see #lookup(ChannelHandlerContext, ByteBuf)
309      */
310     protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
311 
312     @Override
313     public void read(ChannelHandlerContext ctx) throws Exception {
314         if (suppressRead) {
315             readPending = true;
316         } else {
317             ctx.read();
318         }
319     }
320 
321     @Override
322     public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
323         ctx.bind(localAddress, promise);
324     }
325 
326     @Override
327     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
328                         ChannelPromise promise) throws Exception {
329         ctx.connect(remoteAddress, localAddress, promise);
330     }
331 
332     @Override
333     public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
334         ctx.disconnect(promise);
335     }
336 
337     @Override
338     public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
339         ctx.close(promise);
340     }
341 
342     @Override
343     public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
344         ctx.deregister(promise);
345     }
346 
347     @Override
348     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
349         ctx.write(msg, promise);
350     }
351 
352     @Override
353     public void flush(ChannelHandlerContext ctx) throws Exception {
354         ctx.flush();
355     }
356 }