View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.util.CharsetUtil;
21  import io.netty.util.concurrent.Future;
22  import io.netty.util.concurrent.ScheduledFuture;
23  
24  import java.util.Locale;
25  import java.util.concurrent.TimeUnit;
26  
27  import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
28  
29  /**
30   * <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
31   * (Server Name Indication)</a> extension for server side SSL. For clients
32   * support SNI, the server could have multiple host name bound on a single IP.
33   * The client will send host name in the handshake data so server could decide
34   * which certificate to choose for the host name.</p>
35   */
36  public abstract class AbstractSniHandler<T> extends SslClientHelloHandler<T> {
37  
38      private static String extractSniHostname(ByteBuf in) {
39          // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
40          //
41          // Decode the ssl client hello packet.
42          //
43          // struct {
44          //    ProtocolVersion client_version;
45          //    Random random;
46          //    SessionID session_id;
47          //    CipherSuite cipher_suites<2..2^16-2>;
48          //    CompressionMethod compression_methods<1..2^8-1>;
49          //    select (extensions_present) {
50          //        case false:
51          //            struct {};
52          //        case true:
53          //            Extension extensions<0..2^16-1>;
54          //    };
55          // } ClientHello;
56          //
57  
58          // We have to skip bytes until SessionID (which sum to 34 bytes in this case).
59          int offset = in.readerIndex();
60          int endOffset = in.writerIndex();
61          offset += 34;
62  
63          if (endOffset - offset >= 6) {
64              final int sessionIdLength = in.getUnsignedByte(offset);
65              offset += sessionIdLength + 1;
66  
67              final int cipherSuitesLength = in.getUnsignedShort(offset);
68              offset += cipherSuitesLength + 2;
69  
70              final int compressionMethodLength = in.getUnsignedByte(offset);
71              offset += compressionMethodLength + 1;
72  
73              final int extensionsLength = in.getUnsignedShort(offset);
74              offset += 2;
75              final int extensionsLimit = offset + extensionsLength;
76  
77              // Extensions should never exceed the record boundary.
78              if (extensionsLimit <= endOffset) {
79                  while (extensionsLimit - offset >= 4) {
80                      final int extensionType = in.getUnsignedShort(offset);
81                      offset += 2;
82  
83                      final int extensionLength = in.getUnsignedShort(offset);
84                      offset += 2;
85  
86                      if (extensionsLimit - offset < extensionLength) {
87                          break;
88                      }
89  
90                      // SNI
91                      // See https://tools.ietf.org/html/rfc6066#page-6
92                      if (extensionType == 0) {
93                          offset += 2;
94                          if (extensionsLimit - offset < 3) {
95                              break;
96                          }
97  
98                          final int serverNameType = in.getUnsignedByte(offset);
99                          offset++;
100 
101                         if (serverNameType == 0) {
102                             final int serverNameLength = in.getUnsignedShort(offset);
103                             offset += 2;
104 
105                             if (extensionsLimit - offset < serverNameLength) {
106                                 break;
107                             }
108 
109                             final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
110                             return hostname.toLowerCase(Locale.US);
111                         } else {
112                             // invalid enum value
113                             break;
114                         }
115                     }
116 
117                     offset += extensionLength;
118                 }
119             }
120         }
121         return null;
122     }
123 
124     protected final long handshakeTimeoutMillis;
125     private ScheduledFuture<?> timeoutFuture;
126     private String hostname;
127 
128     /**
129      * @param handshakeTimeoutMillis    the handshake timeout in milliseconds
130      */
131     protected AbstractSniHandler(long handshakeTimeoutMillis) {
132         this(0, handshakeTimeoutMillis);
133     }
134 
135     /**
136      * @paramm maxClientHelloLength     the maximum length of the client hello message.
137      * @param handshakeTimeoutMillis    the handshake timeout in milliseconds
138      */
139     protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMillis) {
140         super(maxClientHelloLength);
141         this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis");
142     }
143 
144     public AbstractSniHandler() {
145         this(0, 0L);
146     }
147 
148     @Override
149     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
150         if (ctx.channel().isActive()) {
151             checkStartTimeout(ctx);
152         }
153     }
154 
155     @Override
156     public void channelActive(ChannelHandlerContext ctx) throws Exception {
157         ctx.fireChannelActive();
158         checkStartTimeout(ctx);
159     }
160 
161     private void checkStartTimeout(final ChannelHandlerContext ctx) {
162         if (handshakeTimeoutMillis <= 0 || timeoutFuture != null) {
163             return;
164         }
165         timeoutFuture = ctx.executor().schedule(new Runnable() {
166             @Override
167             public void run() {
168                 if (ctx.channel().isActive()) {
169                     SslHandshakeTimeoutException exception = new SslHandshakeTimeoutException(
170                         "handshake timed out after " + handshakeTimeoutMillis + "ms");
171                     ctx.fireUserEventTriggered(new SniCompletionEvent(exception));
172                     ctx.close();
173                 }
174             }
175         }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
176     }
177 
178     @Override
179     protected Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
180         hostname = clientHello == null ? null : extractSniHostname(clientHello);
181 
182         return lookup(ctx, hostname);
183     }
184 
185     @Override
186     protected void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception {
187         if (timeoutFuture != null) {
188             timeoutFuture.cancel(false);
189         }
190         try {
191             onLookupComplete(ctx, hostname, future);
192         } finally {
193             fireSniCompletionEvent(ctx, hostname, future);
194         }
195     }
196 
197     /**
198      * Kicks off a lookup for the given SNI value and returns a {@link Future} which in turn will
199      * notify the {@link #onLookupComplete(ChannelHandlerContext, String, Future)} on completion.
200      *
201      * @see #onLookupComplete(ChannelHandlerContext, String, Future)
202      */
203     protected abstract Future<T> lookup(ChannelHandlerContext ctx, String hostname) throws Exception;
204 
205     /**
206      * Called upon completion of the {@link #lookup(ChannelHandlerContext, String)} {@link Future}.
207      *
208      * @see #lookup(ChannelHandlerContext, String)
209      */
210     protected abstract void onLookupComplete(ChannelHandlerContext ctx,
211                                              String hostname, Future<T> future) throws Exception;
212 
213     private static void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<?> future) {
214         Throwable cause = future.cause();
215         if (cause == null) {
216             ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
217         } else {
218             ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
219         }
220     }
221 }