1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty.handler.codec.http;
16
17 import io.netty.buffer.ByteBuf;
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.util.ReferenceCountUtil;
23 import io.netty.util.ReferenceCounted;
24
25 import java.util.ArrayList;
26 import java.util.Collection;
27 import java.util.List;
28
29 import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
30 import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
31 import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase;
32 import static io.netty.util.AsciiString.containsContentEqualsIgnoreCase;
33 import static io.netty.util.internal.ObjectUtil.checkNotNull;
34 import static io.netty.util.internal.StringUtil.COMMA;
35
36
37
38
39
40
41 public class HttpServerUpgradeHandler extends HttpObjectAggregator {
42
43
44
45
46 public interface SourceCodec {
47
48
49
50 void upgradeFrom(ChannelHandlerContext ctx);
51 }
52
53
54
55
56 public interface UpgradeCodec {
57
58
59
60
61 Collection<CharSequence> requiredUpgradeHeaders();
62
63
64
65
66
67
68
69
70
71 boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest,
72 HttpHeaders upgradeHeaders);
73
74
75
76
77
78
79
80
81 void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest);
82 }
83
84
85
86
87 public interface UpgradeCodecFactory {
88
89
90
91
92
93
94
95 UpgradeCodec newUpgradeCodec(CharSequence protocol);
96 }
97
98
99
100
101
102
103 public static final class UpgradeEvent implements ReferenceCounted {
104 private final CharSequence protocol;
105 private final FullHttpRequest upgradeRequest;
106
107 UpgradeEvent(CharSequence protocol, FullHttpRequest upgradeRequest) {
108 this.protocol = protocol;
109 this.upgradeRequest = upgradeRequest;
110 }
111
112
113
114
115 public CharSequence protocol() {
116 return protocol;
117 }
118
119
120
121
122 public FullHttpRequest upgradeRequest() {
123 return upgradeRequest;
124 }
125
126 @Override
127 public int refCnt() {
128 return upgradeRequest.refCnt();
129 }
130
131 @Override
132 public UpgradeEvent retain() {
133 upgradeRequest.retain();
134 return this;
135 }
136
137 @Override
138 public UpgradeEvent retain(int increment) {
139 upgradeRequest.retain(increment);
140 return this;
141 }
142
143 @Override
144 public UpgradeEvent touch() {
145 upgradeRequest.touch();
146 return this;
147 }
148
149 @Override
150 public UpgradeEvent touch(Object hint) {
151 upgradeRequest.touch(hint);
152 return this;
153 }
154
155 @Override
156 public boolean release() {
157 return upgradeRequest.release();
158 }
159
160 @Override
161 public boolean release(int decrement) {
162 return upgradeRequest.release(decrement);
163 }
164
165 @Override
166 public String toString() {
167 return "UpgradeEvent [protocol=" + protocol + ", upgradeRequest=" + upgradeRequest + ']';
168 }
169 }
170
171 private final SourceCodec sourceCodec;
172 private final UpgradeCodecFactory upgradeCodecFactory;
173 private final HttpHeadersFactory headersFactory;
174 private final HttpHeadersFactory trailersFactory;
175 private final boolean removeAfterFirstRequest;
176 private boolean handlingUpgrade;
177 private boolean failedAggregationStart;
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193 public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory) {
194 this(sourceCodec, upgradeCodecFactory, 0,
195 DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory());
196 }
197
198
199
200
201
202
203
204
205
206 public HttpServerUpgradeHandler(
207 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength) {
208 this(sourceCodec, upgradeCodecFactory, maxContentLength,
209 DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory());
210 }
211
212
213
214
215
216
217
218
219
220
221 public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory,
222 int maxContentLength, boolean validateHeaders) {
223 this(sourceCodec, upgradeCodecFactory, maxContentLength,
224 DefaultHttpHeadersFactory.headersFactory().withValidation(validateHeaders),
225 DefaultHttpHeadersFactory.trailersFactory().withValidation(validateHeaders));
226 }
227
228
229
230
231
232
233
234
235
236
237
238
239
240 public HttpServerUpgradeHandler(
241 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength,
242 HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) {
243 this(sourceCodec, upgradeCodecFactory, maxContentLength, headersFactory, trailersFactory, false);
244 }
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260 public HttpServerUpgradeHandler(
261 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength,
262 HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory, boolean removeAfterFirstRequest) {
263 super(maxContentLength);
264
265 this.sourceCodec = checkNotNull(sourceCodec, "sourceCodec");
266 this.upgradeCodecFactory = checkNotNull(upgradeCodecFactory, "upgradeCodecFactory");
267 this.headersFactory = checkNotNull(headersFactory, "headersFactory");
268 this.trailersFactory = checkNotNull(trailersFactory, "trailersFactory");
269 this.removeAfterFirstRequest = removeAfterFirstRequest;
270 }
271
272 @Override
273 protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
274 throws Exception {
275
276 if (!handlingUpgrade) {
277
278 if (msg instanceof HttpRequest) {
279 HttpRequest req = (HttpRequest) msg;
280 if (req.headers().contains(HttpHeaderNames.UPGRADE) &&
281 shouldHandleUpgradeRequest(req)) {
282 handlingUpgrade = true;
283 failedAggregationStart = true;
284 } else {
285 if (removeAfterFirstRequest) {
286
287 ctx.pipeline().remove(this);
288 }
289 ReferenceCountUtil.retain(msg);
290 ctx.fireChannelRead(msg);
291 return;
292 }
293 } else {
294 ReferenceCountUtil.retain(msg);
295 ctx.fireChannelRead(msg);
296 return;
297 }
298 }
299
300 FullHttpRequest fullRequest;
301 if (msg instanceof FullHttpRequest) {
302 fullRequest = (FullHttpRequest) msg;
303 ReferenceCountUtil.retain(msg);
304 out.add(msg);
305 } else {
306
307 super.decode(ctx, msg, out);
308 if (out.isEmpty()) {
309 if (msg instanceof LastHttpContent || failedAggregationStart) {
310
311 handlingUpgrade = false;
312 releaseCurrentMessage();
313 }
314
315
316 return;
317 }
318
319
320 assert out.size() == 1;
321 handlingUpgrade = false;
322 fullRequest = (FullHttpRequest) out.get(0);
323 }
324
325 if (upgrade(ctx, fullRequest)) {
326
327
328
329 out.clear();
330 } else if (removeAfterFirstRequest) {
331
332 ctx.pipeline().remove(this);
333 }
334
335
336
337 }
338
339 @Override
340 protected FullHttpMessage beginAggregation(HttpMessage start, ByteBuf content) throws Exception {
341 failedAggregationStart = false;
342 return super.beginAggregation(start, content);
343 }
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358 protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
359 return true;
360 }
361
362
363
364
365
366
367
368
369
370 private boolean upgrade(final ChannelHandlerContext ctx, final FullHttpRequest request) {
371
372 final List<CharSequence> requestedProtocols = splitHeader(request.headers().get(HttpHeaderNames.UPGRADE));
373 final int numRequestedProtocols = requestedProtocols.size();
374 UpgradeCodec upgradeCodec = null;
375 CharSequence upgradeProtocol = null;
376 for (int i = 0; i < numRequestedProtocols; i ++) {
377 final CharSequence p = requestedProtocols.get(i);
378 final UpgradeCodec c = upgradeCodecFactory.newUpgradeCodec(p);
379 if (c != null) {
380 upgradeProtocol = p;
381 upgradeCodec = c;
382 break;
383 }
384 }
385
386 if (upgradeCodec == null) {
387
388 return false;
389 }
390
391
392 List<String> connectionHeaderValues = request.headers().getAll(HttpHeaderNames.CONNECTION);
393
394 if (connectionHeaderValues == null || connectionHeaderValues.isEmpty()) {
395 return false;
396 }
397
398 final StringBuilder concatenatedConnectionValue = new StringBuilder(connectionHeaderValues.size() * 10);
399 for (CharSequence connectionHeaderValue : connectionHeaderValues) {
400 concatenatedConnectionValue.append(connectionHeaderValue).append(COMMA);
401 }
402 concatenatedConnectionValue.setLength(concatenatedConnectionValue.length() - 1);
403
404
405 Collection<CharSequence> requiredHeaders = upgradeCodec.requiredUpgradeHeaders();
406 List<CharSequence> values = splitHeader(concatenatedConnectionValue);
407 if (!containsContentEqualsIgnoreCase(values, HttpHeaderNames.UPGRADE) ||
408 !containsAllContentEqualsIgnoreCase(values, requiredHeaders)) {
409 return false;
410 }
411
412
413 for (CharSequence requiredHeader : requiredHeaders) {
414 if (!request.headers().contains(requiredHeader)) {
415 return false;
416 }
417 }
418
419
420
421 final FullHttpResponse upgradeResponse = createUpgradeResponse(upgradeProtocol);
422 if (!upgradeCodec.prepareUpgradeResponse(ctx, request, upgradeResponse.headers())) {
423 return false;
424 }
425
426
427 final UpgradeEvent event = new UpgradeEvent(upgradeProtocol, request);
428
429
430
431
432
433 try {
434 final ChannelFuture writeComplete = ctx.writeAndFlush(upgradeResponse);
435
436 sourceCodec.upgradeFrom(ctx);
437 upgradeCodec.upgradeTo(ctx, request);
438
439
440 ctx.pipeline().remove(HttpServerUpgradeHandler.this);
441
442
443
444 ctx.fireUserEventTriggered(event.retain());
445
446
447
448
449 writeComplete.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
450 } finally {
451
452 event.release();
453 }
454 return true;
455 }
456
457
458
459
460 private FullHttpResponse createUpgradeResponse(CharSequence upgradeProtocol) {
461 DefaultFullHttpResponse res = new DefaultFullHttpResponse(
462 HTTP_1_1, SWITCHING_PROTOCOLS, Unpooled.EMPTY_BUFFER, headersFactory, trailersFactory);
463 res.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE);
464 res.headers().add(HttpHeaderNames.UPGRADE, upgradeProtocol);
465 return res;
466 }
467
468
469
470
471
472 private static List<CharSequence> splitHeader(CharSequence header) {
473 final StringBuilder builder = new StringBuilder(header.length());
474 final List<CharSequence> protocols = new ArrayList<CharSequence>(4);
475 for (int i = 0; i < header.length(); ++i) {
476 char c = header.charAt(i);
477 if (Character.isWhitespace(c)) {
478
479 continue;
480 }
481 if (c == ',') {
482
483 protocols.add(builder.toString());
484 builder.setLength(0);
485 } else {
486 builder.append(c);
487 }
488 }
489
490
491 if (builder.length() > 0) {
492 protocols.add(builder.toString());
493 }
494
495 return protocols;
496 }
497 }