1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import java.nio.channels.ClosedChannelException;
19 import java.util.Collections;
20 import java.util.LinkedHashSet;
21 import java.util.Set;
22
23 import io.netty.buffer.Unpooled;
24 import io.netty.channel.Channel;
25 import io.netty.channel.ChannelFuture;
26 import io.netty.channel.ChannelFutureListener;
27 import io.netty.channel.ChannelHandler;
28 import io.netty.channel.ChannelHandlerContext;
29 import io.netty.channel.ChannelInboundHandlerAdapter;
30 import io.netty.channel.ChannelOutboundInvoker;
31 import io.netty.channel.ChannelPipeline;
32 import io.netty.channel.ChannelPromise;
33 import io.netty.handler.codec.http.DefaultFullHttpRequest;
34 import io.netty.handler.codec.http.EmptyHttpHeaders;
35 import io.netty.handler.codec.http.FullHttpRequest;
36 import io.netty.handler.codec.http.FullHttpResponse;
37 import io.netty.handler.codec.http.HttpContentCompressor;
38 import io.netty.handler.codec.http.HttpHeaders;
39 import io.netty.handler.codec.http.HttpObject;
40 import io.netty.handler.codec.http.HttpObjectAggregator;
41 import io.netty.handler.codec.http.HttpRequest;
42 import io.netty.handler.codec.http.HttpRequestDecoder;
43 import io.netty.handler.codec.http.HttpResponseEncoder;
44 import io.netty.handler.codec.http.HttpServerCodec;
45 import io.netty.handler.codec.http.HttpUtil;
46 import io.netty.handler.codec.http.LastHttpContent;
47 import io.netty.util.ReferenceCountUtil;
48 import io.netty.util.internal.EmptyArrays;
49 import io.netty.util.internal.ObjectUtil;
50 import io.netty.util.internal.logging.InternalLogger;
51 import io.netty.util.internal.logging.InternalLoggerFactory;
52
53
54
55
56 public abstract class WebSocketServerHandshaker {
57 protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58
59 private final String uri;
60
61 private final String[] subprotocols;
62
63 private final WebSocketVersion version;
64
65 private final WebSocketDecoderConfig decoderConfig;
66
67 private String selectedSubprotocol;
68
69
70
71
72 public static final String SUB_PROTOCOL_WILDCARD = "*";
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87 protected WebSocketServerHandshaker(
88 WebSocketVersion version, String uri, String subprotocols,
89 int maxFramePayloadLength) {
90 this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91 .maxFramePayloadLength(maxFramePayloadLength)
92 .build());
93 }
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108 protected WebSocketServerHandshaker(
109 WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110 this.version = version;
111 this.uri = uri;
112 if (subprotocols != null) {
113 String[] subprotocolArray = subprotocols.split(",");
114 for (int i = 0; i < subprotocolArray.length; i++) {
115 subprotocolArray[i] = subprotocolArray[i].trim();
116 }
117 this.subprotocols = subprotocolArray;
118 } else {
119 this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120 }
121 this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122 }
123
124
125
126
127 @Deprecated
128 public String uri() {
129 return uri;
130 }
131
132
133
134
135 public Set<String> subprotocols() {
136 Set<String> ret = new LinkedHashSet<String>();
137 Collections.addAll(ret, subprotocols);
138 return ret;
139 }
140
141
142
143
144 public WebSocketVersion version() {
145 return version;
146 }
147
148
149
150
151
152
153 public int maxFramePayloadLength() {
154 return decoderConfig.maxFramePayloadLength();
155 }
156
157
158
159
160
161
162 public WebSocketDecoderConfig decoderConfig() {
163 return decoderConfig;
164 }
165
166
167
168
169
170
171
172
173
174
175
176
177 public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
178 return handshake(channel, req, null, channel.newPromise());
179 }
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197 public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
198 HttpHeaders responseHeaders, final ChannelPromise promise) {
199
200 if (logger.isDebugEnabled()) {
201 logger.debug("{} WebSocket version {} server handshake", channel, version());
202 }
203 FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
204 ChannelPipeline p = channel.pipeline();
205 if (p.get(HttpObjectAggregator.class) != null) {
206 p.remove(HttpObjectAggregator.class);
207 }
208 if (p.get(HttpContentCompressor.class) != null) {
209 p.remove(HttpContentCompressor.class);
210 }
211 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
212 final String encoderName;
213 if (ctx == null) {
214
215 ctx = p.context(HttpServerCodec.class);
216 if (ctx == null) {
217 promise.setFailure(
218 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
219 response.release();
220 return promise;
221 }
222 p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
223 p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
224 encoderName = ctx.name();
225 } else {
226 p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
227
228 encoderName = p.context(HttpResponseEncoder.class).name();
229 p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
230 }
231 channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
232 @Override
233 public void operationComplete(ChannelFuture future) throws Exception {
234 if (future.isSuccess()) {
235 ChannelPipeline p = future.channel().pipeline();
236 p.remove(encoderName);
237 promise.setSuccess();
238 } else {
239 promise.setFailure(future.cause());
240 }
241 }
242 });
243 return promise;
244 }
245
246
247
248
249
250
251
252
253
254
255
256
257 public ChannelFuture handshake(Channel channel, HttpRequest req) {
258 return handshake(channel, req, null, channel.newPromise());
259 }
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277 public final ChannelFuture handshake(final Channel channel, HttpRequest req,
278 final HttpHeaders responseHeaders, final ChannelPromise promise) {
279 if (req instanceof FullHttpRequest) {
280 return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
281 }
282
283 if (logger.isDebugEnabled()) {
284 logger.debug("{} WebSocket version {} server handshake", channel, version());
285 }
286
287 ChannelPipeline p = channel.pipeline();
288 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
289 if (ctx == null) {
290
291 ctx = p.context(HttpServerCodec.class);
292 if (ctx == null) {
293 promise.setFailure(
294 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
295 return promise;
296 }
297 }
298
299 String aggregatorCtx = ctx.name();
300 if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
301 version == WebSocketVersion.V00) {
302
303
304 aggregatorCtx = "httpAggregator";
305 p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
306 }
307
308 p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
309
310 private FullHttpRequest fullHttpRequest;
311
312 @Override
313 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
314 if (msg instanceof HttpObject) {
315 try {
316 handleHandshakeRequest(ctx, (HttpObject) msg);
317 } finally {
318 ReferenceCountUtil.release(msg);
319 }
320 } else {
321 super.channelRead(ctx, msg);
322 }
323 }
324
325 @Override
326 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
327
328 ctx.pipeline().remove(this);
329 promise.tryFailure(cause);
330 ctx.fireExceptionCaught(cause);
331 }
332
333 @Override
334 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
335 try {
336
337 if (!promise.isDone()) {
338 promise.tryFailure(new ClosedChannelException());
339 }
340 ctx.fireChannelInactive();
341 } finally {
342 releaseFullHttpRequest();
343 }
344 }
345
346 @Override
347 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
348 releaseFullHttpRequest();
349 }
350
351 private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
352 if (httpObject instanceof FullHttpRequest) {
353 ctx.pipeline().remove(this);
354 handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
355 return;
356 }
357
358 if (httpObject instanceof LastHttpContent) {
359 assert fullHttpRequest != null;
360 FullHttpRequest handshakeRequest = fullHttpRequest;
361 fullHttpRequest = null;
362 try {
363 ctx.pipeline().remove(this);
364 handshake(channel, handshakeRequest, responseHeaders, promise);
365 } finally {
366 handshakeRequest.release();
367 }
368 return;
369 }
370
371 if (httpObject instanceof HttpRequest) {
372 HttpRequest httpRequest = (HttpRequest) httpObject;
373 fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
374 httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
375 if (httpRequest.decoderResult().isFailure()) {
376 fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
377 }
378 }
379 }
380
381 private void releaseFullHttpRequest() {
382 if (fullHttpRequest != null) {
383 fullHttpRequest.release();
384 fullHttpRequest = null;
385 }
386 }
387 });
388 try {
389 ctx.fireChannelRead(ReferenceCountUtil.retain(req));
390 } catch (Throwable cause) {
391 promise.setFailure(cause);
392 }
393 return promise;
394 }
395
396
397
398
399 protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
400 HttpHeaders responseHeaders);
401
402
403
404
405
406
407
408
409
410
411
412 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
413 ObjectUtil.checkNotNull(channel, "channel");
414 return close(channel, frame, channel.newPromise());
415 }
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
431 return close0(channel, frame, promise);
432 }
433
434
435
436
437
438
439
440
441
442 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
443 ObjectUtil.checkNotNull(ctx, "ctx");
444 return close(ctx, frame, ctx.newPromise());
445 }
446
447
448
449
450
451
452
453
454
455
456
457 public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
458 ObjectUtil.checkNotNull(ctx, "ctx");
459 return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
460 }
461
462 private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
463 return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
464 }
465
466
467
468
469
470
471
472
473 protected String selectSubprotocol(String requestedSubprotocols) {
474 if (requestedSubprotocols == null || subprotocols.length == 0) {
475 return null;
476 }
477
478 String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
479 for (String p: requestedSubprotocolArray) {
480 String requestedSubprotocol = p.trim();
481
482 for (String supportedSubprotocol: subprotocols) {
483 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
484 || requestedSubprotocol.equals(supportedSubprotocol)) {
485 selectedSubprotocol = requestedSubprotocol;
486 return requestedSubprotocol;
487 }
488 }
489 }
490
491
492 return null;
493 }
494
495
496
497
498
499
500
501 public String selectedSubprotocol() {
502 return selectedSubprotocol;
503 }
504
505
506
507
508 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
509
510
511
512
513 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
514
515 }