1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandler;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInboundHandler;
24 import io.netty.channel.ChannelInboundHandlerAdapter;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.handler.codec.http.DefaultFullHttpResponse;
27 import io.netty.handler.codec.http.FullHttpRequest;
28 import io.netty.handler.codec.http.FullHttpResponse;
29 import io.netty.handler.codec.http.HttpResponseStatus;
30 import io.netty.util.AttributeKey;
31
32 import java.util.List;
33
34 import static io.netty.handler.codec.http.HttpVersion.*;
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
53
54
55
56
57 public enum ServerHandshakeStateEvent {
58
59
60
61 HANDSHAKE_COMPLETE
62 }
63
64 private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
65 AttributeKey.valueOf(WebSocketServerHandshaker.class.getName() + ".HANDSHAKER");
66
67 private final String websocketPath;
68 private final String subprotocols;
69 private final boolean allowExtensions;
70 private final int maxFramePayloadLength;
71 private final boolean checkStartsWith;
72
73 public WebSocketServerProtocolHandler(String websocketPath) {
74 this(websocketPath, null, false);
75 }
76
77 public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
78 this(websocketPath, null, false, 65536, checkStartsWith);
79 }
80
81 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
82 this(websocketPath, subprotocols, false);
83 }
84
85 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
86 this(websocketPath, subprotocols, allowExtensions, 65536);
87 }
88
89 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
90 boolean allowExtensions, int maxFrameSize) {
91 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false);
92 }
93
94 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
95 boolean allowExtensions, int maxFrameSize, boolean checkStartsWith) {
96 this.websocketPath = websocketPath;
97 this.subprotocols = subprotocols;
98 this.allowExtensions = allowExtensions;
99 maxFramePayloadLength = maxFrameSize;
100 this.checkStartsWith = checkStartsWith;
101 }
102
103 @Override
104 public void handlerAdded(ChannelHandlerContext ctx) {
105 ChannelPipeline cp = ctx.pipeline();
106 if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
107
108 ctx.pipeline().addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
109 new WebSocketServerProtocolHandshakeHandler(websocketPath, subprotocols,
110 allowExtensions, maxFramePayloadLength, checkStartsWith));
111 }
112 }
113
114 @Override
115 protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
116 if (frame instanceof CloseWebSocketFrame) {
117 WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
118 if (handshaker != null) {
119 frame.retain();
120 handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame);
121 } else {
122 ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
123 }
124 return;
125 }
126 super.decode(ctx, frame, out);
127 }
128
129 @Override
130 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
131 if (cause instanceof WebSocketHandshakeException) {
132 FullHttpResponse response = new DefaultFullHttpResponse(
133 HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(cause.getMessage().getBytes()));
134 ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
135 } else {
136 ctx.fireExceptionCaught(cause);
137 ctx.close();
138 }
139 }
140
141 static WebSocketServerHandshaker getHandshaker(Channel channel) {
142 return channel.attr(HANDSHAKER_ATTR_KEY).get();
143 }
144
145 static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
146 channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
147 }
148
149 static ChannelHandler forbiddenHttpRequestResponder() {
150 return new ChannelInboundHandlerAdapter() {
151 @Override
152 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
153 if (msg instanceof FullHttpRequest) {
154 ((FullHttpRequest) msg).release();
155 FullHttpResponse response =
156 new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.FORBIDDEN);
157 ctx.channel().writeAndFlush(response);
158 } else {
159 ctx.fireChannelRead(msg);
160 }
161 }
162 };
163 }
164 }