View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import java.nio.channels.ClosedChannelException;
19  import java.util.Collections;
20  import java.util.LinkedHashSet;
21  import java.util.Set;
22  
23  import io.netty.buffer.Unpooled;
24  import io.netty.channel.Channel;
25  import io.netty.channel.ChannelFuture;
26  import io.netty.channel.ChannelFutureListener;
27  import io.netty.channel.ChannelHandler;
28  import io.netty.channel.ChannelHandlerContext;
29  import io.netty.channel.ChannelInboundHandlerAdapter;
30  import io.netty.channel.ChannelOutboundInvoker;
31  import io.netty.channel.ChannelPipeline;
32  import io.netty.channel.ChannelPromise;
33  import io.netty.handler.codec.http.DefaultFullHttpRequest;
34  import io.netty.handler.codec.http.EmptyHttpHeaders;
35  import io.netty.handler.codec.http.FullHttpRequest;
36  import io.netty.handler.codec.http.FullHttpResponse;
37  import io.netty.handler.codec.http.HttpContentCompressor;
38  import io.netty.handler.codec.http.HttpHeaders;
39  import io.netty.handler.codec.http.HttpObject;
40  import io.netty.handler.codec.http.HttpObjectAggregator;
41  import io.netty.handler.codec.http.HttpRequest;
42  import io.netty.handler.codec.http.HttpRequestDecoder;
43  import io.netty.handler.codec.http.HttpResponseEncoder;
44  import io.netty.handler.codec.http.HttpServerCodec;
45  import io.netty.handler.codec.http.HttpUtil;
46  import io.netty.handler.codec.http.LastHttpContent;
47  import io.netty.util.ReferenceCountUtil;
48  import io.netty.util.internal.EmptyArrays;
49  import io.netty.util.internal.ObjectUtil;
50  import io.netty.util.internal.logging.InternalLogger;
51  import io.netty.util.internal.logging.InternalLoggerFactory;
52  
53  /**
54   * Base class for server side web socket opening and closing handshakes
55   */
56  public abstract class WebSocketServerHandshaker {
57      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58  
59      private final String uri;
60  
61      private final String[] subprotocols;
62  
63      private final WebSocketVersion version;
64  
65      private final WebSocketDecoderConfig decoderConfig;
66  
67      private String selectedSubprotocol;
68  
69      /**
70       * Use this as wildcard to support all requested sub-protocols
71       */
72      public static final String SUB_PROTOCOL_WILDCARD = "*";
73  
74      /**
75       * Constructor specifying the destination web socket location
76       *
77       * @param version
78       *            the protocol version
79       * @param uri
80       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
81       *            sent to this URL.
82       * @param subprotocols
83       *            CSV of supported protocols. Null if sub protocols not supported.
84       * @param maxFramePayloadLength
85       *            Maximum length of a frame's payload
86       */
87      protected WebSocketServerHandshaker(
88              WebSocketVersion version, String uri, String subprotocols,
89              int maxFramePayloadLength) {
90          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91              .maxFramePayloadLength(maxFramePayloadLength)
92              .build());
93      }
94  
95      /**
96       * Constructor specifying the destination web socket location
97       *
98       * @param version
99       *            the protocol version
100      * @param uri
101      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
102      *            sent to this URL.
103      * @param subprotocols
104      *            CSV of supported protocols. Null if sub protocols not supported.
105      * @param decoderConfig
106      *            Frames decoder configuration.
107      */
108     protected WebSocketServerHandshaker(
109             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110         this.version = version;
111         this.uri = uri;
112         if (subprotocols != null) {
113             String[] subprotocolArray = subprotocols.split(",");
114             for (int i = 0; i < subprotocolArray.length; i++) {
115                 subprotocolArray[i] = subprotocolArray[i].trim();
116             }
117             this.subprotocols = subprotocolArray;
118         } else {
119             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120         }
121         this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122     }
123 
124     /**
125      * Returns the URL of the web socket
126      */
127     public String uri() {
128         return uri;
129     }
130 
131     /**
132      * Returns the CSV of supported sub protocols
133      */
134     public Set<String> subprotocols() {
135         Set<String> ret = new LinkedHashSet<String>();
136         Collections.addAll(ret, subprotocols);
137         return ret;
138     }
139 
140     /**
141      * Returns the version of the specification being supported
142      */
143     public WebSocketVersion version() {
144         return version;
145     }
146 
147     /**
148      * Gets the maximum length for any frame's payload.
149      *
150      * @return The maximum length for a frame's payload
151      */
152     public int maxFramePayloadLength() {
153         return decoderConfig.maxFramePayloadLength();
154     }
155 
156     /**
157      * Gets this decoder configuration.
158      *
159      * @return This decoder configuration.
160      */
161     public WebSocketDecoderConfig decoderConfig() {
162         return decoderConfig;
163     }
164 
165     /**
166      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
167      * {@link FullHttpRequest} which is passed in.
168      *
169      * @param channel
170      *              Channel
171      * @param req
172      *              HTTP Request
173      * @return future
174      *              The {@link ChannelFuture} which is notified once the opening handshake completes
175      */
176     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
177         return handshake(channel, req, null, channel.newPromise());
178     }
179 
180     /**
181      * Performs the opening handshake
182      *
183      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
184      *
185      * @param channel
186      *            Channel
187      * @param req
188      *            HTTP Request
189      * @param responseHeaders
190      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
191      * @param promise
192      *            the {@link ChannelPromise} to be notified when the opening handshake is done
193      * @return future
194      *            the {@link ChannelFuture} which is notified when the opening handshake is done
195      */
196     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
197                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
198 
199         if (logger.isDebugEnabled()) {
200             logger.debug("{} WebSocket version {} server handshake", channel, version());
201         }
202         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
203         ChannelPipeline p = channel.pipeline();
204         if (p.get(HttpObjectAggregator.class) != null) {
205             p.remove(HttpObjectAggregator.class);
206         }
207         if (p.get(HttpContentCompressor.class) != null) {
208             p.remove(HttpContentCompressor.class);
209         }
210         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
211         final String encoderName;
212         if (ctx == null) {
213             // this means the user use an HttpServerCodec
214             ctx = p.context(HttpServerCodec.class);
215             if (ctx == null) {
216                 promise.setFailure(
217                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
218                 response.release();
219                 return promise;
220             }
221             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
222             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
223             encoderName = ctx.name();
224         } else {
225             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
226 
227             encoderName = p.context(HttpResponseEncoder.class).name();
228             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
229         }
230         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
231             @Override
232             public void operationComplete(ChannelFuture future) throws Exception {
233                 if (future.isSuccess()) {
234                     ChannelPipeline p = future.channel().pipeline();
235                     p.remove(encoderName);
236                     promise.setSuccess();
237                 } else {
238                     promise.setFailure(future.cause());
239                 }
240             }
241         });
242         return promise;
243     }
244 
245     /**
246      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
247      * {@link FullHttpRequest} which is passed in.
248      *
249      * @param channel
250      *              Channel
251      * @param req
252      *              HTTP Request
253      * @return future
254      *              The {@link ChannelFuture} which is notified once the opening handshake completes
255      */
256     public ChannelFuture handshake(Channel channel, HttpRequest req) {
257         return handshake(channel, req, null, channel.newPromise());
258     }
259 
260     /**
261      * Performs the opening handshake
262      *
263      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
264      *
265      * @param channel
266      *            Channel
267      * @param req
268      *            HTTP Request
269      * @param responseHeaders
270      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
271      * @param promise
272      *            the {@link ChannelPromise} to be notified when the opening handshake is done
273      * @return future
274      *            the {@link ChannelFuture} which is notified when the opening handshake is done
275      */
276     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
277                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
278         if (req instanceof FullHttpRequest) {
279             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
280         }
281 
282         if (logger.isDebugEnabled()) {
283             logger.debug("{} WebSocket version {} server handshake", channel, version());
284         }
285 
286         ChannelPipeline p = channel.pipeline();
287         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
288         if (ctx == null) {
289             // this means the user use an HttpServerCodec
290             ctx = p.context(HttpServerCodec.class);
291             if (ctx == null) {
292                 promise.setFailure(
293                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
294                 return promise;
295             }
296         }
297 
298         String aggregatorCtx = ctx.name();
299         if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
300             version == WebSocketVersion.V00) {
301             // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit of 8192 should be
302             // more then enough for the websockets handshake payload.
303             aggregatorCtx = "httpAggregator";
304             p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
305         }
306 
307         p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
308 
309             private FullHttpRequest fullHttpRequest;
310 
311             @Override
312             public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
313                 if (msg instanceof HttpObject) {
314                     try {
315                         handleHandshakeRequest(ctx, (HttpObject) msg);
316                     } finally {
317                         ReferenceCountUtil.release(msg);
318                     }
319                 } else {
320                     super.channelRead(ctx, msg);
321                 }
322             }
323 
324             @Override
325             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
326                 // Remove ourself and fail the handshake promise.
327                 ctx.pipeline().remove(this);
328                 promise.tryFailure(cause);
329                 ctx.fireExceptionCaught(cause);
330             }
331 
332             @Override
333             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
334                 try {
335                     // Fail promise if Channel was closed
336                     if (!promise.isDone()) {
337                         promise.tryFailure(new ClosedChannelException());
338                     }
339                     ctx.fireChannelInactive();
340                 } finally {
341                     releaseFullHttpRequest();
342                 }
343             }
344 
345             @Override
346             public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
347                 releaseFullHttpRequest();
348             }
349 
350             private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
351                 if (httpObject instanceof FullHttpRequest) {
352                     ctx.pipeline().remove(this);
353                     handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
354                     return;
355                 }
356 
357                 if (httpObject instanceof LastHttpContent) {
358                     assert fullHttpRequest != null;
359                     FullHttpRequest handshakeRequest = fullHttpRequest;
360                     fullHttpRequest = null;
361                     try {
362                         ctx.pipeline().remove(this);
363                         handshake(channel, handshakeRequest, responseHeaders, promise);
364                     } finally {
365                         handshakeRequest.release();
366                     }
367                     return;
368                 }
369 
370                 if (httpObject instanceof HttpRequest) {
371                     HttpRequest httpRequest = (HttpRequest) httpObject;
372                     fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
373                         httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
374                     if (httpRequest.decoderResult().isFailure()) {
375                         fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
376                     }
377                 }
378             }
379 
380             private void releaseFullHttpRequest() {
381                 if (fullHttpRequest != null) {
382                     fullHttpRequest.release();
383                     fullHttpRequest = null;
384                 }
385             }
386         });
387         try {
388             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
389         } catch (Throwable cause) {
390             promise.setFailure(cause);
391         }
392         return promise;
393     }
394 
395     /**
396      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
397      */
398     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
399                                          HttpHeaders responseHeaders);
400     /**
401      * Performs the closing handshake.
402      *
403      * When called from within a {@link ChannelHandler} you most likely want to use
404      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}.
405      *
406      * @param channel
407      *            the {@link Channel} to use.
408      * @param frame
409      *            Closing Frame that was received.
410      */
411     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
412         ObjectUtil.checkNotNull(channel, "channel");
413         return close(channel, frame, channel.newPromise());
414     }
415 
416     /**
417      * Performs the closing handshake.
418      *
419      * When called from within a {@link ChannelHandler} you most likely want to use
420      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame, ChannelPromise)}.
421      *
422      * @param channel
423      *            the {@link Channel} to use.
424      * @param frame
425      *            Closing Frame that was received.
426      * @param promise
427      *            the {@link ChannelPromise} to be notified when the closing handshake is done
428      */
429     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
430         return close0(channel, frame, promise);
431     }
432 
433     /**
434      * Performs the closing handshake.
435      *
436      * @param ctx
437      *            the {@link ChannelHandlerContext} to use.
438      * @param frame
439      *            Closing Frame that was received.
440      */
441     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
442         ObjectUtil.checkNotNull(ctx, "ctx");
443         return close(ctx, frame, ctx.newPromise());
444     }
445 
446     /**
447      * Performs the closing handshake.
448      *
449      * @param ctx
450      *            the {@link ChannelHandlerContext} to use.
451      * @param frame
452      *            Closing Frame that was received.
453      * @param promise
454      *            the {@link ChannelPromise} to be notified when the closing handshake is done.
455      */
456     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
457         ObjectUtil.checkNotNull(ctx, "ctx");
458         return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
459     }
460 
461     private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
462         return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
463     }
464 
465     /**
466      * Selects the first matching supported sub protocol
467      *
468      * @param requestedSubprotocols
469      *            CSV of protocols to be supported. e.g. "chat, superchat"
470      * @return First matching supported sub protocol. Null if not found.
471      */
472     protected String selectSubprotocol(String requestedSubprotocols) {
473         if (requestedSubprotocols == null || subprotocols.length == 0) {
474             return null;
475         }
476 
477         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
478         for (String p: requestedSubprotocolArray) {
479             String requestedSubprotocol = p.trim();
480 
481             for (String supportedSubprotocol: subprotocols) {
482                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
483                         || requestedSubprotocol.equals(supportedSubprotocol)) {
484                     selectedSubprotocol = requestedSubprotocol;
485                     return requestedSubprotocol;
486                 }
487             }
488         }
489 
490         // No match found
491         return null;
492     }
493 
494     /**
495      * Returns the selected subprotocol. Null if no subprotocol has been selected.
496      * <p>
497      * This is only available AFTER <tt>handshake()</tt> has been called.
498      * </p>
499      */
500     public String selectedSubprotocol() {
501         return selectedSubprotocol;
502     }
503 
504     /**
505      * Returns the decoder to use after handshake is complete.
506      */
507     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
508 
509     /**
510      * Returns the encoder to use after the handshake is complete.
511      */
512     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
513 
514 }