View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufUtil;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelOutboundHandler;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.ByteToMessageDecoder;
24  import io.netty.handler.codec.DecoderException;
25  import io.netty.handler.codec.TooLongFrameException;
26  import io.netty.util.concurrent.Future;
27  import io.netty.util.concurrent.FutureListener;
28  import io.netty.util.internal.ObjectUtil;
29  import io.netty.util.internal.PlatformDependent;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.net.SocketAddress;
34  import java.util.List;
35  
36  /**
37   * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
38   */
39  public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40  
41      /**
42       * The maximum length of client hello message as defined by
43       * <a href="https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1">RFC5246</a>.
44       */
45      public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46  
47      private static final InternalLogger logger =
48              InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49  
50      private final int maxClientHelloLength;
51      private boolean handshakeFailed;
52      private boolean suppressRead;
53      private boolean readPending;
54      private ByteBuf handshakeBuffer;
55  
56      public SslClientHelloHandler() {
57          this(MAX_CLIENT_HELLO_LENGTH);
58      }
59  
60      protected SslClientHelloHandler(int maxClientHelloLength) {
61          // 16MB is the maximum as per RFC:
62          // See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1
63          this.maxClientHelloLength =
64                  ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65      }
66  
67      @Override
68      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69          if (!suppressRead && !handshakeFailed) {
70              try {
71                  int readerIndex = in.readerIndex();
72                  int readableBytes = in.readableBytes();
73                  int handshakeLength = -1;
74  
75                  // Check if we have enough data to determine the record type and length.
76                  while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77                      final int contentType = in.getUnsignedByte(readerIndex);
78                      switch (contentType) {
79                          case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80                              // fall-through
81                          case SslUtils.SSL_CONTENT_TYPE_ALERT:
82                              final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
83  
84                              // Not an SSL/TLS packet
85                              if (len == SslUtils.NOT_ENCRYPTED) {
86                                  handshakeFailed = true;
87                                  NotSslRecordException e = new NotSslRecordException(
88                                          "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89                                  in.skipBytes(in.readableBytes());
90                                  ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91                                  SslUtils.handleHandshakeFailure(ctx, e, true);
92                                  throw e;
93                              }
94                              if (len == SslUtils.NOT_ENOUGH_DATA) {
95                                  // Not enough data
96                                  return;
97                              }
98                              // No ClientHello
99                              select(ctx, null);
100                             return;
101                         case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102                             final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103                             // SSLv3 or TLS
104                             if (majorVersion == 3) {
105                                 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106                                         SslUtils.SSL_RECORD_HEADER_LENGTH;
107 
108                                 if (readableBytes < packetLength) {
109                                     // client hello incomplete; try again to decode once more data is ready.
110                                     return;
111                                 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112                                     select(ctx, null);
113                                     return;
114                                 }
115 
116                                 final int endOffset = readerIndex + packetLength;
117 
118                                 // Let's check if we already parsed the handshake length or not.
119                                 if (handshakeLength == -1) {
120                                     if (readerIndex + 4 > endOffset) {
121                                         // Need more data to read HandshakeType and handshakeLength (4 bytes)
122                                         return;
123                                     }
124 
125                                     final int handshakeType = in.getUnsignedByte(readerIndex +
126                                             SslUtils.SSL_RECORD_HEADER_LENGTH);
127 
128                                     // Check if this is a clientHello(1)
129                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
130                                     if (handshakeType != 1) {
131                                         select(ctx, null);
132                                         return;
133                                     }
134 
135                                     // Read the length of the handshake as it may arrive in fragments
136                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
137                                     handshakeLength = in.getUnsignedMedium(readerIndex +
138                                             SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139 
140                                     if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141                                         TooLongFrameException e = new TooLongFrameException(
142                                                 "ClientHello length exceeds " + maxClientHelloLength +
143                                                         ": " + handshakeLength);
144                                         in.skipBytes(in.readableBytes());
145                                         ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146                                         SslUtils.handleHandshakeFailure(ctx, e, true);
147                                         throw e;
148                                     }
149                                     // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
150                                     readerIndex += 4;
151                                     packetLength -= 4;
152 
153                                     if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154                                         // We have everything we need in one packet.
155                                         // Skip the record header
156                                         readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157                                         select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158                                         return;
159                                     } else {
160                                         if (handshakeBuffer == null) {
161                                             handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162                                         } else {
163                                             // Clear the buffer so we can aggregate into it again.
164                                             handshakeBuffer.clear();
165                                         }
166                                     }
167                                 }
168 
169                                 // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
170                                 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171                                         packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172                                 readerIndex += packetLength;
173                                 readableBytes -= packetLength;
174                                 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175                                     ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176                                     handshakeBuffer = null;
177 
178                                     select(ctx, clientHello);
179                                     return;
180                                 }
181                                 break;
182                             }
183                             // fall-through
184                         default:
185                             // not tls, ssl or application data
186                             select(ctx, null);
187                             return;
188                     }
189                 }
190             } catch (NotSslRecordException e) {
191                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
192                 throw e;
193             } catch (TooLongFrameException e) {
194                 // Just rethrow as in this case we also closed the channel
195                 throw e;
196             } catch (Exception e) {
197                 // unexpected encoding, ignore sni and use default
198                 if (logger.isDebugEnabled()) {
199                     logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200                 }
201                 select(ctx, null);
202             }
203         }
204     }
205 
206     private void releaseHandshakeBuffer() {
207         releaseIfNotNull(handshakeBuffer);
208         handshakeBuffer = null;
209     }
210 
211     private static void releaseIfNotNull(ByteBuf buffer) {
212         if (buffer != null) {
213             buffer.release();
214         }
215     }
216 
217     private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218         final Future<T> future;
219         try {
220             future = lookup(ctx, clientHello);
221             if (future.isDone()) {
222                 try {
223                     onLookupComplete(ctx, future);
224                 } catch (DecoderException err) {
225                     ctx.fireExceptionCaught(err);
226                 } catch (Exception cause) {
227                     ctx.fireExceptionCaught(new DecoderException(cause));
228                 } catch (Throwable cause) {
229                     ctx.fireExceptionCaught(cause);
230                 }
231             } else {
232                 suppressRead = true;
233                 final ByteBuf finalClientHello = clientHello;
234                 future.addListener(new FutureListener<T>() {
235                     @Override
236                     public void operationComplete(Future<T> future) {
237                         releaseIfNotNull(finalClientHello);
238                         try {
239                             suppressRead = false;
240                             try {
241                                 onLookupComplete(ctx, future);
242                             } catch (DecoderException err) {
243                                 ctx.fireExceptionCaught(err);
244                             } catch (Exception cause) {
245                                 ctx.fireExceptionCaught(new DecoderException(cause));
246                             } catch (Throwable cause) {
247                                 ctx.fireExceptionCaught(cause);
248                             }
249                         } finally {
250                             if (readPending) {
251                                 readPending = false;
252                                 ctx.read();
253                             }
254                         }
255                     }
256                 });
257 
258                 // Ownership was transferred to the FutureListener.
259                 clientHello = null;
260             }
261         } catch (Throwable cause) {
262             PlatformDependent.throwException(cause);
263         } finally {
264             releaseIfNotNull(clientHello);
265         }
266     }
267 
268     @Override
269     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
270         releaseHandshakeBuffer();
271 
272         super.handlerRemoved0(ctx);
273     }
274 
275     /**
276      * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
277      * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
278      *
279      * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
280      *
281      * <pre>
282      * struct {
283      *    ProtocolVersion client_version;
284      *    Random random;
285      *    SessionID session_id;
286      *    CipherSuite cipher_suites<2..2^16-2>;
287      *    CompressionMethod compression_methods<1..2^8-1>;
288      *    select (extensions_present) {
289      *        case false:
290      *            struct {};
291      *        case true:
292      *            Extension extensions<0..2^16-1>;
293      *    };
294      * } ClientHello;
295      * </pre>
296      *
297      * @see #onLookupComplete(ChannelHandlerContext, Future)
298      */
299     protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
300 
301     /**
302      * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}.
303      *
304      * @see #lookup(ChannelHandlerContext, ByteBuf)
305      */
306     protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
307 
308     @Override
309     public void read(ChannelHandlerContext ctx) throws Exception {
310         if (suppressRead) {
311             readPending = true;
312         } else {
313             ctx.read();
314         }
315     }
316 
317     @Override
318     public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
319         ctx.bind(localAddress, promise);
320     }
321 
322     @Override
323     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
324                         ChannelPromise promise) throws Exception {
325         ctx.connect(remoteAddress, localAddress, promise);
326     }
327 
328     @Override
329     public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
330         ctx.disconnect(promise);
331     }
332 
333     @Override
334     public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
335         ctx.close(promise);
336     }
337 
338     @Override
339     public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
340         ctx.deregister(promise);
341     }
342 
343     @Override
344     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
345         ctx.write(msg, promise);
346     }
347 
348     @Override
349     public void flush(ChannelHandlerContext ctx) throws Exception {
350         ctx.flush();
351     }
352 }