View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx;
17  
18  import io.netty.channel.ChannelFuture;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.channel.ChannelInboundHandlerAdapter;
21  import io.netty.channel.ChannelPipeline;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.http.HttpHeaderNames;
24  import io.netty.handler.codec.http.HttpObject;
25  import io.netty.handler.codec.http.HttpRequest;
26  import io.netty.handler.codec.http.HttpResponse;
27  import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
28  import io.netty.handler.ssl.SslHandler;
29  import io.netty.util.ReferenceCountUtil;
30  import io.netty.util.concurrent.Future;
31  
32  import java.util.concurrent.TimeUnit;
33  
34  import static io.netty.handler.codec.http.HttpUtil.*;
35  import static io.netty.util.internal.ObjectUtil.*;
36  
37  /**
38   * Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}.
39   */
40  class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
41  
42      private final WebSocketServerProtocolConfig serverConfig;
43      private ChannelHandlerContext ctx;
44      private ChannelPromise handshakePromise;
45      private boolean isWebSocketPath;
46  
47      WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
48          this.serverConfig = checkNotNull(serverConfig, "serverConfig");
49      }
50  
51      @Override
52      public void handlerAdded(ChannelHandlerContext ctx) {
53          this.ctx = ctx;
54          handshakePromise = ctx.newPromise();
55      }
56  
57      @Override
58      public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
59          final HttpObject httpObject = (HttpObject) msg;
60  
61          if (httpObject instanceof HttpRequest) {
62              final HttpRequest req = (HttpRequest) httpObject;
63              isWebSocketPath = isWebSocketPath(req);
64              if (!isWebSocketPath) {
65                  ctx.fireChannelRead(msg);
66                  return;
67              }
68  
69              try {
70                  final WebSocketServerHandshaker handshaker = WebSocketServerHandshakerFactory.resolveHandshaker(
71                          req,
72                          getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
73                          serverConfig.subprotocols(), serverConfig.decoderConfig());
74                  final ChannelPromise localHandshakePromise = handshakePromise;
75                  if (handshaker == null) {
76                      WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
77                  } else {
78                      // Ensure we set the handshaker and replace this handler before we
79                      // trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
80                      // before we had a chance to replace it.
81                      //
82                      // See https://github.com/netty/netty/issues/9471.
83                      WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
84                      ctx.pipeline().remove(this);
85  
86                      final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
87                      handshakeFuture.addListener(future -> {
88                          if (!future.isSuccess()) {
89                              localHandshakePromise.tryFailure(future.cause());
90                              ctx.fireExceptionCaught(future.cause());
91                          } else {
92                              localHandshakePromise.trySuccess();
93                              // Kept for compatibility
94                              ctx.fireUserEventTriggered(
95                                      ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
96                              ctx.fireUserEventTriggered(
97                                      new WebSocketServerProtocolHandler.HandshakeComplete(
98                                              req.uri(), req.headers(), handshaker.selectedSubprotocol()));
99                          }
100                     });
101                     applyHandshakeTimeout();
102                 }
103             } finally {
104                 ReferenceCountUtil.release(req);
105             }
106         } else if (!isWebSocketPath) {
107             ctx.fireChannelRead(msg);
108         } else {
109             ReferenceCountUtil.release(msg);
110         }
111     }
112 
113     private boolean isWebSocketPath(HttpRequest req) {
114         String websocketPath = serverConfig.websocketPath();
115         String uri = req.uri();
116         return serverConfig.checkStartsWith()
117                 ? uri.startsWith(websocketPath) && ("/".equals(websocketPath) || checkNextUri(uri, websocketPath))
118                 : uri.equals(websocketPath);
119     }
120 
121     private boolean checkNextUri(String uri, String websocketPath) {
122         int len = websocketPath.length();
123         if (uri.length() > len) {
124             char nextUri = uri.charAt(len);
125             return nextUri == '/' || nextUri == '?';
126         }
127         return true;
128     }
129 
130     private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
131         String protocol = "ws";
132         if (cp.get(SslHandler.class) != null) {
133             // SSL in use so use Secure WebSockets
134             protocol = "wss";
135         }
136         String host = req.headers().get(HttpHeaderNames.HOST);
137         return protocol + "://" + host + path;
138     }
139 
140     private void applyHandshakeTimeout() {
141         final ChannelPromise localHandshakePromise = handshakePromise;
142         final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
143         if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
144             return;
145         }
146 
147         final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
148             @Override
149             public void run() {
150                 if (!localHandshakePromise.isDone() &&
151                     localHandshakePromise.tryFailure(new WebSocketServerHandshakeException("handshake timed out"))) {
152                     ctx.flush()
153                        .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
154                        .close();
155                 }
156             }
157         }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
158 
159         // Cancel the handshake timeout when handshake is finished.
160         localHandshakePromise.addListener(f -> timeoutFuture.cancel(false));
161     }
162 }