View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.buffer.ByteBufUtil;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.handler.codec.ByteToMessageDecoder;
22  import io.netty.util.CharsetUtil;
23  import io.netty.util.DomainNameMapping;
24  import io.netty.util.ReferenceCountUtil;
25  import io.netty.util.internal.PlatformDependent;
26  import io.netty.util.internal.logging.InternalLogger;
27  import io.netty.util.internal.logging.InternalLoggerFactory;
28  
29  import java.net.IDN;
30  import java.util.List;
31  import java.util.Locale;
32  
33  /**
34   * <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
35   * (Server Name Indication)</a> extension for server side SSL. For clients
36   * support SNI, the server could have multiple host name bound on a single IP.
37   * The client will send host name in the handshake data so server could decide
38   * which certificate to choose for the host name. </p>
39   */
40  public class SniHandler extends ByteToMessageDecoder {
41  
42      // Maximal number of ssl records to inspect before fallback to the default SslContext.
43      private static final int MAX_SSL_RECORDS = 4;
44  
45      private static final InternalLogger logger =
46              InternalLoggerFactory.getInstance(SniHandler.class);
47  
48      private static final Selection EMPTY_SELECTION = new Selection(null, null);
49  
50      private final DomainNameMapping<SslContext> mapping;
51  
52      private boolean handshakeFailed;
53  
54      private volatile Selection selection = EMPTY_SELECTION;
55  
56      /**
57       * Create a SNI detection handler with configured {@link SslContext}
58       * maintained by {@link DomainNameMapping}
59       *
60       * @param mapping the mapping of domain name to {@link SslContext}
61       */
62      @SuppressWarnings("unchecked")
63      public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
64          if (mapping == null) {
65              throw new NullPointerException("mapping");
66          }
67  
68          this.mapping = (DomainNameMapping<SslContext>) mapping;
69      }
70  
71      /**
72       * @return the selected hostname
73       */
74      public String hostname() {
75          return selection.hostname;
76      }
77  
78      /**
79       * @return the selected sslcontext
80       */
81      public SslContext sslContext() {
82          return selection.context;
83      }
84  
85      @Override
86      protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
87          if (!handshakeFailed) {
88              final int writerIndex = in.writerIndex();
89              try {
90                  loop:
91                  for (int i = 0; i < MAX_SSL_RECORDS; i++) {
92                      final int readerIndex = in.readerIndex();
93                      final int readableBytes = writerIndex - readerIndex;
94                      if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) {
95                          // Not enough data to determine the record type and length.
96                          return;
97                      }
98  
99                      final int command = in.getUnsignedByte(readerIndex);
100 
101                     // tls, but not handshake command
102                     switch (command) {
103                         case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
104                         case SslUtils.SSL_CONTENT_TYPE_ALERT:
105                             final int len = SslUtils.getEncryptedPacketLength(in, readerIndex);
106 
107                             // Not an SSL/TLS packet
108                             if (len == SslUtils.NOT_ENCRYPTED) {
109                                 handshakeFailed = true;
110                                 NotSslRecordException e = new NotSslRecordException(
111                                         "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
112                                 in.skipBytes(in.readableBytes());
113 
114                                 SslUtils.notifyHandshakeFailure(ctx, e, true);
115                                 throw e;
116                             }
117                             if (len == SslUtils.NOT_ENOUGH_DATA ||
118                                     writerIndex - readerIndex - SslUtils.SSL_RECORD_HEADER_LENGTH < len) {
119                                 // Not enough data
120                                 return;
121                             }
122                             // increase readerIndex and try again.
123                             in.skipBytes(len);
124                             continue;
125                         case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE:
126                             final int majorVersion = in.getUnsignedByte(readerIndex + 1);
127 
128                             // SSLv3 or TLS
129                             if (majorVersion == 3) {
130                                 final int packetLength = in.getUnsignedShort(readerIndex + 3) +
131                                                          SslUtils.SSL_RECORD_HEADER_LENGTH;
132 
133                                 if (readableBytes < packetLength) {
134                                     // client hello incomplete; try again to decode once more data is ready.
135                                     return;
136                                 }
137 
138                                 // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
139                                 //
140                                 // Decode the ssl client hello packet.
141                                 // We have to skip bytes until SessionID (which sum to 43 bytes).
142                                 //
143                                 // struct {
144                                 //    ProtocolVersion client_version;
145                                 //    Random random;
146                                 //    SessionID session_id;
147                                 //    CipherSuite cipher_suites<2..2^16-2>;
148                                 //    CompressionMethod compression_methods<1..2^8-1>;
149                                 //    select (extensions_present) {
150                                 //        case false:
151                                 //            struct {};
152                                 //        case true:
153                                 //            Extension extensions<0..2^16-1>;
154                                 //    };
155                                 // } ClientHello;
156                                 //
157 
158                                 final int endOffset = readerIndex + packetLength;
159                                 int offset = readerIndex + 43;
160 
161                                 if (endOffset - offset < 6) {
162                                     break loop;
163                                 }
164 
165                                 final int sessionIdLength = in.getUnsignedByte(offset);
166                                 offset += sessionIdLength + 1;
167 
168                                 final int cipherSuitesLength = in.getUnsignedShort(offset);
169                                 offset += cipherSuitesLength + 2;
170 
171                                 final int compressionMethodLength = in.getUnsignedByte(offset);
172                                 offset += compressionMethodLength + 1;
173 
174                                 final int extensionsLength = in.getUnsignedShort(offset);
175                                 offset += 2;
176                                 final int extensionsLimit = offset + extensionsLength;
177 
178                                 if (extensionsLimit > endOffset) {
179                                     // Extensions should never exceed the record boundary.
180                                     break loop;
181                                 }
182 
183                                 for (;;) {
184                                     if (extensionsLimit - offset < 4) {
185                                         break loop;
186                                     }
187 
188                                     final int extensionType = in.getUnsignedShort(offset);
189                                     offset += 2;
190 
191                                     final int extensionLength = in.getUnsignedShort(offset);
192                                     offset += 2;
193 
194                                     if (extensionsLimit - offset < extensionLength) {
195                                         break loop;
196                                     }
197 
198                                     // SNI
199                                     // See https://tools.ietf.org/html/rfc6066#page-6
200                                     if (extensionType == 0) {
201                                         offset += 2;
202                                         if (extensionsLimit - offset < 3) {
203                                             break loop;
204                                         }
205 
206                                         final int serverNameType = in.getUnsignedByte(offset);
207                                         offset++;
208 
209                                         if (serverNameType == 0) {
210                                             final int serverNameLength = in.getUnsignedShort(offset);
211                                             offset += 2;
212 
213                                             if (extensionsLimit - offset < serverNameLength) {
214                                                 break loop;
215                                             }
216 
217                                             final String hostname = in.toString(offset, serverNameLength,
218                                                                                 CharsetUtil.UTF_8);
219 
220                                             try {
221                                                 select(ctx, IDN.toASCII(hostname,
222                                                                         IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
223                                             } catch (Throwable t) {
224                                                 PlatformDependent.throwException(t);
225                                             }
226                                             return;
227                                         } else {
228                                             // invalid enum value
229                                             break loop;
230                                         }
231                                     }
232 
233                                     offset += extensionLength;
234                                 }
235                             }
236                             // Fall-through
237                         default:
238                             //not tls, ssl or application data, do not try sni
239                             break loop;
240                     }
241                 }
242             } catch (NotSslRecordException e) {
243                 // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler.
244                 throw e;
245             } catch (Exception e) {
246                 // unexpected encoding, ignore sni and use default
247                 if (logger.isDebugEnabled()) {
248                     logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e);
249                 }
250             }
251             // Just select the default SslContext
252             select(ctx, null);
253         }
254     }
255 
256     private void select(ChannelHandlerContext ctx, String hostname) {
257         SslHandler sslHandler = null;
258         SslContext selectedContext = mapping.map(hostname);
259         selection = new Selection(selectedContext, hostname);
260         try {
261             sslHandler = selection.context.newHandler(ctx.alloc());
262             ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
263         } catch (Throwable cause) {
264             selection = EMPTY_SELECTION;
265             // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
266             // transferred to the SslHandler.
267             // See https://github.com/netty/netty/issues/5678
268             if (sslHandler != null) {
269                 ReferenceCountUtil.safeRelease(sslHandler.engine());
270             }
271             PlatformDependent.throwException(cause);
272         }
273     }
274 
275     private static final class Selection {
276         final SslContext context;
277         final String hostname;
278 
279         Selection(SslContext context, String hostname) {
280             this.context = context;
281             this.hostname = hostname;
282         }
283     }
284 }