View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import java.nio.channels.ClosedChannelException;
19  import java.util.Collections;
20  import java.util.LinkedHashSet;
21  import java.util.Set;
22  
23  import io.netty.buffer.Unpooled;
24  import io.netty.channel.Channel;
25  import io.netty.channel.ChannelFuture;
26  import io.netty.channel.ChannelFutureListener;
27  import io.netty.channel.ChannelHandler;
28  import io.netty.channel.ChannelHandlerContext;
29  import io.netty.channel.ChannelInboundHandlerAdapter;
30  import io.netty.channel.ChannelOutboundInvoker;
31  import io.netty.channel.ChannelPipeline;
32  import io.netty.channel.ChannelPromise;
33  import io.netty.handler.codec.http.DefaultFullHttpRequest;
34  import io.netty.handler.codec.http.EmptyHttpHeaders;
35  import io.netty.handler.codec.http.FullHttpRequest;
36  import io.netty.handler.codec.http.FullHttpResponse;
37  import io.netty.handler.codec.http.HttpContentCompressor;
38  import io.netty.handler.codec.http.HttpHeaders;
39  import io.netty.handler.codec.http.HttpObject;
40  import io.netty.handler.codec.http.HttpObjectAggregator;
41  import io.netty.handler.codec.http.HttpRequest;
42  import io.netty.handler.codec.http.HttpRequestDecoder;
43  import io.netty.handler.codec.http.HttpResponseEncoder;
44  import io.netty.handler.codec.http.HttpServerCodec;
45  import io.netty.handler.codec.http.HttpUtil;
46  import io.netty.handler.codec.http.LastHttpContent;
47  import io.netty.util.ReferenceCountUtil;
48  import io.netty.util.internal.EmptyArrays;
49  import io.netty.util.internal.ObjectUtil;
50  import io.netty.util.internal.logging.InternalLogger;
51  import io.netty.util.internal.logging.InternalLoggerFactory;
52  
53  /**
54   * Base class for server side web socket opening and closing handshakes
55   */
56  public abstract class WebSocketServerHandshaker {
57      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58  
59      private final String uri;
60  
61      private final String[] subprotocols;
62  
63      private final WebSocketVersion version;
64  
65      private final WebSocketDecoderConfig decoderConfig;
66  
67      private String selectedSubprotocol;
68  
69      /**
70       * Use this as wildcard to support all requested sub-protocols
71       */
72      public static final String SUB_PROTOCOL_WILDCARD = "*";
73  
74      /**
75       * Constructor specifying the destination web socket location
76       *
77       * @param version
78       *            the protocol version
79       * @param uri
80       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
81       *            sent to this URL.
82       * @param subprotocols
83       *            CSV of supported protocols. Null if sub protocols not supported.
84       * @param maxFramePayloadLength
85       *            Maximum length of a frame's payload
86       */
87      protected WebSocketServerHandshaker(
88              WebSocketVersion version, String uri, String subprotocols,
89              int maxFramePayloadLength) {
90          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91              .maxFramePayloadLength(maxFramePayloadLength)
92              .build());
93      }
94  
95      /**
96       * Constructor specifying the destination web socket location
97       *
98       * @param version
99       *            the protocol version
100      * @param uri
101      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
102      *            sent to this URL.
103      * @param subprotocols
104      *            CSV of supported protocols. Null if sub protocols not supported.
105      * @param decoderConfig
106      *            Frames decoder configuration.
107      */
108     protected WebSocketServerHandshaker(
109             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110         this.version = version;
111         this.uri = uri;
112         if (subprotocols != null) {
113             String[] subprotocolArray = subprotocols.split(",");
114             for (int i = 0; i < subprotocolArray.length; i++) {
115                 subprotocolArray[i] = subprotocolArray[i].trim();
116             }
117             this.subprotocols = subprotocolArray;
118         } else {
119             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120         }
121         this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122     }
123 
124     /**
125      * Returns the URL of the web socket
126      */
127     @Deprecated
128     public String uri() {
129         return uri;
130     }
131 
132     /**
133      * Returns the CSV of supported sub protocols
134      */
135     public Set<String> subprotocols() {
136         Set<String> ret = new LinkedHashSet<String>();
137         Collections.addAll(ret, subprotocols);
138         return ret;
139     }
140 
141     /**
142      * Returns the version of the specification being supported
143      */
144     public WebSocketVersion version() {
145         return version;
146     }
147 
148     /**
149      * Gets the maximum length for any frame's payload.
150      *
151      * @return The maximum length for a frame's payload
152      */
153     public int maxFramePayloadLength() {
154         return decoderConfig.maxFramePayloadLength();
155     }
156 
157     /**
158      * Gets this decoder configuration.
159      *
160      * @return This decoder configuration.
161      */
162     public WebSocketDecoderConfig decoderConfig() {
163         return decoderConfig;
164     }
165 
166     /**
167      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
168      * {@link FullHttpRequest} which is passed in.
169      *
170      * @param channel
171      *              Channel
172      * @param req
173      *              HTTP Request
174      * @return future
175      *              The {@link ChannelFuture} which is notified once the opening handshake completes
176      */
177     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
178         return handshake(channel, req, null, channel.newPromise());
179     }
180 
181     /**
182      * Performs the opening handshake
183      *
184      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
185      *
186      * @param channel
187      *            Channel
188      * @param req
189      *            HTTP Request
190      * @param responseHeaders
191      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
192      * @param promise
193      *            the {@link ChannelPromise} to be notified when the opening handshake is done
194      * @return future
195      *            the {@link ChannelFuture} which is notified when the opening handshake is done
196      */
197     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
198                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
199 
200         if (logger.isDebugEnabled()) {
201             logger.debug("{} WebSocket version {} server handshake", channel, version());
202         }
203         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
204         ChannelPipeline p = channel.pipeline();
205         if (p.get(HttpObjectAggregator.class) != null) {
206             p.remove(HttpObjectAggregator.class);
207         }
208         if (p.get(HttpContentCompressor.class) != null) {
209             p.remove(HttpContentCompressor.class);
210         }
211         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
212         final String encoderName;
213         if (ctx == null) {
214             // this means the user use an HttpServerCodec
215             ctx = p.context(HttpServerCodec.class);
216             if (ctx == null) {
217                 promise.setFailure(
218                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
219                 response.release();
220                 return promise;
221             }
222             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
223             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
224             encoderName = ctx.name();
225         } else {
226             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
227 
228             encoderName = p.context(HttpResponseEncoder.class).name();
229             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
230         }
231         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
232             @Override
233             public void operationComplete(ChannelFuture future) throws Exception {
234                 if (future.isSuccess()) {
235                     ChannelPipeline p = future.channel().pipeline();
236                     p.remove(encoderName);
237                     promise.setSuccess();
238                 } else {
239                     promise.setFailure(future.cause());
240                 }
241             }
242         });
243         return promise;
244     }
245 
246     /**
247      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
248      * {@link FullHttpRequest} which is passed in.
249      *
250      * @param channel
251      *              Channel
252      * @param req
253      *              HTTP Request
254      * @return future
255      *              The {@link ChannelFuture} which is notified once the opening handshake completes
256      */
257     public ChannelFuture handshake(Channel channel, HttpRequest req) {
258         return handshake(channel, req, null, channel.newPromise());
259     }
260 
261     /**
262      * Performs the opening handshake
263      *
264      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
265      *
266      * @param channel
267      *            Channel
268      * @param req
269      *            HTTP Request
270      * @param responseHeaders
271      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
272      * @param promise
273      *            the {@link ChannelPromise} to be notified when the opening handshake is done
274      * @return future
275      *            the {@link ChannelFuture} which is notified when the opening handshake is done
276      */
277     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
278                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
279         if (req instanceof FullHttpRequest) {
280             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
281         }
282 
283         if (logger.isDebugEnabled()) {
284             logger.debug("{} WebSocket version {} server handshake", channel, version());
285         }
286 
287         ChannelPipeline p = channel.pipeline();
288         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
289         if (ctx == null) {
290             // this means the user use an HttpServerCodec
291             ctx = p.context(HttpServerCodec.class);
292             if (ctx == null) {
293                 promise.setFailure(
294                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
295                 return promise;
296             }
297         }
298 
299         String aggregatorCtx = ctx.name();
300         if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
301             version == WebSocketVersion.V00) {
302             // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit of 8192 should be
303             // more then enough for the websockets handshake payload.
304             aggregatorCtx = "httpAggregator";
305             p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
306         }
307 
308         p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
309 
310             private FullHttpRequest fullHttpRequest;
311 
312             @Override
313             public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
314                 if (msg instanceof HttpObject) {
315                     try {
316                         handleHandshakeRequest(ctx, (HttpObject) msg);
317                     } finally {
318                         ReferenceCountUtil.release(msg);
319                     }
320                 } else {
321                     super.channelRead(ctx, msg);
322                 }
323             }
324 
325             @Override
326             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
327                 // Remove ourself and fail the handshake promise.
328                 ctx.pipeline().remove(this);
329                 promise.tryFailure(cause);
330                 ctx.fireExceptionCaught(cause);
331             }
332 
333             @Override
334             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
335                 try {
336                     // Fail promise if Channel was closed
337                     if (!promise.isDone()) {
338                         promise.tryFailure(new ClosedChannelException());
339                     }
340                     ctx.fireChannelInactive();
341                 } finally {
342                     releaseFullHttpRequest();
343                 }
344             }
345 
346             @Override
347             public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
348                 releaseFullHttpRequest();
349             }
350 
351             private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
352                 if (httpObject instanceof FullHttpRequest) {
353                     ctx.pipeline().remove(this);
354                     handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
355                     return;
356                 }
357 
358                 if (httpObject instanceof LastHttpContent) {
359                     assert fullHttpRequest != null;
360                     FullHttpRequest handshakeRequest = fullHttpRequest;
361                     fullHttpRequest = null;
362                     try {
363                         ctx.pipeline().remove(this);
364                         handshake(channel, handshakeRequest, responseHeaders, promise);
365                     } finally {
366                         handshakeRequest.release();
367                     }
368                     return;
369                 }
370 
371                 if (httpObject instanceof HttpRequest) {
372                     HttpRequest httpRequest = (HttpRequest) httpObject;
373                     fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
374                         httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
375                     if (httpRequest.decoderResult().isFailure()) {
376                         fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
377                     }
378                 }
379             }
380 
381             private void releaseFullHttpRequest() {
382                 if (fullHttpRequest != null) {
383                     fullHttpRequest.release();
384                     fullHttpRequest = null;
385                 }
386             }
387         });
388         try {
389             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
390         } catch (Throwable cause) {
391             promise.setFailure(cause);
392         }
393         return promise;
394     }
395 
396     /**
397      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
398      */
399     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
400                                          HttpHeaders responseHeaders);
401     /**
402      * Performs the closing handshake.
403      *
404      * When called from within a {@link ChannelHandler} you most likely want to use
405      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}.
406      *
407      * @param channel
408      *            the {@link Channel} to use.
409      * @param frame
410      *            Closing Frame that was received.
411      */
412     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
413         ObjectUtil.checkNotNull(channel, "channel");
414         return close(channel, frame, channel.newPromise());
415     }
416 
417     /**
418      * Performs the closing handshake.
419      *
420      * When called from within a {@link ChannelHandler} you most likely want to use
421      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame, ChannelPromise)}.
422      *
423      * @param channel
424      *            the {@link Channel} to use.
425      * @param frame
426      *            Closing Frame that was received.
427      * @param promise
428      *            the {@link ChannelPromise} to be notified when the closing handshake is done
429      */
430     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
431         return close0(channel, frame, promise);
432     }
433 
434     /**
435      * Performs the closing handshake.
436      *
437      * @param ctx
438      *            the {@link ChannelHandlerContext} to use.
439      * @param frame
440      *            Closing Frame that was received.
441      */
442     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
443         ObjectUtil.checkNotNull(ctx, "ctx");
444         return close(ctx, frame, ctx.newPromise());
445     }
446 
447     /**
448      * Performs the closing handshake.
449      *
450      * @param ctx
451      *            the {@link ChannelHandlerContext} to use.
452      * @param frame
453      *            Closing Frame that was received.
454      * @param promise
455      *            the {@link ChannelPromise} to be notified when the closing handshake is done.
456      */
457     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
458         ObjectUtil.checkNotNull(ctx, "ctx");
459         return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
460     }
461 
462     private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
463         return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
464     }
465 
466     /**
467      * Selects the first matching supported sub protocol
468      *
469      * @param requestedSubprotocols
470      *            CSV of protocols to be supported. e.g. "chat, superchat"
471      * @return First matching supported sub protocol. Null if not found.
472      */
473     protected String selectSubprotocol(String requestedSubprotocols) {
474         if (requestedSubprotocols == null || subprotocols.length == 0) {
475             return null;
476         }
477 
478         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
479         for (String p: requestedSubprotocolArray) {
480             String requestedSubprotocol = p.trim();
481 
482             for (String supportedSubprotocol: subprotocols) {
483                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
484                         || requestedSubprotocol.equals(supportedSubprotocol)) {
485                     selectedSubprotocol = requestedSubprotocol;
486                     return requestedSubprotocol;
487                 }
488             }
489         }
490 
491         // No match found
492         return null;
493     }
494 
495     /**
496      * Returns the selected subprotocol. Null if no subprotocol has been selected.
497      * <p>
498      * This is only available AFTER <tt>handshake()</tt> has been called.
499      * </p>
500      */
501     public String selectedSubprotocol() {
502         return selectedSubprotocol;
503     }
504 
505     /**
506      * Returns the decoder to use after handshake is complete.
507      */
508     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
509 
510     /**
511      * Returns the encoder to use after the handshake is complete.
512      */
513     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
514 
515 }