View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufUtil;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelOutboundHandler;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.ByteToMessageDecoder;
24  import io.netty.handler.codec.DecoderException;
25  import io.netty.util.CharsetUtil;
26  import io.netty.util.concurrent.Future;
27  import io.netty.util.concurrent.FutureListener;
28  import io.netty.util.internal.PlatformDependent;
29  import io.netty.util.internal.logging.InternalLogger;
30  import io.netty.util.internal.logging.InternalLoggerFactory;
31  
32  import java.net.SocketAddress;
33  import java.util.List;
34  import java.util.Locale;
35  
36  /**
37   * <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
38   * (Server Name Indication)</a> extension for server side SSL. For clients
39   * support SNI, the server could have multiple host name bound on a single IP.
40   * The client will send host name in the handshake data so server could decide
41   * which certificate to choose for the host name.</p>
42   */
43  public abstract class AbstractSniHandler<T> extends ByteToMessageDecoder implements ChannelOutboundHandler {
44  
45      // Maximal number of ssl records to inspect before fallback to the default SslContext.
46      private static final int MAX_SSL_RECORDS = 4;
47  
48      private static final InternalLogger logger =
49              InternalLoggerFactory.getInstance(AbstractSniHandler.class);
50  
51      private boolean handshakeFailed;
52      private boolean suppressRead;
53      private boolean readPending;
54  
55      @Override
56      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
57          if (!suppressRead && !handshakeFailed) {
58              final int writerIndex = in.writerIndex();
59              try {
60                  loop:
61                  for (int i = 0; i < MAX_SSL_RECORDS; i++) {
62                      final int readerIndex = in.readerIndex();
63                      final int readableBytes = writerIndex - readerIndex;
64                      if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) {
65                          // Not enough data to determine the record type and length.
66                          return;
67                      }
68  
69                      final int command = in.getUnsignedByte(readerIndex);
70  
71                      // tls, but not handshake command
72                      switch (command) {
73                          case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
74                          case SslUtils.SSL_CONTENT_TYPE_ALERT:
75                              final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
76  
77                              // Not an SSL/TLS packet
78                              if (len == SslUtils.NOT_ENCRYPTED) {
79                                  handshakeFailed = true;
80                                  NotSslRecordException e = new NotSslRecordException(
81                                          "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
82                                  in.skipBytes(in.readableBytes());
83                                  ctx.fireUserEventTriggered(new SniCompletionEvent(e));
84                                  SslUtils.handleHandshakeFailure(ctx, e, true);
85                                  throw e;
86                              }
87                              if (len == SslUtils.NOT_ENOUGH_DATA ||
88                                      writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
89                                  // Not enough data
90                                  return;
91                              }
92                              // increase readerIndex and try again.
93                              in.skipBytes(len);
94                              continue;
95                          case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
96                              final int majorVersion = in.getUnsignedByte(readerIndex + 1);
97  
98                              // SSLv3 or TLS
99                              if (majorVersion == 3) {
100                                 final int packetLength = in.getUnsignedShort(readerIndex + 3) +
101                                         SslUtils.SSL_RECORD_HEADER_LENGTH;
102 
103                                 if (readableBytes < packetLength) {
104                                     // client hello incomplete; try again to decode once more data is ready.
105                                     return;
106                                 }
107 
108                                 // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
109                                 //
110                                 // Decode the ssl client hello packet.
111                                 // We have to skip bytes until SessionID (which sum to 43 bytes).
112                                 //
113                                 // struct {
114                                 //    ProtocolVersion client_version;
115                                 //    Random random;
116                                 //    SessionID session_id;
117                                 //    CipherSuite cipher_suites<2..2^16-2>;
118                                 //    CompressionMethod compression_methods<1..2^8-1>;
119                                 //    select (extensions_present) {
120                                 //        case false:
121                                 //            struct {};
122                                 //        case true:
123                                 //            Extension extensions<0..2^16-1>;
124                                 //    };
125                                 // } ClientHello;
126                                 //
127 
128                                 final int endOffset = readerIndex + packetLength;
129                                 int offset = readerIndex + 43;
130 
131                                 if (endOffset - offset < 6) {
132                                     break loop;
133                                 }
134 
135                                 final int sessionIdLength = in.getUnsignedByte(offset);
136                                 offset += sessionIdLength + 1;
137 
138                                 final int cipherSuitesLength = in.getUnsignedShort(offset);
139                                 offset += cipherSuitesLength + 2;
140 
141                                 final int compressionMethodLength = in.getUnsignedByte(offset);
142                                 offset += compressionMethodLength + 1;
143 
144                                 final int extensionsLength = in.getUnsignedShort(offset);
145                                 offset += 2;
146                                 final int extensionsLimit = offset + extensionsLength;
147 
148                                 if (extensionsLimit > endOffset) {
149                                     // Extensions should never exceed the record boundary.
150                                     break loop;
151                                 }
152 
153                                 for (;;) {
154                                     if (extensionsLimit - offset < 4) {
155                                         break loop;
156                                     }
157 
158                                     final int extensionType = in.getUnsignedShort(offset);
159                                     offset += 2;
160 
161                                     final int extensionLength = in.getUnsignedShort(offset);
162                                     offset += 2;
163 
164                                     if (extensionsLimit - offset < extensionLength) {
165                                         break loop;
166                                     }
167 
168                                     // SNI
169                                     // See https://tools.ietf.org/html/rfc6066#page-6
170                                     if (extensionType == 0) {
171                                         offset += 2;
172                                         if (extensionsLimit - offset < 3) {
173                                             break loop;
174                                         }
175 
176                                         final int serverNameType = in.getUnsignedByte(offset);
177                                         offset++;
178 
179                                         if (serverNameType == 0) {
180                                             final int serverNameLength = in.getUnsignedShort(offset);
181                                             offset += 2;
182 
183                                             if (extensionsLimit - offset < serverNameLength) {
184                                                 break loop;
185                                             }
186 
187                                             final String hostname = in.toString(offset, serverNameLength,
188                                                     CharsetUtil.US_ASCII);
189 
190                                             try {
191                                                 select(ctx, hostname.toLowerCase(Locale.US));
192                                             } catch (Throwable t) {
193                                                 PlatformDependent.throwException(t);
194                                             }
195                                             return;
196                                         } else {
197                                             // invalid enum value
198                                             break loop;
199                                         }
200                                     }
201 
202                                     offset += extensionLength;
203                                 }
204                             }
205                             // Fall-through
206                         default:
207                             //not tls, ssl or application data, do not try sni
208                             break loop;
209                     }
210                 }
211             } catch (NotSslRecordException e) {
212                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
213                 throw e;
214             } catch (Exception e) {
215                 // unexpected encoding, ignore sni and use default
216                 if (logger.isDebugEnabled()) {
217                     logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
218                 }
219             }
220             // Just select the default SslContext
221             select(ctx, null);
222         }
223     }
224 
225     private void select(final ChannelHandlerContext ctx, final String hostname) throws Exception {
226         Future<T> future = lookup(ctx, hostname);
227         if (future.isDone()) {
228             fireSniCompletionEvent(ctx, hostname, future);
229             onLookupComplete(ctx, hostname, future);
230         } else {
231             suppressRead = true;
232             future.addListener(new FutureListener<T>() {
233                 @Override
234                 public void operationComplete(Future<T> future) throws Exception {
235                     try {
236                         suppressRead = false;
237                         try {
238                             fireSniCompletionEvent(ctx, hostname, future);
239                             onLookupComplete(ctx, hostname, future);
240                         } catch (DecoderException err) {
241                             ctx.fireExceptionCaught(err);
242                         } catch (Exception cause) {
243                             ctx.fireExceptionCaught(new DecoderException(cause));
244                         } catch (Throwable cause) {
245                             ctx.fireExceptionCaught(cause);
246                         }
247                     } finally {
248                         if (readPending) {
249                             readPending = false;
250                             ctx.read();
251                         }
252                     }
253                 }
254             });
255         }
256     }
257 
258     private void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<T> future) {
259         Throwable cause = future.cause();
260         if (cause == null) {
261             ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
262         } else {
263             ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
264         }
265     }
266 
267     /**
268      * Kicks off a lookup for the given SNI value and returns a {@link Future} which in turn will
269      * notify the {@link #onLookupComplete(ChannelHandlerContext, String, Future)} on completion.
270      *
271      * @see #onLookupComplete(ChannelHandlerContext, String, Future)
272      */
273     protected abstract Future<T> lookup(ChannelHandlerContext ctx, String hostname) throws Exception;
274 
275     /**
276      * Called upon completion of the {@link #lookup(ChannelHandlerContext, String)} {@link Future}.
277      *
278      * @see #lookup(ChannelHandlerContext, String)
279      */
280     protected abstract void onLookupComplete(ChannelHandlerContext ctx,
281                                              String hostname, Future<T> future) throws Exception;
282 
283     @Override
284     public void read(ChannelHandlerContext ctx) throws Exception {
285         if (suppressRead) {
286             readPending = true;
287         } else {
288             ctx.read();
289         }
290     }
291 
292     @Override
293     public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
294         ctx.bind(localAddress, promise);
295     }
296 
297     @Override
298     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
299                         ChannelPromise promise) throws Exception {
300         ctx.connect(remoteAddress, localAddress, promise);
301     }
302 
303     @Override
304     public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
305         ctx.disconnect(promise);
306     }
307 
308     @Override
309     public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
310         ctx.close(promise);
311     }
312 
313     @Override
314     public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
315         ctx.deregister(promise);
316     }
317 
318     @Override
319     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
320         ctx.write(msg, promise);
321     }
322 
323     @Override
324     public void flush(ChannelHandlerContext ctx) throws Exception {
325         ctx.flush();
326     }
327 }