1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.ssl;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.util.CharsetUtil;
21 import io.netty.util.concurrent.Future;
22 import io.netty.util.concurrent.ScheduledFuture;
23
24 import java.util.Locale;
25 import java.util.concurrent.TimeUnit;
26
27 import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
28
29
30
31
32
33
34
35
36 public abstract class AbstractSniHandler<T> extends SslClientHelloHandler<T> {
37
38 private static String extractSniHostname(ByteBuf in) {
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59 int offset = in.readerIndex();
60 int endOffset = in.writerIndex();
61 offset += 34;
62
63 if (endOffset - offset >= 6) {
64 final int sessionIdLength = in.getUnsignedByte(offset);
65 offset += sessionIdLength + 1;
66
67 final int cipherSuitesLength = in.getUnsignedShort(offset);
68 offset += cipherSuitesLength + 2;
69
70 final int compressionMethodLength = in.getUnsignedByte(offset);
71 offset += compressionMethodLength + 1;
72
73 final int extensionsLength = in.getUnsignedShort(offset);
74 offset += 2;
75 final int extensionsLimit = offset + extensionsLength;
76
77
78 if (extensionsLimit <= endOffset) {
79 while (extensionsLimit - offset >= 4) {
80 final int extensionType = in.getUnsignedShort(offset);
81 offset += 2;
82
83 final int extensionLength = in.getUnsignedShort(offset);
84 offset += 2;
85
86 if (extensionsLimit - offset < extensionLength) {
87 break;
88 }
89
90
91
92 if (extensionType == 0) {
93 offset += 2;
94 if (extensionsLimit - offset < 3) {
95 break;
96 }
97
98 final int serverNameType = in.getUnsignedByte(offset);
99 offset++;
100
101 if (serverNameType == 0) {
102 final int serverNameLength = in.getUnsignedShort(offset);
103 offset += 2;
104
105 if (extensionsLimit - offset < serverNameLength) {
106 break;
107 }
108
109 final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII);
110 return hostname.toLowerCase(Locale.US);
111 } else {
112
113 break;
114 }
115 }
116
117 offset += extensionLength;
118 }
119 }
120 }
121 return null;
122 }
123
124 static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(10);
125 protected final long handshakeTimeoutMillis;
126 private ScheduledFuture<?> timeoutFuture;
127 private String hostname;
128
129
130
131
132 protected AbstractSniHandler(long handshakeTimeoutMillis) {
133 this(DEFAULT_MAX_CLIENT_HELLO_LENGTH, handshakeTimeoutMillis);
134 }
135
136
137
138
139
140 protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMillis) {
141 super(maxClientHelloLength);
142 this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis");
143 }
144
145 public AbstractSniHandler() {
146 this(DEFAULT_MAX_CLIENT_HELLO_LENGTH, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
147 }
148
149 @Override
150 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
151 if (ctx.channel().isActive()) {
152 checkStartTimeout(ctx);
153 }
154 }
155
156 @Override
157 public void channelActive(ChannelHandlerContext ctx) throws Exception {
158 ctx.fireChannelActive();
159 checkStartTimeout(ctx);
160 }
161
162 private void checkStartTimeout(final ChannelHandlerContext ctx) {
163 if (handshakeTimeoutMillis <= 0 || timeoutFuture != null) {
164 return;
165 }
166 timeoutFuture = ctx.executor().schedule(new Runnable() {
167 @Override
168 public void run() {
169 if (ctx.channel().isActive()) {
170 SslHandshakeTimeoutException exception = new SslHandshakeTimeoutException(
171 "handshake timed out after " + handshakeTimeoutMillis + "ms");
172 ctx.fireUserEventTriggered(new SniCompletionEvent(exception));
173 ctx.close();
174 }
175 }
176 }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
177 }
178
179 @Override
180 protected Future<T> lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception {
181 hostname = clientHello == null ? null : extractSniHostname(clientHello);
182
183 return lookup(ctx, hostname);
184 }
185
186 @Override
187 protected void onLookupComplete(ChannelHandlerContext ctx, Future<T> future) throws Exception {
188 if (timeoutFuture != null) {
189 timeoutFuture.cancel(false);
190 }
191 try {
192 onLookupComplete(ctx, hostname, future);
193 } finally {
194 fireSniCompletionEvent(ctx, hostname, future);
195 }
196 }
197
198
199
200
201
202
203
204 protected abstract Future<T> lookup(ChannelHandlerContext ctx, String hostname) throws Exception;
205
206
207
208
209
210
211 protected abstract void onLookupComplete(ChannelHandlerContext ctx,
212 String hostname, Future<T> future) throws Exception;
213
214 private static void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future<?> future) {
215 Throwable cause = future.cause();
216 if (cause == null) {
217 ctx.fireUserEventTriggered(new SniCompletionEvent(hostname));
218 } else {
219 ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause));
220 }
221 }
222 }