1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.channel.Channel;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPipeline;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.channel.SimpleChannelInboundHandler;
25 import io.netty.handler.codec.http.FullHttpRequest;
26 import io.netty.handler.codec.http.FullHttpResponse;
27 import io.netty.handler.codec.http.HttpContentCompressor;
28 import io.netty.handler.codec.http.HttpHeaders;
29 import io.netty.handler.codec.http.HttpObjectAggregator;
30 import io.netty.handler.codec.http.HttpRequest;
31 import io.netty.handler.codec.http.HttpRequestDecoder;
32 import io.netty.handler.codec.http.HttpResponseEncoder;
33 import io.netty.handler.codec.http.HttpServerCodec;
34 import io.netty.util.ReferenceCountUtil;
35 import io.netty.util.internal.EmptyArrays;
36 import io.netty.util.internal.StringUtil;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39
40 import java.nio.channels.ClosedChannelException;
41 import java.util.Collections;
42 import java.util.LinkedHashSet;
43 import java.util.Set;
44
45
46
47
48 public abstract class WebSocketServerHandshaker {
49 protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
50 private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();
51
52 static {
53 CLOSED_CHANNEL_EXCEPTION.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
54 }
55
56 private final String uri;
57
58 private final String[] subprotocols;
59
60 private final WebSocketVersion version;
61
62 private final int maxFramePayloadLength;
63
64 private String selectedSubprotocol;
65
66
67
68
69 public static final String SUB_PROTOCOL_WILDCARD = "*";
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84 protected WebSocketServerHandshaker(
85 WebSocketVersion version, String uri, String subprotocols,
86 int maxFramePayloadLength) {
87 this.version = version;
88 this.uri = uri;
89 if (subprotocols != null) {
90 String[] subprotocolArray = StringUtil.split(subprotocols, ',');
91 for (int i = 0; i < subprotocolArray.length; i++) {
92 subprotocolArray[i] = subprotocolArray[i].trim();
93 }
94 this.subprotocols = subprotocolArray;
95 } else {
96 this.subprotocols = EmptyArrays.EMPTY_STRINGS;
97 }
98 this.maxFramePayloadLength = maxFramePayloadLength;
99 }
100
101
102
103
104 public String uri() {
105 return uri;
106 }
107
108
109
110
111 public Set<String> subprotocols() {
112 Set<String> ret = new LinkedHashSet<String>();
113 Collections.addAll(ret, subprotocols);
114 return ret;
115 }
116
117
118
119
120 public WebSocketVersion version() {
121 return version;
122 }
123
124
125
126
127
128
129 public int maxFramePayloadLength() {
130 return maxFramePayloadLength;
131 }
132
133
134
135
136
137
138
139
140
141
142
143
144 public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
145 return handshake(channel, req, null, channel.newPromise());
146 }
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164 public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
165 HttpHeaders responseHeaders, final ChannelPromise promise) {
166
167 if (logger.isDebugEnabled()) {
168 logger.debug("{} WebSocket version {} server handshake", channel, version());
169 }
170 FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
171 ChannelPipeline p = channel.pipeline();
172 if (p.get(HttpObjectAggregator.class) != null) {
173 p.remove(HttpObjectAggregator.class);
174 }
175 if (p.get(HttpContentCompressor.class) != null) {
176 p.remove(HttpContentCompressor.class);
177 }
178 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
179 final String encoderName;
180 if (ctx == null) {
181
182 ctx = p.context(HttpServerCodec.class);
183 if (ctx == null) {
184 promise.setFailure(
185 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
186 return promise;
187 }
188 p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
189 p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
190 encoderName = ctx.name();
191 } else {
192 p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
193
194 encoderName = p.context(HttpResponseEncoder.class).name();
195 p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
196 }
197 channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
198 @Override
199 public void operationComplete(ChannelFuture future) throws Exception {
200 if (future.isSuccess()) {
201 ChannelPipeline p = future.channel().pipeline();
202 p.remove(encoderName);
203 promise.setSuccess();
204 } else {
205 promise.setFailure(future.cause());
206 }
207 }
208 });
209 return promise;
210 }
211
212
213
214
215
216
217
218
219
220
221
222
223 public ChannelFuture handshake(Channel channel, HttpRequest req) {
224 return handshake(channel, req, null, channel.newPromise());
225 }
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243 public final ChannelFuture handshake(final Channel channel, HttpRequest req,
244 final HttpHeaders responseHeaders, final ChannelPromise promise) {
245
246 if (req instanceof FullHttpRequest) {
247 return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
248 }
249 if (logger.isDebugEnabled()) {
250 logger.debug("{} WebSocket version {} server handshake", channel, version());
251 }
252 ChannelPipeline p = channel.pipeline();
253 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
254 if (ctx == null) {
255
256 ctx = p.context(HttpServerCodec.class);
257 if (ctx == null) {
258 promise.setFailure(
259 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
260 return promise;
261 }
262 }
263
264
265
266
267 String aggregatorName = "httpAggregator";
268 p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
269 p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
270 @Override
271 protected void messageReceived(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
272
273 ctx.pipeline().remove(this);
274 handshake(channel, msg, responseHeaders, promise);
275 }
276
277 @Override
278 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
279
280 ctx.pipeline().remove(this);
281 promise.tryFailure(cause);
282 ctx.fireExceptionCaught(cause);
283 }
284
285 @Override
286 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
287
288 promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
289 ctx.fireChannelInactive();
290 }
291 });
292 try {
293 ctx.fireChannelRead(ReferenceCountUtil.retain(req));
294 } catch (Throwable cause) {
295 promise.setFailure(cause);
296 }
297 return promise;
298 }
299
300
301
302
303 protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
304 HttpHeaders responseHeaders);
305
306
307
308
309
310
311
312
313 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
314 if (channel == null) {
315 throw new NullPointerException("channel");
316 }
317 return close(channel, frame, channel.newPromise());
318 }
319
320
321
322
323
324
325
326
327
328
329
330 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
331 if (channel == null) {
332 throw new NullPointerException("channel");
333 }
334 return channel.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
335 }
336
337
338
339
340
341
342
343
344 protected String selectSubprotocol(String requestedSubprotocols) {
345 if (requestedSubprotocols == null || subprotocols.length == 0) {
346 return null;
347 }
348
349 String[] requestedSubprotocolArray = StringUtil.split(requestedSubprotocols, ',');
350 for (String p: requestedSubprotocolArray) {
351 String requestedSubprotocol = p.trim();
352
353 for (String supportedSubprotocol: subprotocols) {
354 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
355 || requestedSubprotocol.equals(supportedSubprotocol)) {
356 selectedSubprotocol = requestedSubprotocol;
357 return requestedSubprotocol;
358 }
359 }
360 }
361
362
363 return null;
364 }
365
366
367
368
369
370
371
372 public String selectedSubprotocol() {
373 return selectedSubprotocol;
374 }
375
376
377
378
379 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
380
381
382
383
384 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
385
386 }