View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpClientCodec;
28  import io.netty.handler.codec.http.HttpContentDecompressor;
29  import io.netty.handler.codec.http.HttpHeaderNames;
30  import io.netty.handler.codec.http.HttpHeaders;
31  import io.netty.handler.codec.http.HttpObjectAggregator;
32  import io.netty.handler.codec.http.HttpRequestEncoder;
33  import io.netty.handler.codec.http.HttpResponse;
34  import io.netty.handler.codec.http.HttpResponseDecoder;
35  import io.netty.handler.codec.http.HttpScheme;
36  import io.netty.util.NetUtil;
37  import io.netty.util.ReferenceCountUtil;
38  import io.netty.util.internal.ObjectUtil;
39  
40  import java.net.URI;
41  import java.nio.channels.ClosedChannelException;
42  import java.util.Locale;
43  import java.util.concurrent.Future;
44  import java.util.concurrent.TimeUnit;
45  import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
46  
47  /**
48   * Base class for web socket client handshake implementations
49   */
50  public abstract class WebSocketClientHandshaker {
51  
52      private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://";
53      private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://";
54      protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
55  
56      private final URI uri;
57  
58      private final WebSocketVersion version;
59  
60      private volatile boolean handshakeComplete;
61  
62      private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS;
63  
64      private volatile int forceCloseInit;
65  
66      private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER =
67              AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");
68  
69      private volatile boolean forceCloseComplete;
70  
71      private final String expectedSubprotocol;
72  
73      private volatile String actualSubprotocol;
74  
75      protected final HttpHeaders customHeaders;
76  
77      private final int maxFramePayloadLength;
78  
79      private final boolean absoluteUpgradeUrl;
80  
81      /**
82       * Base constructor
83       *
84       * @param uri
85       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
86       *            sent to this URL.
87       * @param version
88       *            Version of web socket specification to use to connect to the server
89       * @param subprotocol
90       *            Sub protocol request sent to the server.
91       * @param customHeaders
92       *            Map of custom headers to add to the client request
93       * @param maxFramePayloadLength
94       *            Maximum length of a frame's payload
95       */
96      protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
97                                          HttpHeaders customHeaders, int maxFramePayloadLength) {
98          this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
99      }
100 
101     /**
102      * Base constructor
103      *
104      * @param uri
105      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
106      *            sent to this URL.
107      * @param version
108      *            Version of web socket specification to use to connect to the server
109      * @param subprotocol
110      *            Sub protocol request sent to the server.
111      * @param customHeaders
112      *            Map of custom headers to add to the client request
113      * @param maxFramePayloadLength
114      *            Maximum length of a frame's payload
115      * @param forceCloseTimeoutMillis
116      *            Close the connection if it was not closed by the server after timeout specified
117      */
118     protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
119                                         HttpHeaders customHeaders, int maxFramePayloadLength,
120                                         long forceCloseTimeoutMillis) {
121         this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false);
122     }
123 
124     /**
125      * Base constructor
126      *
127      * @param uri
128      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
129      *            sent to this URL.
130      * @param version
131      *            Version of web socket specification to use to connect to the server
132      * @param subprotocol
133      *            Sub protocol request sent to the server.
134      * @param customHeaders
135      *            Map of custom headers to add to the client request
136      * @param maxFramePayloadLength
137      *            Maximum length of a frame's payload
138      * @param forceCloseTimeoutMillis
139      *            Close the connection if it was not closed by the server after timeout specified
140      * @param  absoluteUpgradeUrl
141      *            Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
142      *            clear HTTP
143      */
144     protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
145                                         HttpHeaders customHeaders, int maxFramePayloadLength,
146                                         long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
147         this.uri = uri;
148         this.version = version;
149         expectedSubprotocol = subprotocol;
150         this.customHeaders = customHeaders;
151         this.maxFramePayloadLength = maxFramePayloadLength;
152         this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
153         this.absoluteUpgradeUrl = absoluteUpgradeUrl;
154     }
155 
156     /**
157      * Returns the URI to the web socket. e.g. "ws://myhost.com/path"
158      */
159     public URI uri() {
160         return uri;
161     }
162 
163     /**
164      * Version of the web socket specification that is being used
165      */
166     public WebSocketVersion version() {
167         return version;
168     }
169 
170     /**
171      * Returns the max length for any frame's payload
172      */
173     public int maxFramePayloadLength() {
174         return maxFramePayloadLength;
175     }
176 
177     /**
178      * Flag to indicate if the opening handshake is complete
179      */
180     public boolean isHandshakeComplete() {
181         return handshakeComplete;
182     }
183 
184     private void setHandshakeComplete() {
185         handshakeComplete = true;
186     }
187 
188     /**
189      * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor
190      */
191     public String expectedSubprotocol() {
192         return expectedSubprotocol;
193     }
194 
195     /**
196      * Returns the subprotocol response sent by the server. Only available after end of handshake.
197      * Null if no subprotocol was requested or confirmed by the server.
198      */
199     public String actualSubprotocol() {
200         return actualSubprotocol;
201     }
202 
203     private void setActualSubprotocol(String actualSubprotocol) {
204         this.actualSubprotocol = actualSubprotocol;
205     }
206 
207     public long forceCloseTimeoutMillis() {
208         return forceCloseTimeoutMillis;
209     }
210 
211     /**
212      * Flag to indicate if the closing handshake was initiated because of timeout.
213      * For testing only.
214      */
215     protected boolean isForceCloseComplete() {
216         return forceCloseComplete;
217     }
218 
219     /**
220      * Sets timeout to close the connection if it was not closed by the server.
221      *
222      * @param forceCloseTimeoutMillis
223      *            Close the connection if it was not closed by the server after timeout specified
224      */
225     public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
226         this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
227         return this;
228     }
229 
230     /**
231      * Begins the opening handshake
232      *
233      * @param channel
234      *            Channel
235      */
236     public ChannelFuture handshake(Channel channel) {
237         ObjectUtil.checkNotNull(channel, "channel");
238         return handshake(channel, channel.newPromise());
239     }
240 
241     /**
242      * Begins the opening handshake
243      *
244      * @param channel
245      *            Channel
246      * @param promise
247      *            the {@link ChannelPromise} to be notified when the opening handshake is sent
248      */
249     public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
250         ChannelPipeline pipeline = channel.pipeline();
251         HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
252         if (decoder == null) {
253             HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
254             if (codec == null) {
255                promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
256                        "an HttpResponseDecoder or HttpClientCodec"));
257                return promise;
258             }
259         }
260 
261         FullHttpRequest request = newHandshakeRequest();
262 
263         channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
264             @Override
265             public void operationComplete(ChannelFuture future) {
266                 if (future.isSuccess()) {
267                     ChannelPipeline p = future.channel().pipeline();
268                     ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
269                     if (ctx == null) {
270                         ctx = p.context(HttpClientCodec.class);
271                     }
272                     if (ctx == null) {
273                         promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
274                                 "an HttpRequestEncoder or HttpClientCodec"));
275                         return;
276                     }
277                     p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
278 
279                     promise.setSuccess();
280                 } else {
281                     promise.setFailure(future.cause());
282                 }
283             }
284         });
285         return promise;
286     }
287 
288     /**
289      * Returns a new {@link FullHttpRequest) which will be used for the handshake.
290      */
291     protected abstract FullHttpRequest newHandshakeRequest();
292 
293     /**
294      * Validates and finishes the opening handshake initiated by {@link #handshake}}.
295      *
296      * @param channel
297      *            Channel
298      * @param response
299      *            HTTP response containing the closing handshake details
300      */
301     public final void finishHandshake(Channel channel, FullHttpResponse response) {
302         verify(response);
303 
304         // Verify the subprotocol that we received from the server.
305         // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol
306         String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
307         receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
308         String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
309         boolean protocolValid = false;
310 
311         if (expectedProtocol.isEmpty() && receivedProtocol == null) {
312             // No subprotocol required and none received
313             protocolValid = true;
314             setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested
315         } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
316             // We require a subprotocol and received one -> verify it
317             for (String protocol : expectedProtocol.split(",")) {
318                 if (protocol.trim().equals(receivedProtocol)) {
319                     protocolValid = true;
320                     setActualSubprotocol(receivedProtocol);
321                     break;
322                 }
323             }
324         } // else mixed cases - which are all errors
325 
326         if (!protocolValid) {
327             throw new WebSocketHandshakeException(String.format(
328                     "Invalid subprotocol. Actual: %s. Expected one of: %s",
329                     receivedProtocol, expectedSubprotocol));
330         }
331 
332         setHandshakeComplete();
333 
334         final ChannelPipeline p = channel.pipeline();
335         // Remove decompressor from pipeline if its in use
336         HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
337         if (decompressor != null) {
338             p.remove(decompressor);
339         }
340 
341         // Remove aggregator if present before
342         HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
343         if (aggregator != null) {
344             p.remove(aggregator);
345         }
346 
347         ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
348         if (ctx == null) {
349             ctx = p.context(HttpClientCodec.class);
350             if (ctx == null) {
351                 throw new IllegalStateException("ChannelPipeline does not contain " +
352                         "an HttpRequestEncoder or HttpClientCodec");
353             }
354             final HttpClientCodec codec =  (HttpClientCodec) ctx.handler();
355             // Remove the encoder part of the codec as the user may start writing frames after this method returns.
356             codec.removeOutboundHandler();
357 
358             p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
359 
360             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
361             // WebSocketFrame messages.
362             // See https://github.com/netty/netty/issues/4533
363             channel.eventLoop().execute(new Runnable() {
364                 @Override
365                 public void run() {
366                     p.remove(codec);
367                 }
368             });
369         } else {
370             if (p.get(HttpRequestEncoder.class) != null) {
371                 // Remove the encoder part of the codec as the user may start writing frames after this method returns.
372                 p.remove(HttpRequestEncoder.class);
373             }
374             final ChannelHandlerContext context = ctx;
375             p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
376 
377             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
378             // WebSocketFrame messages.
379             // See https://github.com/netty/netty/issues/4533
380             channel.eventLoop().execute(new Runnable() {
381                 @Override
382                 public void run() {
383                     p.remove(context.handler());
384                 }
385             });
386         }
387     }
388 
389     /**
390      * Process the opening handshake initiated by {@link #handshake}}.
391      *
392      * @param channel
393      *            Channel
394      * @param response
395      *            HTTP response containing the closing handshake details
396      * @return future
397      *            the {@link ChannelFuture} which is notified once the handshake completes.
398      */
399     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
400         return processHandshake(channel, response, channel.newPromise());
401     }
402 
403     /**
404      * Process the opening handshake initiated by {@link #handshake}}.
405      *
406      * @param channel
407      *            Channel
408      * @param response
409      *            HTTP response containing the closing handshake details
410      * @param promise
411      *            the {@link ChannelPromise} to notify once the handshake completes.
412      * @return future
413      *            the {@link ChannelFuture} which is notified once the handshake completes.
414      */
415     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
416                                                 final ChannelPromise promise) {
417         if (response instanceof FullHttpResponse) {
418             try {
419                 finishHandshake(channel, (FullHttpResponse) response);
420                 promise.setSuccess();
421             } catch (Throwable cause) {
422                 promise.setFailure(cause);
423             }
424         } else {
425             ChannelPipeline p = channel.pipeline();
426             ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
427             if (ctx == null) {
428                 ctx = p.context(HttpClientCodec.class);
429                 if (ctx == null) {
430                     return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
431                             "an HttpResponseDecoder or HttpClientCodec"));
432                 }
433             }
434             // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
435             // then enough for the websockets handshake payload.
436             //
437             // TODO: Make handshake work without HttpObjectAggregator at all.
438             String aggregatorName = "httpAggregator";
439             p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
440             p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
441                 @Override
442                 protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
443                     // Remove ourself and do the actual handshake
444                     ctx.pipeline().remove(this);
445                     try {
446                         finishHandshake(channel, msg);
447                         promise.setSuccess();
448                     } catch (Throwable cause) {
449                         promise.setFailure(cause);
450                     }
451                 }
452 
453                 @Override
454                 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
455                     // Remove ourself and fail the handshake promise.
456                     ctx.pipeline().remove(this);
457                     promise.setFailure(cause);
458                 }
459 
460                 @Override
461                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
462                     // Fail promise if Channel was closed
463                     if (!promise.isDone()) {
464                         promise.tryFailure(new ClosedChannelException());
465                     }
466                     ctx.fireChannelInactive();
467                 }
468             });
469             try {
470                 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
471             } catch (Throwable cause) {
472                 promise.setFailure(cause);
473             }
474         }
475         return promise;
476     }
477 
478     /**
479      * Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
480      */
481     protected abstract void verify(FullHttpResponse response);
482 
483     /**
484      * Returns the decoder to use after handshake is complete.
485      */
486     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
487 
488     /**
489      * Returns the encoder to use after the handshake is complete.
490      */
491     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
492 
493     /**
494      * Performs the closing handshake
495      *
496      * @param channel
497      *            Channel
498      * @param frame
499      *            Closing Frame that was received
500      */
501     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
502         ObjectUtil.checkNotNull(channel, "channel");
503         return close(channel, frame, channel.newPromise());
504     }
505 
506     /**
507      * Performs the closing handshake
508      *
509      * @param channel
510      *            Channel
511      * @param frame
512      *            Closing Frame that was received
513      * @param promise
514      *            the {@link ChannelPromise} to be notified when the closing handshake is done
515      */
516     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
517         ObjectUtil.checkNotNull(channel, "channel");
518         channel.writeAndFlush(frame, promise);
519         applyForceCloseTimeout(channel, promise);
520         return promise;
521     }
522 
523     private void applyForceCloseTimeout(final Channel channel, ChannelFuture flushFuture) {
524         final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis;
525         final WebSocketClientHandshaker handshaker = this;
526         if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) {
527             return;
528         }
529 
530         flushFuture.addListener(new ChannelFutureListener() {
531             @Override
532             public void operationComplete(ChannelFuture future) throws Exception {
533                 // If flush operation failed, there is no reason to expect
534                 // a server to receive CloseFrame. Thus this should be handled
535                 // by the application separately.
536                 // Also, close might be called twice from different threads.
537                 if (future.isSuccess() && channel.isActive() &&
538                         FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) {
539                     final Future<?> forceCloseFuture = channel.eventLoop().schedule(new Runnable() {
540                         @Override
541                         public void run() {
542                             if (channel.isActive()) {
543                                 channel.close();
544                                 forceCloseComplete = true;
545                             }
546                         }
547                     }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
548 
549                     channel.closeFuture().addListener(new ChannelFutureListener() {
550                         @Override
551                         public void operationComplete(ChannelFuture future) throws Exception {
552                             forceCloseFuture.cancel(false);
553                         }
554                     });
555                 }
556             }
557         });
558     }
559 
560     /**
561      * Return the constructed raw path for the give {@link URI}.
562      */
563     protected String upgradeUrl(URI wsURL) {
564         if (absoluteUpgradeUrl) {
565             return wsURL.toString();
566         }
567 
568         String path = wsURL.getRawPath();
569         path = path == null || path.isEmpty() ? "/" : path;
570         String query = wsURL.getRawQuery();
571         return query != null && !query.isEmpty() ? path + '?' + query : path;
572     }
573 
574     static CharSequence websocketHostValue(URI wsURL) {
575         int port = wsURL.getPort();
576         if (port == -1) {
577             return wsURL.getHost();
578         }
579         String host = wsURL.getHost();
580         String scheme = wsURL.getScheme();
581         if (port == HttpScheme.HTTP.port()) {
582             return HttpScheme.HTTP.name().contentEquals(scheme)
583                     || WebSocketScheme.WS.name().contentEquals(scheme) ?
584                     host : NetUtil.toSocketAddressString(host, port);
585         }
586         if (port == HttpScheme.HTTPS.port()) {
587             return HttpScheme.HTTPS.name().contentEquals(scheme)
588                     || WebSocketScheme.WSS.name().contentEquals(scheme) ?
589                     host : NetUtil.toSocketAddressString(host, port);
590         }
591 
592         // if the port is not standard (80/443) its needed to add the port to the header.
593         // See http://tools.ietf.org/html/rfc6454#section-6.2
594         return NetUtil.toSocketAddressString(host, port);
595     }
596 
597     static CharSequence websocketOriginValue(URI wsURL) {
598         String scheme = wsURL.getScheme();
599         final String schemePrefix;
600         int port = wsURL.getPort();
601         final int defaultPort;
602         if (WebSocketScheme.WSS.name().contentEquals(scheme)
603             || HttpScheme.HTTPS.name().contentEquals(scheme)
604             || (scheme == null && port == WebSocketScheme.WSS.port())) {
605 
606             schemePrefix = HTTPS_SCHEME_PREFIX;
607             defaultPort = WebSocketScheme.WSS.port();
608         } else {
609             schemePrefix = HTTP_SCHEME_PREFIX;
610             defaultPort = WebSocketScheme.WS.port();
611         }
612 
613         // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI")
614         String host = wsURL.getHost().toLowerCase(Locale.US);
615 
616         if (port != defaultPort && port != -1) {
617             // if the port is not standard (80/443) its needed to add the port to the header.
618             // See http://tools.ietf.org/html/rfc6454#section-6.2
619             return schemePrefix + NetUtil.toSocketAddressString(host, port);
620         }
621         return schemePrefix + host;
622     }
623 }