View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.Channel;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.channel.SimpleChannelInboundHandler;
25  import io.netty.handler.codec.http.FullHttpRequest;
26  import io.netty.handler.codec.http.FullHttpResponse;
27  import io.netty.handler.codec.http.HttpClientCodec;
28  import io.netty.handler.codec.http.HttpContentDecompressor;
29  import io.netty.handler.codec.http.HttpHeaderNames;
30  import io.netty.handler.codec.http.HttpHeaders;
31  import io.netty.handler.codec.http.HttpObjectAggregator;
32  import io.netty.handler.codec.http.HttpRequestEncoder;
33  import io.netty.handler.codec.http.HttpResponse;
34  import io.netty.handler.codec.http.HttpResponseDecoder;
35  import io.netty.util.ReferenceCountUtil;
36  import io.netty.util.internal.EmptyArrays;
37  import io.netty.util.internal.StringUtil;
38  
39  import java.net.URI;
40  import java.nio.channels.ClosedChannelException;
41  
42  /**
43   * Base class for web socket client handshake implementations
44   */
45  public abstract class WebSocketClientHandshaker {
46      private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
47  
48      static {
49          CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
50      }
51  
52      private final URI uri;
53  
54      private final WebSocketVersion version;
55  
56      private volatile boolean handshakeComplete;
57  
58      private final String expectedSubprotocol;
59  
60      private volatile String actualSubprotocol;
61  
62      protected final HttpHeaders customHeaders;
63  
64      private final int maxFramePayloadLength;
65  
66      /**
67       * Base constructor
68       *
69       * @param uri
70       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
71       *            sent to this URL.
72       * @param version
73       *            Version of web socket specification to use to connect to the server
74       * @param subprotocol
75       *            Sub protocol request sent to the server.
76       * @param customHeaders
77       *            Map of custom headers to add to the client request
78       * @param maxFramePayloadLength
79       *            Maximum length of a frame's payload
80       */
81      protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
82                                          HttpHeaders customHeaders, int maxFramePayloadLength) {
83          this.uri = uri;
84          this.version = version;
85          expectedSubprotocol = subprotocol;
86          this.customHeaders = customHeaders;
87          this.maxFramePayloadLength = maxFramePayloadLength;
88      }
89  
90      /**
91       * Returns the URI to the web socket. e.g. "ws://myhost.com/path"
92       */
93      public URI uri() {
94          return uri;
95      }
96  
97      /**
98       * Version of the web socket specification that is being used
99       */
100     public WebSocketVersion version() {
101         return version;
102     }
103 
104     /**
105      * Returns the max length for any frame's payload
106      */
107     public int maxFramePayloadLength() {
108         return maxFramePayloadLength;
109     }
110 
111     /**
112      * Flag to indicate if the opening handshake is complete
113      */
114     public boolean isHandshakeComplete() {
115         return handshakeComplete;
116     }
117 
118     private void setHandshakeComplete() {
119         handshakeComplete = true;
120     }
121 
122     /**
123      * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor
124      */
125     public String expectedSubprotocol() {
126         return expectedSubprotocol;
127     }
128 
129     /**
130      * Returns the subprotocol response sent by the server. Only available after end of handshake.
131      * Null if no subprotocol was requested or confirmed by the server.
132      */
133     public String actualSubprotocol() {
134         return actualSubprotocol;
135     }
136 
137     private void setActualSubprotocol(String actualSubprotocol) {
138         this.actualSubprotocol = actualSubprotocol;
139     }
140 
141     /**
142      * Begins the opening handshake
143      *
144      * @param channel
145      *            Channel
146      */
147     public ChannelFuture handshake(Channel channel) {
148         if (channel == null) {
149             throw new NullPointerException("channel");
150         }
151         return handshake(channel, channel.newPromise());
152     }
153 
154     /**
155      * Begins the opening handshake
156      *
157      * @param channel
158      *            Channel
159      * @param promise
160      *            the {@link ChannelPromise} to be notified when the opening handshake is sent
161      */
162     public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) {
163         FullHttpRequest request =  newHandshakeRequest();
164 
165         HttpResponseDecoder decoder = channel.pipeline().get(HttpResponseDecoder.class);
166         if (decoder == null) {
167             HttpClientCodec codec = channel.pipeline().get(HttpClientCodec.class);
168             if (codec == null) {
169                promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
170                        "a HttpResponseDecoder or HttpClientCodec"));
171                return promise;
172             }
173         }
174 
175         channel.writeAndFlush(request).addListener(new ChannelFutureListener() {
176             @Override
177             public void operationComplete(ChannelFuture future) {
178                 if (future.isSuccess()) {
179                     ChannelPipeline p = future.channel().pipeline();
180                     ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
181                     if (ctx == null) {
182                         ctx = p.context(HttpClientCodec.class);
183                     }
184                     if (ctx == null) {
185                         promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
186                                 "a HttpRequestEncoder or HttpClientCodec"));
187                         return;
188                     }
189                     p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
190 
191                     promise.setSuccess();
192                 } else {
193                     promise.setFailure(future.cause());
194                 }
195             }
196         });
197         return promise;
198     }
199 
200     /**
201      * Returns a new {@link FullHttpRequest) which will be used for the handshake.
202      */
203     protected abstract FullHttpRequest newHandshakeRequest();
204 
205     /**
206      * Validates and finishes the opening handshake initiated by {@link #handshake}}.
207      *
208      * @param channel
209      *            Channel
210      * @param response
211      *            HTTP response containing the closing handshake details
212      */
213     public final void finishHandshake(Channel channel, FullHttpResponse response) {
214         verify(response);
215 
216         // Verify the subprotocol that we received from the server.
217         // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol
218         String receivedProtocol = response.headers().getAndConvert(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
219         receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null;
220         String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
221         boolean protocolValid = false;
222 
223         if (expectedProtocol.isEmpty() && receivedProtocol == null) {
224             // No subprotocol required and none received
225             protocolValid = true;
226             setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested
227         } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) {
228             // We require a subprotocol and received one -> verify it
229             for (String protocol : StringUtil.split(expectedSubprotocol, ',')) {
230                 if (protocol.trim().equals(receivedProtocol)) {
231                     protocolValid = true;
232                     setActualSubprotocol(receivedProtocol);
233                     break;
234                 }
235             }
236         } // else mixed cases - which are all errors
237 
238         if (!protocolValid) {
239             throw new WebSocketHandshakeException(String.format(
240                     "Invalid subprotocol. Actual: %s. Expected one of: %s",
241                     receivedProtocol, expectedSubprotocol));
242         }
243 
244         setHandshakeComplete();
245 
246         ChannelPipeline p = channel.pipeline();
247         // Remove decompressor from pipeline if its in use
248         HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
249         if (decompressor != null) {
250             p.remove(decompressor);
251         }
252 
253         // Remove aggregator if present before
254         HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
255         if (aggregator != null) {
256             p.remove(aggregator);
257         }
258 
259         ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
260         if (ctx == null) {
261             ctx = p.context(HttpClientCodec.class);
262             if (ctx == null) {
263                 throw new IllegalStateException("ChannelPipeline does not contain " +
264                         "a HttpRequestEncoder or HttpClientCodec");
265             }
266             p.replace(ctx.name(), "ws-decoder", newWebsocketDecoder());
267         } else {
268             if (p.get(HttpRequestEncoder.class) != null) {
269                 p.remove(HttpRequestEncoder.class);
270             }
271             p.replace(ctx.name(),
272                     "ws-decoder", newWebsocketDecoder());
273         }
274     }
275 
276     /**
277      * Process the opening handshake initiated by {@link #handshake}}.
278      *
279      * @param channel
280      *            Channel
281      * @param response
282      *            HTTP response containing the closing handshake details
283      * @return future
284      *            the {@link ChannelFuture} which is notified once the handshake completes.
285      */
286     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) {
287         return processHandshake(channel, response, channel.newPromise());
288     }
289 
290     /**
291      * Process the opening handshake initiated by {@link #handshake}}.
292      *
293      * @param channel
294      *            Channel
295      * @param response
296      *            HTTP response containing the closing handshake details
297      * @param promise
298      *            the {@link ChannelPromise} to notify once the handshake completes.
299      * @return future
300      *            the {@link ChannelFuture} which is notified once the handshake completes.
301      */
302     public final ChannelFuture processHandshake(final Channel channel, HttpResponse response,
303                                                 final ChannelPromise promise) {
304         if (response instanceof FullHttpResponse) {
305             try {
306                 finishHandshake(channel, (FullHttpResponse) response);
307                 promise.setSuccess();
308             } catch (Throwable cause) {
309                 promise.setFailure(cause);
310             }
311         } else {
312             ChannelPipeline p = channel.pipeline();
313             ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
314             if (ctx == null) {
315                 ctx = p.context(HttpClientCodec.class);
316                 if (ctx == null) {
317                     return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " +
318                             "a HttpResponseDecoder or HttpClientCodec"));
319                 }
320             }
321             // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be more
322             // then enough for the websockets handshake payload.
323             //
324             // TODO: Make handshake work without HttpObjectAggregator at all.
325             String aggregatorName = "httpAggregator";
326             p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
327             p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpResponse>() {
328                 @Override
329                 protected void messageReceived(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception {
330                     // Remove ourself and do the actual handshake
331                     ctx.pipeline().remove(this);
332                     try {
333                         finishHandshake(channel, msg);
334                         promise.setSuccess();
335                     } catch (Throwable cause) {
336                         promise.setFailure(cause);
337                     }
338                 }
339 
340                 @Override
341                 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
342                     // Remove ourself and fail the handshake promise.
343                     ctx.pipeline().remove(this);
344                     promise.setFailure(cause);
345                 }
346 
347                 @Override
348                 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
349                     // Fail promise if Channel was closed
350                     promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
351                     ctx.fireChannelInactive();
352                 }
353             });
354             try {
355                 ctx.fireChannelRead(ReferenceCountUtil.retain(response));
356             } catch (Throwable cause) {
357                 promise.setFailure(cause);
358             }
359         }
360         return promise;
361     }
362 
363     /**
364      * Verfiy the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong.
365      */
366     protected abstract void verify(FullHttpResponse response);
367 
368     /**
369      * Returns the decoder to use after handshake is complete.
370      */
371     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
372 
373     /**
374      * Returns the encoder to use after the handshake is complete.
375      */
376     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
377 
378     /**
379      * Performs the closing handshake
380      *
381      * @param channel
382      *            Channel
383      * @param frame
384      *            Closing Frame that was received
385      */
386     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
387         if (channel == null) {
388             throw new NullPointerException("channel");
389         }
390         return close(channel, frame, channel.newPromise());
391     }
392 
393     /**
394      * Performs the closing handshake
395      *
396      * @param channel
397      *            Channel
398      * @param frame
399      *            Closing Frame that was received
400      * @param promise
401      *            the {@link ChannelPromise} to be notified when the closing handshake is done
402      */
403     public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
404         if (channel == null) {
405             throw new NullPointerException("channel");
406         }
407         return channel.writeAndFlush(frame, promise);
408     }
409 }