1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import java.nio.channels.ClosedChannelException;
19 import java.util.Collections;
20 import java.util.LinkedHashSet;
21 import java.util.Set;
22
23 import io.netty.buffer.Unpooled;
24 import io.netty.channel.Channel;
25 import io.netty.channel.ChannelFuture;
26 import io.netty.channel.ChannelFutureListener;
27 import io.netty.channel.ChannelHandler;
28 import io.netty.channel.ChannelHandlerContext;
29 import io.netty.channel.ChannelInboundHandlerAdapter;
30 import io.netty.channel.ChannelOutboundInvoker;
31 import io.netty.channel.ChannelPipeline;
32 import io.netty.channel.ChannelPromise;
33 import io.netty.handler.codec.http.DefaultFullHttpRequest;
34 import io.netty.handler.codec.http.EmptyHttpHeaders;
35 import io.netty.handler.codec.http.FullHttpRequest;
36 import io.netty.handler.codec.http.FullHttpResponse;
37 import io.netty.handler.codec.http.HttpContentCompressor;
38 import io.netty.handler.codec.http.HttpHeaders;
39 import io.netty.handler.codec.http.HttpObject;
40 import io.netty.handler.codec.http.HttpObjectAggregator;
41 import io.netty.handler.codec.http.HttpRequest;
42 import io.netty.handler.codec.http.HttpRequestDecoder;
43 import io.netty.handler.codec.http.HttpResponseEncoder;
44 import io.netty.handler.codec.http.HttpServerCodec;
45 import io.netty.handler.codec.http.HttpUtil;
46 import io.netty.handler.codec.http.LastHttpContent;
47 import io.netty.util.ReferenceCountUtil;
48 import io.netty.util.internal.EmptyArrays;
49 import io.netty.util.internal.ObjectUtil;
50 import io.netty.util.internal.logging.InternalLogger;
51 import io.netty.util.internal.logging.InternalLoggerFactory;
52
53
54
55
56 public abstract class WebSocketServerHandshaker {
57 protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58
59 private final String uri;
60
61 private final String[] subprotocols;
62
63 private final WebSocketVersion version;
64
65 private final WebSocketDecoderConfig decoderConfig;
66
67 private String selectedSubprotocol;
68
69
70
71
72 public static final String SUB_PROTOCOL_WILDCARD = "*";
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87 protected WebSocketServerHandshaker(
88 WebSocketVersion version, String uri, String subprotocols,
89 int maxFramePayloadLength) {
90 this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91 .maxFramePayloadLength(maxFramePayloadLength)
92 .build());
93 }
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108 protected WebSocketServerHandshaker(
109 WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110 this.version = version;
111 this.uri = uri;
112 if (subprotocols != null) {
113 String[] subprotocolArray = subprotocols.split(",");
114 for (int i = 0; i < subprotocolArray.length; i++) {
115 subprotocolArray[i] = subprotocolArray[i].trim();
116 }
117 this.subprotocols = subprotocolArray;
118 } else {
119 this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120 }
121 this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122 }
123
124
125
126
127 @Deprecated
128 public String uri() {
129 return uri;
130 }
131
132
133
134
135 public Set<String> subprotocols() {
136 Set<String> ret = new LinkedHashSet<String>();
137 Collections.addAll(ret, subprotocols);
138 return ret;
139 }
140
141
142
143
144 public WebSocketVersion version() {
145 return version;
146 }
147
148
149
150
151
152
153 public int maxFramePayloadLength() {
154 return decoderConfig.maxFramePayloadLength();
155 }
156
157
158
159
160
161
162 public WebSocketDecoderConfig decoderConfig() {
163 return decoderConfig;
164 }
165
166
167
168
169
170
171
172
173
174
175
176
177 public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
178 return handshake(channel, req, null, channel.newPromise());
179 }
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197 public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
198 HttpHeaders responseHeaders, final ChannelPromise promise) {
199
200 if (logger.isDebugEnabled()) {
201 logger.debug("{} WebSocket version {} server handshake", channel, version());
202 }
203 FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
204 ChannelPipeline p = channel.pipeline();
205 if (p.get(HttpObjectAggregator.class) != null) {
206 p.remove(HttpObjectAggregator.class);
207 }
208 if (p.get(HttpContentCompressor.class) != null) {
209 p.remove(HttpContentCompressor.class);
210 }
211 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
212 final String encoderName;
213 if (ctx == null) {
214
215 ctx = p.context(HttpServerCodec.class);
216 if (ctx == null) {
217 promise.setFailure(
218 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
219 response.release();
220 return promise;
221 }
222 p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
223 p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
224 encoderName = ctx.name();
225 } else {
226 p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
227
228 encoderName = p.context(HttpResponseEncoder.class).name();
229 p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
230 }
231 channel.writeAndFlush(response).addListener(future -> {
232 if (future.isSuccess()) {
233 ChannelPipeline p1 = channel.pipeline();
234 p1.remove(encoderName);
235 promise.setSuccess();
236 } else {
237 promise.setFailure(future.cause());
238 }
239 });
240 return promise;
241 }
242
243
244
245
246
247
248
249
250
251
252
253
254 public ChannelFuture handshake(Channel channel, HttpRequest req) {
255 return handshake(channel, req, null, channel.newPromise());
256 }
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274 public final ChannelFuture handshake(final Channel channel, HttpRequest req,
275 final HttpHeaders responseHeaders, final ChannelPromise promise) {
276 if (req instanceof FullHttpRequest) {
277 return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
278 }
279
280 if (logger.isDebugEnabled()) {
281 logger.debug("{} WebSocket version {} server handshake", channel, version());
282 }
283
284 ChannelPipeline p = channel.pipeline();
285 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
286 if (ctx == null) {
287
288 ctx = p.context(HttpServerCodec.class);
289 if (ctx == null) {
290 promise.setFailure(
291 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
292 return promise;
293 }
294 }
295
296 String aggregatorCtx = ctx.name();
297 if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
298 version == WebSocketVersion.V00) {
299
300
301 aggregatorCtx = "httpAggregator";
302 p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
303 }
304
305 p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
306
307 private FullHttpRequest fullHttpRequest;
308
309 @Override
310 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
311 if (msg instanceof HttpObject) {
312 try {
313 handleHandshakeRequest(ctx, (HttpObject) msg);
314 } finally {
315 ReferenceCountUtil.release(msg);
316 }
317 } else {
318 super.channelRead(ctx, msg);
319 }
320 }
321
322 @Override
323 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
324
325 ctx.pipeline().remove(this);
326 promise.tryFailure(cause);
327 ctx.fireExceptionCaught(cause);
328 }
329
330 @Override
331 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
332 try {
333
334 if (!promise.isDone()) {
335 promise.tryFailure(new ClosedChannelException());
336 }
337 ctx.fireChannelInactive();
338 } finally {
339 releaseFullHttpRequest();
340 }
341 }
342
343 @Override
344 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
345 releaseFullHttpRequest();
346 }
347
348 private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
349 if (httpObject instanceof FullHttpRequest) {
350 ctx.pipeline().remove(this);
351 handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
352 return;
353 }
354
355 if (httpObject instanceof LastHttpContent) {
356 assert fullHttpRequest != null;
357 FullHttpRequest handshakeRequest = fullHttpRequest;
358 fullHttpRequest = null;
359 try {
360 ctx.pipeline().remove(this);
361 handshake(channel, handshakeRequest, responseHeaders, promise);
362 } finally {
363 handshakeRequest.release();
364 }
365 return;
366 }
367
368 if (httpObject instanceof HttpRequest) {
369 HttpRequest httpRequest = (HttpRequest) httpObject;
370 fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
371 httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
372 if (httpRequest.decoderResult().isFailure()) {
373 fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
374 }
375 }
376 }
377
378 private void releaseFullHttpRequest() {
379 if (fullHttpRequest != null) {
380 fullHttpRequest.release();
381 fullHttpRequest = null;
382 }
383 }
384 });
385 try {
386 ctx.fireChannelRead(ReferenceCountUtil.retain(req));
387 } catch (Throwable cause) {
388 promise.setFailure(cause);
389 }
390 return promise;
391 }
392
393
394
395
396 protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
397 HttpHeaders responseHeaders);
398
399
400
401
402
403
404
405
406
407
408
409 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
410 ObjectUtil.checkNotNull(channel, "channel");
411 return close(channel, frame, channel.newPromise());
412 }
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
428 return close0(channel, frame, promise);
429 }
430
431
432
433
434
435
436
437
438
439 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
440 ObjectUtil.checkNotNull(ctx, "ctx");
441 return close(ctx, frame, ctx.newPromise());
442 }
443
444
445
446
447
448
449
450
451
452
453
454 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
455 ObjectUtil.checkNotNull(ctx, "ctx");
456 return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
457 }
458
459 private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
460 return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
461 }
462
463
464
465
466
467
468
469
470 protected String selectSubprotocol(String requestedSubprotocols) {
471 if (requestedSubprotocols == null || subprotocols.length == 0) {
472 return null;
473 }
474
475 String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
476 for (String p: requestedSubprotocolArray) {
477 String requestedSubprotocol = p.trim();
478
479 for (String supportedSubprotocol: subprotocols) {
480 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
481 || requestedSubprotocol.equals(supportedSubprotocol)) {
482 selectedSubprotocol = requestedSubprotocol;
483 return requestedSubprotocol;
484 }
485 }
486 }
487
488
489 return null;
490 }
491
492
493
494
495
496
497
498 public String selectedSubprotocol() {
499 return selectedSubprotocol;
500 }
501
502
503
504
505 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
506
507
508
509
510 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
511
512 }