1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelHandler;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInboundHandlerAdapter;
24 import io.netty.channel.ChannelOutboundInvoker;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.channel.ChannelPromise;
27 import io.netty.handler.codec.http.DefaultFullHttpResponse;
28 import io.netty.handler.codec.http.EmptyHttpHeaders;
29 import io.netty.handler.codec.http.FullHttpRequest;
30 import io.netty.handler.codec.http.FullHttpResponse;
31 import io.netty.handler.codec.http.HttpClientCodec;
32 import io.netty.handler.codec.http.HttpContentDecompressor;
33 import io.netty.handler.codec.http.HttpHeaderNames;
34 import io.netty.handler.codec.http.HttpHeaders;
35 import io.netty.handler.codec.http.HttpObject;
36 import io.netty.handler.codec.http.HttpObjectAggregator;
37 import io.netty.handler.codec.http.HttpRequestEncoder;
38 import io.netty.handler.codec.http.HttpResponse;
39 import io.netty.handler.codec.http.HttpResponseDecoder;
40 import io.netty.handler.codec.http.HttpScheme;
41 import io.netty.handler.codec.http.LastHttpContent;
42 import io.netty.util.NetUtil;
43 import io.netty.util.ReferenceCountUtil;
44 import io.netty.util.internal.ObjectUtil;
45
46 import java.net.URI;
47 import java.nio.channels.ClosedChannelException;
48 import java.util.Locale;
49 import java.util.concurrent.Future;
50 import java.util.concurrent.TimeUnit;
51 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
52
53
54
55
56 public abstract class WebSocketClientHandshaker {
57
58 private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
59 private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
60 protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
61
62 private final URI uri;
63
64 private final WebSocketVersion version;
65
66 private volatile boolean handshakeComplete;
67
68 private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS;
69
70 private volatile int forceCloseInit;
71
72 private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER =
73 AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");
74
75 private volatile boolean forceCloseComplete;
76
77 private final String expectedSubprotocol;
78
79 private volatile String actualSubprotocol;
80
81 protected final HttpHeaders customHeaders;
82
83 private final int maxFramePayloadLength;
84
85 private final boolean absoluteUpgradeUrl;
86
87 protected final boolean generateOriginHeader;
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
105 HttpHeaders customHeaders, int maxFramePayloadLength) {
106 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
107 }
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
127 HttpHeaders customHeaders, int maxFramePayloadLength,
128 long forceCloseTimeoutMillis) {
129 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false);
130 }
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
153 HttpHeaders customHeaders, int maxFramePayloadLength,
154 long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
155 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis,
156 absoluteUpgradeUrl, true);
157 }
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
183 HttpHeaders customHeaders, int maxFramePayloadLength,
184 long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
185 this.uri = uri;
186 this.version = version;
187 expectedSubprotocol = subprotocol;
188 this.customHeaders = customHeaders;
189 this.maxFramePayloadLength = maxFramePayloadLength;
190 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
191 this.absoluteUpgradeUrl = absoluteUpgradeUrl;
192 this.generateOriginHeader = generateOriginHeader;
193 }
194
195
196
197
198 public URI uri() {
199 return uri;
200 }
201
202
203
204
205 public WebSocketVersion version() {
206 return version;
207 }
208
209
210
211
212 public int maxFramePayloadLength() {
213 return maxFramePayloadLength;
214 }
215
216
217
218
219 public boolean isHandshakeComplete() {
220 return handshakeComplete;
221 }
222
223 private void setHandshakeComplete() {
224 handshakeComplete = true;
225 }
226
227
228
229
230 public String expectedSubprotocol() {
231 return expectedSubprotocol;
232 }
233
234
235
236
237
238 public String actualSubprotocol() {
239 return actualSubprotocol;
240 }
241
242 private void setActualSubprotocol(String actualSubprotocol) {
243 this.actualSubprotocol = actualSubprotocol;
244 }
245
246 public long forceCloseTimeoutMillis() {
247 return forceCloseTimeoutMillis;
248 }
249
250
251
252
253
254 protected boolean isForceCloseComplete() {
255 return forceCloseComplete;
256 }
257
258
259
260
261
262
263
264 public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
265 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
266 return this;
267 }
268
269
270
271
272
273
274
275 public ChannelFuture handshake(Channel channel) {
276 ObjectUtil.checkNotNull(channel, "channel");
277 return handshake(channel, channel.newPromise());
278 }
279
280
281
282
283
284
285
286
287
288 public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
289 final ChannelPipeline pipeline = channel.pipeline();
290 HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
291 if (decoder == null) {
292 HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
293 if (codec == null) {
294 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
295 "an HttpResponseDecoder or HttpClientCodec"));
296 return promise;
297 }
298 }
299
300 if (uri.getHost() == null) {
301 if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) {
302 promise.setFailure(new IllegalArgumentException("Cannot generate the 'host' header value," +
303 " webSocketURI should contain host or passed through customHeaders"));
304 return promise;
305 }
306
307 if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) {
308 final String originName;
309 if (version == WebSocketVersion.V07 || version == WebSocketVersion.V08) {
310 originName = HttpHeaderNames.SEC_WEBSOCKET_ORIGIN.toString();
311 } else {
312 originName = HttpHeaderNames.ORIGIN.toString();
313 }
314
315 promise.setFailure(new IllegalArgumentException("Cannot generate the '" + originName + "' header" +
316 " value, webSocketURI should contain host or disable generateOriginHeader or pass value" +
317 " through customHeaders"));
318 return promise;
319 }
320 }
321
322 FullHttpRequest request = newHandshakeRequest();
323
324 channel.writeAndFlush(request).addListener(future -> {
325 if (future.isSuccess()) {
326 ChannelHandlerContext ctx = pipeline.context(HttpRequestEncoder.class);
327 if (ctx == null) {
328 ctx = pipeline.context(HttpClientCodec.class);
329 }
330 if (ctx == null) {
331 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
332 "an HttpRequestEncoder or HttpClientCodec"));
333 return;
334 }
335 pipeline.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
336
337 promise.setSuccess();
338 } else {
339 promise.setFailure(future.cause());
340 }
341 });
342 return promise;
343 }
344
345
346
347
348 protected abstract FullHttpRequest newHandshakeRequest();
349
350
351
352
353
354
355
356
357
358 public final void finishHandshake(Channel channel, FullHttpResponse response) {
359 verify(response);
360
361
362
363 String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
364 receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
365 String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
366 boolean protocolValid = false;
367
368 if (expectedProtocol.isEmpty() && receivedProtocol == null) {
369
370 protocolValid = true;
371 setActualSubprotocol(expectedSubprotocol);
372 } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
373
374 for (String protocol : expectedProtocol.split(",")) {
375 if (protocol.trim().equals(receivedProtocol)) {
376 protocolValid = true;
377 setActualSubprotocol(receivedProtocol);
378 break;
379 }
380 }
381 }
382
383 if (!protocolValid) {
384 throw new WebSocketClientHandshakeException(String.format(
385 "Invalid subprotocol. Actual: %s. Expected one of: %s",
386 receivedProtocol, expectedSubprotocol), response);
387 }
388
389 setHandshakeComplete();
390
391 final ChannelPipeline p = channel.pipeline();
392
393 HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
394 if (decompressor != null) {
395 p.remove(decompressor);
396 }
397
398
399 HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
400 if (aggregator != null) {
401 p.remove(aggregator);
402 }
403
404 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
405 if (ctx == null) {
406 ctx = p.context(HttpClientCodec.class);
407 if (ctx == null) {
408 throw new IllegalStateException("ChannelPipeline does not contain " +
409 "an HttpRequestEncoder or HttpClientCodec");
410 }
411 final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
412
413 codec.removeOutboundHandler();
414
415 p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
416
417
418
419
420 channel.eventLoop().execute(new Runnable() {
421 @Override
422 public void run() {
423 p.remove(codec);
424 }
425 });
426 } else {
427 if (p.get(HttpRequestEncoder.class) != null) {
428
429 p.remove(HttpRequestEncoder.class);
430 }
431 final ChannelHandlerContext context = ctx;
432 p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
433
434
435
436
437 channel.eventLoop().execute(new Runnable() {
438 @Override
439 public void run() {
440 p.remove(context.handler());
441 }
442 });
443 }
444 }
445
446
447
448
449
450
451
452
453
454
455
456 public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
457 return processHandshake(channel, response, channel.newPromise());
458 }
459
460
461
462
463
464
465
466
467
468
469
470
471
472 public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
473 final ChannelPromise promise) {
474 if (response instanceof FullHttpResponse) {
475 try {
476 finishHandshake(channel, (FullHttpResponse) response);
477 promise.setSuccess();
478 } catch (Throwable cause) {
479 promise.setFailure(cause);
480 }
481 } else {
482 ChannelPipeline p = channel.pipeline();
483 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
484 if (ctx == null) {
485 ctx = p.context(HttpClientCodec.class);
486 if (ctx == null) {
487 return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
488 "an HttpResponseDecoder or HttpClientCodec"));
489 }
490 }
491
492 String aggregatorCtx = ctx.name();
493
494 if (version == WebSocketVersion.V00) {
495
496
497 aggregatorCtx = "httpAggregator";
498 p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
499 }
500
501 p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
502
503 private FullHttpResponse fullHttpResponse;
504
505 @Override
506 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
507 if (msg instanceof HttpObject) {
508 try {
509 handleHandshakeResponse(ctx, (HttpObject) msg);
510 } finally {
511 ReferenceCountUtil.release(msg);
512 }
513 } else {
514 super.channelRead(ctx, msg);
515 }
516 }
517
518 @Override
519 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
520
521 ctx.pipeline().remove(this);
522 promise.setFailure(cause);
523 }
524
525 @Override
526 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
527 try {
528
529 if (!promise.isDone()) {
530 promise.tryFailure(new ClosedChannelException());
531 }
532 ctx.fireChannelInactive();
533 } finally {
534 releaseFullHttpResponse();
535 }
536 }
537
538 @Override
539 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
540 releaseFullHttpResponse();
541 }
542
543 private void handleHandshakeResponse(ChannelHandlerContext ctx, HttpObject response) {
544 if (response instanceof FullHttpResponse) {
545 ctx.pipeline().remove(this);
546 tryFinishHandshake((FullHttpResponse) response);
547 return;
548 }
549
550 if (response instanceof LastHttpContent) {
551 assert fullHttpResponse != null;
552 FullHttpResponse handshakeResponse = fullHttpResponse;
553 fullHttpResponse = null;
554 try {
555 ctx.pipeline().remove(this);
556 tryFinishHandshake(handshakeResponse);
557 } finally {
558 handshakeResponse.release();
559 }
560 return;
561 }
562
563 if (response instanceof HttpResponse) {
564 HttpResponse httpResponse = (HttpResponse) response;
565 fullHttpResponse = new DefaultFullHttpResponse(httpResponse.protocolVersion(),
566 httpResponse.status(), Unpooled.EMPTY_BUFFER, httpResponse.headers(),
567 EmptyHttpHeaders.INSTANCE);
568 if (httpResponse.decoderResult().isFailure()) {
569 fullHttpResponse.setDecoderResult(httpResponse.decoderResult());
570 }
571 }
572 }
573
574 private void tryFinishHandshake(FullHttpResponse fullHttpResponse) {
575 try {
576 finishHandshake(channel, fullHttpResponse);
577 promise.setSuccess();
578 } catch (Throwable cause) {
579 promise.setFailure(cause);
580 }
581 }
582
583 private void releaseFullHttpResponse() {
584 if (fullHttpResponse != null) {
585 fullHttpResponse.release();
586 fullHttpResponse = null;
587 }
588 }
589 });
590 try {
591 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
592 } catch (Throwable cause) {
593 promise.setFailure(cause);
594 }
595 }
596 return promise;
597 }
598
599
600
601
602 protected abstract void verify(FullHttpResponse response);
603
604
605
606
607 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
608
609
610
611
612 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
613
614
615
616
617
618
619
620
621
622
623
624
625 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
626 ObjectUtil.checkNotNull(channel, "channel");
627 return close(channel, frame, channel.newPromise());
628 }
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
644 ObjectUtil.checkNotNull(channel, "channel");
645 return close0(channel, channel, frame, promise);
646 }
647
648
649
650
651
652
653
654
655
656 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
657 ObjectUtil.checkNotNull(ctx, "ctx");
658 return close(ctx, frame, ctx.newPromise());
659 }
660
661
662
663
664
665
666
667
668
669
670
671 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
672 ObjectUtil.checkNotNull(ctx, "ctx");
673 return close0(ctx, ctx.channel(), frame, promise);
674 }
675
676 private ChannelFuture close0(final ChannelOutboundInvoker invoker, final Channel channel,
677 CloseWebSocketFrame frame, ChannelPromise promise) {
678 invoker.writeAndFlush(frame, promise);
679 final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis;
680 final WebSocketClientHandshaker handshaker = this;
681 if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) {
682 return promise;
683 }
684
685 promise.addListener(future -> {
686
687
688
689
690 if (future.isSuccess() && channel.isActive() &&
691 FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) {
692 final Future<?> forceCloseFuture = channel.eventLoop().schedule(new Runnable() {
693 @Override
694 public void run() {
695 if (channel.isActive()) {
696 invoker.close();
697 forceCloseComplete = true;
698 }
699 }
700 }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
701 channel.closeFuture().addListener(f -> forceCloseFuture.cancel(false));
702 }
703 });
704 return promise;
705 }
706
707
708
709
710 protected String upgradeUrl(URI wsURL) {
711 if (absoluteUpgradeUrl) {
712 return wsURL.toString();
713 }
714
715 String path = wsURL.getRawPath();
716 path = path == null || path.isEmpty() ? "/" : path;
717 String query = wsURL.getRawQuery();
718 return query != null && !query.isEmpty() ? path + '?' + query : path;
719 }
720
721 static CharSequence websocketHostValue(URI wsURL) {
722 int port = wsURL.getPort();
723 if (port == -1) {
724 return wsURL.getHost();
725 }
726 String host = wsURL.getHost();
727 String scheme = wsURL.getScheme();
728 if (port == HttpScheme.HTTP.port()) {
729 return HttpScheme.HTTP.name().contentEquals(scheme)
730 || WebSocketScheme.WS.name().contentEquals(scheme) ?
731 host : NetUtil.toSocketAddressString(host, port);
732 }
733 if (port == HttpScheme.HTTPS.port()) {
734 return HttpScheme.HTTPS.name().contentEquals(scheme)
735 || WebSocketScheme.WSS.name().contentEquals(scheme) ?
736 host : NetUtil.toSocketAddressString(host, port);
737 }
738
739
740
741 return NetUtil.toSocketAddressString(host, port);
742 }
743
744 static CharSequence websocketOriginValue(URI wsURL) {
745 String scheme = wsURL.getScheme();
746 final String schemePrefix;
747 int port = wsURL.getPort();
748 final int defaultPort;
749 if (WebSocketScheme.WSS.name().contentEquals(scheme)
750 || HttpScheme.HTTPS.name().contentEquals(scheme)
751 || (scheme == null && port == WebSocketScheme.WSS.port())) {
752
753 schemePrefix = HTTPS_SCHEME_PREFIX;
754 defaultPort = WebSocketScheme.WSS.port();
755 } else {
756 schemePrefix = HTTP_SCHEME_PREFIX;
757 defaultPort = WebSocketScheme.WS.port();
758 }
759
760
761 String host = wsURL.getHost().toLowerCase(Locale.US);
762
763 if (port != defaultPort && port != -1) {
764
765
766 return schemePrefix + NetUtil.toSocketAddressString(host, port);
767 }
768 return schemePrefix + host;
769 }
770 }