1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.channel.Channel;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPipeline;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.channel.SimpleChannelInboundHandler;
25 import io.netty.handler.codec.http.FullHttpRequest;
26 import io.netty.handler.codec.http.FullHttpResponse;
27 import io.netty.handler.codec.http.HttpContentCompressor;
28 import io.netty.handler.codec.http.HttpHeaders;
29 import io.netty.handler.codec.http.HttpObjectAggregator;
30 import io.netty.handler.codec.http.HttpRequest;
31 import io.netty.handler.codec.http.HttpRequestDecoder;
32 import io.netty.handler.codec.http.HttpResponseEncoder;
33 import io.netty.handler.codec.http.HttpServerCodec;
34 import io.netty.util.ReferenceCountUtil;
35 import io.netty.util.internal.EmptyArrays;
36 import io.netty.util.internal.ThrowableUtil;
37 import io.netty.util.internal.logging.InternalLogger;
38 import io.netty.util.internal.logging.InternalLoggerFactory;
39
40 import java.nio.channels.ClosedChannelException;
41 import java.util.Collections;
42 import java.util.LinkedHashSet;
43 import java.util.Set;
44
45
46
47
48 public abstract class WebSocketServerHandshaker {
49 protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
50 private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = ThrowableUtil.unknownStackTrace(
51 new ClosedChannelException(), WebSocketServerHandshaker.class, "handshake(...)");
52
53 private final String uri;
54
55 private final String[] subprotocols;
56
57 private final WebSocketVersion version;
58
59 private final int maxFramePayloadLength;
60
61 private String selectedSubprotocol;
62
63
64
65
66 public static final String SUB_PROTOCOL_WILDCARD = "*";
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 protected WebSocketServerHandshaker(
82 WebSocketVersion version, String uri, String subprotocols,
83 int maxFramePayloadLength) {
84 this.version = version;
85 this.uri = uri;
86 if (subprotocols != null) {
87 String[] subprotocolArray = subprotocols.split(",");
88 for (int i = 0; i < subprotocolArray.length; i++) {
89 subprotocolArray[i] = subprotocolArray[i].trim();
90 }
91 this.subprotocols = subprotocolArray;
92 } else {
93 this.subprotocols = EmptyArrays.EMPTY_STRINGS;
94 }
95 this.maxFramePayloadLength = maxFramePayloadLength;
96 }
97
98
99
100
101 public String uri() {
102 return uri;
103 }
104
105
106
107
108 public Set<String> subprotocols() {
109 Set<String> ret = new LinkedHashSet<String>();
110 Collections.addAll(ret, subprotocols);
111 return ret;
112 }
113
114
115
116
117 public WebSocketVersion version() {
118 return version;
119 }
120
121
122
123
124
125
126 public int maxFramePayloadLength() {
127 return maxFramePayloadLength;
128 }
129
130
131
132
133
134
135
136
137
138
139
140
141 public ChannelFuture handshake(Channel channel, FullHttpRequest req) {
142 return handshake(channel, req, null, channel.newPromise());
143 }
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161 public final ChannelFuture handshake(Channel channel, FullHttpRequest req,
162 HttpHeaders responseHeaders, final ChannelPromise promise) {
163
164 if (logger.isDebugEnabled()) {
165 logger.debug("{} WebSocket version {} server handshake", channel, version());
166 }
167 FullHttpResponse response = newHandshakeResponse(req, responseHeaders);
168 ChannelPipeline p = channel.pipeline();
169 if (p.get(HttpObjectAggregator.class) != null) {
170 p.remove(HttpObjectAggregator.class);
171 }
172 if (p.get(HttpContentCompressor.class) != null) {
173 p.remove(HttpContentCompressor.class);
174 }
175 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
176 final String encoderName;
177 if (ctx == null) {
178
179 ctx = p.context(HttpServerCodec.class);
180 if (ctx == null) {
181 promise.setFailure(
182 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
183 return promise;
184 }
185 p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
186 p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
187 encoderName = ctx.name();
188 } else {
189 p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
190
191 encoderName = p.context(HttpResponseEncoder.class).name();
192 p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
193 }
194 channel.writeAndFlush(response).addListener(new ChannelFutureListener() {
195 @Override
196 public void operationComplete(ChannelFuture future) throws Exception {
197 if (future.isSuccess()) {
198 ChannelPipeline p = future.channel().pipeline();
199 p.remove(encoderName);
200 promise.setSuccess();
201 } else {
202 promise.setFailure(future.cause());
203 }
204 }
205 });
206 return promise;
207 }
208
209
210
211
212
213
214
215
216
217
218
219
220 public ChannelFuture handshake(Channel channel, HttpRequest req) {
221 return handshake(channel, req, null, channel.newPromise());
222 }
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240 public final ChannelFuture handshake(final Channel channel, HttpRequest req,
241 final HttpHeaders responseHeaders, final ChannelPromise promise) {
242
243 if (req instanceof FullHttpRequest) {
244 return handshake(channel, (FullHttpRequest) req, responseHeaders, promise);
245 }
246 if (logger.isDebugEnabled()) {
247 logger.debug("{} WebSocket version {} server handshake", channel, version());
248 }
249 ChannelPipeline p = channel.pipeline();
250 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
251 if (ctx == null) {
252
253 ctx = p.context(HttpServerCodec.class);
254 if (ctx == null) {
255 promise.setFailure(
256 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
257 return promise;
258 }
259 }
260
261
262
263
264 String aggregatorName = "httpAggregator";
265 p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
266 p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
267 @Override
268 protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
269
270 ctx.pipeline().remove(this);
271 handshake(channel, msg, responseHeaders, promise);
272 }
273
274 @Override
275 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
276
277 ctx.pipeline().remove(this);
278 promise.tryFailure(cause);
279 ctx.fireExceptionCaught(cause);
280 }
281
282 @Override
283 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
284
285 promise.tryFailure(CLOSED_CHANNEL_EXCEPTION);
286 ctx.fireChannelInactive();
287 }
288 });
289 try {
290 ctx.fireChannelRead(ReferenceCountUtil.retain(req));
291 } catch (Throwable cause) {
292 promise.setFailure(cause);
293 }
294 return promise;
295 }
296
297
298
299
300 protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req,
301 HttpHeaders responseHeaders);
302
303
304
305
306
307
308
309
310 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) {
311 if (channel == null) {
312 throw new NullPointerException("channel");
313 }
314 return close(channel, frame, channel.newPromise());
315 }
316
317
318
319
320
321
322
323
324
325
326
327 public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) {
328 if (channel == null) {
329 throw new NullPointerException("channel");
330 }
331 return channel.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE);
332 }
333
334
335
336
337
338
339
340
341 protected String selectSubprotocol(String requestedSubprotocols) {
342 if (requestedSubprotocols == null || subprotocols.length == 0) {
343 return null;
344 }
345
346 String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
347 for (String p: requestedSubprotocolArray) {
348 String requestedSubprotocol = p.trim();
349
350 for (String supportedSubprotocol: subprotocols) {
351 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
352 || requestedSubprotocol.equals(supportedSubprotocol)) {
353 selectedSubprotocol = requestedSubprotocol;
354 return requestedSubprotocol;
355 }
356 }
357 }
358
359
360 return null;
361 }
362
363
364
365
366
367
368
369 public String selectedSubprotocol() {
370 return selectedSubprotocol;
371 }
372
373
374
375
376 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
377
378
379
380
381 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
382
383 }