1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.http.websocketx;
17
18 import io.netty5.buffer.api.BufferAllocator;
19 import io.netty5.channel.Channel;
20 import io.netty5.channel.ChannelHandler;
21 import io.netty5.channel.ChannelHandlerContext;
22 import io.netty5.channel.ChannelOutboundInvoker;
23 import io.netty5.channel.ChannelPipeline;
24 import io.netty5.channel.SimpleChannelInboundHandler;
25 import io.netty5.handler.codec.http.FullHttpRequest;
26 import io.netty5.handler.codec.http.FullHttpResponse;
27 import io.netty5.handler.codec.http.HttpClientCodec;
28 import io.netty5.handler.codec.http.HttpContentDecompressor;
29 import io.netty5.handler.codec.http.HttpHeaderNames;
30 import io.netty5.handler.codec.http.HttpHeaders;
31 import io.netty5.handler.codec.http.HttpObjectAggregator;
32 import io.netty5.handler.codec.http.HttpRequestEncoder;
33 import io.netty5.handler.codec.http.HttpResponse;
34 import io.netty5.handler.codec.http.HttpResponseDecoder;
35 import io.netty5.handler.codec.http.HttpScheme;
36 import io.netty5.util.NetUtil;
37 import io.netty5.util.ReferenceCountUtil;
38 import io.netty5.util.concurrent.Future;
39 import io.netty5.util.concurrent.Promise;
40
41 import java.net.URI;
42 import java.nio.channels.ClosedChannelException;
43 import java.util.concurrent.TimeUnit;
44 import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
45
46 import static java.util.Objects.requireNonNull;
47
48
49
50
51 public abstract class WebSocketClientHandshaker {
52
53 protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
54
55 private final URI uri;
56
57 private final WebSocketVersion version;
58
59 private volatile boolean handshakeComplete;
60
61 private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS;
62
63 private volatile int forceCloseInit;
64
65 private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER =
66 AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");
67
68 private volatile boolean forceCloseComplete;
69
70 private final String expectedSubprotocol;
71
72 private volatile String actualSubprotocol;
73
74 protected final HttpHeaders customHeaders;
75
76 private final int maxFramePayloadLength;
77
78 private final boolean absoluteUpgradeUrl;
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
96 HttpHeaders customHeaders, int maxFramePayloadLength) {
97 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
98 }
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
118 HttpHeaders customHeaders, int maxFramePayloadLength,
119 long forceCloseTimeoutMillis) {
120 this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false);
121 }
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143 protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
144 HttpHeaders customHeaders, int maxFramePayloadLength,
145 long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
146 this.uri = uri;
147 this.version = version;
148 expectedSubprotocol = subprotocol;
149 this.customHeaders = customHeaders;
150 this.maxFramePayloadLength = maxFramePayloadLength;
151 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
152 this.absoluteUpgradeUrl = absoluteUpgradeUrl;
153 }
154
155
156
157
158 public URI uri() {
159 return uri;
160 }
161
162
163
164
165 public WebSocketVersion version() {
166 return version;
167 }
168
169
170
171
172 public int maxFramePayloadLength() {
173 return maxFramePayloadLength;
174 }
175
176
177
178
179 public boolean isHandshakeComplete() {
180 return handshakeComplete;
181 }
182
183 private void setHandshakeComplete() {
184 handshakeComplete = true;
185 }
186
187
188
189
190 public String expectedSubprotocol() {
191 return expectedSubprotocol;
192 }
193
194
195
196
197
198 public String actualSubprotocol() {
199 return actualSubprotocol;
200 }
201
202 private void setActualSubprotocol(String actualSubprotocol) {
203 this.actualSubprotocol = actualSubprotocol;
204 }
205
206 public long forceCloseTimeoutMillis() {
207 return forceCloseTimeoutMillis;
208 }
209
210
211
212
213
214 protected boolean isForceCloseComplete() {
215 return forceCloseComplete;
216 }
217
218
219
220
221
222
223
224 public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
225 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
226 return this;
227 }
228
229
230
231
232
233
234
235 public Future<Void> handshake(Channel channel) {
236 requireNonNull(channel, "channel");
237 ChannelPipeline pipeline = channel.pipeline();
238 HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
239 if (decoder == null) {
240 HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
241 if (codec == null) {
242 return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " +
243 "an HttpResponseDecoder or HttpClientCodec"));
244 }
245 }
246
247 FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator());
248
249 Promise<Void> promise = channel.newPromise();
250 channel.writeAndFlush(request).addListener(channel, (ch, future) -> {
251 if (future.isSuccess()) {
252 ChannelPipeline p = ch.pipeline();
253 ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
254 if (ctx == null) {
255 ctx = p.context(HttpClientCodec.class);
256 }
257 if (ctx == null) {
258 promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
259 "an HttpRequestEncoder or HttpClientCodec"));
260 return;
261 }
262 p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
263
264 promise.setSuccess(null);
265 } else {
266 promise.setFailure(future.cause());
267 }
268 });
269 return promise.asFuture();
270 }
271
272
273
274
275 protected abstract FullHttpRequest newHandshakeRequest(BufferAllocator allocator);
276
277
278
279
280
281
282
283
284
285 public final void finishHandshake(Channel channel, FullHttpResponse response) {
286 verify(response);
287
288
289
290 String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
291 receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
292 String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
293 boolean protocolValid = false;
294
295 if (expectedProtocol.isEmpty() && receivedProtocol == null) {
296
297 protocolValid = true;
298 setActualSubprotocol(expectedSubprotocol);
299 } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
300
301 for (String protocol : expectedProtocol.split(",")) {
302 if (protocol.trim().equals(receivedProtocol)) {
303 protocolValid = true;
304 setActualSubprotocol(receivedProtocol);
305 break;
306 }
307 }
308 }
309
310 if (!protocolValid) {
311 throw new WebSocketClientHandshakeException(String.format(
312 "Invalid subprotocol. Actual: %s. Expected one of: %s",
313 receivedProtocol, expectedSubprotocol), response);
314 }
315
316 setHandshakeComplete();
317
318 final ChannelPipeline p = channel.pipeline();
319
320 HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
321 if (decompressor != null) {
322 p.remove(decompressor);
323 }
324
325
326 HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
327 if (aggregator != null) {
328 p.remove(aggregator);
329 }
330
331 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
332 if (ctx == null) {
333 ctx = p.context(HttpClientCodec.class);
334 if (ctx == null) {
335 throw new IllegalStateException("ChannelPipeline does not contain " +
336 "an HttpRequestEncoder or HttpClientCodec");
337 }
338 final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
339
340 codec.removeOutboundHandler();
341
342 p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
343
344
345
346
347 channel.executor().execute(() -> p.remove(codec));
348 } else {
349 if (p.get(HttpRequestEncoder.class) != null) {
350
351 p.remove(HttpRequestEncoder.class);
352 }
353 final ChannelHandlerContext context = ctx;
354 p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
355
356
357
358
359 channel.executor().execute(() -> p.remove(context.handler()));
360 }
361 }
362
363
364
365
366
367
368
369
370
371
372
373 public final Future<Void> processHandshake(final Channel channel, HttpResponse response) {
374 if (response instanceof FullHttpResponse) {
375 try {
376 finishHandshake(channel, (FullHttpResponse) response);
377 return channel.newSucceededFuture();
378 } catch (Throwable cause) {
379 return channel.newFailedFuture(cause);
380 }
381 } else {
382 ChannelPipeline p = channel.pipeline();
383 ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
384 if (ctx == null) {
385 ctx = p.context(HttpClientCodec.class);
386 if (ctx == null) {
387 return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " +
388 "an HttpResponseDecoder or HttpClientCodec"));
389 }
390 }
391
392 Promise<Void> promise = channel.newPromise();
393
394
395
396
397 String aggregatorName = "httpAggregator";
398 p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
399 p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
400 @Override
401 protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
402
403 ctx.pipeline().remove(this);
404 try {
405 finishHandshake(channel, msg);
406 promise.setSuccess(null);
407 } catch (Throwable cause) {
408 promise.setFailure(cause);
409 }
410 }
411
412 @Override
413 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
414
415 ctx.pipeline().remove(this);
416 promise.setFailure(cause);
417 }
418
419 @Override
420 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
421
422 if (!promise.isDone()) {
423 promise.tryFailure(new ClosedChannelException());
424 }
425 ctx.fireChannelInactive();
426 }
427 });
428 try {
429 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
430 } catch (Throwable cause) {
431 promise.setFailure(cause);
432 }
433 return promise.asFuture();
434 }
435 }
436
437
438
439
440 protected abstract void verify(FullHttpResponse response);
441
442
443
444
445 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
446
447
448
449
450 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
451
452
453
454
455
456
457
458
459
460
461
462
463 public Future<Void> close(Channel channel, CloseWebSocketFrame frame) {
464 requireNonNull(channel, "channel");
465 return close0(channel, channel, frame);
466 }
467
468
469
470
471
472
473
474
475
476 public Future<Void> close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
477 requireNonNull(ctx, "ctx");
478 return close0(ctx, ctx.channel(), frame);
479 }
480
481 private Future<Void> close0(final ChannelOutboundInvoker invoker, final Channel channel,
482 CloseWebSocketFrame frame) {
483 Future<Void> f = invoker.writeAndFlush(frame);
484 final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis;
485 final WebSocketClientHandshaker handshaker = this;
486 if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) {
487 return f;
488 }
489
490 f.addListener(future -> {
491
492
493
494
495 if (future.isSuccess() && channel.isActive() &&
496 FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) {
497 final Future<?> forceCloseFuture = channel.executor().schedule(() -> {
498 if (channel.isActive()) {
499 channel.close();
500 forceCloseComplete = true;
501 }
502 }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
503
504 channel.closeFuture().addListener(ignore -> forceCloseFuture.cancel());
505 }
506 });
507 return f;
508 }
509
510
511
512
513 protected String upgradeUrl(URI wsURL) {
514 if (absoluteUpgradeUrl) {
515 return wsURL.toString();
516 }
517
518 String path = wsURL.getRawPath();
519 path = path == null || path.isEmpty() ? "/" : path;
520 String query = wsURL.getRawQuery();
521 return query != null && !query.isEmpty() ? path + '?' + query : path;
522 }
523
524 static CharSequence websocketHostValue(URI wsURL) {
525 int port = wsURL.getPort();
526 if (port == -1) {
527 return wsURL.getHost();
528 }
529 String host = wsURL.getHost();
530 String scheme = wsURL.getScheme();
531 if (port == HttpScheme.HTTP.port()) {
532 return HttpScheme.HTTP.name().contentEquals(scheme)
533 || WebSocketScheme.WS.name().contentEquals(scheme) ?
534 host : NetUtil.toSocketAddressString(host, port);
535 }
536 if (port == HttpScheme.HTTPS.port()) {
537 return HttpScheme.HTTPS.name().contentEquals(scheme)
538 || WebSocketScheme.WSS.name().contentEquals(scheme) ?
539 host : NetUtil.toSocketAddressString(host, port);
540 }
541
542
543
544 return NetUtil.toSocketAddressString(host, port);
545 }
546 }