1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty.handler.codec.http;
16
17 import io.netty.buffer.ByteBuf;
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.util.ReferenceCountUtil;
23 import io.netty.util.ReferenceCounted;
24
25 import java.util.ArrayList;
26 import java.util.Collection;
27 import java.util.List;
28
29 import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
30 import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
31 import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase;
32 import static io.netty.util.AsciiString.containsContentEqualsIgnoreCase;
33 import static io.netty.util.internal.ObjectUtil.checkNotNull;
34 import static io.netty.util.internal.StringUtil.COMMA;
35
36
37
38
39
40
41 public class HttpServerUpgradeHandler extends HttpObjectAggregator {
42
43
44
45
46 public interface SourceCodec {
47
48
49
50 void upgradeFrom(ChannelHandlerContext ctx);
51 }
52
53
54
55
56 public interface UpgradeCodec {
57
58
59
60
61 Collection<CharSequence> requiredUpgradeHeaders();
62
63
64
65
66
67
68
69
70
71 boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest,
72 HttpHeaders upgradeHeaders);
73
74
75
76
77
78
79
80
81 void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest);
82 }
83
84
85
86
87 public interface UpgradeCodecFactory {
88
89
90
91
92
93
94
95 UpgradeCodec newUpgradeCodec(CharSequence protocol);
96 }
97
98
99
100
101
102
103 public static final class UpgradeEvent implements ReferenceCounted {
104 private final CharSequence protocol;
105 private final FullHttpRequest upgradeRequest;
106
107 UpgradeEvent(CharSequence protocol, FullHttpRequest upgradeRequest) {
108 this.protocol = protocol;
109 this.upgradeRequest = upgradeRequest;
110 }
111
112
113
114
115 public CharSequence protocol() {
116 return protocol;
117 }
118
119
120
121
122 public FullHttpRequest upgradeRequest() {
123 return upgradeRequest;
124 }
125
126 @Override
127 public int refCnt() {
128 return upgradeRequest.refCnt();
129 }
130
131 @Override
132 public UpgradeEvent retain() {
133 upgradeRequest.retain();
134 return this;
135 }
136
137 @Override
138 public UpgradeEvent retain(int increment) {
139 upgradeRequest.retain(increment);
140 return this;
141 }
142
143 @Override
144 public UpgradeEvent touch() {
145 upgradeRequest.touch();
146 return this;
147 }
148
149 @Override
150 public UpgradeEvent touch(Object hint) {
151 upgradeRequest.touch(hint);
152 return this;
153 }
154
155 @Override
156 public boolean release() {
157 return upgradeRequest.release();
158 }
159
160 @Override
161 public boolean release(int decrement) {
162 return upgradeRequest.release(decrement);
163 }
164
165 @Override
166 public String toString() {
167 return "UpgradeEvent [protocol=" + protocol + ", upgradeRequest=" + upgradeRequest + ']';
168 }
169 }
170
171 private final SourceCodec sourceCodec;
172 private final UpgradeCodecFactory upgradeCodecFactory;
173 private final HttpHeadersFactory headersFactory;
174 private final HttpHeadersFactory trailersFactory;
175 private boolean handlingUpgrade;
176 private boolean failedAggregationStart;
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192 public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory) {
193 this(sourceCodec, upgradeCodecFactory, 0,
194 DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory());
195 }
196
197
198
199
200
201
202
203
204
205 public HttpServerUpgradeHandler(
206 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength) {
207 this(sourceCodec, upgradeCodecFactory, maxContentLength,
208 DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory());
209 }
210
211
212
213
214
215
216
217
218
219
220 public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory,
221 int maxContentLength, boolean validateHeaders) {
222 this(sourceCodec, upgradeCodecFactory, maxContentLength,
223 DefaultHttpHeadersFactory.headersFactory().withValidation(validateHeaders),
224 DefaultHttpHeadersFactory.trailersFactory().withValidation(validateHeaders));
225 }
226
227
228
229
230
231
232
233
234
235
236
237
238
239 public HttpServerUpgradeHandler(
240 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength,
241 HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) {
242 super(maxContentLength);
243
244 this.sourceCodec = checkNotNull(sourceCodec, "sourceCodec");
245 this.upgradeCodecFactory = checkNotNull(upgradeCodecFactory, "upgradeCodecFactory");
246 this.headersFactory = checkNotNull(headersFactory, "headersFactory");
247 this.trailersFactory = checkNotNull(trailersFactory, "trailersFactory");
248 }
249
250 @Override
251 protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
252 throws Exception {
253
254 if (!handlingUpgrade) {
255
256 if (msg instanceof HttpRequest) {
257 HttpRequest req = (HttpRequest) msg;
258 if (req.headers().contains(HttpHeaderNames.UPGRADE) &&
259 shouldHandleUpgradeRequest(req)) {
260 handlingUpgrade = true;
261 failedAggregationStart = true;
262 } else {
263 ReferenceCountUtil.retain(msg);
264 ctx.fireChannelRead(msg);
265 return;
266 }
267 } else {
268 ReferenceCountUtil.retain(msg);
269 ctx.fireChannelRead(msg);
270 return;
271 }
272 }
273
274 FullHttpRequest fullRequest;
275 if (msg instanceof FullHttpRequest) {
276 fullRequest = (FullHttpRequest) msg;
277 ReferenceCountUtil.retain(msg);
278 out.add(msg);
279 } else {
280
281 super.decode(ctx, msg, out);
282 if (out.isEmpty()) {
283 if (msg instanceof LastHttpContent || failedAggregationStart) {
284
285 handlingUpgrade = false;
286 releaseCurrentMessage();
287 }
288
289
290 return;
291 }
292
293
294 assert out.size() == 1;
295 handlingUpgrade = false;
296 fullRequest = (FullHttpRequest) out.get(0);
297 }
298
299 if (upgrade(ctx, fullRequest)) {
300
301
302
303 out.clear();
304 }
305
306
307
308 }
309
310 @Override
311 protected FullHttpMessage beginAggregation(HttpMessage start, ByteBuf content) throws Exception {
312 failedAggregationStart = false;
313 return super.beginAggregation(start, content);
314 }
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329 protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
330 return true;
331 }
332
333
334
335
336
337
338
339
340
341 private boolean upgrade(final ChannelHandlerContext ctx, final FullHttpRequest request) {
342
343 final List<CharSequence> requestedProtocols = splitHeader(request.headers().get(HttpHeaderNames.UPGRADE));
344 final int numRequestedProtocols = requestedProtocols.size();
345 UpgradeCodec upgradeCodec = null;
346 CharSequence upgradeProtocol = null;
347 for (int i = 0; i < numRequestedProtocols; i ++) {
348 final CharSequence p = requestedProtocols.get(i);
349 final UpgradeCodec c = upgradeCodecFactory.newUpgradeCodec(p);
350 if (c != null) {
351 upgradeProtocol = p;
352 upgradeCodec = c;
353 break;
354 }
355 }
356
357 if (upgradeCodec == null) {
358
359 return false;
360 }
361
362
363 List<String> connectionHeaderValues = request.headers().getAll(HttpHeaderNames.CONNECTION);
364
365 if (connectionHeaderValues == null || connectionHeaderValues.isEmpty()) {
366 return false;
367 }
368
369 final StringBuilder concatenatedConnectionValue = new StringBuilder(connectionHeaderValues.size() * 10);
370 for (CharSequence connectionHeaderValue : connectionHeaderValues) {
371 concatenatedConnectionValue.append(connectionHeaderValue).append(COMMA);
372 }
373 concatenatedConnectionValue.setLength(concatenatedConnectionValue.length() - 1);
374
375
376 Collection<CharSequence> requiredHeaders = upgradeCodec.requiredUpgradeHeaders();
377 List<CharSequence> values = splitHeader(concatenatedConnectionValue);
378 if (!containsContentEqualsIgnoreCase(values, HttpHeaderNames.UPGRADE) ||
379 !containsAllContentEqualsIgnoreCase(values, requiredHeaders)) {
380 return false;
381 }
382
383
384 for (CharSequence requiredHeader : requiredHeaders) {
385 if (!request.headers().contains(requiredHeader)) {
386 return false;
387 }
388 }
389
390
391
392 final FullHttpResponse upgradeResponse = createUpgradeResponse(upgradeProtocol);
393 if (!upgradeCodec.prepareUpgradeResponse(ctx, request, upgradeResponse.headers())) {
394 return false;
395 }
396
397
398 final UpgradeEvent event = new UpgradeEvent(upgradeProtocol, request);
399
400
401
402
403
404 try {
405 final ChannelFuture writeComplete = ctx.writeAndFlush(upgradeResponse);
406
407 sourceCodec.upgradeFrom(ctx);
408 upgradeCodec.upgradeTo(ctx, request);
409
410
411 ctx.pipeline().remove(HttpServerUpgradeHandler.this);
412
413
414
415 ctx.fireUserEventTriggered(event.retain());
416
417
418
419
420 writeComplete.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
421 } finally {
422
423 event.release();
424 }
425 return true;
426 }
427
428
429
430
431 private FullHttpResponse createUpgradeResponse(CharSequence upgradeProtocol) {
432 DefaultFullHttpResponse res = new DefaultFullHttpResponse(
433 HTTP_1_1, SWITCHING_PROTOCOLS, Unpooled.EMPTY_BUFFER, headersFactory, trailersFactory);
434 res.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE);
435 res.headers().add(HttpHeaderNames.UPGRADE, upgradeProtocol);
436 return res;
437 }
438
439
440
441
442
443 private static List<CharSequence> splitHeader(CharSequence header) {
444 final StringBuilder builder = new StringBuilder(header.length());
445 final List<CharSequence> protocols = new ArrayList<CharSequence>(4);
446 for (int i = 0; i < header.length(); ++i) {
447 char c = header.charAt(i);
448 if (Character.isWhitespace(c)) {
449
450 continue;
451 }
452 if (c == ',') {
453
454 protocols.add(builder.toString());
455 builder.setLength(0);
456 } else {
457 builder.append(c);
458 }
459 }
460
461
462 if (builder.length() > 0) {
463 protocols.add(builder.toString());
464 }
465
466 return protocols;
467 }
468 }