View Javadoc
1   /*
2    * Copyright 2017 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.ssl;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.util.CharsetUtil;
21  import io.netty.util.concurrent.Future;
22  import io.netty.util.concurrent.ScheduledFuture;
23  
24  import java.util.Locale;
25  import java.util.concurrent.TimeUnit;
26  
27  import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
28  
29  /**
30   * <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
31   * (Server Name Indication)</a> extension for server side SSL. For clients
32   * support SNI, the server could have multiple host name bound on a single IP.
33   * The client will send host name in the handshake data so server could decide
34   * which certificate to choose for the host name.</p>
35   */
36  public abstract class AbstractSniHandler<T> extends SslClientHelloHandler<T> {
37  
38      private static String extractSniHostname(ByteBuf in) {
39          // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
40          //
41          // Decode the ssl client hello packet.
42          //
43          // struct {
44          //    ProtocolVersion client_version;
45          //    Random random;
46          //    SessionID session_id;
47          //    CipherSuite cipher_suites<2..2^16-2>;
48          //    CompressionMethod compression_methods<1..2^8-1>;
49          //    select (extensions_present) {
50          //        case false:
51          //            struct {};
52          //        case true:
53          //            Extension extensions<0..2^16-1>;
54          //    };
55          // } ClientHello;
56          //
57  
58          // We have to skip bytes until SessionID (which sum to 34 bytes in this case).
59          int offset = in.readerIndex();
60          int endOffset = in.writerIndex();
61          offset += 34;
62  
63          if (endOffset - offset >= 6) {
64              final int sessionIdLength = in.getUnsignedByte(offset);
65              offset += sessionIdLength + 1;
66  
67              final int cipherSuitesLength = in.getUnsignedShort(offset);
68              offset += cipherSuitesLength + 2;
69  
70              final int compressionMethodLength = in.getUnsignedByte(offset);
71              offset += compressionMethodLength + 1;
72  
73              final int extensionsLength = in.getUnsignedShort(offset);
74              offset += 2;
75              final int extensionsLimit = offset + extensionsLength;
76  
77              // Extensions should never exceed the record boundary.
78              if (extensionsLimit <= endOffset) {
79                  while (extensionsLimit - offset >= 4) {
80                      final int extensionType = in.getUnsignedShort(offset);
81                      offset += 2;
82  
83                      final int extensionLength = in.getUnsignedShort(offset);
84                      offset += 2;
85  
86                      if (extensionsLimit - offset < extensionLength) {
87                          break;
88                      }
89  
90                      // SNI
91                      // See https://tools.ietf.org/html/rfc6066#page-6
92                      if (extensionType == 0) {
93                          offset += 2;
94                          if (extensionsLimit - offset < 3) {
95                              break;
96                          }
97  
98                          final int serverNameType = in.getUnsignedByte(offset);
99                          offset++;
100 
101                         if (serverNameType == 0) {
102                             final int serverNameLength = in.getUnsignedShort(offset);
103                             offset += 2;
104 
105                             if (extensionsLimit - offset < serverNameLength) {
106                                 break;
107                             }
108 
109                             final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
110                             return hostname.toLowerCase(Locale.US);
111                         } else {
112                             // invalid enum value
113                             break;
114                         }
115                     }
116 
117                     offset += extensionLength;
118                 }
119             }
120         }
121         return null;
122     }
123 
124     static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(10);
125     protected final long handshakeTimeoutMillis;
126     private ScheduledFuture<?> timeoutFuture;
127     private String hostname;
128 
129     /**
130      * @param handshakeTimeoutMillis    the handshake timeout in milliseconds
131      */
132     protected AbstractSniHandler(long handshakeTimeoutMillis) {
133         this(DEFAULT_MAX_CLIENT_HELLO_LENGTH, handshakeTimeoutMillis);
134     }
135 
136     /**
137      * @paramm maxClientHelloLength     the maximum length of the client hello message.
138      * @param handshakeTimeoutMillis    the handshake timeout in milliseconds
139      */
140     protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMillis) {
141         super(maxClientHelloLength);
142         this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis");
143     }
144 
145     public AbstractSniHandler() {
146         this(DEFAULT_MAX_CLIENT_HELLO_LENGTH, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
147     }
148 
149     @Override
150     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
151         if (ctx.channel().isActive()) {
152             checkStartTimeout(ctx);
153         }
154     }
155 
156     @Override
157     public void channelActive(ChannelHandlerContext ctx) throws Exception {
158         ctx.fireChannelActive();
159         checkStartTimeout(ctx);
160     }
161 
162     private void checkStartTimeout(final ChannelHandlerContext ctx) {
163         if (handshakeTimeoutMillis <= 0 || timeoutFuture != null) {
164             return;
165         }
166         timeoutFuture = ctx.executor().schedule(new Runnable() {
167             @Override
168             public void run() {
169                 if (ctx.channel().isActive()) {
170                     SslHandshakeTimeoutException exception = new SslHandshakeTimeoutException(
171                         "handshake timed out after " + handshakeTimeoutMillis + "ms");
172                     ctx.fireUserEventTriggered(new SniCompletionEvent(exception));
173                     ctx.close();
174                 }
175             }
176         }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
177     }
178 
179     @Override
180     protected Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
181         hostname = clientHello == null ? null : extractSniHostname(clientHello);
182 
183         return lookup(ctx, hostname);
184     }
185 
186     @Override
187     protected void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception {
188         if (timeoutFuture != null) {
189             timeoutFuture.cancel(false);
190         }
191         try {
192             onLookupComplete(ctx, hostname, future);
193         } finally {
194             fireSniCompletionEvent(ctx, hostname, future);
195         }
196     }
197 
198     /**
199      * Kicks off a lookup for the given SNI value and returns a {@link Future} which in turn will
200      * notify the {@link #onLookupComplete(ChannelHandlerContext, String, Future)} on completion.
201      *
202      * @see #onLookupComplete(ChannelHandlerContext, String, Future)
203      */
204     protected abstract Future<T> lookup(ChannelHandlerContext ctx, String hostname) throws Exception;
205 
206     /**
207      * Called upon completion of the {@link #lookup(ChannelHandlerContext, String)} {@link Future}.
208      *
209      * @see #lookup(ChannelHandlerContext, String)
210      */
211     protected abstract void onLookupComplete(ChannelHandlerContext ctx,
212                                              String hostname, Future<T> future) throws Exception;
213 
214     private static void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<?> future) {
215         Throwable cause = future.cause();
216         if (cause == null) {
217             ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
218         } else {
219             ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
220         }
221     }
222 }