View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpClientCodec;
28  import io.netty.handler.codec.http.HttpContentDecompressor;
29  import io.netty.handler.codec.http.HttpHeaderNames;
30  import io.netty.handler.codec.http.HttpHeaders;
31  import io.netty.handler.codec.http.HttpObjectAggregator;
32  import io.netty.handler.codec.http.HttpRequestEncoder;
33  import io.netty.handler.codec.http.HttpResponse;
34  import io.netty.handler.codec.http.HttpResponseDecoder;
35  import io.netty.handler.codec.http.HttpScheme;
36  import io.netty.util.NetUtil;
37  import io.netty.util.ReferenceCountUtil;
38  import io.netty.util.internal.ThrowableUtil;
39  
40  import java.net.URI;
41  import java.nio.channels.ClosedChannelException;
42  import java.util.Locale;
43  
44  /**
45   * Base class for web socket client handshake implementations
46   */
47  public abstract class WebSocketClientHandshaker {
48      private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
49              new ClosedChannelException(), WebSocketClientHandshaker.class, "processHandshake(...)");
50  
51      private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
52      private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
53  
54      private final URI uri;
55  
56      private final WebSocketVersion version;
57  
58      private volatile boolean handshakeComplete;
59  
60      private final String expectedSubprotocol;
61  
62      private volatile String actualSubprotocol;
63  
64      protected final HttpHeaders customHeaders;
65  
66      private final int maxFramePayloadLength;
67  
68      /**
69       * Base constructor
70       *
71       * @param uri
72       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
73       *            sent to this URL.
74       * @param version
75       *            Version of web socket specification to use to connect to the server
76       * @param subprotocol
77       *            Sub protocol request sent to the server.
78       * @param customHeaders
79       *            Map of custom headers to add to the client request
80       * @param maxFramePayloadLength
81       *            Maximum length of a frame's payload
82       */
83      protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
84                                          HttpHeaders customHeaders, int maxFramePayloadLength) {
85          this.uri = uri;
86          this.version = version;
87          expectedSubprotocol = subprotocol;
88          this.customHeaders = customHeaders;
89          this.maxFramePayloadLength = maxFramePayloadLength;
90      }
91  
92      /**
93       * Returns the URI to the web socket. e.g. "ws://myhost.com/path"
94       */
95      public URI uri() {
96          return uri;
97      }
98  
99      /**
100      * Version of the web socket specification that is being used
101      */
102     public WebSocketVersion version() {
103         return version;
104     }
105 
106     /**
107      * Returns the max length for any frame's payload
108      */
109     public int maxFramePayloadLength() {
110         return maxFramePayloadLength;
111     }
112 
113     /**
114      * Flag to indicate if the opening handshake is complete
115      */
116     public boolean isHandshakeComplete() {
117         return handshakeComplete;
118     }
119 
120     private void setHandshakeComplete() {
121         handshakeComplete = true;
122     }
123 
124     /**
125      * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor
126      */
127     public String expectedSubprotocol() {
128         return expectedSubprotocol;
129     }
130 
131     /**
132      * Returns the subprotocol response sent by the server. Only available after end of handshake.
133      * Null if no subprotocol was requested or confirmed by the server.
134      */
135     public String actualSubprotocol() {
136         return actualSubprotocol;
137     }
138 
139     private void setActualSubprotocol(String actualSubprotocol) {
140         this.actualSubprotocol = actualSubprotocol;
141     }
142 
143     /**
144      * Begins the opening handshake
145      *
146      * @param channel
147      *            Channel
148      */
149     public ChannelFuture handshake(Channel channel) {
150         if (channel == null) {
151             throw new NullPointerException("channel");
152         }
153         return handshake(channel, channel.newPromise());
154     }
155 
156     /**
157      * Begins the opening handshake
158      *
159      * @param channel
160      *            Channel
161      * @param promise
162      *            the {@link ChannelPromise} to be notified when the opening handshake is sent
163      */
164     public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
165         FullHttpRequest request =  newHandshakeRequest();
166 
167         HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
168         if (decoder == null) {
169             HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
170             if (codec == null) {
171                promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
172                        "a HttpResponseDecoder or HttpClientCodec"));
173                return promise;
174             }
175         }
176 
177         channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
178             @Override
179             public void operationComplete(ChannelFuture future) {
180                 if (future.isSuccess()) {
181                     ChannelPipeline p = future.channel().pipeline();
182                     ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
183                     if (ctx == null) {
184                         ctx = p.context(HttpClientCodec.class);
185                     }
186                     if (ctx == null) {
187                         promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
188                                 "a HttpRequestEncoder or HttpClientCodec"));
189                         return;
190                     }
191                     p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
192 
193                     promise.setSuccess();
194                 } else {
195                     promise.setFailure(future.cause());
196                 }
197             }
198         });
199         return promise;
200     }
201 
202     /**
203      * Returns a new {@link FullHttpRequest) which will be used for the handshake.
204      */
205     protected abstract FullHttpRequest newHandshakeRequest();
206 
207     /**
208      * Validates and finishes the opening handshake initiated by {@link #handshake}}.
209      *
210      * @param channel
211      *            Channel
212      * @param response
213      *            HTTP response containing the closing handshake details
214      */
215     public final void finishHandshake(Channel channel, FullHttpResponse response) {
216         verify(response);
217 
218         // Verify the subprotocol that we received from the server.
219         // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol
220         String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
221         receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
222         String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
223         boolean protocolValid = false;
224 
225         if (expectedProtocol.isEmpty() && receivedProtocol == null) {
226             // No subprotocol required and none received
227             protocolValid = true;
228             setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested
229         } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
230             // We require a subprotocol and received one -> verify it
231             for (String protocol : expectedProtocol.split(",")) {
232                 if (protocol.trim().equals(receivedProtocol)) {
233                     protocolValid = true;
234                     setActualSubprotocol(receivedProtocol);
235                     break;
236                 }
237             }
238         } // else mixed cases - which are all errors
239 
240         if (!protocolValid) {
241             throw new WebSocketHandshakeException(String.format(
242                     "Invalid subprotocol. Actual: %s. Expected one of: %s",
243                     receivedProtocol, expectedSubprotocol));
244         }
245 
246         setHandshakeComplete();
247 
248         final ChannelPipeline p = channel.pipeline();
249         // Remove decompressor from pipeline if its in use
250         HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
251         if (decompressor != null) {
252             p.remove(decompressor);
253         }
254 
255         // Remove aggregator if present before
256         HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
257         if (aggregator != null) {
258             p.remove(aggregator);
259         }
260 
261         ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
262         if (ctx == null) {
263             ctx = p.context(HttpClientCodec.class);
264             if (ctx == null) {
265                 throw new IllegalStateException("ChannelPipeline does not contain " +
266                         "a HttpRequestEncoder or HttpClientCodec");
267             }
268             final HttpClientCodec codec =  (HttpClientCodec) ctx.handler();
269             // Remove the encoder part of the codec as the user may start writing frames after this method returns.
270             codec.removeOutboundHandler();
271 
272             p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
273 
274             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
275             // WebSocketFrame messages.
276             // See https://github.com/netty/netty/issues/4533
277             channel.eventLoop().execute(new Runnable() {
278                 @Override
279                 public void run() {
280                     p.remove(codec);
281                 }
282             });
283         } else {
284             if (p.get(HttpRequestEncoder.class) != null) {
285                 // Remove the encoder part of the codec as the user may start writing frames after this method returns.
286                 p.remove(HttpRequestEncoder.class);
287             }
288             final ChannelHandlerContext context = ctx;
289             p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
290 
291             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
292             // WebSocketFrame messages.
293             // See https://github.com/netty/netty/issues/4533
294             channel.eventLoop().execute(new Runnable() {
295                 @Override
296                 public void run() {
297                     p.remove(context.handler());
298                 }
299             });
300         }
301     }
302 
303     /**
304      * Process the opening handshake initiated by {@link #handshake}}.
305      *
306      * @param channel
307      *            Channel
308      * @param response
309      *            HTTP response containing the closing handshake details
310      * @return future
311      *            the {@link ChannelFuture} which is notified once the handshake completes.
312      */
313     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
314         return processHandshake(channel, response, channel.newPromise());
315     }
316 
317     /**
318      * Process the opening handshake initiated by {@link #handshake}}.
319      *
320      * @param channel
321      *            Channel
322      * @param response
323      *            HTTP response containing the closing handshake details
324      * @param promise
325      *            the {@link ChannelPromise} to notify once the handshake completes.
326      * @return future
327      *            the {@link ChannelFuture} which is notified once the handshake completes.
328      */
329     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
330                                                 final ChannelPromise promise) {
331         if (response instanceof FullHttpResponse) {
332             try {
333                 finishHandshake(channel, (FullHttpResponse) response);
334                 promise.setSuccess();
335             } catch (Throwable cause) {
336                 promise.setFailure(cause);
337             }
338         } else {
339             ChannelPipeline p = channel.pipeline();
340             ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
341             if (ctx == null) {
342                 ctx = p.context(HttpClientCodec.class);
343                 if (ctx == null) {
344                     return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
345                             "a HttpResponseDecoder or HttpClientCodec"));
346                 }
347             }
348             // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
349             // then enough for the websockets handshake payload.
350             //
351             // TODO: Make handshake work without HttpObjectAggregator at all.
352             String aggregatorName = "httpAggregator";
353             p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
354             p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
355                 @Override
356                 protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
357                     // Remove ourself and do the actual handshake
358                     ctx.pipeline().remove(this);
359                     try {
360                         finishHandshake(channel, msg);
361                         promise.setSuccess();
362                     } catch (Throwable cause) {
363                         promise.setFailure(cause);
364                     }
365                 }
366 
367                 @Override
368                 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
369                     // Remove ourself and fail the handshake promise.
370                     ctx.pipeline().remove(this);
371                     promise.setFailure(cause);
372                 }
373 
374                 @Override
375                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
376                     // Fail promise if Channel was closed
377                     promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
378                     ctx.fireChannelInactive();
379                 }
380             });
381             try {
382                 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
383             } catch (Throwable cause) {
384                 promise.setFailure(cause);
385             }
386         }
387         return promise;
388     }
389 
390     /**
391      * Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
392      */
393     protected abstract void verify(FullHttpResponse response);
394 
395     /**
396      * Returns the decoder to use after handshake is complete.
397      */
398     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
399 
400     /**
401      * Returns the encoder to use after the handshake is complete.
402      */
403     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
404 
405     /**
406      * Performs the closing handshake
407      *
408      * @param channel
409      *            Channel
410      * @param frame
411      *            Closing Frame that was received
412      */
413     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
414         if (channel == null) {
415             throw new NullPointerException("channel");
416         }
417         return close(channel, frame, channel.newPromise());
418     }
419 
420     /**
421      * Performs the closing handshake
422      *
423      * @param channel
424      *            Channel
425      * @param frame
426      *            Closing Frame that was received
427      * @param promise
428      *            the {@link ChannelPromise} to be notified when the closing handshake is done
429      */
430     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
431         if (channel == null) {
432             throw new NullPointerException("channel");
433         }
434         return channel.writeAndFlush(frame, promise);
435     }
436 
437     /**
438      * Return the constructed raw path for the give {@link URI}.
439      */
440     static String rawPath(URI wsURL) {
441         String path = wsURL.getRawPath();
442         String query = wsURL.getRawQuery();
443         if (query != null && !query.isEmpty()) {
444             path = path + '?' + query;
445         }
446 
447         return path == null || path.isEmpty() ? "/" : path;
448     }
449 
450     static CharSequence websocketHostValue(URI wsURL) {
451         int port = wsURL.getPort();
452         if (port == -1) {
453             return wsURL.getHost();
454         }
455         String host = wsURL.getHost();
456         if (port == HttpScheme.HTTP.port()) {
457             return HttpScheme.HTTP.name().contentEquals(wsURL.getScheme())
458                     || WebSocketScheme.WS.name().contentEquals(wsURL.getScheme()) ?
459                     host : NetUtil.toSocketAddressString(host, port);
460         }
461         if (port == HttpScheme.HTTPS.port()) {
462             return HttpScheme.HTTPS.name().contentEquals(wsURL.getScheme())
463                     || WebSocketScheme.WSS.name().contentEquals(wsURL.getScheme()) ?
464                     host : NetUtil.toSocketAddressString(host, port);
465         }
466 
467         // if the port is not standard (80/443) its needed to add the port to the header.
468         // See http://tools.ietf.org/html/rfc6454#section-6.2
469         return NetUtil.toSocketAddressString(host, port);
470     }
471 
472     static CharSequence websocketOriginValue(URI wsURL) {
473         String scheme = wsURL.getScheme();
474         final String schemePrefix;
475         int port = wsURL.getPort();
476         final int defaultPort;
477         if (WebSocketScheme.WSS.name().contentEquals(scheme)
478             || HttpScheme.HTTPS.name().contentEquals(scheme)
479             || (scheme == null && port == WebSocketScheme.WSS.port())) {
480 
481             schemePrefix = HTTPS_SCHEME_PREFIX;
482             defaultPort = WebSocketScheme.WSS.port();
483         } else {
484             schemePrefix = HTTP_SCHEME_PREFIX;
485             defaultPort = WebSocketScheme.WS.port();
486         }
487 
488         // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
489         String host = wsURL.getHost().toLowerCase(Locale.US);
490 
491         if (port != defaultPort && port != -1) {
492             // if the port is not standard (80/443) its needed to add the port to the header.
493             // See http://tools.ietf.org/html/rfc6454#section-6.2
494             return schemePrefix + NetUtil.toSocketAddressString(host, port);
495         }
496         return schemePrefix + host;
497     }
498 }