View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpContentCompressor;
28  import io.netty.handler.codec.http.HttpHeaders;
29  import io.netty.handler.codec.http.HttpObjectAggregator;
30  import io.netty.handler.codec.http.HttpRequest;
31  import io.netty.handler.codec.http.HttpRequestDecoder;
32  import io.netty.handler.codec.http.HttpResponseEncoder;
33  import io.netty.handler.codec.http.HttpServerCodec;
34  import io.netty.util.ReferenceCountUtil;
35  import io.netty.util.internal.EmptyArrays;
36  import io.netty.util.internal.StringUtil;
37  import io.netty.util.internal.logging.InternalLogger;
38  import io.netty.util.internal.logging.InternalLoggerFactory;
39  
40  import java.nio.channels.ClosedChannelException;
41  import java.util.Collections;
42  import java.util.LinkedHashSet;
43  import java.util.Set;
44  
45  /**
46   * Base class for server side web socket opening and closing handshakes
47   */
48  public abstract class WebSocketServerHandshaker {
49      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
50      private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
51  
52      static {
53          CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
54      }
55  
56      private final String uri;
57  
58      private final String[] subprotocols;
59  
60      private final WebSocketVersion version;
61  
62      private final int maxFramePayloadLength;
63  
64      private String selectedSubprotocol;
65  
66      /**
67       * Use this as wildcard to support all requested sub-protocols
68       */
69      public static final String SUB_PROTOCOL_WILDCARD = "*";
70  
71      /**
72       * Constructor specifying the destination web socket location
73       *
74       * @param version
75       *            the protocol version
76       * @param uri
77       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
78       *            sent to this URL.
79       * @param subprotocols
80       *            CSV of supported protocols. Null if sub protocols not supported.
81       * @param maxFramePayloadLength
82       *            Maximum length of a frame's payload
83       */
84      protected WebSocketServerHandshaker(
85              WebSocketVersion version, String uri, String subprotocols,
86              int maxFramePayloadLength) {
87          this.version = version;
88          this.uri = uri;
89          if (subprotocols != null) {
90              String[] subprotocolArray = StringUtil.split(subprotocols, ',');
91              for (int i = 0; i < subprotocolArray.length; i++) {
92                  subprotocolArray[i] = subprotocolArray[i].trim();
93              }
94              this.subprotocols = subprotocolArray;
95          } else {
96              this.subprotocols = EmptyArrays.EMPTY_STRINGS;
97          }
98          this.maxFramePayloadLength = maxFramePayloadLength;
99      }
100 
101     /**
102      * Returns the URL of the web socket
103      */
104     public String uri() {
105         return uri;
106     }
107 
108     /**
109      * Returns the CSV of supported sub protocols
110      */
111     public Set<String> subprotocols() {
112         Set<String> ret = new LinkedHashSet<String>();
113         Collections.addAll(ret, subprotocols);
114         return ret;
115     }
116 
117     /**
118      * Returns the version of the specification being supported
119      */
120     public WebSocketVersion version() {
121         return version;
122     }
123 
124     /**
125      * Gets the maximum length for any frame's payload.
126      *
127      * @return The maximum length for a frame's payload
128      */
129     public int maxFramePayloadLength() {
130         return maxFramePayloadLength;
131     }
132 
133     /**
134      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
135      * {@link FullHttpRequest} which is passed in.
136      *
137      * @param channel
138      *              Channel
139      * @param req
140      *              HTTP Request
141      * @return future
142      *              The {@link ChannelFuture} which is notified once the opening handshake completes
143      */
144     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
145         return handshake(channel, req, null, channel.newPromise());
146     }
147 
148     /**
149      * Performs the opening handshake
150      *
151      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
152      *
153      * @param channel
154      *            Channel
155      * @param req
156      *            HTTP Request
157      * @param responseHeaders
158      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
159      * @param promise
160      *            the {@link ChannelPromise} to be notified when the opening handshake is done
161      * @return future
162      *            the {@link ChannelFuture} which is notified when the opening handshake is done
163      */
164     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
165                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
166 
167         if (logger.isDebugEnabled()) {
168             logger.debug("{} WebSocket version {} server handshake", channel, version());
169         }
170         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
171         ChannelPipeline p = channel.pipeline();
172         if (p.get(HttpObjectAggregator.class) != null) {
173             p.remove(HttpObjectAggregator.class);
174         }
175         if (p.get(HttpContentCompressor.class) != null) {
176             p.remove(HttpContentCompressor.class);
177         }
178         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
179         final String encoderName;
180         if (ctx == null) {
181             // this means the user use a HttpServerCodec
182             ctx = p.context(HttpServerCodec.class);
183             if (ctx == null) {
184                 promise.setFailure(
185                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
186                 return promise;
187             }
188             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
189             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
190             encoderName = ctx.name();
191         } else {
192             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
193 
194             encoderName = p.context(HttpResponseEncoder.class).name();
195             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
196         }
197         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
198             @Override
199             public void operationComplete(ChannelFuture future) throws Exception {
200                 if (future.isSuccess()) {
201                     ChannelPipeline p = future.channel().pipeline();
202                     p.remove(encoderName);
203                     promise.setSuccess();
204                 } else {
205                     promise.setFailure(future.cause());
206                 }
207             }
208         });
209         return promise;
210     }
211 
212     /**
213      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
214      * {@link FullHttpRequest} which is passed in.
215      *
216      * @param channel
217      *              Channel
218      * @param req
219      *              HTTP Request
220      * @return future
221      *              The {@link ChannelFuture} which is notified once the opening handshake completes
222      */
223     public ChannelFuture handshake(Channel channel, HttpRequest req) {
224         return handshake(channel, req, null, channel.newPromise());
225     }
226 
227     /**
228      * Performs the opening handshake
229      *
230      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
231      *
232      * @param channel
233      *            Channel
234      * @param req
235      *            HTTP Request
236      * @param responseHeaders
237      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
238      * @param promise
239      *            the {@link ChannelPromise} to be notified when the opening handshake is done
240      * @return future
241      *            the {@link ChannelFuture} which is notified when the opening handshake is done
242      */
243     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
244                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
245 
246         if (req instanceof FullHttpRequest) {
247             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
248         }
249         if (logger.isDebugEnabled()) {
250             logger.debug("{} WebSocket version {} server handshake", channel, version());
251         }
252         ChannelPipeline p = channel.pipeline();
253         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
254         if (ctx == null) {
255             // this means the user use a HttpServerCodec
256             ctx = p.context(HttpServerCodec.class);
257             if (ctx == null) {
258                 promise.setFailure(
259                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
260                 return promise;
261             }
262         }
263         // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
264         // enough for the websockets handshake payload.
265         //
266         // TODO: Make handshake work without HttpObjectAggregator at all.
267         String aggregatorName = "httpAggregator";
268         p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
269         p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
270             @Override
271             protected void messageReceived(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
272                 // Remove ourself and do the actual handshake
273                 ctx.pipeline().remove(this);
274                 handshake(channel, msg, responseHeaders, promise);
275             }
276 
277             @Override
278             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
279                 // Remove ourself and fail the handshake promise.
280                 ctx.pipeline().remove(this);
281                 promise.tryFailure(cause);
282                 ctx.fireExceptionCaught(cause);
283             }
284 
285             @Override
286             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
287                 // Fail promise if Channel was closed
288                 promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
289                 ctx.fireChannelInactive();
290             }
291         });
292         try {
293             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
294         } catch (Throwable cause) {
295             promise.setFailure(cause);
296         }
297         return promise;
298     }
299 
300     /**
301      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
302      */
303     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
304                                          HttpHeaders responseHeaders);
305     /**
306      * Performs the closing handshake
307      *
308      * @param channel
309      *            Channel
310      * @param frame
311      *            Closing Frame that was received
312      */
313     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
314         if (channel == null) {
315             throw new NullPointerException("channel");
316         }
317         return close(channel, frame, channel.newPromise());
318     }
319 
320     /**
321      * Performs the closing handshake
322      *
323      * @param channel
324      *            Channel
325      * @param frame
326      *            Closing Frame that was received
327      * @param promise
328      *            the {@link ChannelPromise} to be notified when the closing handshake is done
329      */
330     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
331         if (channel == null) {
332             throw new NullPointerException("channel");
333         }
334         return channel.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
335     }
336 
337     /**
338      * Selects the first matching supported sub protocol
339      *
340      * @param requestedSubprotocols
341      *            CSV of protocols to be supported. e.g. "chat, superchat"
342      * @return First matching supported sub protocol. Null if not found.
343      */
344     protected String selectSubprotocol(String requestedSubprotocols) {
345         if (requestedSubprotocols == null || subprotocols.length == 0) {
346             return null;
347         }
348 
349         String[] requestedSubprotocolArray = StringUtil.split(requestedSubprotocols, ',');
350         for (String p: requestedSubprotocolArray) {
351             String requestedSubprotocol = p.trim();
352 
353             for (String supportedSubprotocol: subprotocols) {
354                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
355                         || requestedSubprotocol.equals(supportedSubprotocol)) {
356                     selectedSubprotocol = requestedSubprotocol;
357                     return requestedSubprotocol;
358                 }
359             }
360         }
361 
362         // No match found
363         return null;
364     }
365 
366     /**
367      * Returns the selected subprotocol. Null if no subprotocol has been selected.
368      * <p>
369      * This is only available AFTER <tt>handshake()</tt> has been called.
370      * </p>
371      */
372     public String selectedSubprotocol() {
373         return selectedSubprotocol;
374     }
375 
376     /**
377      * Returns the decoder to use after handshake is complete.
378      */
379     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
380 
381     /**
382      * Returns the encoder to use after the handshake is complete.
383      */
384     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
385 
386 }