1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty.handler.codec.http;
16
17 import io.netty.buffer.Unpooled;
18 import io.netty.channel.ChannelFuture;
19 import io.netty.channel.ChannelFutureListener;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.util.ReferenceCountUtil;
22 import io.netty.util.ReferenceCounted;
23
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.List;
27
28 import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
29 import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
30 import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase;
31 import static io.netty.util.AsciiString.containsContentEqualsIgnoreCase;
32 import static io.netty.util.internal.ObjectUtil.checkNotNull;
33 import static io.netty.util.internal.StringUtil.COMMA;
34
35
36
37
38
39
40 public class HttpServerUpgradeHandler extends HttpObjectAggregator {
41
42
43
44
45 public interface SourceCodec {
46
47
48
49 void upgradeFrom(ChannelHandlerContext ctx);
50 }
51
52
53
54
55 public interface UpgradeCodec {
56
57
58
59
60 Collection<CharSequence> requiredUpgradeHeaders();
61
62
63
64
65
66
67
68
69
70 boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest,
71 HttpHeaders upgradeHeaders);
72
73
74
75
76
77
78
79
80 void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest);
81 }
82
83
84
85
86 public interface UpgradeCodecFactory {
87
88
89
90
91
92
93
94 UpgradeCodec newUpgradeCodec(CharSequence protocol);
95 }
96
97
98
99
100
101
102 public static final class UpgradeEvent implements ReferenceCounted {
103 private final CharSequence protocol;
104 private final FullHttpRequest upgradeRequest;
105
106 UpgradeEvent(CharSequence protocol, FullHttpRequest upgradeRequest) {
107 this.protocol = protocol;
108 this.upgradeRequest = upgradeRequest;
109 }
110
111
112
113
114 public CharSequence protocol() {
115 return protocol;
116 }
117
118
119
120
121 public FullHttpRequest upgradeRequest() {
122 return upgradeRequest;
123 }
124
125 @Override
126 public int refCnt() {
127 return upgradeRequest.refCnt();
128 }
129
130 @Override
131 public UpgradeEvent retain() {
132 upgradeRequest.retain();
133 return this;
134 }
135
136 @Override
137 public UpgradeEvent retain(int increment) {
138 upgradeRequest.retain(increment);
139 return this;
140 }
141
142 @Override
143 public UpgradeEvent touch() {
144 upgradeRequest.touch();
145 return this;
146 }
147
148 @Override
149 public UpgradeEvent touch(Object hint) {
150 upgradeRequest.touch(hint);
151 return this;
152 }
153
154 @Override
155 public boolean release() {
156 return upgradeRequest.release();
157 }
158
159 @Override
160 public boolean release(int decrement) {
161 return upgradeRequest.release(decrement);
162 }
163
164 @Override
165 public String toString() {
166 return "UpgradeEvent [protocol=" + protocol + ", upgradeRequest=" + upgradeRequest + ']';
167 }
168 }
169
170 private final SourceCodec sourceCodec;
171 private final UpgradeCodecFactory upgradeCodecFactory;
172 private final HttpHeadersFactory headersFactory;
173 private final HttpHeadersFactory trailersFactory;
174 private boolean handlingUpgrade;
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190 public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory) {
191 this(sourceCodec, upgradeCodecFactory, 0,
192 DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory());
193 }
194
195
196
197
198
199
200
201
202
203 public HttpServerUpgradeHandler(
204 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength) {
205 this(sourceCodec, upgradeCodecFactory, maxContentLength,
206 DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory());
207 }
208
209
210
211
212
213
214
215
216
217
218 public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory,
219 int maxContentLength, boolean validateHeaders) {
220 this(sourceCodec, upgradeCodecFactory, maxContentLength,
221 DefaultHttpHeadersFactory.headersFactory().withValidation(validateHeaders),
222 DefaultHttpHeadersFactory.trailersFactory().withValidation(validateHeaders));
223 }
224
225
226
227
228
229
230
231
232
233
234
235
236
237 public HttpServerUpgradeHandler(
238 SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength,
239 HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) {
240 super(maxContentLength);
241
242 this.sourceCodec = checkNotNull(sourceCodec, "sourceCodec");
243 this.upgradeCodecFactory = checkNotNull(upgradeCodecFactory, "upgradeCodecFactory");
244 this.headersFactory = checkNotNull(headersFactory, "headersFactory");
245 this.trailersFactory = checkNotNull(trailersFactory, "trailersFactory");
246 }
247
248 @Override
249 protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
250 throws Exception {
251
252 if (!handlingUpgrade) {
253
254 if (msg instanceof HttpRequest) {
255 HttpRequest req = (HttpRequest) msg;
256 if (req.headers().contains(HttpHeaderNames.UPGRADE) &&
257 shouldHandleUpgradeRequest(req)) {
258 handlingUpgrade = true;
259 } else {
260 ReferenceCountUtil.retain(msg);
261 ctx.fireChannelRead(msg);
262 return;
263 }
264 } else {
265 ReferenceCountUtil.retain(msg);
266 ctx.fireChannelRead(msg);
267 return;
268 }
269 }
270
271 FullHttpRequest fullRequest;
272 if (msg instanceof FullHttpRequest) {
273 fullRequest = (FullHttpRequest) msg;
274 ReferenceCountUtil.retain(msg);
275 out.add(msg);
276 } else {
277
278 super.decode(ctx, msg, out);
279 if (out.isEmpty()) {
280 if (msg instanceof LastHttpContent) {
281
282 handlingUpgrade = false;
283 releaseCurrentMessage();
284 }
285
286
287 return;
288 }
289
290
291 assert out.size() == 1;
292 handlingUpgrade = false;
293 fullRequest = (FullHttpRequest) out.get(0);
294 }
295
296 if (upgrade(ctx, fullRequest)) {
297
298
299
300 out.clear();
301 }
302
303
304
305 }
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320 protected boolean shouldHandleUpgradeRequest(HttpRequest req) {
321 return true;
322 }
323
324
325
326
327
328
329
330
331
332 private boolean upgrade(final ChannelHandlerContext ctx, final FullHttpRequest request) {
333
334 final List<CharSequence> requestedProtocols = splitHeader(request.headers().get(HttpHeaderNames.UPGRADE));
335 final int numRequestedProtocols = requestedProtocols.size();
336 UpgradeCodec upgradeCodec = null;
337 CharSequence upgradeProtocol = null;
338 for (int i = 0; i < numRequestedProtocols; i ++) {
339 final CharSequence p = requestedProtocols.get(i);
340 final UpgradeCodec c = upgradeCodecFactory.newUpgradeCodec(p);
341 if (c != null) {
342 upgradeProtocol = p;
343 upgradeCodec = c;
344 break;
345 }
346 }
347
348 if (upgradeCodec == null) {
349
350 return false;
351 }
352
353
354 List<String> connectionHeaderValues = request.headers().getAll(HttpHeaderNames.CONNECTION);
355
356 if (connectionHeaderValues == null || connectionHeaderValues.isEmpty()) {
357 return false;
358 }
359
360 final StringBuilder concatenatedConnectionValue = new StringBuilder(connectionHeaderValues.size() * 10);
361 for (CharSequence connectionHeaderValue : connectionHeaderValues) {
362 concatenatedConnectionValue.append(connectionHeaderValue).append(COMMA);
363 }
364 concatenatedConnectionValue.setLength(concatenatedConnectionValue.length() - 1);
365
366
367 Collection<CharSequence> requiredHeaders = upgradeCodec.requiredUpgradeHeaders();
368 List<CharSequence> values = splitHeader(concatenatedConnectionValue);
369 if (!containsContentEqualsIgnoreCase(values, HttpHeaderNames.UPGRADE) ||
370 !containsAllContentEqualsIgnoreCase(values, requiredHeaders)) {
371 return false;
372 }
373
374
375 for (CharSequence requiredHeader : requiredHeaders) {
376 if (!request.headers().contains(requiredHeader)) {
377 return false;
378 }
379 }
380
381
382
383 final FullHttpResponse upgradeResponse = createUpgradeResponse(upgradeProtocol);
384 if (!upgradeCodec.prepareUpgradeResponse(ctx, request, upgradeResponse.headers())) {
385 return false;
386 }
387
388
389 final UpgradeEvent event = new UpgradeEvent(upgradeProtocol, request);
390
391
392
393
394
395 try {
396 final ChannelFuture writeComplete = ctx.writeAndFlush(upgradeResponse);
397
398 sourceCodec.upgradeFrom(ctx);
399 upgradeCodec.upgradeTo(ctx, request);
400
401
402 ctx.pipeline().remove(HttpServerUpgradeHandler.this);
403
404
405
406 ctx.fireUserEventTriggered(event.retain());
407
408
409
410
411 writeComplete.addListener(ChannelFutureListener.CLOSE_ON_FAILURE);
412 } finally {
413
414 event.release();
415 }
416 return true;
417 }
418
419
420
421
422 private FullHttpResponse createUpgradeResponse(CharSequence upgradeProtocol) {
423 DefaultFullHttpResponse res = new DefaultFullHttpResponse(
424 HTTP_1_1, SWITCHING_PROTOCOLS, Unpooled.EMPTY_BUFFER, headersFactory, trailersFactory);
425 res.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE);
426 res.headers().add(HttpHeaderNames.UPGRADE, upgradeProtocol);
427 return res;
428 }
429
430
431
432
433
434 private static List<CharSequence> splitHeader(CharSequence header) {
435 final StringBuilder builder = new StringBuilder(header.length());
436 final List<CharSequence> protocols = new ArrayList<CharSequence>(4);
437 for (int i = 0; i < header.length(); ++i) {
438 char c = header.charAt(i);
439 if (Character.isWhitespace(c)) {
440
441 continue;
442 }
443 if (c == ',') {
444
445 protocols.add(builder.toString());
446 builder.setLength(0);
447 } else {
448 builder.append(c);
449 }
450 }
451
452
453 if (builder.length() > 0) {
454 protocols.add(builder.toString());
455 }
456
457 return protocols;
458 }
459 }