1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelInboundHandler;
23 import io.netty.channel.ChannelPipeline;
24 import io.netty.channel.ChannelPromise;
25 import io.netty.handler.codec.http.DefaultFullHttpResponse;
26 import io.netty.handler.codec.http.FullHttpResponse;
27 import io.netty.handler.codec.http.HttpHeaders;
28 import io.netty.handler.codec.http.HttpResponseStatus;
29 import io.netty.util.AttributeKey;
30
31 import java.util.List;
32
33 import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
34 import static io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig.DEFAULT_HANDSHAKE_TIMEOUT_MILLIS;
35 import static io.netty.util.internal.ObjectUtil.checkNotNull;
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
55
56
57
58
59 public enum ServerHandshakeStateEvent {
60
61
62
63
64
65
66 @Deprecated
67 HANDSHAKE_COMPLETE,
68
69
70
71
72 HANDSHAKE_TIMEOUT
73 }
74
75
76
77
78 public static final class HandshakeComplete {
79 private final String requestUri;
80 private final HttpHeaders requestHeaders;
81 private final String selectedSubprotocol;
82
83 public HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) {
84 this.requestUri = requestUri;
85 this.requestHeaders = requestHeaders;
86 this.selectedSubprotocol = selectedSubprotocol;
87 }
88
89 public String requestUri() {
90 return requestUri;
91 }
92
93 public HttpHeaders requestHeaders() {
94 return requestHeaders;
95 }
96
97 public String selectedSubprotocol() {
98 return selectedSubprotocol;
99 }
100 }
101
102 private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
103 AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
104
105 private final WebSocketServerProtocolConfig serverConfig;
106
107
108
109
110
111
112
113 public WebSocketServerProtocolHandler(WebSocketServerProtocolConfig serverConfig) {
114 super(checkNotNull(serverConfig, "serverConfig").dropPongFrames(),
115 serverConfig.sendCloseFrame(),
116 serverConfig.forceCloseTimeoutMillis()
117 );
118 this.serverConfig = serverConfig;
119 }
120
121 public WebSocketServerProtocolHandler(String websocketPath) {
122 this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
123 }
124
125 public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) {
126 this(websocketPath, false, handshakeTimeoutMillis);
127 }
128
129 public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
130 this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
131 }
132
133 public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) {
134 this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis);
135 }
136
137 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
138 this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
139 }
140
141 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) {
142 this(websocketPath, subprotocols, false, handshakeTimeoutMillis);
143 }
144
145 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
146 this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
147 }
148
149 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
150 long handshakeTimeoutMillis) {
151 this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis);
152 }
153
154 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
155 boolean allowExtensions, int maxFrameSize) {
156 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
157 }
158
159 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
160 boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) {
161 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis);
162 }
163
164 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
165 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
166 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
167 DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
168 }
169
170 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
171 int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) {
172 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,
173 handshakeTimeoutMillis);
174 }
175
176 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
177 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
178 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
179 DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
180 }
181
182 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
183 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
184 boolean checkStartsWith, long handshakeTimeoutMillis) {
185 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true,
186 handshakeTimeoutMillis);
187 }
188
189 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
190 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
191 boolean checkStartsWith, boolean dropPongFrames) {
192 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
193 dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
194 }
195
196 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
197 int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith,
198 boolean dropPongFrames, long handshakeTimeoutMillis) {
199 this(websocketPath, subprotocols, checkStartsWith, dropPongFrames, handshakeTimeoutMillis,
200 WebSocketDecoderConfig.newBuilder()
201 .maxFramePayloadLength(maxFrameSize)
202 .allowMaskMismatch(allowMaskMismatch)
203 .allowExtensions(allowExtensions)
204 .build());
205 }
206
207 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean checkStartsWith,
208 boolean dropPongFrames, long handshakeTimeoutMillis,
209 WebSocketDecoderConfig decoderConfig) {
210 this(WebSocketServerProtocolConfig.newBuilder()
211 .websocketPath(websocketPath)
212 .subprotocols(subprotocols)
213 .checkStartsWith(checkStartsWith)
214 .handshakeTimeoutMillis(handshakeTimeoutMillis)
215 .dropPongFrames(dropPongFrames)
216 .decoderConfig(decoderConfig)
217 .build());
218 }
219
220 @Override
221 public void handlerAdded(ChannelHandlerContext ctx) {
222 ChannelPipeline cp = ctx.pipeline();
223 if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
224
225 cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
226 new WebSocketServerProtocolHandshakeHandler(serverConfig));
227 }
228 if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
229
230 cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
231 new Utf8FrameValidator(serverConfig.decoderConfig().closeOnProtocolViolation()));
232 }
233 }
234
235 @Override
236 protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
237 if (serverConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
238 WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
239 if (handshaker != null) {
240 frame.retain();
241 ChannelPromise promise = ctx.newPromise();
242 closeSent(promise);
243 handshaker.close(ctx, (CloseWebSocketFrame) frame, promise);
244 } else {
245 ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
246 }
247 return;
248 }
249 super.decode(ctx, frame, out);
250 }
251
252 @Override
253 protected WebSocketServerHandshakeException buildHandshakeException(String message) {
254 return new WebSocketServerHandshakeException(message);
255 }
256
257 @Override
258 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
259 if (cause instanceof WebSocketHandshakeException) {
260 FullHttpResponse response = new DefaultFullHttpResponse(
261 HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
262 ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
263 } else {
264 ctx.fireExceptionCaught(cause);
265 ctx.close();
266 }
267 }
268
269 static WebSocketServerHandshaker getHandshaker(Channel channel) {
270 return channel.attr(HANDSHAKER_ATTR_KEY).get();
271 }
272
273 static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
274 channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
275 }
276 }