View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufUtil;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelOutboundHandler;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.ByteToMessageDecoder;
24  import io.netty.handler.codec.DecoderException;
25  import io.netty.handler.codec.TooLongFrameException;
26  import io.netty.util.concurrent.Future;
27  import io.netty.util.concurrent.FutureListener;
28  import io.netty.util.internal.ObjectUtil;
29  import io.netty.util.internal.PlatformDependent;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.net.SocketAddress;
34  import java.util.List;
35  
36  /**
37   * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
38   */
39  public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40  
41      /**
42       * The maximum length of client hello message as defined by
43       * <a href="https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1">RFC5246</a>.
44       */
45      public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46  
47      private static final InternalLogger logger =
48              InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49  
50      private final int maxClientHelloLength;
51      private boolean handshakeFailed;
52      private boolean suppressRead;
53      private boolean readPending;
54      private ByteBuf handshakeBuffer;
55  
56      public SslClientHelloHandler() {
57          this(MAX_CLIENT_HELLO_LENGTH);
58      }
59  
60      protected SslClientHelloHandler(int maxClientHelloLength) {
61          // 16MB is the maximum as per RFC:
62          // See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1
63          this.maxClientHelloLength =
64                  ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65      }
66  
67      @Override
68      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69          if (!suppressRead && !handshakeFailed) {
70              try {
71                  int readerIndex = in.readerIndex();
72                  int readableBytes = in.readableBytes();
73                  int handshakeLength = -1;
74  
75                  // Check if we have enough data to determine the record type and length.
76                  while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77                      final int contentType = in.getUnsignedByte(readerIndex);
78                      switch (contentType) {
79                          case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80                              // fall-through
81                          case SslUtils.SSL_CONTENT_TYPE_ALERT:
82                              final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
83  
84                              // Not an SSL/TLS packet
85                              if (len == SslUtils.NOT_ENCRYPTED) {
86                                  handshakeFailed = true;
87                                  NotSslRecordException e = new NotSslRecordException(
88                                          "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89                                  in.skipBytes(in.readableBytes());
90                                  ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91                                  SslUtils.handleHandshakeFailure(ctx, e, true);
92                                  throw e;
93                              }
94                              if (len == SslUtils.NOT_ENOUGH_DATA) {
95                                  // Not enough data
96                                  return;
97                              }
98                              // No ClientHello
99                              select(ctx, null);
100                             return;
101                         case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102                             final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103                             // SSLv3 or TLS
104                             if (majorVersion == 3) {
105                                 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106                                         SslUtils.SSL_RECORD_HEADER_LENGTH;
107 
108                                 if (readableBytes < packetLength) {
109                                     // client hello incomplete; try again to decode once more data is ready.
110                                     return;
111                                 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112                                     select(ctx, null);
113                                     return;
114                                 }
115 
116                                 final int endOffset = readerIndex + packetLength;
117 
118                                 // Let's check if we already parsed the handshake length or not.
119                                 if (handshakeLength == -1) {
120                                     if (readerIndex + 4 > endOffset) {
121                                         // Need more data to read HandshakeType and handshakeLength (4 bytes)
122                                         return;
123                                     }
124 
125                                     final int handshakeType = in.getUnsignedByte(readerIndex +
126                                             SslUtils.SSL_RECORD_HEADER_LENGTH);
127 
128                                     // Check if this is a clientHello(1)
129                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
130                                     if (handshakeType != 1) {
131                                         select(ctx, null);
132                                         return;
133                                     }
134 
135                                     // Read the length of the handshake as it may arrive in fragments
136                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
137                                     handshakeLength = in.getUnsignedMedium(readerIndex +
138                                             SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139 
140                                     if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141                                         TooLongFrameException e = new TooLongFrameException(
142                                                 "ClientHello length exceeds " + maxClientHelloLength +
143                                                         ": " + handshakeLength);
144                                         in.skipBytes(in.readableBytes());
145                                         ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146                                         SslUtils.handleHandshakeFailure(ctx, e, true);
147                                         throw e;
148                                     }
149                                     // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
150                                     readerIndex += 4;
151                                     packetLength -= 4;
152 
153                                     if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154                                         // We have everything we need in one packet.
155                                         // Skip the record header
156                                         readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157                                         select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158                                         return;
159                                     } else {
160                                         if (handshakeBuffer == null) {
161                                             handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162                                         } else {
163                                             // Clear the buffer so we can aggregate into it again.
164                                             handshakeBuffer.clear();
165                                         }
166                                     }
167                                 }
168 
169                                 // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
170                                 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171                                         packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172                                 readerIndex += packetLength;
173                                 readableBytes -= packetLength;
174                                 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175                                     ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176                                     handshakeBuffer = null;
177 
178                                     select(ctx, clientHello);
179                                     return;
180                                 }
181                                 break;
182                             }
183                             // fall-through
184                         default:
185                             // not tls, ssl or application data
186                             select(ctx, null);
187                             return;
188                     }
189                 }
190             } catch (NotSslRecordException e) {
191                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
192                 throw e;
193             } catch (TooLongFrameException e) {
194                 // Just rethrow as in this case we also closed the channel
195                 throw e;
196             } catch (Exception e) {
197                 // unexpected encoding, ignore sni and use default
198                 if (logger.isDebugEnabled()) {
199                     logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200                 }
201                 select(ctx, null);
202             }
203         }
204     }
205 
206     private void releaseHandshakeBuffer() {
207         releaseIfNotNull(handshakeBuffer);
208         handshakeBuffer = null;
209     }
210 
211     private static void releaseIfNotNull(ByteBuf buffer) {
212         if (buffer != null) {
213             buffer.release();
214         }
215     }
216 
217     private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218         final Future<T> future;
219         try {
220             future = lookup(ctx, clientHello);
221             if (future.isDone()) {
222                 onLookupComplete(ctx, future);
223             } else {
224                 suppressRead = true;
225                 final ByteBuf finalClientHello = clientHello;
226                 future.addListener(new FutureListener<T>() {
227                     @Override
228                     public void operationComplete(Future<T> future) {
229                         releaseIfNotNull(finalClientHello);
230                         try {
231                             suppressRead = false;
232                             try {
233                                 onLookupComplete(ctx, future);
234                             } catch (DecoderException err) {
235                                 ctx.fireExceptionCaught(err);
236                             } catch (Exception cause) {
237                                 ctx.fireExceptionCaught(new DecoderException(cause));
238                             } catch (Throwable cause) {
239                                 ctx.fireExceptionCaught(cause);
240                             }
241                         } finally {
242                             if (readPending) {
243                                 readPending = false;
244                                 ctx.read();
245                             }
246                         }
247                     }
248                 });
249 
250                 // Ownership was transferred to the FutureListener.
251                 clientHello = null;
252             }
253         } catch (Throwable cause) {
254             PlatformDependent.throwException(cause);
255         } finally {
256             releaseIfNotNull(clientHello);
257         }
258     }
259 
260     @Override
261     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
262         releaseHandshakeBuffer();
263 
264         super.handlerRemoved0(ctx);
265     }
266 
267     /**
268      * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
269      * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
270      *
271      * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
272      *
273      * <pre>
274      * struct {
275      *    ProtocolVersion client_version;
276      *    Random random;
277      *    SessionID session_id;
278      *    CipherSuite cipher_suites<2..2^16-2>;
279      *    CompressionMethod compression_methods<1..2^8-1>;
280      *    select (extensions_present) {
281      *        case false:
282      *            struct {};
283      *        case true:
284      *            Extension extensions<0..2^16-1>;
285      *    };
286      * } ClientHello;
287      * </pre>
288      *
289      * @see #onLookupComplete(ChannelHandlerContext, Future)
290      */
291     protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
292 
293     /**
294      * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}.
295      *
296      * @see #lookup(ChannelHandlerContext, ByteBuf)
297      */
298     protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
299 
300     @Override
301     public void read(ChannelHandlerContext ctx) throws Exception {
302         if (suppressRead) {
303             readPending = true;
304         } else {
305             ctx.read();
306         }
307     }
308 
309     @Override
310     public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
311         ctx.bind(localAddress, promise);
312     }
313 
314     @Override
315     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
316                         ChannelPromise promise) throws Exception {
317         ctx.connect(remoteAddress, localAddress, promise);
318     }
319 
320     @Override
321     public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
322         ctx.disconnect(promise);
323     }
324 
325     @Override
326     public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
327         ctx.close(promise);
328     }
329 
330     @Override
331     public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
332         ctx.deregister(promise);
333     }
334 
335     @Override
336     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
337         ctx.write(msg, promise);
338     }
339 
340     @Override
341     public void flush(ChannelHandlerContext ctx) throws Exception {
342         ctx.flush();
343     }
344 }