1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.channel.ChannelFuture;
19 import io.netty.channel.ChannelFutureListener;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelInboundHandlerAdapter;
22 import io.netty.channel.ChannelPipeline;
23 import io.netty.handler.codec.http.DefaultFullHttpResponse;
24 import io.netty.handler.codec.http.FullHttpRequest;
25 import io.netty.handler.codec.http.HttpRequest;
26 import io.netty.handler.codec.http.HttpResponse;
27 import io.netty.handler.ssl.SslHandler;
28
29 import static io.netty.handler.codec.http.HttpHeaders.*;
30 import static io.netty.handler.codec.http.HttpMethod.*;
31 import static io.netty.handler.codec.http.HttpResponseStatus.*;
32 import static io.netty.handler.codec.http.HttpVersion.*;
33
34
35
36
37 class WebSocketServerProtocolHandshakeHandler
38 extends ChannelInboundHandlerAdapter {
39
40 private final String websocketPath;
41 private final String subprotocols;
42 private final boolean allowExtensions;
43 private final int maxFramePayloadSize;
44 private final boolean checkStartsWith;
45
46 WebSocketServerProtocolHandshakeHandler(String websocketPath, String subprotocols,
47 boolean allowExtensions, int maxFrameSize, boolean checkStartsWith) {
48 this.websocketPath = websocketPath;
49 this.subprotocols = subprotocols;
50 this.allowExtensions = allowExtensions;
51 maxFramePayloadSize = maxFrameSize;
52 this.checkStartsWith = checkStartsWith;
53 }
54
55 @Override
56 public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
57 final FullHttpRequest req = (FullHttpRequest) msg;
58 if (checkStartsWith) {
59 if (!req.getUri().startsWith(websocketPath)) {
60 ctx.fireChannelRead(msg);
61 return;
62 }
63 } else if (!req.getUri().equals(websocketPath)) {
64 ctx.fireChannelRead(msg);
65 return;
66 }
67
68 try {
69 if (req.getMethod() != GET) {
70 sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN));
71 return;
72 }
73
74 final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
75 getWebSocketLocation(ctx.pipeline(), req, websocketPath), subprotocols,
76 allowExtensions, maxFramePayloadSize);
77 final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
78 if (handshaker == null) {
79 WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
80 } else {
81 final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
82 handshakeFuture.addListener(new ChannelFutureListener() {
83 @Override
84 public void operationComplete(ChannelFuture future) throws Exception {
85 if (!future.isSuccess()) {
86 ctx.fireExceptionCaught(future.cause());
87 } else {
88 ctx.fireUserEventTriggered(
89 WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
90 }
91 }
92 });
93 WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
94 ctx.pipeline().replace(this, "WS403Responder",
95 WebSocketServerProtocolHandler.forbiddenHttpRequestResponder());
96 }
97 } finally {
98 req.release();
99 }
100 }
101
102 private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
103 ChannelFuture f = ctx.channel().writeAndFlush(res);
104 if (!isKeepAlive(req) || res.getStatus().code() != 200) {
105 f.addListener(ChannelFutureListener.CLOSE);
106 }
107 }
108
109 private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
110 String protocol = "ws";
111 if (cp.get(SslHandler.class) != null) {
112
113 protocol = "wss";
114 }
115 return protocol + "://" + req.headers().get(Names.HOST) + path;
116 }
117
118 }