View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import java.nio.channels.ClosedChannelException;
19  import java.util.Collections;
20  import java.util.LinkedHashSet;
21  import java.util.Set;
22  
23  import io.netty.buffer.Unpooled;
24  import io.netty.channel.Channel;
25  import io.netty.channel.ChannelFuture;
26  import io.netty.channel.ChannelFutureListener;
27  import io.netty.channel.ChannelHandler;
28  import io.netty.channel.ChannelHandlerContext;
29  import io.netty.channel.ChannelInboundHandlerAdapter;
30  import io.netty.channel.ChannelOutboundInvoker;
31  import io.netty.channel.ChannelPipeline;
32  import io.netty.channel.ChannelPromise;
33  import io.netty.handler.codec.http.DefaultFullHttpRequest;
34  import io.netty.handler.codec.http.EmptyHttpHeaders;
35  import io.netty.handler.codec.http.FullHttpRequest;
36  import io.netty.handler.codec.http.FullHttpResponse;
37  import io.netty.handler.codec.http.HttpContentCompressor;
38  import io.netty.handler.codec.http.HttpHeaders;
39  import io.netty.handler.codec.http.HttpObject;
40  import io.netty.handler.codec.http.HttpObjectAggregator;
41  import io.netty.handler.codec.http.HttpRequest;
42  import io.netty.handler.codec.http.HttpRequestDecoder;
43  import io.netty.handler.codec.http.HttpResponseEncoder;
44  import io.netty.handler.codec.http.HttpServerCodec;
45  import io.netty.handler.codec.http.HttpUtil;
46  import io.netty.handler.codec.http.LastHttpContent;
47  import io.netty.util.ReferenceCountUtil;
48  import io.netty.util.internal.EmptyArrays;
49  import io.netty.util.internal.ObjectUtil;
50  import io.netty.util.internal.logging.InternalLogger;
51  import io.netty.util.internal.logging.InternalLoggerFactory;
52  
53  /**
54   * Base class for server side web socket opening and closing handshakes
55   */
56  public abstract class WebSocketServerHandshaker {
57      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
58  
59      private final String uri;
60  
61      private final String[] subprotocols;
62  
63      private final WebSocketVersion version;
64  
65      private final WebSocketDecoderConfig decoderConfig;
66  
67      private String selectedSubprotocol;
68  
69      /**
70       * Use this as wildcard to support all requested sub-protocols
71       */
72      public static final String SUB_PROTOCOL_WILDCARD = "*";
73  
74      /**
75       * Constructor specifying the destination web socket location
76       *
77       * @param version
78       *            the protocol version
79       * @param uri
80       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
81       *            sent to this URL.
82       * @param subprotocols
83       *            CSV of supported protocols. Null if sub protocols not supported.
84       * @param maxFramePayloadLength
85       *            Maximum length of a frame's payload
86       */
87      protected WebSocketServerHandshaker(
88              WebSocketVersion version, String uri, String subprotocols,
89              int maxFramePayloadLength) {
90          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
91              .maxFramePayloadLength(maxFramePayloadLength)
92              .build());
93      }
94  
95      /**
96       * Constructor specifying the destination web socket location
97       *
98       * @param version
99       *            the protocol version
100      * @param uri
101      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
102      *            sent to this URL.
103      * @param subprotocols
104      *            CSV of supported protocols. Null if sub protocols not supported.
105      * @param decoderConfig
106      *            Frames decoder configuration.
107      */
108     protected WebSocketServerHandshaker(
109             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
110         this.version = version;
111         this.uri = uri;
112         if (subprotocols != null) {
113             String[] subprotocolArray = subprotocols.split(",");
114             for (int i = 0; i < subprotocolArray.length; i++) {
115                 subprotocolArray[i] = subprotocolArray[i].trim();
116             }
117             this.subprotocols = subprotocolArray;
118         } else {
119             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
120         }
121         this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
122     }
123 
124     /**
125      * Returns the URL of the web socket
126      */
127     public String uri() {
128         return uri;
129     }
130 
131     /**
132      * Returns the CSV of supported sub protocols
133      */
134     public Set<String> subprotocols() {
135         Set<String> ret = new LinkedHashSet<String>();
136         Collections.addAll(ret, subprotocols);
137         return ret;
138     }
139 
140     /**
141      * Returns the version of the specification being supported
142      */
143     public WebSocketVersion version() {
144         return version;
145     }
146 
147     /**
148      * Gets the maximum length for any frame's payload.
149      *
150      * @return The maximum length for a frame's payload
151      */
152     public int maxFramePayloadLength() {
153         return decoderConfig.maxFramePayloadLength();
154     }
155 
156     /**
157      * Gets this decoder configuration.
158      *
159      * @return This decoder configuration.
160      */
161     public WebSocketDecoderConfig decoderConfig() {
162         return decoderConfig;
163     }
164 
165     /**
166      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
167      * {@link FullHttpRequest} which is passed in.
168      *
169      * @param channel
170      *              Channel
171      * @param req
172      *              HTTP Request
173      * @return future
174      *              The {@link ChannelFuture} which is notified once the opening handshake completes
175      */
176     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
177         return handshake(channel, req, null, channel.newPromise());
178     }
179 
180     /**
181      * Performs the opening handshake
182      *
183      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
184      *
185      * @param channel
186      *            Channel
187      * @param req
188      *            HTTP Request
189      * @param responseHeaders
190      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
191      * @param promise
192      *            the {@link ChannelPromise} to be notified when the opening handshake is done
193      * @return future
194      *            the {@link ChannelFuture} which is notified when the opening handshake is done
195      */
196     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
197                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
198 
199         if (logger.isDebugEnabled()) {
200             logger.debug("{} WebSocket version {} server handshake", channel, version());
201         }
202         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
203         ChannelPipeline p = channel.pipeline();
204         if (p.get(HttpObjectAggregator.class) != null) {
205             p.remove(HttpObjectAggregator.class);
206         }
207         if (p.get(HttpContentCompressor.class) != null) {
208             p.remove(HttpContentCompressor.class);
209         }
210         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
211         final String encoderName;
212         if (ctx == null) {
213             // this means the user use an HttpServerCodec
214             ctx = p.context(HttpServerCodec.class);
215             if (ctx == null) {
216                 promise.setFailure(
217                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
218                 return promise;
219             }
220             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
221             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
222             encoderName = ctx.name();
223         } else {
224             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
225 
226             encoderName = p.context(HttpResponseEncoder.class).name();
227             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
228         }
229         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
230             @Override
231             public void operationComplete(ChannelFuture future) throws Exception {
232                 if (future.isSuccess()) {
233                     ChannelPipeline p = future.channel().pipeline();
234                     p.remove(encoderName);
235                     promise.setSuccess();
236                 } else {
237                     promise.setFailure(future.cause());
238                 }
239             }
240         });
241         return promise;
242     }
243 
244     /**
245      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
246      * {@link FullHttpRequest} which is passed in.
247      *
248      * @param channel
249      *              Channel
250      * @param req
251      *              HTTP Request
252      * @return future
253      *              The {@link ChannelFuture} which is notified once the opening handshake completes
254      */
255     public ChannelFuture handshake(Channel channel, HttpRequest req) {
256         return handshake(channel, req, null, channel.newPromise());
257     }
258 
259     /**
260      * Performs the opening handshake
261      *
262      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
263      *
264      * @param channel
265      *            Channel
266      * @param req
267      *            HTTP Request
268      * @param responseHeaders
269      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
270      * @param promise
271      *            the {@link ChannelPromise} to be notified when the opening handshake is done
272      * @return future
273      *            the {@link ChannelFuture} which is notified when the opening handshake is done
274      */
275     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
276                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
277         if (req instanceof FullHttpRequest) {
278             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
279         }
280 
281         if (logger.isDebugEnabled()) {
282             logger.debug("{} WebSocket version {} server handshake", channel, version());
283         }
284 
285         ChannelPipeline p = channel.pipeline();
286         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
287         if (ctx == null) {
288             // this means the user use an HttpServerCodec
289             ctx = p.context(HttpServerCodec.class);
290             if (ctx == null) {
291                 promise.setFailure(
292                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
293                 return promise;
294             }
295         }
296 
297         String aggregatorCtx = ctx.name();
298         if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) ||
299             version == WebSocketVersion.V00) {
300             // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit of 8192 should be
301             // more then enough for the websockets handshake payload.
302             aggregatorCtx = "httpAggregator";
303             p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192));
304         }
305 
306         p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() {
307 
308             private FullHttpRequest fullHttpRequest;
309 
310             @Override
311             public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
312                 if (msg instanceof HttpObject) {
313                     try {
314                         handleHandshakeRequest(ctx, (HttpObject) msg);
315                     } finally {
316                         ReferenceCountUtil.release(msg);
317                     }
318                 } else {
319                     super.channelRead(ctx, msg);
320                 }
321             }
322 
323             @Override
324             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
325                 // Remove ourself and fail the handshake promise.
326                 ctx.pipeline().remove(this);
327                 promise.tryFailure(cause);
328                 ctx.fireExceptionCaught(cause);
329             }
330 
331             @Override
332             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
333                 try {
334                     // Fail promise if Channel was closed
335                     if (!promise.isDone()) {
336                         promise.tryFailure(new ClosedChannelException());
337                     }
338                     ctx.fireChannelInactive();
339                 } finally {
340                     releaseFullHttpRequest();
341                 }
342             }
343 
344             @Override
345             public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
346                 releaseFullHttpRequest();
347             }
348 
349             private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
350                 if (httpObject instanceof FullHttpRequest) {
351                     ctx.pipeline().remove(this);
352                     handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise);
353                     return;
354                 }
355 
356                 if (httpObject instanceof LastHttpContent) {
357                     assert fullHttpRequest != null;
358                     FullHttpRequest handshakeRequest = fullHttpRequest;
359                     fullHttpRequest = null;
360                     try {
361                         ctx.pipeline().remove(this);
362                         handshake(channel, handshakeRequest, responseHeaders, promise);
363                     } finally {
364                         handshakeRequest.release();
365                     }
366                     return;
367                 }
368 
369                 if (httpObject instanceof HttpRequest) {
370                     HttpRequest httpRequest = (HttpRequest) httpObject;
371                     fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
372                         httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE);
373                     if (httpRequest.decoderResult().isFailure()) {
374                         fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
375                     }
376                 }
377             }
378 
379             private void releaseFullHttpRequest() {
380                 if (fullHttpRequest != null) {
381                     fullHttpRequest.release();
382                     fullHttpRequest = null;
383                 }
384             }
385         });
386         try {
387             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
388         } catch (Throwable cause) {
389             promise.setFailure(cause);
390         }
391         return promise;
392     }
393 
394     /**
395      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
396      */
397     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
398                                          HttpHeaders responseHeaders);
399     /**
400      * Performs the closing handshake.
401      *
402      * When called from within a {@link ChannelHandler} you most likely want to use
403      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}.
404      *
405      * @param channel
406      *            the {@link Channel} to use.
407      * @param frame
408      *            Closing Frame that was received.
409      */
410     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
411         ObjectUtil.checkNotNull(channel, "channel");
412         return close(channel, frame, channel.newPromise());
413     }
414 
415     /**
416      * Performs the closing handshake.
417      *
418      * When called from within a {@link ChannelHandler} you most likely want to use
419      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame, ChannelPromise)}.
420      *
421      * @param channel
422      *            the {@link Channel} to use.
423      * @param frame
424      *            Closing Frame that was received.
425      * @param promise
426      *            the {@link ChannelPromise} to be notified when the closing handshake is done
427      */
428     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
429         return close0(channel, frame, promise);
430     }
431 
432     /**
433      * Performs the closing handshake.
434      *
435      * @param ctx
436      *            the {@link ChannelHandlerContext} to use.
437      * @param frame
438      *            Closing Frame that was received.
439      */
440     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
441         ObjectUtil.checkNotNull(ctx, "ctx");
442         return close(ctx, frame, ctx.newPromise());
443     }
444 
445     /**
446      * Performs the closing handshake.
447      *
448      * @param ctx
449      *            the {@link ChannelHandlerContext} to use.
450      * @param frame
451      *            Closing Frame that was received.
452      * @param promise
453      *            the {@link ChannelPromise} to be notified when the closing handshake is done.
454      */
455     public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) {
456         ObjectUtil.checkNotNull(ctx, "ctx");
457         return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE);
458     }
459 
460     private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) {
461         return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
462     }
463 
464     /**
465      * Selects the first matching supported sub protocol
466      *
467      * @param requestedSubprotocols
468      *            CSV of protocols to be supported. e.g. "chat, superchat"
469      * @return First matching supported sub protocol. Null if not found.
470      */
471     protected String selectSubprotocol(String requestedSubprotocols) {
472         if (requestedSubprotocols == null || subprotocols.length == 0) {
473             return null;
474         }
475 
476         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
477         for (String p: requestedSubprotocolArray) {
478             String requestedSubprotocol = p.trim();
479 
480             for (String supportedSubprotocol: subprotocols) {
481                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
482                         || requestedSubprotocol.equals(supportedSubprotocol)) {
483                     selectedSubprotocol = requestedSubprotocol;
484                     return requestedSubprotocol;
485                 }
486             }
487         }
488 
489         // No match found
490         return null;
491     }
492 
493     /**
494      * Returns the selected subprotocol. Null if no subprotocol has been selected.
495      * <p>
496      * This is only available AFTER <tt>handshake()</tt> has been called.
497      * </p>
498      */
499     public String selectedSubprotocol() {
500         return selectedSubprotocol;
501     }
502 
503     /**
504      * Returns the decoder to use after handshake is complete.
505      */
506     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
507 
508     /**
509      * Returns the encoder to use after the handshake is complete.
510      */
511     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
512 
513 }