View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpContentCompressor;
28  import io.netty.handler.codec.http.HttpHeaders;
29  import io.netty.handler.codec.http.HttpObjectAggregator;
30  import io.netty.handler.codec.http.HttpRequest;
31  import io.netty.handler.codec.http.HttpRequestDecoder;
32  import io.netty.handler.codec.http.HttpResponseEncoder;
33  import io.netty.handler.codec.http.HttpServerCodec;
34  import io.netty.util.ReferenceCountUtil;
35  import io.netty.util.internal.EmptyArrays;
36  import io.netty.util.internal.ObjectUtil;
37  import io.netty.util.internal.logging.InternalLogger;
38  import io.netty.util.internal.logging.InternalLoggerFactory;
39  
40  import java.nio.channels.ClosedChannelException;
41  import java.util.Collections;
42  import java.util.LinkedHashSet;
43  import java.util.Set;
44  
45  /**
46   * Base class for server side web socket opening and closing handshakes
47   */
48  public abstract class WebSocketServerHandshaker {
49      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
50  
51      private final String uri;
52  
53      private final String[] subprotocols;
54  
55      private final WebSocketVersion version;
56  
57      private final WebSocketDecoderConfig decoderConfig;
58  
59      private String selectedSubprotocol;
60  
61      /**
62       * Use this as wildcard to support all requested sub-protocols
63       */
64      public static final String SUB_PROTOCOL_WILDCARD = "*";
65  
66      /**
67       * Constructor specifying the destination web socket location
68       *
69       * @param version
70       *            the protocol version
71       * @param uri
72       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
73       *            sent to this URL.
74       * @param subprotocols
75       *            CSV of supported protocols. Null if sub protocols not supported.
76       * @param maxFramePayloadLength
77       *            Maximum length of a frame's payload
78       */
79      protected WebSocketServerHandshaker(
80              WebSocketVersion version, String uri, String subprotocols,
81              int maxFramePayloadLength) {
82          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
83              .maxFramePayloadLength(maxFramePayloadLength)
84              .build());
85      }
86  
87      /**
88       * Constructor specifying the destination web socket location
89       *
90       * @param version
91       *            the protocol version
92       * @param uri
93       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
94       *            sent to this URL.
95       * @param subprotocols
96       *            CSV of supported protocols. Null if sub protocols not supported.
97       * @param decoderConfig
98       *            Frames decoder configuration.
99       */
100     protected WebSocketServerHandshaker(
101             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
102         this.version = version;
103         this.uri = uri;
104         if (subprotocols != null) {
105             String[] subprotocolArray = subprotocols.split(",");
106             for (int i = 0; i < subprotocolArray.length; i++) {
107                 subprotocolArray[i] = subprotocolArray[i].trim();
108             }
109             this.subprotocols = subprotocolArray;
110         } else {
111             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
112         }
113         this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
114     }
115 
116     /**
117      * Returns the URL of the web socket
118      */
119     public String uri() {
120         return uri;
121     }
122 
123     /**
124      * Returns the CSV of supported sub protocols
125      */
126     public Set<String> subprotocols() {
127         Set<String> ret = new LinkedHashSet<String>();
128         Collections.addAll(ret, subprotocols);
129         return ret;
130     }
131 
132     /**
133      * Returns the version of the specification being supported
134      */
135     public WebSocketVersion version() {
136         return version;
137     }
138 
139     /**
140      * Gets the maximum length for any frame's payload.
141      *
142      * @return The maximum length for a frame's payload
143      */
144     public int maxFramePayloadLength() {
145         return decoderConfig.maxFramePayloadLength();
146     }
147 
148     /**
149      * Gets this decoder configuration.
150      *
151      * @return This decoder configuration.
152      */
153     public WebSocketDecoderConfig decoderConfig() {
154         return decoderConfig;
155     }
156 
157     /**
158      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
159      * {@link FullHttpRequest} which is passed in.
160      *
161      * @param channel
162      *              Channel
163      * @param req
164      *              HTTP Request
165      * @return future
166      *              The {@link ChannelFuture} which is notified once the opening handshake completes
167      */
168     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
169         return handshake(channel, req, null, channel.newPromise());
170     }
171 
172     /**
173      * Performs the opening handshake
174      *
175      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
176      *
177      * @param channel
178      *            Channel
179      * @param req
180      *            HTTP Request
181      * @param responseHeaders
182      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
183      * @param promise
184      *            the {@link ChannelPromise} to be notified when the opening handshake is done
185      * @return future
186      *            the {@link ChannelFuture} which is notified when the opening handshake is done
187      */
188     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
189                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
190 
191         if (logger.isDebugEnabled()) {
192             logger.debug("{} WebSocket version {} server handshake", channel, version());
193         }
194         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
195         ChannelPipeline p = channel.pipeline();
196         if (p.get(HttpObjectAggregator.class) != null) {
197             p.remove(HttpObjectAggregator.class);
198         }
199         if (p.get(HttpContentCompressor.class) != null) {
200             p.remove(HttpContentCompressor.class);
201         }
202         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
203         final String encoderName;
204         if (ctx == null) {
205             // this means the user use an HttpServerCodec
206             ctx = p.context(HttpServerCodec.class);
207             if (ctx == null) {
208                 promise.setFailure(
209                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
210                 return promise;
211             }
212             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
213             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
214             encoderName = ctx.name();
215         } else {
216             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
217 
218             encoderName = p.context(HttpResponseEncoder.class).name();
219             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
220         }
221         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
222             @Override
223             public void operationComplete(ChannelFuture future) throws Exception {
224                 if (future.isSuccess()) {
225                     ChannelPipeline p = future.channel().pipeline();
226                     p.remove(encoderName);
227                     promise.setSuccess();
228                 } else {
229                     promise.setFailure(future.cause());
230                 }
231             }
232         });
233         return promise;
234     }
235 
236     /**
237      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
238      * {@link FullHttpRequest} which is passed in.
239      *
240      * @param channel
241      *              Channel
242      * @param req
243      *              HTTP Request
244      * @return future
245      *              The {@link ChannelFuture} which is notified once the opening handshake completes
246      */
247     public ChannelFuture handshake(Channel channel, HttpRequest req) {
248         return handshake(channel, req, null, channel.newPromise());
249     }
250 
251     /**
252      * Performs the opening handshake
253      *
254      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
255      *
256      * @param channel
257      *            Channel
258      * @param req
259      *            HTTP Request
260      * @param responseHeaders
261      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
262      * @param promise
263      *            the {@link ChannelPromise} to be notified when the opening handshake is done
264      * @return future
265      *            the {@link ChannelFuture} which is notified when the opening handshake is done
266      */
267     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
268                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
269 
270         if (req instanceof FullHttpRequest) {
271             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
272         }
273         if (logger.isDebugEnabled()) {
274             logger.debug("{} WebSocket version {} server handshake", channel, version());
275         }
276         ChannelPipeline p = channel.pipeline();
277         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
278         if (ctx == null) {
279             // this means the user use an HttpServerCodec
280             ctx = p.context(HttpServerCodec.class);
281             if (ctx == null) {
282                 promise.setFailure(
283                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
284                 return promise;
285             }
286         }
287         // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
288         // enough for the websockets handshake payload.
289         //
290         // TODO: Make handshake work without HttpObjectAggregator at all.
291         String aggregatorName = "httpAggregator";
292         p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
293         p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
294             @Override
295             protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
296                 // Remove ourself and do the actual handshake
297                 ctx.pipeline().remove(this);
298                 handshake(channel, msg, responseHeaders, promise);
299             }
300 
301             @Override
302             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
303                 // Remove ourself and fail the handshake promise.
304                 ctx.pipeline().remove(this);
305                 promise.tryFailure(cause);
306                 ctx.fireExceptionCaught(cause);
307             }
308 
309             @Override
310             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
311                 // Fail promise if Channel was closed
312                 if (!promise.isDone()) {
313                     promise.tryFailure(new ClosedChannelException());
314                 }
315                 ctx.fireChannelInactive();
316             }
317         });
318         try {
319             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
320         } catch (Throwable cause) {
321             promise.setFailure(cause);
322         }
323         return promise;
324     }
325 
326     /**
327      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
328      */
329     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
330                                          HttpHeaders responseHeaders);
331     /**
332      * Performs the closing handshake
333      *
334      * @param channel
335      *            Channel
336      * @param frame
337      *            Closing Frame that was received
338      */
339     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
340         ObjectUtil.checkNotNull(channel, "channel");
341         return close(channel, frame, channel.newPromise());
342     }
343 
344     /**
345      * Performs the closing handshake
346      *
347      * @param channel
348      *            Channel
349      * @param frame
350      *            Closing Frame that was received
351      * @param promise
352      *            the {@link ChannelPromise} to be notified when the closing handshake is done
353      */
354     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
355         ObjectUtil.checkNotNull(channel, "channel");
356         return channel.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
357     }
358 
359     /**
360      * Selects the first matching supported sub protocol
361      *
362      * @param requestedSubprotocols
363      *            CSV of protocols to be supported. e.g. "chat, superchat"
364      * @return First matching supported sub protocol. Null if not found.
365      */
366     protected String selectSubprotocol(String requestedSubprotocols) {
367         if (requestedSubprotocols == null || subprotocols.length == 0) {
368             return null;
369         }
370 
371         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
372         for (String p: requestedSubprotocolArray) {
373             String requestedSubprotocol = p.trim();
374 
375             for (String supportedSubprotocol: subprotocols) {
376                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
377                         || requestedSubprotocol.equals(supportedSubprotocol)) {
378                     selectedSubprotocol = requestedSubprotocol;
379                     return requestedSubprotocol;
380                 }
381             }
382         }
383 
384         // No match found
385         return null;
386     }
387 
388     /**
389      * Returns the selected subprotocol. Null if no subprotocol has been selected.
390      * <p>
391      * This is only available AFTER <tt>handshake()</tt> has been called.
392      * </p>
393      */
394     public String selectedSubprotocol() {
395         return selectedSubprotocol;
396     }
397 
398     /**
399      * Returns the decoder to use after handshake is complete.
400      */
401     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
402 
403     /**
404      * Returns the encoder to use after the handshake is complete.
405      */
406     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
407 
408 }