View Javadoc
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.ChannelFuture;
19  import io.netty.channel.ChannelFutureListener;
20  import io.netty.channel.ChannelHandlerContext;
21  import io.netty.channel.ChannelInboundHandlerAdapter;
22  import io.netty.channel.ChannelPipeline;
23  import io.netty.handler.codec.http.DefaultFullHttpResponse;
24  import io.netty.handler.codec.http.FullHttpRequest;
25  import io.netty.handler.codec.http.HttpHeaderNames;
26  import io.netty.handler.codec.http.HttpRequest;
27  import io.netty.handler.codec.http.HttpResponse;
28  import io.netty.handler.ssl.SslHandler;
29  
30  import static io.netty.handler.codec.http.HttpUtil.*;
31  import static io.netty.handler.codec.http.HttpMethod.*;
32  import static io.netty.handler.codec.http.HttpResponseStatus.*;
33  import static io.netty.handler.codec.http.HttpVersion.*;
34  
35  /**
36   * Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
37   */
38  class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
39  
40      private final String websocketPath;
41      private final String subprotocols;
42      private final boolean allowExtensions;
43      private final int maxFramePayloadSize;
44      private final boolean allowMaskMismatch;
45      private final boolean checkStartsWith;
46  
47      WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
48              boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
49          this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false);
50      }
51  
52      WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
53              boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
54          this.websocketPath = websocketPath;
55          this.subprotocols = subprotocols;
56          this.allowExtensions = allowExtensions;
57          maxFramePayloadSize = maxFrameSize;
58          this.allowMaskMismatch = allowMaskMismatch;
59          this.checkStartsWith = checkStartsWith;
60      }
61  
62      @Override
63      public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
64          final FullHttpRequest req = (FullHttpRequest) msg;
65          if (isNotWebSocketPath(req)) {
66              ctx.fireChannelRead(msg);
67              return;
68          }
69  
70          try {
71              if (req.method() != GET) {
72                  sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
73                  return;
74              }
75  
76              final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
77                      getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
78                              allowExtensions, maxFramePayloadSize, allowMaskMismatch);
79              final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
80              if (handshaker == null) {
81                  WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
82              } else {
83                  final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
84                  handshakeFuture.addListener(new ChannelFutureListener() {
85                      @Override
86                      public void operationComplete(ChannelFuture future) throws Exception {
87                          if (!future.isSuccess()) {
88                              ctx.fireExceptionCaught(future.cause());
89                          } else {
90                              // Kept for compatibility
91                              ctx.fireUserEventTriggered(
92                                      WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
93                              ctx.fireUserEventTriggered(
94                                      new WebSocketServerProtocolHandler.HandshakeComplete(
95                                              req.uri(), req.headers(), handshaker.selectedSubprotocol()));
96                          }
97                      }
98                  });
99                  WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
100                 ctx.pipeline().replace(this, "WS403Responder",
101                         WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
102             }
103         } finally {
104             req.release();
105         }
106     }
107 
108     private boolean isNotWebSocketPath(FullHttpRequest req) {
109         return checkStartsWith ? !req.uri().startsWith(websocketPath) : !req.uri().equals(websocketPath);
110     }
111 
112     private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
113         ChannelFuture f = ctx.channel().writeAndFlush(res);
114         if (!isKeepAlive(req) || res.status().code() != 200) {
115             f.addListener(ChannelFutureListener.CLOSE);
116         }
117     }
118 
119     private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
120         String protocol = "ws";
121         if (cp.get(SslHandler.class) != null) {
122             // SSL in use so use Secure WebSockets
123             protocol = "wss";
124         }
125         String host = req.headers().get(HttpHeaderNames.HOST);
126         return protocol + "://" + host + path;
127     }
128 }