View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpContentCompressor;
28  import io.netty.handler.codec.http.HttpHeaders;
29  import io.netty.handler.codec.http.HttpObjectAggregator;
30  import io.netty.handler.codec.http.HttpRequest;
31  import io.netty.handler.codec.http.HttpRequestDecoder;
32  import io.netty.handler.codec.http.HttpResponseEncoder;
33  import io.netty.handler.codec.http.HttpServerCodec;
34  import io.netty.util.ReferenceCountUtil;
35  import io.netty.util.internal.EmptyArrays;
36  import io.netty.util.internal.ThrowableUtil;
37  import io.netty.util.internal.logging.InternalLogger;
38  import io.netty.util.internal.logging.InternalLoggerFactory;
39  
40  import java.nio.channels.ClosedChannelException;
41  import java.util.Collections;
42  import java.util.LinkedHashSet;
43  import java.util.Set;
44  
45  /**
46   * Base class for server side web socket opening and closing handshakes
47   */
48  public abstract class WebSocketServerHandshaker {
49      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
50      private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
51              new ClosedChannelException(), WebSocketServerHandshaker.class, "handshake(...)");
52  
53      private final String uri;
54  
55      private final String[] subprotocols;
56  
57      private final WebSocketVersion version;
58  
59      private final int maxFramePayloadLength;
60  
61      private String selectedSubprotocol;
62  
63      /**
64       * Use this as wildcard to support all requested sub-protocols
65       */
66      public static final String SUB_PROTOCOL_WILDCARD = "*";
67  
68      /**
69       * Constructor specifying the destination web socket location
70       *
71       * @param version
72       *            the protocol version
73       * @param uri
74       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
75       *            sent to this URL.
76       * @param subprotocols
77       *            CSV of supported protocols. Null if sub protocols not supported.
78       * @param maxFramePayloadLength
79       *            Maximum length of a frame's payload
80       */
81      protected WebSocketServerHandshaker(
82              WebSocketVersion version, String uri, String subprotocols,
83              int maxFramePayloadLength) {
84          this.version = version;
85          this.uri = uri;
86          if (subprotocols != null) {
87              String[] subprotocolArray = subprotocols.split(",");
88              for (int i = 0; i < subprotocolArray.length; i++) {
89                  subprotocolArray[i] = subprotocolArray[i].trim();
90              }
91              this.subprotocols = subprotocolArray;
92          } else {
93              this.subprotocols = EmptyArrays.EMPTY_STRINGS;
94          }
95          this.maxFramePayloadLength = maxFramePayloadLength;
96      }
97  
98      /**
99       * Returns the URL of the web socket
100      */
101     public String uri() {
102         return uri;
103     }
104 
105     /**
106      * Returns the CSV of supported sub protocols
107      */
108     public Set<String> subprotocols() {
109         Set<String> ret = new LinkedHashSet<String>();
110         Collections.addAll(ret, subprotocols);
111         return ret;
112     }
113 
114     /**
115      * Returns the version of the specification being supported
116      */
117     public WebSocketVersion version() {
118         return version;
119     }
120 
121     /**
122      * Gets the maximum length for any frame's payload.
123      *
124      * @return The maximum length for a frame's payload
125      */
126     public int maxFramePayloadLength() {
127         return maxFramePayloadLength;
128     }
129 
130     /**
131      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
132      * {@link FullHttpRequest} which is passed in.
133      *
134      * @param channel
135      *              Channel
136      * @param req
137      *              HTTP Request
138      * @return future
139      *              The {@link ChannelFuture} which is notified once the opening handshake completes
140      */
141     public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
142         return handshake(channel, req, null, channel.newPromise());
143     }
144 
145     /**
146      * Performs the opening handshake
147      *
148      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
149      *
150      * @param channel
151      *            Channel
152      * @param req
153      *            HTTP Request
154      * @param responseHeaders
155      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
156      * @param promise
157      *            the {@link ChannelPromise} to be notified when the opening handshake is done
158      * @return future
159      *            the {@link ChannelFuture} which is notified when the opening handshake is done
160      */
161     public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
162                                             HttpHeaders responseHeaders, final ChannelPromise promise) {
163 
164         if (logger.isDebugEnabled()) {
165             logger.debug("{} WebSocket version {} server handshake", channel, version());
166         }
167         FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
168         ChannelPipeline p = channel.pipeline();
169         if (p.get(HttpObjectAggregator.class) != null) {
170             p.remove(HttpObjectAggregator.class);
171         }
172         if (p.get(HttpContentCompressor.class) != null) {
173             p.remove(HttpContentCompressor.class);
174         }
175         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
176         final String encoderName;
177         if (ctx == null) {
178             // this means the user use a HttpServerCodec
179             ctx = p.context(HttpServerCodec.class);
180             if (ctx == null) {
181                 promise.setFailure(
182                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
183                 return promise;
184             }
185             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
186             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
187             encoderName = ctx.name();
188         } else {
189             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
190 
191             encoderName = p.context(HttpResponseEncoder.class).name();
192             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
193         }
194         channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
195             @Override
196             public void operationComplete(ChannelFuture future) throws Exception {
197                 if (future.isSuccess()) {
198                     ChannelPipeline p = future.channel().pipeline();
199                     p.remove(encoderName);
200                     promise.setSuccess();
201                 } else {
202                     promise.setFailure(future.cause());
203                 }
204             }
205         });
206         return promise;
207     }
208 
209     /**
210      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
211      * {@link FullHttpRequest} which is passed in.
212      *
213      * @param channel
214      *              Channel
215      * @param req
216      *              HTTP Request
217      * @return future
218      *              The {@link ChannelFuture} which is notified once the opening handshake completes
219      */
220     public ChannelFuture handshake(Channel channel, HttpRequest req) {
221         return handshake(channel, req, null, channel.newPromise());
222     }
223 
224     /**
225      * Performs the opening handshake
226      *
227      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
228      *
229      * @param channel
230      *            Channel
231      * @param req
232      *            HTTP Request
233      * @param responseHeaders
234      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
235      * @param promise
236      *            the {@link ChannelPromise} to be notified when the opening handshake is done
237      * @return future
238      *            the {@link ChannelFuture} which is notified when the opening handshake is done
239      */
240     public final ChannelFuture handshake(final Channel channel, HttpRequest req,
241                                          final HttpHeaders responseHeaders, final ChannelPromise promise) {
242 
243         if (req instanceof FullHttpRequest) {
244             return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
245         }
246         if (logger.isDebugEnabled()) {
247             logger.debug("{} WebSocket version {} server handshake", channel, version());
248         }
249         ChannelPipeline p = channel.pipeline();
250         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
251         if (ctx == null) {
252             // this means the user use a HttpServerCodec
253             ctx = p.context(HttpServerCodec.class);
254             if (ctx == null) {
255                 promise.setFailure(
256                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
257                 return promise;
258             }
259         }
260         // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
261         // enough for the websockets handshake payload.
262         //
263         // TODO: Make handshake work without HttpObjectAggregator at all.
264         String aggregatorName = "httpAggregator";
265         p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
266         p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
267             @Override
268             protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
269                 // Remove ourself and do the actual handshake
270                 ctx.pipeline().remove(this);
271                 handshake(channel, msg, responseHeaders, promise);
272             }
273 
274             @Override
275             public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
276                 // Remove ourself and fail the handshake promise.
277                 ctx.pipeline().remove(this);
278                 promise.tryFailure(cause);
279                 ctx.fireExceptionCaught(cause);
280             }
281 
282             @Override
283             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
284                 // Fail promise if Channel was closed
285                 promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
286                 ctx.fireChannelInactive();
287             }
288         });
289         try {
290             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
291         } catch (Throwable cause) {
292             promise.setFailure(cause);
293         }
294         return promise;
295     }
296 
297     /**
298      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
299      */
300     protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
301                                          HttpHeaders responseHeaders);
302     /**
303      * Performs the closing handshake
304      *
305      * @param channel
306      *            Channel
307      * @param frame
308      *            Closing Frame that was received
309      */
310     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
311         if (channel == null) {
312             throw new NullPointerException("channel");
313         }
314         return close(channel, frame, channel.newPromise());
315     }
316 
317     /**
318      * Performs the closing handshake
319      *
320      * @param channel
321      *            Channel
322      * @param frame
323      *            Closing Frame that was received
324      * @param promise
325      *            the {@link ChannelPromise} to be notified when the closing handshake is done
326      */
327     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
328         if (channel == null) {
329             throw new NullPointerException("channel");
330         }
331         return channel.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
332     }
333 
334     /**
335      * Selects the first matching supported sub protocol
336      *
337      * @param requestedSubprotocols
338      *            CSV of protocols to be supported. e.g. "chat, superchat"
339      * @return First matching supported sub protocol. Null if not found.
340      */
341     protected String selectSubprotocol(String requestedSubprotocols) {
342         if (requestedSubprotocols == null || subprotocols.length == 0) {
343             return null;
344         }
345 
346         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
347         for (String p: requestedSubprotocolArray) {
348             String requestedSubprotocol = p.trim();
349 
350             for (String supportedSubprotocol: subprotocols) {
351                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
352                         || requestedSubprotocol.equals(supportedSubprotocol)) {
353                     selectedSubprotocol = requestedSubprotocol;
354                     return requestedSubprotocol;
355                 }
356             }
357         }
358 
359         // No match found
360         return null;
361     }
362 
363     /**
364      * Returns the selected subprotocol. Null if no subprotocol has been selected.
365      * <p>
366      * This is only available AFTER <tt>handshake()</tt> has been called.
367      * </p>
368      */
369     public String selectedSubprotocol() {
370         return selectedSubprotocol;
371     }
372 
373     /**
374      * Returns the decoder to use after handshake is complete.
375      */
376     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
377 
378     /**
379      * Returns the encoder to use after the handshake is complete.
380      */
381     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
382 
383 }