View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import java.nio.channels.ClosedChannelException;
19  import java.util.Collections;
20  import java.util.LinkedHashSet;
21  import java.util.Set;
22  
23  import io.netty.buffer.Unpooled;
24  import io.netty.channel.Channel;
25  import io.netty.channel.ChannelFuture;
26  import io.netty.channel.ChannelFutureListener;
27  import io.netty.channel.ChannelHandler;
28  import io.netty.channel.ChannelHandlerContext;
29  import io.netty.channel.ChannelInboundHandlerAdapter;
30  import io.netty.channel.ChannelOutboundInvoker;
31  import io.netty.channel.ChannelPipeline;
32  import io.netty.channel.ChannelPromise;
33  import io.netty.handler.codec.http.DefaultFullHttpRequest;
34  import io.netty.handler.codec.http.EmptyHttpHeaders;
35  import io.netty.handler.codec.http.FullHttpRequest;
36  import io.netty.handler.codec.http.FullHttpResponse;
37  import io.netty.handler.codec.http.HttpContentCompressor;
38  import io.netty.handler.codec.http.HttpHeaders;
39  import io.netty.handler.codec.http.HttpObject;
40  import io.netty.handler.codec.http.HttpObjectAggregator;
41  import io.netty.handler.codec.http.HttpRequest;
42  import io.netty.handler.codec.http.HttpRequestDecoder;
43  import io.netty.handler.codec.http.HttpResponseEncoder;
44  import io.netty.handler.codec.http.HttpServerCodec;
45  import io.netty.handler.codec.http.HttpUtil;
46  import io.netty.handler.codec.http.LastHttpContent;
47  import io.netty.util.ReferenceCountUtil;
48  import io.netty.util.internal.EmptyArrays;
49  import io.netty.util.internal.ObjectUtil;
50  import io.netty.util.internal.logging.InternalLogger;
51  import io.netty.util.internal.logging.InternalLoggerFactory;
52  
53  /**
54   * Base class for server side web socket opening and closing handshakes
55   */
56  public abstract class WebSocketServerHandshaker {
57      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58  
59      private final String uri;
60  
61      private final String[] subprotocols;
62  
63      private final WebSocketVersion version;
64  
65      private final WebSocketDecoderConfig decoderConfig;
66  
67      private String selectedSubprotocol;
68  
69      /**
70       * Use this as wildcard to support all requested sub-protocols
71       */
72      public static final String SUB_PROTOCOL_WILDCARD = "*";
73  
74      /**
75       * Constructor specifying the destination web socket location
76       *
77       * @param version
78       *            the protocol version
79       * @param uri
80       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
81       *            sent to this URL.
82       * @param subprotocols
83       *            CSV of supported protocols. Null if sub protocols not supported.
84       * @param maxFramePayloadLength
85       *            Maximum length of a frame's payload
86       */
87      protected WebSocketServerHandshaker(
88              WebSocketVersion version, String uri, String subprotocols,
89              int maxFramePayloadLength) {
90          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91              .maxFramePayloadLength(maxFramePayloadLength)
92              .build());
93      }
94  
95      /**
96       * Constructor specifying the destination web socket location
97       *
98       * @param version
99       *            the protocol version
100      * @param uri
101      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
102      *            sent to this URL.
103      * @param subprotocols
104      *            CSV of supported protocols. Null if sub protocols not supported.
105      * @param decoderConfig
106      *            Frames decoder configuration.
107      */
108     protected WebSocketServerHandshaker(
109             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110         this.version = version;
111         this.uri = uri;
112         if (subprotocols != null) {
113             String[] subprotocolArray = subprotocols.split(",");
114             for (int i = 0; i < subprotocolArray.length; i++) {
115                 subprotocolArray[i] = subprotocolArray[i].trim();
116             }
117             this.subprotocols = subprotocolArray;
118         } else {
119             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120         }
121         this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122     }
123 
124     /**
125      * Returns the URL of the web socket
126      */
127     @Deprecated
128     public String uri() {
129         return uri;
130     }
131 
132     /**
133      * Returns the CSV of supported sub protocols
134      */
135     public Set<String> subprotocols() {
136         Set<String> ret = new LinkedHashSet<String>();
137         Collections.addAll(ret, subprotocols);
138         return ret;
139     }
140 
141     /**
142      * Returns the version of the specification being supported
143      */
144     public WebSocketVersion version() {
145         return version;
146     }
147 
148     /**
149      * Gets the maximum length for any frame's payload.
150      *
151      * @return The maximum length for a frame's payload
152      */
153     public int maxFramePayloadLength() {
154         return decoderConfig.maxFramePayloadLength();
155     }
156 
157     /**
158      * Gets this decoder configuration.
159      *
160      * @return This decoder configuration.
161      */
162     public WebSocketDecoderConfig decoderConfig() {
163         return decoderConfig;
164     }
165 
166     /**
167      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
168      * {@link FullHttpRequest} which is passed in.
169      *
170      * @param channel
171      *              Channel
172      * @param req
173      *              HTTP Request
174      * @return future
175      *              The {@link ChannelFuture} which is notified once the opening handshake completes
176      */
177     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
178         return handshake(channel, req, null, channel.newPromise());
179     }
180 
181     /**
182      * Performs the opening handshake
183      *
184      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
185      *
186      * @param channel
187      *            Channel
188      * @param req
189      *            HTTP Request
190      * @param responseHeaders
191      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
192      * @param promise
193      *            the {@link ChannelPromise} to be notified when the opening handshake is done
194      * @return future
195      *            the {@link ChannelFuture} which is notified when the opening handshake is done
196      */
197     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
198                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
199 
200         if (logger.isDebugEnabled()) {
201             logger.debug("{} WebSocket version {} server handshake", channel, version());
202         }
203         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
204         ChannelPipeline p = channel.pipeline();
205         if (p.get(HttpObjectAggregator.class) != null) {
206             p.remove(HttpObjectAggregator.class);
207         }
208         if (p.get(HttpContentCompressor.class) != null) {
209             p.remove(HttpContentCompressor.class);
210         }
211         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
212         final String encoderName;
213         if (ctx == null) {
214             // this means the user use an HttpServerCodec
215             ctx = p.context(HttpServerCodec.class);
216             if (ctx == null) {
217                 promise.setFailure(
218                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
219                 response.release();
220                 return promise;
221             }
222             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
223             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
224             encoderName = ctx.name();
225         } else {
226             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
227 
228             encoderName = p.context(HttpResponseEncoder.class).name();
229             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
230         }
231         channel.writeAndFlush(response).addListener(future -> {
232             if (future.isSuccess()) {
233                 ChannelPipeline p1 = channel.pipeline();
234                 p1.remove(encoderName);
235                 promise.setSuccess();
236             } else {
237                 promise.setFailure(future.cause());
238             }
239         });
240         return promise;
241     }
242 
243     /**
244      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
245      * {@link FullHttpRequest} which is passed in.
246      *
247      * @param channel
248      *              Channel
249      * @param req
250      *              HTTP Request
251      * @return future
252      *              The {@link ChannelFuture} which is notified once the opening handshake completes
253      */
254     public ChannelFuture handshake(Channel channel, HttpRequest req) {
255         return handshake(channel, req, null, channel.newPromise());
256     }
257 
258     /**
259      * Performs the opening handshake
260      *
261      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
262      *
263      * @param channel
264      *            Channel
265      * @param req
266      *            HTTP Request
267      * @param responseHeaders
268      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
269      * @param promise
270      *            the {@link ChannelPromise} to be notified when the opening handshake is done
271      * @return future
272      *            the {@link ChannelFuture} which is notified when the opening handshake is done
273      */
274     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
275                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
276         if (req instanceof FullHttpRequest) {
277             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
278         }
279 
280         if (logger.isDebugEnabled()) {
281             logger.debug("{} WebSocket version {} server handshake", channel, version());
282         }
283 
284         ChannelPipeline p = channel.pipeline();
285         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
286         if (ctx == null) {
287             // this means the user use an HttpServerCodec
288             ctx = p.context(HttpServerCodec.class);
289             if (ctx == null) {
290                 promise.setFailure(
291                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
292                 return promise;
293             }
294         }
295 
296         String aggregatorCtx = ctx.name();
297         if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
298             version == WebSocketVersion.V00) {
299             // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit of 8192 should be
300             // more then enough for the websockets handshake payload.
301             aggregatorCtx = "httpAggregator";
302             p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
303         }
304 
305         p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
306 
307             private FullHttpRequest fullHttpRequest;
308 
309             @Override
310             public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
311                 if (msg instanceof HttpObject) {
312                     try {
313                         handleHandshakeRequest(ctx, (HttpObject) msg);
314                     } finally {
315                         ReferenceCountUtil.release(msg);
316                     }
317                 } else {
318                     super.channelRead(ctx, msg);
319                 }
320             }
321 
322             @Override
323             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
324                 // Remove ourself and fail the handshake promise.
325                 ctx.pipeline().remove(this);
326                 promise.tryFailure(cause);
327                 ctx.fireExceptionCaught(cause);
328             }
329 
330             @Override
331             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
332                 try {
333                     // Fail promise if Channel was closed
334                     if (!promise.isDone()) {
335                         promise.tryFailure(new ClosedChannelException());
336                     }
337                     ctx.fireChannelInactive();
338                 } finally {
339                     releaseFullHttpRequest();
340                 }
341             }
342 
343             @Override
344             public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
345                 releaseFullHttpRequest();
346             }
347 
348             private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
349                 if (httpObject instanceof FullHttpRequest) {
350                     ctx.pipeline().remove(this);
351                     handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
352                     return;
353                 }
354 
355                 if (httpObject instanceof LastHttpContent) {
356                     assert fullHttpRequest != null;
357                     FullHttpRequest handshakeRequest = fullHttpRequest;
358                     fullHttpRequest = null;
359                     try {
360                         ctx.pipeline().remove(this);
361                         handshake(channel, handshakeRequest, responseHeaders, promise);
362                     } finally {
363                         handshakeRequest.release();
364                     }
365                     return;
366                 }
367 
368                 if (httpObject instanceof HttpRequest) {
369                     HttpRequest httpRequest = (HttpRequest) httpObject;
370                     fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
371                         httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
372                     if (httpRequest.decoderResult().isFailure()) {
373                         fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
374                     }
375                 }
376             }
377 
378             private void releaseFullHttpRequest() {
379                 if (fullHttpRequest != null) {
380                     fullHttpRequest.release();
381                     fullHttpRequest = null;
382                 }
383             }
384         });
385         try {
386             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
387         } catch (Throwable cause) {
388             promise.setFailure(cause);
389         }
390         return promise;
391     }
392 
393     /**
394      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
395      */
396     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
397                                          HttpHeaders responseHeaders);
398     /**
399      * Performs the closing handshake.
400      *
401      * When called from within a {@link ChannelHandler} you most likely want to use
402      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}.
403      *
404      * @param channel
405      *            the {@link Channel} to use.
406      * @param frame
407      *            Closing Frame that was received.
408      */
409     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
410         ObjectUtil.checkNotNull(channel, "channel");
411         return close(channel, frame, channel.newPromise());
412     }
413 
414     /**
415      * Performs the closing handshake.
416      *
417      * When called from within a {@link ChannelHandler} you most likely want to use
418      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame, ChannelPromise)}.
419      *
420      * @param channel
421      *            the {@link Channel} to use.
422      * @param frame
423      *            Closing Frame that was received.
424      * @param promise
425      *            the {@link ChannelPromise} to be notified when the closing handshake is done
426      */
427     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
428         return close0(channel, frame, promise);
429     }
430 
431     /**
432      * Performs the closing handshake.
433      *
434      * @param ctx
435      *            the {@link ChannelHandlerContext} to use.
436      * @param frame
437      *            Closing Frame that was received.
438      */
439     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
440         ObjectUtil.checkNotNull(ctx, "ctx");
441         return close(ctx, frame, ctx.newPromise());
442     }
443 
444     /**
445      * Performs the closing handshake.
446      *
447      * @param ctx
448      *            the {@link ChannelHandlerContext} to use.
449      * @param frame
450      *            Closing Frame that was received.
451      * @param promise
452      *            the {@link ChannelPromise} to be notified when the closing handshake is done.
453      */
454     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
455         ObjectUtil.checkNotNull(ctx, "ctx");
456         return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
457     }
458 
459     private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
460         return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
461     }
462 
463     /**
464      * Selects the first matching supported sub protocol
465      *
466      * @param requestedSubprotocols
467      *            CSV of protocols to be supported. e.g. "chat, superchat"
468      * @return First matching supported sub protocol. Null if not found.
469      */
470     protected String selectSubprotocol(String requestedSubprotocols) {
471         if (requestedSubprotocols == null || subprotocols.length == 0) {
472             return null;
473         }
474 
475         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
476         for (String p: requestedSubprotocolArray) {
477             String requestedSubprotocol = p.trim();
478 
479             for (String supportedSubprotocol: subprotocols) {
480                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
481                         || requestedSubprotocol.equals(supportedSubprotocol)) {
482                     selectedSubprotocol = requestedSubprotocol;
483                     return requestedSubprotocol;
484                 }
485             }
486         }
487 
488         // No match found
489         return null;
490     }
491 
492     /**
493      * Returns the selected subprotocol. Null if no subprotocol has been selected.
494      * <p>
495      * This is only available AFTER <tt>handshake()</tt> has been called.
496      * </p>
497      */
498     public String selectedSubprotocol() {
499         return selectedSubprotocol;
500     }
501 
502     /**
503      * Returns the decoder to use after handshake is complete.
504      */
505     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
506 
507     /**
508      * Returns the encoder to use after the handshake is complete.
509      */
510     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
511 
512 }