View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpClientCodec;
28  import io.netty.handler.codec.http.HttpContentDecompressor;
29  import io.netty.handler.codec.http.HttpHeaders;
30  import io.netty.handler.codec.http.HttpObjectAggregator;
31  import io.netty.handler.codec.http.HttpRequestEncoder;
32  import io.netty.handler.codec.http.HttpResponse;
33  import io.netty.handler.codec.http.HttpResponseDecoder;
34  import io.netty.util.NetUtil;
35  import io.netty.util.ReferenceCountUtil;
36  import io.netty.util.internal.ThrowableUtil;
37  
38  import java.net.URI;
39  import java.nio.channels.ClosedChannelException;
40  import java.util.Locale;
41  
42  /**
43   * Base class for web socket client handshake implementations
44   */
45  public abstract class WebSocketClientHandshaker {
46      private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
47              new ClosedChannelException(), WebSocketClientHandshaker.class, "processHandshake(...)");
48  
49      private static final String HTTP_SCHEME_PREFIX = "http://";
50      private static final String HTTPS_SCHEME_PREFIX = "https://";
51  
52      private final URI uri;
53  
54      private final WebSocketVersion version;
55  
56      private volatile boolean handshakeComplete;
57  
58      private final String expectedSubprotocol;
59  
60      private volatile String actualSubprotocol;
61  
62      protected final HttpHeaders customHeaders;
63  
64      private final int maxFramePayloadLength;
65  
66      /**
67       * Base constructor
68       *
69       * @param uri
70       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
71       *            sent to this URL.
72       * @param version
73       *            Version of web socket specification to use to connect to the server
74       * @param subprotocol
75       *            Sub protocol request sent to the server.
76       * @param customHeaders
77       *            Map of custom headers to add to the client request
78       * @param maxFramePayloadLength
79       *            Maximum length of a frame's payload
80       */
81      protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
82                                          HttpHeaders customHeaders, int maxFramePayloadLength) {
83          this.uri = uri;
84          this.version = version;
85          expectedSubprotocol = subprotocol;
86          this.customHeaders = customHeaders;
87          this.maxFramePayloadLength = maxFramePayloadLength;
88      }
89  
90      /**
91       * Returns the URI to the web socket. e.g. "ws://myhost.com/path"
92       */
93      public URI uri() {
94          return uri;
95      }
96  
97      /**
98       * Version of the web socket specification that is being used
99       */
100     public WebSocketVersion version() {
101         return version;
102     }
103 
104     /**
105      * Returns the max length for any frame's payload
106      */
107     public int maxFramePayloadLength() {
108         return maxFramePayloadLength;
109     }
110 
111     /**
112      * Flag to indicate if the opening handshake is complete
113      */
114     public boolean isHandshakeComplete() {
115         return handshakeComplete;
116     }
117 
118     private void setHandshakeComplete() {
119         handshakeComplete = true;
120     }
121 
122     /**
123      * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor
124      */
125     public String expectedSubprotocol() {
126         return expectedSubprotocol;
127     }
128 
129     /**
130      * Returns the subprotocol response sent by the server. Only available after end of handshake.
131      * Null if no subprotocol was requested or confirmed by the server.
132      */
133     public String actualSubprotocol() {
134         return actualSubprotocol;
135     }
136 
137     private void setActualSubprotocol(String actualSubprotocol) {
138         this.actualSubprotocol = actualSubprotocol;
139     }
140 
141     /**
142      * Begins the opening handshake
143      *
144      * @param channel
145      *            Channel
146      */
147     public ChannelFuture handshake(Channel channel) {
148         if (channel == null) {
149             throw new NullPointerException("channel");
150         }
151         return handshake(channel, channel.newPromise());
152     }
153 
154     /**
155      * Begins the opening handshake
156      *
157      * @param channel
158      *            Channel
159      * @param promise
160      *            the {@link ChannelPromise} to be notified when the opening handshake is sent
161      */
162     public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
163         FullHttpRequest request =  newHandshakeRequest();
164 
165         HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
166         if (decoder == null) {
167             HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
168             if (codec == null) {
169                promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
170                        "a HttpResponseDecoder or HttpClientCodec"));
171                return promise;
172             }
173         }
174 
175         channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
176             @Override
177             public void operationComplete(ChannelFuture future) {
178                 if (future.isSuccess()) {
179                     ChannelPipeline p = future.channel().pipeline();
180                     ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
181                     if (ctx == null) {
182                         ctx = p.context(HttpClientCodec.class);
183                     }
184                     if (ctx == null) {
185                         promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
186                                 "a HttpRequestEncoder or HttpClientCodec"));
187                         return;
188                     }
189                     p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
190 
191                     promise.setSuccess();
192                 } else {
193                     promise.setFailure(future.cause());
194                 }
195             }
196         });
197         return promise;
198     }
199 
200     /**
201      * Returns a new {@link FullHttpRequest) which will be used for the handshake.
202      */
203     protected abstract FullHttpRequest newHandshakeRequest();
204 
205     /**
206      * Validates and finishes the opening handshake initiated by {@link #handshake}}.
207      *
208      * @param channel
209      *            Channel
210      * @param response
211      *            HTTP response containing the closing handshake details
212      */
213     public final void finishHandshake(Channel channel, FullHttpResponse response) {
214         verify(response);
215 
216         // Verify the subprotocol that we received from the server.
217         // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol
218         String receivedProtocol = response.headers().get(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL);
219         receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
220         String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
221         boolean protocolValid = false;
222 
223         if (expectedProtocol.isEmpty() && receivedProtocol == null) {
224             // No subprotocol required and none received
225             protocolValid = true;
226             setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested
227         } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
228             // We require a subprotocol and received one -> verify it
229             for (String protocol : expectedProtocol.split(",")) {
230                 if (protocol.trim().equals(receivedProtocol)) {
231                     protocolValid = true;
232                     setActualSubprotocol(receivedProtocol);
233                     break;
234                 }
235             }
236         } // else mixed cases - which are all errors
237 
238         if (!protocolValid) {
239             throw new WebSocketHandshakeException(String.format(
240                     "Invalid subprotocol. Actual: %s. Expected one of: %s",
241                     receivedProtocol, expectedSubprotocol));
242         }
243 
244         setHandshakeComplete();
245 
246         final ChannelPipeline p = channel.pipeline();
247         // Remove decompressor from pipeline if its in use
248         HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
249         if (decompressor != null) {
250             p.remove(decompressor);
251         }
252 
253         // Remove aggregator if present before
254         HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
255         if (aggregator != null) {
256             p.remove(aggregator);
257         }
258 
259         ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
260         if (ctx == null) {
261             ctx = p.context(HttpClientCodec.class);
262             if (ctx == null) {
263                 throw new IllegalStateException("ChannelPipeline does not contain " +
264                         "a HttpRequestEncoder or HttpClientCodec");
265             }
266             final HttpClientCodec codec =  (HttpClientCodec) ctx.handler();
267             // Remove the encoder part of the codec as the user may start writing frames after this method returns.
268             codec.removeOutboundHandler();
269 
270             p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
271 
272             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
273             // WebSocketFrame messages.
274             // See https://github.com/netty/netty/issues/4533
275             channel.eventLoop().execute(new Runnable() {
276                 @Override
277                 public void run() {
278                     p.remove(codec);
279                 }
280             });
281         } else {
282             if (p.get(HttpRequestEncoder.class) != null) {
283                 // Remove the encoder part of the codec as the user may start writing frames after this method returns.
284                 p.remove(HttpRequestEncoder.class);
285             }
286             final ChannelHandlerContext context = ctx;
287             p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
288 
289             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
290             // WebSocketFrame messages.
291             // See https://github.com/netty/netty/issues/4533
292             channel.eventLoop().execute(new Runnable() {
293                 @Override
294                 public void run() {
295                     p.remove(context.handler());
296                 }
297             });
298         }
299     }
300 
301     /**
302      * Process the opening handshake initiated by {@link #handshake}}.
303      *
304      * @param channel
305      *            Channel
306      * @param response
307      *            HTTP response containing the closing handshake details
308      * @return future
309      *            the {@link ChannelFuture} which is notified once the handshake completes.
310      */
311     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
312         return processHandshake(channel, response, channel.newPromise());
313     }
314 
315     /**
316      * Process the opening handshake initiated by {@link #handshake}}.
317      *
318      * @param channel
319      *            Channel
320      * @param response
321      *            HTTP response containing the closing handshake details
322      * @param promise
323      *            the {@link ChannelPromise} to notify once the handshake completes.
324      * @return future
325      *            the {@link ChannelFuture} which is notified once the handshake completes.
326      */
327     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
328                                                 final ChannelPromise promise) {
329         if (response instanceof FullHttpResponse) {
330             try {
331                 finishHandshake(channel, (FullHttpResponse) response);
332                 promise.setSuccess();
333             } catch (Throwable cause) {
334                 promise.setFailure(cause);
335             }
336         } else {
337             ChannelPipeline p = channel.pipeline();
338             ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
339             if (ctx == null) {
340                 ctx = p.context(HttpClientCodec.class);
341                 if (ctx == null) {
342                     return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
343                             "a HttpResponseDecoder or HttpClientCodec"));
344                 }
345             }
346             // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
347             // then enough for the websockets handshake payload.
348             //
349             // TODO: Make handshake work without HttpObjectAggregator at all.
350             String aggregatorName = "httpAggregator";
351             p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
352             p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
353                 @Override
354                 protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
355                     // Remove ourself and do the actual handshake
356                     ctx.pipeline().remove(this);
357                     try {
358                         finishHandshake(channel, msg);
359                         promise.setSuccess();
360                     } catch (Throwable cause) {
361                         promise.setFailure(cause);
362                     }
363                 }
364 
365                 @Override
366                 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
367                     // Remove ourself and fail the handshake promise.
368                     ctx.pipeline().remove(this);
369                     promise.setFailure(cause);
370                 }
371 
372                 @Override
373                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
374                     // Fail promise if Channel was closed
375                     promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
376                     ctx.fireChannelInactive();
377                 }
378             });
379             try {
380                 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
381             } catch (Throwable cause) {
382                 promise.setFailure(cause);
383             }
384         }
385         return promise;
386     }
387 
388     /**
389      * Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
390      */
391     protected abstract void verify(FullHttpResponse response);
392 
393     /**
394      * Returns the decoder to use after handshake is complete.
395      */
396     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
397 
398     /**
399      * Returns the encoder to use after the handshake is complete.
400      */
401     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
402 
403     /**
404      * Performs the closing handshake
405      *
406      * @param channel
407      *            Channel
408      * @param frame
409      *            Closing Frame that was received
410      */
411     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
412         if (channel == null) {
413             throw new NullPointerException("channel");
414         }
415         return close(channel, frame, channel.newPromise());
416     }
417 
418     /**
419      * Performs the closing handshake
420      *
421      * @param channel
422      *            Channel
423      * @param frame
424      *            Closing Frame that was received
425      * @param promise
426      *            the {@link ChannelPromise} to be notified when the closing handshake is done
427      */
428     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
429         if (channel == null) {
430             throw new NullPointerException("channel");
431         }
432         return channel.writeAndFlush(frame, promise);
433     }
434 
435     /**
436      * Return the constructed raw path for the give {@link URI}.
437      */
438     static String rawPath(URI wsURL) {
439         String path = wsURL.getRawPath();
440         String query = wsURL.getRawQuery();
441         if (query != null && !query.isEmpty()) {
442             path = path + '?' + query;
443         }
444 
445         return path == null || path.isEmpty() ? "/" : path;
446     }
447 
448     static CharSequence websocketHostValue(URI wsURL) {
449         int port = wsURL.getPort();
450         if (port == -1) {
451             return wsURL.getHost();
452         }
453         String host = wsURL.getHost();
454         if (port == 80) {
455             return "http".equals(wsURL.getScheme())
456                     || "ws".equals(wsURL.getScheme()) ?
457                     host : NetUtil.toSocketAddressString(host, 80);
458         }
459         if (port == 443) {
460             return "https".equals(wsURL.getScheme())
461                     || "wss".equals(wsURL.getScheme()) ?
462                     host : NetUtil.toSocketAddressString(host, 443);
463         }
464 
465         // if the port is not standard (80/443) its needed to add the port to the header.
466         // See http://tools.ietf.org/html/rfc6454#section-6.2
467         return NetUtil.toSocketAddressString(host, port);
468     }
469 
470     static CharSequence websocketOriginValue(URI wsURL) {
471         String scheme = wsURL.getScheme();
472         final String schemePrefix;
473         int port = wsURL.getPort();
474         final int defaultPort;
475         if ("wss".equals(scheme)
476             || "https".equals(scheme)
477             || (scheme == null && port == 443)) {
478 
479             schemePrefix = HTTPS_SCHEME_PREFIX;
480             defaultPort = 443;
481         } else {
482             schemePrefix = HTTP_SCHEME_PREFIX;
483             defaultPort = 80;
484         }
485 
486         // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
487         String host = wsURL.getHost().toLowerCase(Locale.US);
488 
489         if (port != defaultPort && port != -1) {
490             // if the port is not standard (80/443) its needed to add the port to the header.
491             // See http://tools.ietf.org/html/rfc6454#section-6.2
492             return schemePrefix + NetUtil.toSocketAddressString(host, port);
493         }
494         return schemePrefix + host;
495     }
496 }