1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.channel.ChannelFuture;
19 import io.netty.channel.ChannelFutureListener;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelInboundHandlerAdapter;
22 import io.netty.channel.ChannelPipeline;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.handler.codec.http.DefaultFullHttpResponse;
25 import io.netty.handler.codec.http.HttpHeaderNames;
26 import io.netty.handler.codec.http.HttpObject;
27 import io.netty.handler.codec.http.HttpRequest;
28 import io.netty.handler.codec.http.HttpResponse;
29 import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
30 import io.netty.handler.ssl.SslHandler;
31 import io.netty.util.ReferenceCountUtil;
32 import io.netty.util.concurrent.Future;
33 import io.netty.util.concurrent.FutureListener;
34
35 import java.util.concurrent.TimeUnit;
36
37 import static io.netty.handler.codec.http.HttpMethod.*;
38 import static io.netty.handler.codec.http.HttpResponseStatus.*;
39 import static io.netty.handler.codec.http.HttpUtil.*;
40 import static io.netty.handler.codec.http.HttpVersion.*;
41 import static io.netty.util.internal.ObjectUtil.*;
42
43
44
45
46 class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
47
48 private final WebSocketServerProtocolConfig serverConfig;
49 private ChannelHandlerContext ctx;
50 private ChannelPromise handshakePromise;
51 private boolean isWebSocketPath;
52
53 WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
54 this.serverConfig = checkNotNull(serverConfig, "serverConfig");
55 }
56
57 @Override
58 public void handlerAdded(ChannelHandlerContext ctx) {
59 this.ctx = ctx;
60 handshakePromise = ctx.newPromise();
61 }
62
63 @Override
64 public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
65 final HttpObject httpObject = (HttpObject) msg;
66
67 if (httpObject instanceof HttpRequest) {
68 final HttpRequest req = (HttpRequest) httpObject;
69 isWebSocketPath = isWebSocketPath(req);
70 if (!isWebSocketPath) {
71 ctx.fireChannelRead(msg);
72 return;
73 }
74
75 try {
76 if (!GET.equals(req.method())) {
77 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN, ctx.alloc().buffer(0)));
78 return;
79 }
80
81 final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
82 getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
83 serverConfig.subprotocols(), serverConfig.decoderConfig());
84 final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
85 final ChannelPromise localHandshakePromise = handshakePromise;
86 if (handshaker == null) {
87 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
88 } else {
89
90
91
92
93
94 WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
95 ctx.pipeline().remove(this);
96
97 final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
98 handshakeFuture.addListener(new ChannelFutureListener() {
99 @Override
100 public void operationComplete(ChannelFuture future) {
101 if (!future.isSuccess()) {
102 localHandshakePromise.tryFailure(future.cause());
103 ctx.fireExceptionCaught(future.cause());
104 } else {
105 localHandshakePromise.trySuccess();
106
107 ctx.fireUserEventTriggered(
108 WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
109 ctx.fireUserEventTriggered(
110 new WebSocketServerProtocolHandler.HandshakeComplete(
111 req.uri(), req.headers(), handshaker.selectedSubprotocol()));
112 }
113 }
114 });
115 applyHandshakeTimeout();
116 }
117 } finally {
118 ReferenceCountUtil.release(req);
119 }
120 } else if (!isWebSocketPath) {
121 ctx.fireChannelRead(msg);
122 } else {
123 ReferenceCountUtil.release(msg);
124 }
125 }
126
127 private boolean isWebSocketPath(HttpRequest req) {
128 String websocketPath = serverConfig.websocketPath();
129 String uri = req.uri();
130 boolean checkStartUri = uri.startsWith(websocketPath);
131 boolean checkNextUri = "/".equals(websocketPath) || checkNextUri(uri, websocketPath);
132 return serverConfig.checkStartsWith() ? (checkStartUri && checkNextUri) : uri.equals(websocketPath);
133 }
134
135 private boolean checkNextUri(String uri, String websocketPath) {
136 int len = websocketPath.length();
137 if (uri.length() > len) {
138 char nextUri = uri.charAt(len);
139 return nextUri == '/' || nextUri == '?';
140 }
141 return true;
142 }
143
144 private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
145 ChannelFuture f = ctx.channel().writeAndFlush(res);
146 if (!isKeepAlive(req) || res.status().code() != 200) {
147 f.addListener(ChannelFutureListener.CLOSE);
148 }
149 }
150
151 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
152 String protocol = "ws";
153 if (cp.get(SslHandler.class) != null) {
154
155 protocol = "wss";
156 }
157 String host = req.headers().get(HttpHeaderNames.HOST);
158 return protocol + "://" + host + path;
159 }
160
161 private void applyHandshakeTimeout() {
162 final ChannelPromise localHandshakePromise = handshakePromise;
163 final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
164 if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
165 return;
166 }
167
168 final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
169 @Override
170 public void run() {
171 if (!localHandshakePromise.isDone() &&
172 localHandshakePromise.tryFailure(new WebSocketServerHandshakeException("handshake timed out"))) {
173 ctx.flush()
174 .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
175 .close();
176 }
177 }
178 }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
179
180
181 localHandshakePromise.addListener(new FutureListener<Void>() {
182 @Override
183 public void operationComplete(Future<Void> f) {
184 timeoutFuture.cancel(false);
185 }
186 });
187 }
188 }