1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import java.nio.channels.ClosedChannelException;
19 import java.util.Collections;
20 import java.util.LinkedHashSet;
21 import java.util.Set;
22
23 import io.netty.buffer.Unpooled;
24 import io.netty.channel.Channel;
25 import io.netty.channel.ChannelFuture;
26 import io.netty.channel.ChannelFutureListener;
27 import io.netty.channel.ChannelHandler;
28 import io.netty.channel.ChannelHandlerContext;
29 import io.netty.channel.ChannelInboundHandlerAdapter;
30 import io.netty.channel.ChannelOutboundInvoker;
31 import io.netty.channel.ChannelPipeline;
32 import io.netty.channel.ChannelPromise;
33 import io.netty.handler.codec.http.DefaultFullHttpRequest;
34 import io.netty.handler.codec.http.EmptyHttpHeaders;
35 import io.netty.handler.codec.http.FullHttpRequest;
36 import io.netty.handler.codec.http.FullHttpResponse;
37 import io.netty.handler.codec.http.HttpContentCompressor;
38 import io.netty.handler.codec.http.HttpHeaders;
39 import io.netty.handler.codec.http.HttpObject;
40 import io.netty.handler.codec.http.HttpObjectAggregator;
41 import io.netty.handler.codec.http.HttpRequest;
42 import io.netty.handler.codec.http.HttpRequestDecoder;
43 import io.netty.handler.codec.http.HttpResponseEncoder;
44 import io.netty.handler.codec.http.HttpServerCodec;
45 import io.netty.handler.codec.http.HttpUtil;
46 import io.netty.handler.codec.http.LastHttpContent;
47 import io.netty.util.ReferenceCountUtil;
48 import io.netty.util.internal.EmptyArrays;
49 import io.netty.util.internal.ObjectUtil;
50 import io.netty.util.internal.logging.InternalLogger;
51 import io.netty.util.internal.logging.InternalLoggerFactory;
52
53
54
55
56 public abstract class WebSocketServerHandshaker {
57 protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58
59 private final String uri;
60
61 private final String[] subprotocols;
62
63 private final WebSocketVersion version;
64
65 private final WebSocketDecoderConfig decoderConfig;
66
67 private String selectedSubprotocol;
68
69
70
71
72 public static final String SUB_PROTOCOL_WILDCARD = "*";
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87 protected WebSocketServerHandshaker(
88 WebSocketVersion version, String uri, String subprotocols,
89 int maxFramePayloadLength) {
90 this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91 .maxFramePayloadLength(maxFramePayloadLength)
92 .build());
93 }
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108 protected WebSocketServerHandshaker(
109 WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110 this.version = version;
111 this.uri = uri;
112 if (subprotocols != null) {
113 String[] subprotocolArray = subprotocols.split(",");
114 for (int i = 0; i < subprotocolArray.length; i++) {
115 subprotocolArray[i] = subprotocolArray[i].trim();
116 }
117 this.subprotocols = subprotocolArray;
118 } else {
119 this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120 }
121 this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122 }
123
124
125
126
127 public String uri() {
128 return uri;
129 }
130
131
132
133
134 public Set<String> subprotocols() {
135 Set<String> ret = new LinkedHashSet<String>();
136 Collections.addAll(ret, subprotocols);
137 return ret;
138 }
139
140
141
142
143 public WebSocketVersion version() {
144 return version;
145 }
146
147
148
149
150
151
152 public int maxFramePayloadLength() {
153 return decoderConfig.maxFramePayloadLength();
154 }
155
156
157
158
159
160
161 public WebSocketDecoderConfig decoderConfig() {
162 return decoderConfig;
163 }
164
165
166
167
168
169
170
171
172
173
174
175
176 public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
177 return handshake(channel, req, null, channel.newPromise());
178 }
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196 public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
197 HttpHeaders responseHeaders, final ChannelPromise promise) {
198
199 if (logger.isDebugEnabled()) {
200 logger.debug("{} WebSocket version {} server handshake", channel, version());
201 }
202 FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
203 ChannelPipeline p = channel.pipeline();
204 if (p.get(HttpObjectAggregator.class) != null) {
205 p.remove(HttpObjectAggregator.class);
206 }
207 if (p.get(HttpContentCompressor.class) != null) {
208 p.remove(HttpContentCompressor.class);
209 }
210 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
211 final String encoderName;
212 if (ctx == null) {
213
214 ctx = p.context(HttpServerCodec.class);
215 if (ctx == null) {
216 promise.setFailure(
217 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
218 response.release();
219 return promise;
220 }
221 p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
222 p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
223 encoderName = ctx.name();
224 } else {
225 p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
226
227 encoderName = p.context(HttpResponseEncoder.class).name();
228 p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
229 }
230 channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
231 @Override
232 public void operationComplete(ChannelFuture future) throws Exception {
233 if (future.isSuccess()) {
234 ChannelPipeline p = future.channel().pipeline();
235 p.remove(encoderName);
236 promise.setSuccess();
237 } else {
238 promise.setFailure(future.cause());
239 }
240 }
241 });
242 return promise;
243 }
244
245
246
247
248
249
250
251
252
253
254
255
256 public ChannelFuture handshake(Channel channel, HttpRequest req) {
257 return handshake(channel, req, null, channel.newPromise());
258 }
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276 public final ChannelFuture handshake(final Channel channel, HttpRequest req,
277 final HttpHeaders responseHeaders, final ChannelPromise promise) {
278 if (req instanceof FullHttpRequest) {
279 return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
280 }
281
282 if (logger.isDebugEnabled()) {
283 logger.debug("{} WebSocket version {} server handshake", channel, version());
284 }
285
286 ChannelPipeline p = channel.pipeline();
287 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
288 if (ctx == null) {
289
290 ctx = p.context(HttpServerCodec.class);
291 if (ctx == null) {
292 promise.setFailure(
293 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
294 return promise;
295 }
296 }
297
298 String aggregatorCtx = ctx.name();
299 if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
300 version == WebSocketVersion.V00) {
301
302
303 aggregatorCtx = "httpAggregator";
304 p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
305 }
306
307 p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
308
309 private FullHttpRequest fullHttpRequest;
310
311 @Override
312 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
313 if (msg instanceof HttpObject) {
314 try {
315 handleHandshakeRequest(ctx, (HttpObject) msg);
316 } finally {
317 ReferenceCountUtil.release(msg);
318 }
319 } else {
320 super.channelRead(ctx, msg);
321 }
322 }
323
324 @Override
325 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
326
327 ctx.pipeline().remove(this);
328 promise.tryFailure(cause);
329 ctx.fireExceptionCaught(cause);
330 }
331
332 @Override
333 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
334 try {
335
336 if (!promise.isDone()) {
337 promise.tryFailure(new ClosedChannelException());
338 }
339 ctx.fireChannelInactive();
340 } finally {
341 releaseFullHttpRequest();
342 }
343 }
344
345 @Override
346 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
347 releaseFullHttpRequest();
348 }
349
350 private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
351 if (httpObject instanceof FullHttpRequest) {
352 ctx.pipeline().remove(this);
353 handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
354 return;
355 }
356
357 if (httpObject instanceof LastHttpContent) {
358 assert fullHttpRequest != null;
359 FullHttpRequest handshakeRequest = fullHttpRequest;
360 fullHttpRequest = null;
361 try {
362 ctx.pipeline().remove(this);
363 handshake(channel, handshakeRequest, responseHeaders, promise);
364 } finally {
365 handshakeRequest.release();
366 }
367 return;
368 }
369
370 if (httpObject instanceof HttpRequest) {
371 HttpRequest httpRequest = (HttpRequest) httpObject;
372 fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
373 httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
374 if (httpRequest.decoderResult().isFailure()) {
375 fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
376 }
377 }
378 }
379
380 private void releaseFullHttpRequest() {
381 if (fullHttpRequest != null) {
382 fullHttpRequest.release();
383 fullHttpRequest = null;
384 }
385 }
386 });
387 try {
388 ctx.fireChannelRead(ReferenceCountUtil.retain(req));
389 } catch (Throwable cause) {
390 promise.setFailure(cause);
391 }
392 return promise;
393 }
394
395
396
397
398 protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
399 HttpHeaders responseHeaders);
400
401
402
403
404
405
406
407
408
409
410
411 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
412 ObjectUtil.checkNotNull(channel, "channel");
413 return close(channel, frame, channel.newPromise());
414 }
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
430 return close0(channel, frame, promise);
431 }
432
433
434
435
436
437
438
439
440
441 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
442 ObjectUtil.checkNotNull(ctx, "ctx");
443 return close(ctx, frame, ctx.newPromise());
444 }
445
446
447
448
449
450
451
452
453
454
455
456 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
457 ObjectUtil.checkNotNull(ctx, "ctx");
458 return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
459 }
460
461 private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
462 return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
463 }
464
465
466
467
468
469
470
471
472 protected String selectSubprotocol(String requestedSubprotocols) {
473 if (requestedSubprotocols == null || subprotocols.length == 0) {
474 return null;
475 }
476
477 String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
478 for (String p: requestedSubprotocolArray) {
479 String requestedSubprotocol = p.trim();
480
481 for (String supportedSubprotocol: subprotocols) {
482 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
483 || requestedSubprotocol.equals(supportedSubprotocol)) {
484 selectedSubprotocol = requestedSubprotocol;
485 return requestedSubprotocol;
486 }
487 }
488 }
489
490
491 return null;
492 }
493
494
495
496
497
498
499
500 public String selectedSubprotocol() {
501 return selectedSubprotocol;
502 }
503
504
505
506
507 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
508
509
510
511
512 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
513
514 }