1 /*
2 * Copyright 2014 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBufAllocator;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.handler.codec.DecoderException;
21 import io.netty.util.AsyncMapping;
22 import io.netty.util.DomainNameMapping;
23 import io.netty.util.Mapping;
24 import io.netty.util.ReferenceCountUtil;
25 import io.netty.util.concurrent.Future;
26 import io.netty.util.concurrent.Promise;
27 import io.netty.util.internal.ObjectUtil;
28 import io.netty.util.internal.PlatformDependent;
29
30 /**
31 * <p>Enables <a href="https://tools.ietf.org/html/rfc3546#section-3.1">SNI
32 * (Server Name Indication)</a> extension for server side SSL. For clients
33 * support SNI, the server could have multiple host name bound on a single IP.
34 * The client will send host name in the handshake data so server could decide
35 * which certificate to choose for the host name.</p>
36 */
37 public class SniHandler extends AbstractSniHandler<SslContext> {
38 private static final Selection EMPTY_SELECTION = new Selection(null, null);
39
40 protected final AsyncMapping<String, SslContext> mapping;
41
42 private volatile Selection selection = EMPTY_SELECTION;
43
44 /**
45 * Creates a SNI detection handler with configured {@link SslContext}
46 * maintained by {@link Mapping}
47 *
48 * @param mapping the mapping of domain name to {@link SslContext}
49 */
50 public SniHandler(Mapping<? super String, ? extends SslContext> mapping) {
51 this(new AsyncMappingAdapter(mapping));
52 }
53
54 /**
55 * Creates a SNI detection handler with configured {@link SslContext}
56 * maintained by {@link Mapping}
57 *
58 * @param mapping the mapping of domain name to {@link SslContext}
59 * @param maxClientHelloLength the maximum length of the client hello message
60 * @param handshakeTimeoutMillis the handshake timeout in milliseconds
61 */
62 public SniHandler(Mapping<? super String, ? extends SslContext> mapping,
63 int maxClientHelloLength, long handshakeTimeoutMillis) {
64 this(new AsyncMappingAdapter(mapping), maxClientHelloLength, handshakeTimeoutMillis);
65 }
66
67 /**
68 * Creates a SNI detection handler with configured {@link SslContext}
69 * maintained by {@link DomainNameMapping}
70 *
71 * @param mapping the mapping of domain name to {@link SslContext}
72 */
73 public SniHandler(DomainNameMapping<? extends SslContext> mapping) {
74 this((Mapping<String, ? extends SslContext>) mapping);
75 }
76
77 /**
78 * Creates a SNI detection handler with configured {@link SslContext}
79 * maintained by {@link AsyncMapping}
80 *
81 * @param mapping the mapping of domain name to {@link SslContext}
82 */
83 @SuppressWarnings("unchecked")
84 public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping) {
85 this(mapping, 0, 0L);
86 }
87
88 /**
89 * Creates a SNI detection handler with configured {@link SslContext}
90 * maintained by {@link AsyncMapping}
91 *
92 * @param mapping the mapping of domain name to {@link SslContext}
93 * @param maxClientHelloLength the maximum length of the client hello message
94 * @param handshakeTimeoutMillis the handshake timeout in milliseconds
95 */
96 @SuppressWarnings("unchecked")
97 public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping,
98 int maxClientHelloLength, long handshakeTimeoutMillis) {
99 super(maxClientHelloLength, handshakeTimeoutMillis);
100 this.mapping = (AsyncMapping<String, SslContext>) ObjectUtil.checkNotNull(mapping, "mapping");
101 }
102
103 /**
104 * Creates a SNI detection handler with configured {@link SslContext}
105 * maintained by {@link Mapping}
106 *
107 * @param mapping the mapping of domain name to {@link SslContext}
108 * @param handshakeTimeoutMillis the handshake timeout in milliseconds
109 */
110 public SniHandler(Mapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
111 this(new AsyncMappingAdapter(mapping), handshakeTimeoutMillis);
112 }
113
114 /**
115 * Creates a SNI detection handler with configured {@link SslContext}
116 * maintained by {@link AsyncMapping}
117 *
118 * @param mapping the mapping of domain name to {@link SslContext}
119 * @param handshakeTimeoutMillis the handshake timeout in milliseconds
120 */
121 public SniHandler(AsyncMapping<? super String, ? extends SslContext> mapping, long handshakeTimeoutMillis) {
122 this(mapping, 0, handshakeTimeoutMillis);
123 }
124
125 /**
126 * @return the selected hostname
127 */
128 public String hostname() {
129 return selection.hostname;
130 }
131
132 /**
133 * @return the selected {@link SslContext}
134 */
135 public SslContext sslContext() {
136 return selection.context;
137 }
138
139 /**
140 * The default implementation will simply call {@link AsyncMapping#map(Object, Promise)} but
141 * users can override this method to implement custom behavior.
142 *
143 * @see AsyncMapping#map(Object, Promise)
144 */
145 @Override
146 protected Future<SslContext> lookup(ChannelHandlerContext ctx, String hostname) throws Exception {
147 return mapping.map(hostname, ctx.executor().<SslContext>newPromise());
148 }
149
150 @Override
151 protected final void onLookupComplete(ChannelHandlerContext ctx,
152 String hostname, Future<SslContext> future) throws Exception {
153 if (!future.isSuccess()) {
154 final Throwable cause = future.cause();
155 if (cause instanceof Error) {
156 throw (Error) cause;
157 }
158 throw new DecoderException("failed to get the SslContext for " + hostname, cause);
159 }
160
161 SslContext sslContext = future.getNow();
162 selection = new Selection(sslContext, hostname);
163 try {
164 replaceHandler(ctx, hostname, sslContext);
165 } catch (Throwable cause) {
166 selection = EMPTY_SELECTION;
167 PlatformDependent.throwException(cause);
168 }
169 }
170
171 /**
172 * The default implementation of this method will simply replace {@code this} {@link SniHandler}
173 * instance with a {@link SslHandler}. Users may override this method to implement custom behavior.
174 *
175 * Please be aware that this method may get called after a client has already disconnected and
176 * custom implementations must take it into consideration when overriding this method.
177 *
178 * It's also possible for the hostname argument to be {@code null}.
179 */
180 protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception {
181 SslHandler sslHandler = null;
182 try {
183 sslHandler = newSslHandler(sslContext, ctx.alloc());
184 ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
185 sslHandler = null;
186 } finally {
187 // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
188 // transferred to the SslHandler.
189 // See https://github.com/netty/netty/issues/5678
190 if (sslHandler != null) {
191 ReferenceCountUtil.safeRelease(sslHandler.engine());
192 }
193 }
194 }
195
196 /**
197 * Returns a new {@link SslHandler} using the given {@link SslContext} and {@link ByteBufAllocator}.
198 * Users may override this method to implement custom behavior.
199 */
200 protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) {
201 SslHandler sslHandler = context.newHandler(allocator);
202 sslHandler.setHandshakeTimeoutMillis(handshakeTimeoutMillis);
203 return sslHandler;
204 }
205
206 private static final class AsyncMappingAdapter implements AsyncMapping<String, SslContext> {
207 private final Mapping<? super String, ? extends SslContext> mapping;
208
209 private AsyncMappingAdapter(Mapping<? super String, ? extends SslContext> mapping) {
210 this.mapping = ObjectUtil.checkNotNull(mapping, "mapping");
211 }
212
213 @Override
214 public Future<SslContext> map(String input, Promise<SslContext> promise) {
215 final SslContext context;
216 try {
217 context = mapping.map(input);
218 } catch (Throwable cause) {
219 return promise.setFailure(cause);
220 }
221 return promise.setSuccess(context);
222 }
223 }
224
225 private static final class Selection {
226 final SslContext context;
227 final String hostname;
228
229 Selection(SslContext context, String hostname) {
230 this.context = context;
231 this.hostname = hostname;
232 }
233 }
234 }