View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufUtil;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelOutboundHandler;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.ByteToMessageDecoder;
24  import io.netty.handler.codec.DecoderException;
25  import io.netty.handler.codec.TooLongFrameException;
26  import io.netty.util.concurrent.Future;
27  import io.netty.util.concurrent.FutureListener;
28  import io.netty.util.internal.ObjectUtil;
29  import io.netty.util.internal.PlatformDependent;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.net.SocketAddress;
34  import java.util.List;
35  
36  /**
37   * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
38   */
39  public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
40  
41      /**
42       * The maximum length of client hello message as defined by
43       * <a href="https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1">RFC5246</a>.
44       */
45      public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF;
46  
47      private static final InternalLogger logger =
48              InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
49  
50      private final int maxClientHelloLength;
51      private boolean handshakeFailed;
52      private boolean suppressRead;
53      private boolean readPending;
54      private ByteBuf handshakeBuffer;
55  
56      public SslClientHelloHandler() {
57          this(MAX_CLIENT_HELLO_LENGTH);
58      }
59  
60      protected SslClientHelloHandler(int maxClientHelloLength) {
61          // 16MB is the maximum as per RFC:
62          // See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1
63          this.maxClientHelloLength =
64                  ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength");
65      }
66  
67      @Override
68      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69          if (!suppressRead && !handshakeFailed) {
70              try {
71                  int readerIndex = in.readerIndex();
72                  int readableBytes = in.readableBytes();
73                  int handshakeLength = -1;
74  
75                  // Check if we have enough data to determine the record type and length.
76                  while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
77                      final int contentType = in.getUnsignedByte(readerIndex);
78                      switch (contentType) {
79                          case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
80                              // fall-through
81                          case SslUtils.SSL_CONTENT_TYPE_ALERT:
82                              final int len = SslUtils.getEncryptedPacketLength(in, readerIndex, true);
83  
84                              // Not an SSL/TLS packet
85                              if (len == SslUtils.NOT_ENCRYPTED) {
86                                  handshakeFailed = true;
87                                  NotSslRecordException e = new NotSslRecordException(
88                                          "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
89                                  in.skipBytes(in.readableBytes());
90                                  ctx.fireUserEventTriggered(new SniCompletionEvent(e));
91                                  SslUtils.handleHandshakeFailure(ctx, e, true);
92                                  throw e;
93                              }
94                              if (len == SslUtils.NOT_ENOUGH_DATA) {
95                                  // Not enough data
96                                  return;
97                              }
98                              // No ClientHello
99                              select(ctx, null);
100                             return;
101                         case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
102                             final int majorVersion = in.getUnsignedByte(readerIndex + 1);
103                             // SSLv3 or TLS
104                             if (majorVersion == 3) {
105                                 int packetLength = in.getUnsignedShort(readerIndex + 3) +
106                                         SslUtils.SSL_RECORD_HEADER_LENGTH;
107 
108                                 if (readableBytes < packetLength) {
109                                     // client hello incomplete; try again to decode once more data is ready.
110                                     return;
111                                 } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
112                                     select(ctx, null);
113                                     return;
114                                 }
115 
116                                 final int endOffset = readerIndex + packetLength;
117 
118                                 // Let's check if we already parsed the handshake length or not.
119                                 if (handshakeLength == -1) {
120                                     if (readerIndex + 4 > endOffset) {
121                                         // Need more data to read HandshakeType and handshakeLength (4 bytes)
122                                         return;
123                                     }
124 
125                                     final int handshakeType = in.getUnsignedByte(readerIndex +
126                                             SslUtils.SSL_RECORD_HEADER_LENGTH);
127 
128                                     // Check if this is a clientHello(1)
129                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
130                                     if (handshakeType != 1) {
131                                         select(ctx, null);
132                                         return;
133                                     }
134 
135                                     // Read the length of the handshake as it may arrive in fragments
136                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
137                                     handshakeLength = in.getUnsignedMedium(readerIndex +
138                                             SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
139 
140                                     if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) {
141                                         TooLongFrameException e = new TooLongFrameException(
142                                                 "ClientHello length exceeds " + maxClientHelloLength +
143                                                         ": " + handshakeLength);
144                                         in.skipBytes(in.readableBytes());
145                                         ctx.fireUserEventTriggered(new SniCompletionEvent(e));
146                                         SslUtils.handleHandshakeFailure(ctx, e, true);
147                                         throw e;
148                                     }
149                                     // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
150                                     readerIndex += 4;
151                                     packetLength -= 4;
152 
153                                     if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
154                                         // We have everything we need in one packet.
155                                         // Skip the record header
156                                         readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
157                                         select(ctx, in.retainedSlice(readerIndex, handshakeLength));
158                                         return;
159                                     } else {
160                                         if (handshakeBuffer == null) {
161                                             handshakeBuffer = ctx.alloc().buffer(handshakeLength);
162                                         } else {
163                                             // Clear the buffer so we can aggregate into it again.
164                                             handshakeBuffer.clear();
165                                         }
166                                     }
167                                 }
168 
169                                 // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
170                                 handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
171                                         packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH);
172                                 readerIndex += packetLength;
173                                 readableBytes -= packetLength;
174                                 if (handshakeLength <= handshakeBuffer.readableBytes()) {
175                                     ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength);
176                                     handshakeBuffer = null;
177 
178                                     select(ctx, clientHello);
179                                     return;
180                                 }
181                                 break;
182                             }
183                             // fall-through
184                         default:
185                             // not tls, ssl or application data
186                             select(ctx, null);
187                             return;
188                     }
189                 }
190             } catch (NotSslRecordException e) {
191                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
192                 throw e;
193             } catch (TooLongFrameException e) {
194                 // Just rethrow as in this case we also closed the channel
195                 throw e;
196             } catch (Exception e) {
197                 // unexpected encoding, ignore sni and use default
198                 if (logger.isDebugEnabled()) {
199                     logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
200                 }
201                 select(ctx, null);
202             }
203         }
204     }
205 
206     private void releaseHandshakeBuffer() {
207         releaseIfNotNull(handshakeBuffer);
208         handshakeBuffer = null;
209     }
210 
211     private static void releaseIfNotNull(ByteBuf buffer) {
212         if (buffer != null) {
213             buffer.release();
214         }
215     }
216 
217     private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
218         final Future<T> future;
219         try {
220             future = lookup(ctx, clientHello);
221             if (future.isDone()) {
222                 onLookupComplete(ctx, future);
223             } else {
224                 suppressRead = true;
225                 final ByteBuf finalClientHello = clientHello;
226                 future.addListener((FutureListener<T>) future1 -> {
227                     releaseIfNotNull(finalClientHello);
228                     try {
229                         suppressRead = false;
230                         try {
231                             onLookupComplete(ctx, future1);
232                         } catch (DecoderException err) {
233                             ctx.fireExceptionCaught(err);
234                         } catch (Exception cause) {
235                             ctx.fireExceptionCaught(new DecoderException(cause));
236                         } catch (Throwable cause) {
237                             ctx.fireExceptionCaught(cause);
238                         }
239                     } finally {
240                         if (readPending) {
241                             readPending = false;
242                             ctx.read();
243                         }
244                     }
245                 });
246 
247                 // Ownership was transferred to the FutureListener.
248                 clientHello = null;
249             }
250         } catch (Throwable cause) {
251             PlatformDependent.throwException(cause);
252         } finally {
253             releaseIfNotNull(clientHello);
254         }
255     }
256 
257     @Override
258     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
259         releaseHandshakeBuffer();
260 
261         super.handlerRemoved0(ctx);
262     }
263 
264     /**
265      * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
266      * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
267      *
268      * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
269      *
270      * <pre>
271      * struct {
272      *    ProtocolVersion client_version;
273      *    Random random;
274      *    SessionID session_id;
275      *    CipherSuite cipher_suites<2..2^16-2>;
276      *    CompressionMethod compression_methods<1..2^8-1>;
277      *    select (extensions_present) {
278      *        case false:
279      *            struct {};
280      *        case true:
281      *            Extension extensions<0..2^16-1>;
282      *    };
283      * } ClientHello;
284      * </pre>
285      *
286      * @see #onLookupComplete(ChannelHandlerContext, Future)
287      */
288     protected abstract Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception;
289 
290     /**
291      * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}.
292      *
293      * @see #lookup(ChannelHandlerContext, ByteBuf)
294      */
295     protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception;
296 
297     @Override
298     public void read(ChannelHandlerContext ctx) throws Exception {
299         if (suppressRead) {
300             readPending = true;
301         } else {
302             ctx.read();
303         }
304     }
305 
306     @Override
307     public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
308         ctx.bind(localAddress, promise);
309     }
310 
311     @Override
312     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
313                         ChannelPromise promise) throws Exception {
314         ctx.connect(remoteAddress, localAddress, promise);
315     }
316 
317     @Override
318     public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
319         ctx.disconnect(promise);
320     }
321 
322     @Override
323     public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
324         ctx.close(promise);
325     }
326 
327     @Override
328     public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
329         ctx.deregister(promise);
330     }
331 
332     @Override
333     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
334         ctx.write(msg, promise);
335     }
336 
337     @Override
338     public void flush(ChannelHandlerContext ctx) throws Exception {
339         ctx.flush();
340     }
341 }