1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.channel.Channel;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPipeline;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.channel.SimpleChannelInboundHandler;
25 import io.netty.handler.codec.http.FullHttpRequest;
26 import io.netty.handler.codec.http.FullHttpResponse;
27 import io.netty.handler.codec.http.HttpClientCodec;
28 import io.netty.handler.codec.http.HttpContentDecompressor;
29 import io.netty.handler.codec.http.HttpHeaders;
30 import io.netty.handler.codec.http.HttpObjectAggregator;
31 import io.netty.handler.codec.http.HttpRequestEncoder;
32 import io.netty.handler.codec.http.HttpResponse;
33 import io.netty.handler.codec.http.HttpResponseDecoder;
34 import io.netty.util.NetUtil;
35 import io.netty.util.ReferenceCountUtil;
36 import io.netty.util.internal.ThrowableUtil;
37
38 import java.net.URI;
39 import java.nio.channels.ClosedChannelException;
40 import java.util.Locale;
41
42
43
44
45 public abstract class WebSocketClientHandshaker {
46 private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
47 new ClosedChannelException(), WebSocketClientHandshaker.class, "processHandshake(...)");
48
49 private static final String HTTP_SCHEME_PREFIX = "http://";
50 private static final String HTTPS_SCHEME_PREFIX = "https://";
51
52 private final URI uri;
53
54 private final WebSocketVersion version;
55
56 private volatile boolean handshakeComplete;
57
58 private final String expectedSubprotocol;
59
60 private volatile String actualSubprotocol;
61
62 protected final HttpHeaders customHeaders;
63
64 private final int maxFramePayloadLength;
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
82 HttpHeaders customHeaders, int maxFramePayloadLength) {
83 this.uri = uri;
84 this.version = version;
85 expectedSubprotocol = subprotocol;
86 this.customHeaders = customHeaders;
87 this.maxFramePayloadLength = maxFramePayloadLength;
88 }
89
90
91
92
93 public URI uri() {
94 return uri;
95 }
96
97
98
99
100 public WebSocketVersion version() {
101 return version;
102 }
103
104
105
106
107 public int maxFramePayloadLength() {
108 return maxFramePayloadLength;
109 }
110
111
112
113
114 public boolean isHandshakeComplete() {
115 return handshakeComplete;
116 }
117
118 private void setHandshakeComplete() {
119 handshakeComplete = true;
120 }
121
122
123
124
125 public String expectedSubprotocol() {
126 return expectedSubprotocol;
127 }
128
129
130
131
132
133 public String actualSubprotocol() {
134 return actualSubprotocol;
135 }
136
137 private void setActualSubprotocol(String actualSubprotocol) {
138 this.actualSubprotocol = actualSubprotocol;
139 }
140
141
142
143
144
145
146
147 public ChannelFuture handshake(Channel channel) {
148 if (channel == null) {
149 throw new NullPointerException("channel");
150 }
151 return handshake(channel, channel.newPromise());
152 }
153
154
155
156
157
158
159
160
161
162 public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
163 FullHttpRequest request = newHandshakeRequest();
164
165 HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
166 if (decoder == null) {
167 HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
168 if (codec == null) {
169 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
170 "a HttpResponseDecoder or HttpClientCodec"));
171 return promise;
172 }
173 }
174
175 channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
176 @Override
177 public void operationComplete(ChannelFuture future) {
178 if (future.isSuccess()) {
179 ChannelPipeline p = future.channel().pipeline();
180 ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
181 if (ctx == null) {
182 ctx = p.context(HttpClientCodec.class);
183 }
184 if (ctx == null) {
185 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
186 "a HttpRequestEncoder or HttpClientCodec"));
187 return;
188 }
189 p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
190
191 promise.setSuccess();
192 } else {
193 promise.setFailure(future.cause());
194 }
195 }
196 });
197 return promise;
198 }
199
200
201
202
203 protected abstract FullHttpRequest newHandshakeRequest();
204
205
206
207
208
209
210
211
212
213 public final void finishHandshake(Channel channel, FullHttpResponse response) {
214 verify(response);
215
216
217
218 String receivedProtocol = response.headers().get(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL);
219 receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
220 String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
221 boolean protocolValid = false;
222
223 if (expectedProtocol.isEmpty() && receivedProtocol == null) {
224
225 protocolValid = true;
226 setActualSubprotocol(expectedSubprotocol);
227 } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
228
229 for (String protocol : expectedProtocol.split(",")) {
230 if (protocol.trim().equals(receivedProtocol)) {
231 protocolValid = true;
232 setActualSubprotocol(receivedProtocol);
233 break;
234 }
235 }
236 }
237
238 if (!protocolValid) {
239 throw new WebSocketHandshakeException(String.format(
240 "Invalid subprotocol. Actual: %s. Expected one of: %s",
241 receivedProtocol, expectedSubprotocol));
242 }
243
244 setHandshakeComplete();
245
246 final ChannelPipeline p = channel.pipeline();
247
248 HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
249 if (decompressor != null) {
250 p.remove(decompressor);
251 }
252
253
254 HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
255 if (aggregator != null) {
256 p.remove(aggregator);
257 }
258
259 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
260 if (ctx == null) {
261 ctx = p.context(HttpClientCodec.class);
262 if (ctx == null) {
263 throw new IllegalStateException("ChannelPipeline does not contain " +
264 "a HttpRequestEncoder or HttpClientCodec");
265 }
266 final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
267
268 codec.removeOutboundHandler();
269
270 p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
271
272
273
274
275 channel.eventLoop().execute(new Runnable() {
276 @Override
277 public void run() {
278 p.remove(codec);
279 }
280 });
281 } else {
282 if (p.get(HttpRequestEncoder.class) != null) {
283
284 p.remove(HttpRequestEncoder.class);
285 }
286 final ChannelHandlerContext context = ctx;
287 p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
288
289
290
291
292 channel.eventLoop().execute(new Runnable() {
293 @Override
294 public void run() {
295 p.remove(context.handler());
296 }
297 });
298 }
299 }
300
301
302
303
304
305
306
307
308
309
310
311 public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
312 return processHandshake(channel, response, channel.newPromise());
313 }
314
315
316
317
318
319
320
321
322
323
324
325
326
327 public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
328 final ChannelPromise promise) {
329 if (response instanceof FullHttpResponse) {
330 try {
331 finishHandshake(channel, (FullHttpResponse) response);
332 promise.setSuccess();
333 } catch (Throwable cause) {
334 promise.setFailure(cause);
335 }
336 } else {
337 ChannelPipeline p = channel.pipeline();
338 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
339 if (ctx == null) {
340 ctx = p.context(HttpClientCodec.class);
341 if (ctx == null) {
342 return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
343 "a HttpResponseDecoder or HttpClientCodec"));
344 }
345 }
346
347
348
349
350 String aggregatorName = "httpAggregator";
351 p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
352 p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
353 @Override
354 protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
355
356 ctx.pipeline().remove(this);
357 try {
358 finishHandshake(channel, msg);
359 promise.setSuccess();
360 } catch (Throwable cause) {
361 promise.setFailure(cause);
362 }
363 }
364
365 @Override
366 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
367
368 ctx.pipeline().remove(this);
369 promise.setFailure(cause);
370 }
371
372 @Override
373 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
374
375 promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
376 ctx.fireChannelInactive();
377 }
378 });
379 try {
380 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
381 } catch (Throwable cause) {
382 promise.setFailure(cause);
383 }
384 }
385 return promise;
386 }
387
388
389
390
391 protected abstract void verify(FullHttpResponse response);
392
393
394
395
396 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
397
398
399
400
401 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
402
403
404
405
406
407
408
409
410
411 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
412 if (channel == null) {
413 throw new NullPointerException("channel");
414 }
415 return close(channel, frame, channel.newPromise());
416 }
417
418
419
420
421
422
423
424
425
426
427
428 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
429 if (channel == null) {
430 throw new NullPointerException("channel");
431 }
432 return channel.writeAndFlush(frame, promise);
433 }
434
435
436
437
438 static String rawPath(URI wsURL) {
439 String path = wsURL.getRawPath();
440 String query = wsURL.getRawQuery();
441 if (query != null && !query.isEmpty()) {
442 path = path + '?' + query;
443 }
444
445 return path == null || path.isEmpty() ? "/" : path;
446 }
447
448 static CharSequence websocketHostValue(URI wsURL) {
449 int port = wsURL.getPort();
450 if (port == -1) {
451 return wsURL.getHost();
452 }
453 String host = wsURL.getHost();
454 if (port == 80) {
455 return "http".equals(wsURL.getScheme())
456 || "ws".equals(wsURL.getScheme()) ?
457 host : NetUtil.toSocketAddressString(host, 80);
458 }
459 if (port == 443) {
460 return "https".equals(wsURL.getScheme())
461 || "wss".equals(wsURL.getScheme()) ?
462 host : NetUtil.toSocketAddressString(host, 443);
463 }
464
465
466
467 return NetUtil.toSocketAddressString(host, port);
468 }
469
470 static CharSequence websocketOriginValue(URI wsURL) {
471 String scheme = wsURL.getScheme();
472 final String schemePrefix;
473 int port = wsURL.getPort();
474 final int defaultPort;
475 if ("wss".equals(scheme)
476 || "https".equals(scheme)
477 || (scheme == null && port == 443)) {
478
479 schemePrefix = HTTPS_SCHEME_PREFIX;
480 defaultPort = 443;
481 } else {
482 schemePrefix = HTTP_SCHEME_PREFIX;
483 defaultPort = 80;
484 }
485
486
487 String host = wsURL.getHost().toLowerCase(Locale.US);
488
489 if (port != defaultPort && port != -1) {
490
491
492 return schemePrefix + NetUtil.toSocketAddressString(host, port);
493 }
494 return schemePrefix + host;
495 }
496 }