1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.channel.ChannelFuture;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.channel.ChannelInboundHandlerAdapter;
21 import io.netty.channel.ChannelPipeline;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.http.HttpHeaderNames;
24 import io.netty.handler.codec.http.HttpObject;
25 import io.netty.handler.codec.http.HttpRequest;
26 import io.netty.handler.codec.http.HttpResponse;
27 import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
28 import io.netty.handler.ssl.SslHandler;
29 import io.netty.util.ReferenceCountUtil;
30 import io.netty.util.concurrent.Future;
31
32 import java.util.concurrent.TimeUnit;
33
34 import static io.netty.handler.codec.http.HttpUtil.*;
35 import static io.netty.util.internal.ObjectUtil.*;
36
37
38
39
40 class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter {
41
42 private final WebSocketServerProtocolConfig serverConfig;
43 private ChannelHandlerContext ctx;
44 private ChannelPromise handshakePromise;
45 private boolean isWebSocketPath;
46
47 WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
48 this.serverConfig = checkNotNull(serverConfig, "serverConfig");
49 }
50
51 @Override
52 public void handlerAdded(ChannelHandlerContext ctx) {
53 this.ctx = ctx;
54 handshakePromise = ctx.newPromise();
55 }
56
57 @Override
58 public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
59 final HttpObject httpObject = (HttpObject) msg;
60
61 if (httpObject instanceof HttpRequest) {
62 final HttpRequest req = (HttpRequest) httpObject;
63 isWebSocketPath = isWebSocketPath(req);
64 if (!isWebSocketPath) {
65 ctx.fireChannelRead(msg);
66 return;
67 }
68
69 try {
70 final WebSocketServerHandshaker handshaker = WebSocketServerHandshakerFactory.resolveHandshaker(
71 req,
72 getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
73 serverConfig.subprotocols(), serverConfig.decoderConfig());
74 final ChannelPromise localHandshakePromise = handshakePromise;
75 if (handshaker == null) {
76 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
77 } else {
78
79
80
81
82
83 WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
84 ctx.pipeline().remove(this);
85
86 final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
87 handshakeFuture.addListener(future -> {
88 if (!future.isSuccess()) {
89 localHandshakePromise.tryFailure(future.cause());
90 ctx.fireExceptionCaught(future.cause());
91 } else {
92 localHandshakePromise.trySuccess();
93
94 ctx.fireUserEventTriggered(
95 ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
96 ctx.fireUserEventTriggered(
97 new WebSocketServerProtocolHandler.HandshakeComplete(
98 req.uri(), req.headers(), handshaker.selectedSubprotocol()));
99 }
100 });
101 applyHandshakeTimeout();
102 }
103 } finally {
104 ReferenceCountUtil.release(req);
105 }
106 } else if (!isWebSocketPath) {
107 ctx.fireChannelRead(msg);
108 } else {
109 ReferenceCountUtil.release(msg);
110 }
111 }
112
113 private boolean isWebSocketPath(HttpRequest req) {
114 String websocketPath = serverConfig.websocketPath();
115 String uri = req.uri();
116 return serverConfig.checkStartsWith()
117 ? uri.startsWith(websocketPath) && ("/".equals(websocketPath) || checkNextUri(uri, websocketPath))
118 : uri.equals(websocketPath);
119 }
120
121 private boolean checkNextUri(String uri, String websocketPath) {
122 int len = websocketPath.length();
123 if (uri.length() > len) {
124 char nextUri = uri.charAt(len);
125 return nextUri == '/' || nextUri == '?';
126 }
127 return true;
128 }
129
130 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
131 String protocol = "ws";
132 if (cp.get(SslHandler.class) != null) {
133
134 protocol = "wss";
135 }
136 String host = req.headers().get(HttpHeaderNames.HOST);
137 return protocol + "://" + host + path;
138 }
139
140 private void applyHandshakeTimeout() {
141 final ChannelPromise localHandshakePromise = handshakePromise;
142 final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis();
143 if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) {
144 return;
145 }
146
147 final Future<?> timeoutFuture = ctx.executor().schedule(new Runnable() {
148 @Override
149 public void run() {
150 if (!localHandshakePromise.isDone() &&
151 localHandshakePromise.tryFailure(new WebSocketServerHandshakeException("handshake timed out"))) {
152 ctx.flush()
153 .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT)
154 .close();
155 }
156 }
157 }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS);
158
159
160 localHandshakePromise.addListener(f -> timeoutFuture.cancel(false));
161 }
162 }