1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.http.websocketx;
17
18 import io.netty5.buffer.api.BufferAllocator;
19 import io.netty5.channel.Channel;
20 import io.netty5.channel.ChannelFutureListeners;
21 import io.netty5.channel.ChannelHandler;
22 import io.netty5.channel.ChannelHandlerContext;
23 import io.netty5.channel.ChannelOutboundInvoker;
24 import io.netty5.channel.ChannelPipeline;
25 import io.netty5.channel.SimpleChannelInboundHandler;
26 import io.netty5.handler.codec.http.FullHttpRequest;
27 import io.netty5.handler.codec.http.FullHttpResponse;
28 import io.netty5.handler.codec.http.HttpContentCompressor;
29 import io.netty5.handler.codec.http.HttpHeaders;
30 import io.netty5.handler.codec.http.HttpObjectAggregator;
31 import io.netty5.handler.codec.http.HttpRequest;
32 import io.netty5.handler.codec.http.HttpRequestDecoder;
33 import io.netty5.handler.codec.http.HttpResponseEncoder;
34 import io.netty5.handler.codec.http.HttpServerCodec;
35 import io.netty5.util.ReferenceCountUtil;
36 import io.netty5.util.concurrent.Future;
37 import io.netty5.util.concurrent.Promise;
38 import io.netty5.util.internal.EmptyArrays;
39 import io.netty5.util.internal.logging.InternalLogger;
40 import io.netty5.util.internal.logging.InternalLoggerFactory;
41
42 import java.nio.channels.ClosedChannelException;
43 import java.util.Collections;
44 import java.util.LinkedHashSet;
45 import java.util.Set;
46
47 import static java.util.Objects.requireNonNull;
48
49
50
51
52 public abstract class WebSocketServerHandshaker {
53 protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
54
55 private final String uri;
56
57 private final String[] subprotocols;
58
59 private final WebSocketVersion version;
60
61 private final WebSocketDecoderConfig decoderConfig;
62
63 private String selectedSubprotocol;
64
65
66
67
68 public static final String SUB_PROTOCOL_WILDCARD = "*";
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83 protected WebSocketServerHandshaker(
84 WebSocketVersion version, String uri, String subprotocols,
85 int maxFramePayloadLength) {
86 this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
87 .maxFramePayloadLength(maxFramePayloadLength)
88 .build());
89 }
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104 protected WebSocketServerHandshaker(
105 WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
106 this.version = version;
107 this.uri = uri;
108 if (subprotocols != null) {
109 String[] subprotocolArray = subprotocols.split(",");
110 for (int i = 0; i < subprotocolArray.length; i++) {
111 subprotocolArray[i] = subprotocolArray[i].trim();
112 }
113 this.subprotocols = subprotocolArray;
114 } else {
115 this.subprotocols = EmptyArrays.EMPTY_STRINGS;
116 }
117 this.decoderConfig = requireNonNull(decoderConfig, "decoderConfig");
118 }
119
120
121
122
123 public String uri() {
124 return uri;
125 }
126
127
128
129
130 public Set<String> subprotocols() {
131 Set<String> ret = new LinkedHashSet<>();
132 Collections.addAll(ret, subprotocols);
133 return ret;
134 }
135
136
137
138
139 public WebSocketVersion version() {
140 return version;
141 }
142
143
144
145
146
147
148 public int maxFramePayloadLength() {
149 return decoderConfig.maxFramePayloadLength();
150 }
151
152
153
154
155
156
157 public WebSocketDecoderConfig decoderConfig() {
158 return decoderConfig;
159 }
160
161
162
163
164
165
166
167
168
169
170
171
172 public Future<Void> handshake(Channel channel, FullHttpRequest req) {
173 return handshake(channel, req, null);
174 }
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190 public final Future<Void> handshake(Channel channel, FullHttpRequest req, HttpHeaders responseHeaders) {
191
192 if (logger.isDebugEnabled()) {
193 logger.debug("{} WebSocket version {} server handshake", channel, version());
194 }
195 FullHttpResponse response = newHandshakeResponse(channel.bufferAllocator(), req, responseHeaders);
196 ChannelPipeline p = channel.pipeline();
197 if (p.get(HttpObjectAggregator.class) != null) {
198 p.remove(HttpObjectAggregator.class);
199 }
200 if (p.get(HttpContentCompressor.class) != null) {
201 p.remove(HttpContentCompressor.class);
202 }
203 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
204 final String encoderName;
205 if (ctx == null) {
206
207 ctx = p.context(HttpServerCodec.class);
208 if (ctx == null) {
209 return channel.newFailedFuture(
210 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
211 }
212 p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
213 p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
214 encoderName = ctx.name();
215 } else {
216 p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
217
218 encoderName = p.context(HttpResponseEncoder.class).name();
219 p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
220 }
221 return channel.writeAndFlush(response).addListener(channel, (ch, future) -> {
222 if (future.isSuccess()) {
223 ChannelPipeline p1 = ch.pipeline();
224 p1.remove(encoderName);
225 }
226 });
227 }
228
229
230
231
232
233
234
235
236
237
238
239
240 public Future<Void> handshake(Channel channel, HttpRequest req) {
241 return handshake(channel, req, null);
242 }
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258 public final Future<Void> handshake(final Channel channel, HttpRequest req,
259 final HttpHeaders responseHeaders) {
260
261 if (req instanceof FullHttpRequest) {
262 return handshake(channel, (FullHttpRequest) req, responseHeaders);
263 }
264 if (logger.isDebugEnabled()) {
265 logger.debug("{} WebSocket version {} server handshake", channel, version());
266 }
267 ChannelPipeline p = channel.pipeline();
268 ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
269 if (ctx == null) {
270
271 ctx = p.context(HttpServerCodec.class);
272 if (ctx == null) {
273 return channel.newFailedFuture(
274 new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
275 }
276 }
277
278 Promise<Void> promise = channel.newPromise();
279
280
281
282
283 String aggregatorName = "httpAggregator";
284 p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
285 p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
286 @Override
287 protected void messageReceived(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
288
289 ctx.pipeline().remove(this);
290 handshake(channel, msg, responseHeaders);
291 }
292
293 @Override
294 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
295
296 promise.tryFailure(cause);
297 ctx.fireChannelExceptionCaught(cause);
298 ctx.pipeline().remove(this);
299 }
300
301 @Override
302 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
303
304 if (!promise.isDone()) {
305 promise.tryFailure(new ClosedChannelException());
306 }
307 ctx.fireChannelInactive();
308 }
309 });
310 try {
311 ctx.fireChannelRead(ReferenceCountUtil.retain(req));
312 } catch (Throwable cause) {
313 promise.setFailure(cause);
314 }
315 return promise.asFuture();
316 }
317
318
319
320
321 protected abstract FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,
322 HttpHeaders responseHeaders);
323
324
325
326
327
328
329
330
331
332
333
334 public Future<Void> close(Channel channel, CloseWebSocketFrame frame) {
335 requireNonNull(channel, "channel");
336 return close0(channel, frame);
337 }
338
339
340
341
342
343
344
345
346
347 public Future<Void> close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
348 requireNonNull(ctx, "ctx");
349 return close0(ctx, frame);
350 }
351
352 private static Future<Void> close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame) {
353 return invoker.writeAndFlush(frame).addListener(invoker, ChannelFutureListeners.CLOSE);
354 }
355
356
357
358
359
360
361
362
363 protected String selectSubprotocol(String requestedSubprotocols) {
364 if (requestedSubprotocols == null || subprotocols.length == 0) {
365 return null;
366 }
367
368 String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
369 for (String p: requestedSubprotocolArray) {
370 String requestedSubprotocol = p.trim();
371
372 for (String supportedSubprotocol: subprotocols) {
373 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
374 || requestedSubprotocol.equals(supportedSubprotocol)) {
375 selectedSubprotocol = requestedSubprotocol;
376 return requestedSubprotocol;
377 }
378 }
379 }
380
381
382 return null;
383 }
384
385
386
387
388
389
390
391 public String selectedSubprotocol() {
392 return selectedSubprotocol;
393 }
394
395
396
397
398 protected abstract WebSocketFrameDecoder newWebsocketDecoder();
399
400
401
402
403 protected abstract WebSocketFrameEncoder newWebSocketEncoder();
404
405 }