View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.handler.codec.http.websocketx;
17  
18  import io.netty5.buffer.api.BufferAllocator;
19  import io.netty5.channel.Channel;
20  import io.netty5.channel.ChannelHandler;
21  import io.netty5.channel.ChannelHandlerContext;
22  import io.netty5.channel.ChannelOutboundInvoker;
23  import io.netty5.channel.ChannelPipeline;
24  import io.netty5.channel.SimpleChannelInboundHandler;
25  import io.netty5.handler.codec.http.FullHttpRequest;
26  import io.netty5.handler.codec.http.FullHttpResponse;
27  import io.netty5.handler.codec.http.HttpClientCodec;
28  import io.netty5.handler.codec.http.HttpContentDecompressor;
29  import io.netty5.handler.codec.http.HttpHeaderNames;
30  import io.netty5.handler.codec.http.HttpHeaders;
31  import io.netty5.handler.codec.http.HttpObjectAggregator;
32  import io.netty5.handler.codec.http.HttpRequestEncoder;
33  import io.netty5.handler.codec.http.HttpResponse;
34  import io.netty5.handler.codec.http.HttpResponseDecoder;
35  import io.netty5.handler.codec.http.HttpScheme;
36  import io.netty5.util.NetUtil;
37  import io.netty5.util.ReferenceCountUtil;
38  import io.netty5.util.concurrent.Future;
39  import io.netty5.util.concurrent.Promise;
40  
41  import java.net.URI;
42  import java.nio.channels.ClosedChannelException;
43  import java.util.concurrent.TimeUnit;
44  import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
45  
46  import static java.util.Objects.requireNonNull;
47  
48  /**
49   * Base class for web socket client handshake implementations
50   */
51  public abstract class WebSocketClientHandshaker {
52  
53      protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
54  
55      private final URI uri;
56  
57      private final WebSocketVersion version;
58  
59      private volatile boolean handshakeComplete;
60  
61      private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS;
62  
63      private volatile int forceCloseInit;
64  
65      private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER =
66              AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");
67  
68      private volatile boolean forceCloseComplete;
69  
70      private final String expectedSubprotocol;
71  
72      private volatile String actualSubprotocol;
73  
74      protected final HttpHeaders customHeaders;
75  
76      private final int maxFramePayloadLength;
77  
78      private final boolean absoluteUpgradeUrl;
79  
80      /**
81       * Base constructor
82       *
83       * @param uri
84       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
85       *            sent to this URL.
86       * @param version
87       *            Version of web socket specification to use to connect to the server
88       * @param subprotocol
89       *            Sub protocol request sent to the server.
90       * @param customHeaders
91       *            Map of custom headers to add to the client request
92       * @param maxFramePayloadLength
93       *            Maximum length of a frame's payload
94       */
95      protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
96                                          HttpHeaders customHeaders, int maxFramePayloadLength) {
97          this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS);
98      }
99  
100     /**
101      * Base constructor
102      *
103      * @param uri
104      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
105      *            sent to this URL.
106      * @param version
107      *            Version of web socket specification to use to connect to the server
108      * @param subprotocol
109      *            Sub protocol request sent to the server.
110      * @param customHeaders
111      *            Map of custom headers to add to the client request
112      * @param maxFramePayloadLength
113      *            Maximum length of a frame's payload
114      * @param forceCloseTimeoutMillis
115      *            Close the connection if it was not closed by the server after timeout specified
116      */
117     protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
118                                         HttpHeaders customHeaders, int maxFramePayloadLength,
119                                         long forceCloseTimeoutMillis) {
120         this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false);
121     }
122 
123     /**
124      * Base constructor
125      *
126      * @param uri
127      *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
128      *            sent to this URL.
129      * @param version
130      *            Version of web socket specification to use to connect to the server
131      * @param subprotocol
132      *            Sub protocol request sent to the server.
133      * @param customHeaders
134      *            Map of custom headers to add to the client request
135      * @param maxFramePayloadLength
136      *            Maximum length of a frame's payload
137      * @param forceCloseTimeoutMillis
138      *            Close the connection if it was not closed by the server after timeout specified
139      * @param  absoluteUpgradeUrl
140      *            Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over
141      *            clear HTTP
142      */
143     protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
144                                         HttpHeaders customHeaders, int maxFramePayloadLength,
145                                         long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) {
146         this.uri = uri;
147         this.version = version;
148         expectedSubprotocol = subprotocol;
149         this.customHeaders = customHeaders;
150         this.maxFramePayloadLength = maxFramePayloadLength;
151         this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
152         this.absoluteUpgradeUrl = absoluteUpgradeUrl;
153     }
154 
155     /**
156      * Returns the URI to the web socket. e.g. "ws://myhost.com/path"
157      */
158     public URI uri() {
159         return uri;
160     }
161 
162     /**
163      * Version of the web socket specification that is being used
164      */
165     public WebSocketVersion version() {
166         return version;
167     }
168 
169     /**
170      * Returns the max length for any frame's payload
171      */
172     public int maxFramePayloadLength() {
173         return maxFramePayloadLength;
174     }
175 
176     /**
177      * Flag to indicate if the opening handshake is complete
178      */
179     public boolean isHandshakeComplete() {
180         return handshakeComplete;
181     }
182 
183     private void setHandshakeComplete() {
184         handshakeComplete = true;
185     }
186 
187     /**
188      * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor
189      */
190     public String expectedSubprotocol() {
191         return expectedSubprotocol;
192     }
193 
194     /**
195      * Returns the subprotocol response sent by the server. Only available after end of handshake.
196      * Null if no subprotocol was requested or confirmed by the server.
197      */
198     public String actualSubprotocol() {
199         return actualSubprotocol;
200     }
201 
202     private void setActualSubprotocol(String actualSubprotocol) {
203         this.actualSubprotocol = actualSubprotocol;
204     }
205 
206     public long forceCloseTimeoutMillis() {
207         return forceCloseTimeoutMillis;
208     }
209 
210     /**
211      * Flag to indicate if the closing handshake was initiated because of timeout.
212      * For testing only.
213      */
214     protected boolean isForceCloseComplete() {
215         return forceCloseComplete;
216     }
217 
218     /**
219      * Sets timeout to close the connection if it was not closed by the server.
220      *
221      * @param forceCloseTimeoutMillis
222      *            Close the connection if it was not closed by the server after timeout specified
223      */
224     public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
225         this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
226         return this;
227     }
228 
229     /**
230      * Begins the opening handshake
231      *
232      * @param channel
233      *            Channel
234      */
235     public Future<Void> handshake(Channel channel) {
236         requireNonNull(channel, "channel");
237         ChannelPipeline pipeline = channel.pipeline();
238         HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
239         if (decoder == null) {
240             HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
241             if (codec == null) {
242                return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " +
243                        "an HttpResponseDecoder or HttpClientCodec"));
244             }
245         }
246 
247         FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator());
248 
249         Promise<Void> promise = channel.newPromise();
250         channel.writeAndFlush(request).addListener(channel, (ch, future) -> {
251             if (future.isSuccess()) {
252                 ChannelPipeline p = ch.pipeline();
253                 ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
254                 if (ctx == null) {
255                     ctx = p.context(HttpClientCodec.class);
256                 }
257                 if (ctx == null) {
258                     promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
259                             "an HttpRequestEncoder or HttpClientCodec"));
260                     return;
261                 }
262                 p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
263 
264                 promise.setSuccess(null);
265             } else {
266                 promise.setFailure(future.cause());
267             }
268         });
269         return promise.asFuture();
270     }
271 
272     /**
273      * Returns a new {@link FullHttpRequest) which will be used for the handshake.
274      */
275     protected abstract FullHttpRequest newHandshakeRequest(BufferAllocator allocator);
276 
277     /**
278      * Validates and finishes the opening handshake initiated by {@link #handshake}}.
279      *
280      * @param channel
281      *            Channel
282      * @param response
283      *            HTTP response containing the closing handshake details
284      */
285     public final void finishHandshake(Channel channel, FullHttpResponse response) {
286         verify(response);
287 
288         // Verify the subprotocol that we received from the server.
289         // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol
290         String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
291         receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
292         String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
293         boolean protocolValid = false;
294 
295         if (expectedProtocol.isEmpty() && receivedProtocol == null) {
296             // No subprotocol required and none received
297             protocolValid = true;
298             setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested
299         } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
300             // We require a subprotocol and received one -> verify it
301             for (String protocol : expectedProtocol.split(",")) {
302                 if (protocol.trim().equals(receivedProtocol)) {
303                     protocolValid = true;
304                     setActualSubprotocol(receivedProtocol);
305                     break;
306                 }
307             }
308         } // else mixed cases - which are all errors
309 
310         if (!protocolValid) {
311             throw new WebSocketClientHandshakeException(String.format(
312                     "Invalid subprotocol. Actual: %s. Expected one of: %s",
313                     receivedProtocol, expectedSubprotocol), response);
314         }
315 
316         setHandshakeComplete();
317 
318         final ChannelPipeline p = channel.pipeline();
319         // Remove decompressor from pipeline if its in use
320         HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
321         if (decompressor != null) {
322             p.remove(decompressor);
323         }
324 
325         // Remove aggregator if present before
326         HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
327         if (aggregator != null) {
328             p.remove(aggregator);
329         }
330 
331         ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
332         if (ctx == null) {
333             ctx = p.context(HttpClientCodec.class);
334             if (ctx == null) {
335                 throw new IllegalStateException("ChannelPipeline does not contain " +
336                         "an HttpRequestEncoder or HttpClientCodec");
337             }
338             final HttpClientCodec codec =  (HttpClientCodec) ctx.handler();
339             // Remove the encoder part of the codec as the user may start writing frames after this method returns.
340             codec.removeOutboundHandler();
341 
342             p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
343 
344             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
345             // WebSocketFrame messages.
346             // See https://github.com/netty/netty/issues/4533
347             channel.executor().execute(() -> p.remove(codec));
348         } else {
349             if (p.get(HttpRequestEncoder.class) != null) {
350                 // Remove the encoder part of the codec as the user may start writing frames after this method returns.
351                 p.remove(HttpRequestEncoder.class);
352             }
353             final ChannelHandlerContext context = ctx;
354             p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
355 
356             // Delay the removal of the decoder so the user can setup the pipeline if needed to handle
357             // WebSocketFrame messages.
358             // See https://github.com/netty/netty/issues/4533
359             channel.executor().execute(() -> p.remove(context.handler()));
360         }
361     }
362 
363     /**
364      * Process the opening handshake initiated by {@link #handshake}}.
365      *
366      * @param channel
367      *            Channel
368      * @param response
369      *            HTTP response containing the closing handshake details
370      * @return future
371      *            the {@link Future} which is notified once the handshake completes.
372      */
373     public final Future<Void> processHandshake(final Channel channel, HttpResponse response) {
374         if (response instanceof FullHttpResponse) {
375             try {
376                 finishHandshake(channel, (FullHttpResponse) response);
377                 return channel.newSucceededFuture();
378             } catch (Throwable cause) {
379                 return channel.newFailedFuture(cause);
380             }
381         } else {
382             ChannelPipeline p = channel.pipeline();
383             ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
384             if (ctx == null) {
385                 ctx = p.context(HttpClientCodec.class);
386                 if (ctx == null) {
387                     return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " +
388                             "an HttpResponseDecoder or HttpClientCodec"));
389                 }
390             }
391 
392             Promise<Void> promise = channel.newPromise();
393             // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
394             // then enough for the websockets handshake payload.
395             //
396             // TODO: Make handshake work without HttpObjectAggregator at all.
397             String aggregatorName = "httpAggregator";
398             p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
399             p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
400                 @Override
401                 protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
402                     // Remove ourself and do the actual handshake
403                     ctx.pipeline().remove(this);
404                     try {
405                         finishHandshake(channel, msg);
406                         promise.setSuccess(null);
407                     } catch (Throwable cause) {
408                         promise.setFailure(cause);
409                     }
410                 }
411 
412                 @Override
413                 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
414                     // Remove ourself and fail the handshake promise.
415                     ctx.pipeline().remove(this);
416                     promise.setFailure(cause);
417                 }
418 
419                 @Override
420                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
421                     // Fail promise if Channel was closed
422                     if (!promise.isDone()) {
423                         promise.tryFailure(new ClosedChannelException());
424                     }
425                     ctx.fireChannelInactive();
426                 }
427             });
428             try {
429                 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
430             } catch (Throwable cause) {
431                 promise.setFailure(cause);
432             }
433             return promise.asFuture();
434         }
435     }
436 
437     /**
438      * Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
439      */
440     protected abstract void verify(FullHttpResponse response);
441 
442     /**
443      * Returns the decoder to use after handshake is complete.
444      */
445     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
446 
447     /**
448      * Returns the encoder to use after the handshake is complete.
449      */
450     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
451 
452     /**
453      * Performs the closing handshake.
454      *
455      * When called from within a {@link ChannelHandler} you most likely want to use
456      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}.
457      *
458      * @param channel
459      *            Channel
460      * @param frame
461      *            Closing Frame that was received
462      */
463     public Future<Void> close(Channel channel, CloseWebSocketFrame frame) {
464         requireNonNull(channel, "channel");
465         return close0(channel, channel, frame);
466     }
467 
468     /**
469      * Performs the closing handshake
470      *
471      * @param ctx
472      *            the {@link ChannelHandlerContext} to use.
473      * @param frame
474      *            Closing Frame that was received
475      */
476     public Future<Void> close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
477         requireNonNull(ctx, "ctx");
478         return close0(ctx, ctx.channel(), frame);
479     }
480 
481     private Future<Void> close0(final ChannelOutboundInvoker invoker, final Channel channel,
482                                  CloseWebSocketFrame frame) {
483         Future<Void> f = invoker.writeAndFlush(frame);
484         final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis;
485         final WebSocketClientHandshaker handshaker = this;
486         if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) {
487             return f;
488         }
489 
490         f.addListener(future -> {
491             // If flush operation failed, there is no reason to expect
492             // a server to receive CloseFrame. Thus this should be handled
493             // by the application separately.
494             // Also, close might be called twice from different threads.
495             if (future.isSuccess() && channel.isActive() &&
496                     FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) {
497                 final Future<?> forceCloseFuture = channel.executor().schedule(() -> {
498                     if (channel.isActive()) {
499                         channel.close();
500                         forceCloseComplete = true;
501                     }
502                 }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
503 
504                 channel.closeFuture().addListener(ignore -> forceCloseFuture.cancel());
505             }
506         });
507         return f;
508     }
509 
510     /**
511      * Return the constructed raw path for the give {@link URI}.
512      */
513     protected String upgradeUrl(URI wsURL) {
514         if (absoluteUpgradeUrl) {
515             return wsURL.toString();
516         }
517 
518         String path = wsURL.getRawPath();
519         path = path == null || path.isEmpty() ? "/" : path;
520         String query = wsURL.getRawQuery();
521         return query != null && !query.isEmpty() ? path + '?' + query : path;
522     }
523 
524     static CharSequence websocketHostValue(URI wsURL) {
525         int port = wsURL.getPort();
526         if (port == -1) {
527             return wsURL.getHost();
528         }
529         String host = wsURL.getHost();
530         String scheme = wsURL.getScheme();
531         if (port == HttpScheme.HTTP.port()) {
532             return HttpScheme.HTTP.name().contentEquals(scheme)
533                     || WebSocketScheme.WS.name().contentEquals(scheme) ?
534                     host : NetUtil.toSocketAddressString(host, port);
535         }
536         if (port == HttpScheme.HTTPS.port()) {
537             return HttpScheme.HTTPS.name().contentEquals(scheme)
538                     || WebSocketScheme.WSS.name().contentEquals(scheme) ?
539                     host : NetUtil.toSocketAddressString(host, port);
540         }
541 
542         // if the port is not standard (80/443) its needed to add the port to the header.
543         // See https://tools.ietf.org/html/rfc6454#section-6.2
544         return NetUtil.toSocketAddressString(host, port);
545     }
546 }