View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.handler.ssl;
17  
18  import io.netty5.buffer.BufferUtil;
19  import io.netty5.buffer.api.Buffer;
20  import io.netty5.util.Resource;
21  import io.netty5.channel.ChannelHandlerContext;
22  import io.netty5.handler.codec.ByteToMessageDecoder;
23  import io.netty5.handler.codec.DecoderException;
24  import io.netty5.util.concurrent.Future;
25  import io.netty5.util.internal.PlatformDependent;
26  import io.netty5.util.internal.logging.InternalLogger;
27  import io.netty5.util.internal.logging.InternalLoggerFactory;
28  
29  /**
30   * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received.
31   */
32  public abstract class SslClientHelloHandler<T> extends ByteToMessageDecoder {
33  
34      private static final InternalLogger logger =
35              InternalLoggerFactory.getInstance(SslClientHelloHandler.class);
36  
37      private boolean handshakeFailed;
38      private boolean suppressRead;
39      private boolean readPending;
40      private Buffer handshakeBuffer;
41  
42      @Override
43      protected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
44          // TODO It ought to be possible to simplify this by using split() to grab the handshakes,
45          //  and avoid awkward copying and offsets book-keeping.
46          if (!suppressRead && !handshakeFailed) {
47              try {
48                  int readerIndex = in.readerOffset();
49                  int readableBytes = in.readableBytes();
50                  int handshakeLength = -1;
51  
52                  // Check if we have enough data to determine the record type and length.
53                  while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) {
54                      final int contentType = in.getUnsignedByte(readerIndex);
55                      switch (contentType) {
56                          case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
57                              // fall-through
58                          case SslUtils.SSL_CONTENT_TYPE_ALERT:
59                              final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
60  
61                              // Not an SSL/TLS packet
62                              if (len == SslUtils.NOT_ENCRYPTED) {
63                                  handshakeFailed = true;
64                                  NotSslRecordException e = new NotSslRecordException(
65                                          "not an SSL/TLS record: " + BufferUtil.hexDump(in));
66                                  in.skipReadableBytes(in.readableBytes());
67                                  ctx.fireChannelInboundEvent(new SniCompletionEvent(e));
68                                  ctx.fireChannelInboundEvent(new SslHandshakeCompletionEvent(e));
69                                  throw e;
70                              }
71                              if (len == SslUtils.NOT_ENOUGH_DATA) {
72                                  // Not enough data
73                                  return;
74                              }
75                              // No ClientHello
76                              select(ctx, null);
77                              return;
78                          case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
79                              final int majorVersion = in.getUnsignedByte(readerIndex + 1);
80                              // SSLv3 or TLS
81                              if (majorVersion == 3) {
82                                  int packetLength = in.getUnsignedShort(readerIndex + 3) +
83                                          SslUtils.SSL_RECORD_HEADER_LENGTH;
84  
85                                  if (readableBytes < packetLength) {
86                                      // client hello incomplete; try again to decode once more data is ready.
87                                      return;
88                                  }
89                                  if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) {
90                                      select(ctx, null);
91                                      return;
92                                  }
93  
94                                  final int endOffset = readerIndex + packetLength;
95  
96                                  // Let's check if we already parsed the handshake length or not.
97                                  if (handshakeLength == -1) {
98                                      if (readerIndex + 4 > endOffset) {
99                                          // Need more data to read HandshakeType and handshakeLength (4 bytes)
100                                         return;
101                                     }
102 
103                                     final int handshakeType = in.getUnsignedByte(readerIndex +
104                                             SslUtils.SSL_RECORD_HEADER_LENGTH);
105 
106                                     // Check if this is a clientHello(1)
107                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
108                                     if (handshakeType != 1) {
109                                         select(ctx, null);
110                                         return;
111                                     }
112 
113                                     // Read the length of the handshake as it may arrive in fragments
114                                     // See https://tools.ietf.org/html/rfc5246#section-7.4
115                                     handshakeLength = in.getUnsignedMedium(readerIndex +
116                                             SslUtils.SSL_RECORD_HEADER_LENGTH + 1);
117 
118                                     // Consume handshakeType and handshakeLength (this sums up as 4 bytes)
119                                     readerIndex += 4;
120                                     packetLength -= 4;
121 
122                                     if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) {
123                                         // We have everything we need in one packet.
124                                         // Skip the record header
125                                         readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH;
126                                         in.readerOffset(readerIndex);
127                                         select(ctx, in.readSplit(handshakeLength));
128                                         return;
129                                     } else {
130                                         if (handshakeBuffer == null) {
131                                             handshakeBuffer = ctx.bufferAllocator().allocate(handshakeLength);
132                                         } else {
133                                             // Reset the buffer offsets, so we can aggregate into it again.
134                                             handshakeBuffer.resetOffsets();
135                                         }
136                                     }
137                                 }
138 
139                                 // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER
140                                 int hsLen = packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH;
141                                 in.copyInto(readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH,
142                                             handshakeBuffer, handshakeBuffer.writerOffset(), hsLen);
143                                 handshakeBuffer.skipWritableBytes(hsLen);
144                                 readerIndex += packetLength;
145                                 readableBytes -= packetLength;
146                                 if (handshakeLength <= handshakeBuffer.readableBytes()) {
147                                     Buffer clientHello = handshakeBuffer.readerOffset(0).writerOffset(handshakeLength);
148                                     handshakeBuffer = null;
149 
150                                     select(ctx, clientHello);
151                                     return;
152                                 }
153                                 break;
154                             }
155                             // fall-through
156                         default:
157                             // not tls, ssl or application data
158                             select(ctx, null);
159                             return;
160                     }
161                 }
162             } catch (NotSslRecordException e) {
163                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
164                 throw e;
165             } catch (Exception e) {
166                 // unexpected encoding, ignore sni and use default
167                 if (logger.isDebugEnabled()) {
168                     logger.debug("Unexpected client hello packet: " + BufferUtil.hexDump(in), e);
169                 }
170                 select(ctx, null);
171             }
172         }
173     }
174 
175     private void releaseHandshakeBuffer() {
176         Resource.dispose(handshakeBuffer);
177         handshakeBuffer = null;
178     }
179 
180     private void select(final ChannelHandlerContext ctx, Buffer clientHello) {
181         final Future<T> future;
182         try {
183             future = lookup(ctx, clientHello);
184             if (future.isDone()) {
185                 Resource.dispose(clientHello); // Future is completed. We can dispose it immediately.
186                 onLookupComplete(ctx, future);
187             } else {
188                 suppressRead = true;
189                 future.addListener(f -> {
190                     Resource.dispose(clientHello); // Delay disposing until the future completes.
191                     try {
192                         suppressRead = false;
193                         try {
194                             onLookupComplete(ctx, f);
195                         } catch (DecoderException err) {
196                             ctx.fireChannelExceptionCaught(err);
197                         } catch (Exception cause) {
198                             ctx.fireChannelExceptionCaught(new DecoderException(cause));
199                         } catch (Throwable cause) {
200                             ctx.fireChannelExceptionCaught(cause);
201                         }
202                     } finally {
203                         if (readPending) {
204                             readPending = false;
205                             ctx.read();
206                         }
207                     }
208                 });
209             }
210         } catch (Throwable cause) {
211             PlatformDependent.throwException(cause);
212         }
213     }
214 
215     @Override
216     protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
217         releaseHandshakeBuffer();
218 
219         super.handlerRemoved0(ctx);
220     }
221 
222     /**
223      * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will
224      * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion.
225      *
226      * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
227      *
228      * <pre>
229      * struct {
230      *    ProtocolVersion client_version;
231      *    Random random;
232      *    SessionID session_id;
233      *    CipherSuite cipher_suites<2..2^16-2>;
234      *    CompressionMethod compression_methods<1..2^8-1>;
235      *    select (extensions_present) {
236      *        case false:
237      *            struct {};
238      *        case true:
239      *            Extension extensions<0..2^16-1>;
240      *    };
241      * } ClientHello;
242      * </pre>
243      *
244      * @see #onLookupComplete(ChannelHandlerContext, Future)
245      */
246     protected abstract Future<T> lookup(ChannelHandlerContext ctx, Buffer clientHello) throws Exception;
247 
248     /**
249      * Called upon completion of the {@link #lookup(ChannelHandlerContext, Buffer)} {@link Future}.
250      *
251      * @see #lookup(ChannelHandlerContext, Buffer)
252      */
253     protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future<? extends T> future) throws Exception;
254 
255     @Override
256     public void read(ChannelHandlerContext ctx) {
257         if (suppressRead) {
258             readPending = true;
259         } else {
260             ctx.read();
261         }
262     }
263 }