1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.http.websocketx;
17
18 import io.netty5.channel.Channel;
19 import io.netty5.channel.ChannelFutureListeners;
20 import io.netty5.channel.ChannelHandler;
21 import io.netty5.channel.ChannelHandlerContext;
22 import io.netty5.channel.ChannelPipeline;
23 import io.netty5.handler.codec.http.DefaultFullHttpResponse;
24 import io.netty5.handler.codec.http.FullHttpResponse;
25 import io.netty5.handler.codec.http.HttpResponseStatus;
26 import io.netty5.util.AttributeKey;
27 import io.netty5.util.concurrent.Promise;
28
29 import java.util.Objects;
30
31 import static io.netty5.handler.codec.http.HttpVersion.HTTP_1_1;
32 import static io.netty5.handler.codec.http.websocketx.WebSocketServerProtocolConfig.DEFAULT_HANDSHAKE_TIMEOUT_MILLIS;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51 public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
52
53 private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY =
54 AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
55
56 private final WebSocketServerProtocolConfig serverConfig;
57
58
59
60
61
62
63
64 public WebSocketServerProtocolHandler(WebSocketServerProtocolConfig serverConfig) {
65 super(Objects.requireNonNull(serverConfig, "serverConfig").dropPongFrames(),
66 serverConfig.sendCloseFrame(),
67 serverConfig.forceCloseTimeoutMillis()
68 );
69 this.serverConfig = serverConfig;
70 }
71
72 public WebSocketServerProtocolHandler(String websocketPath) {
73 this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
74 }
75
76 public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) {
77 this(websocketPath, false, handshakeTimeoutMillis);
78 }
79
80 public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) {
81 this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
82 }
83
84 public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) {
85 this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis);
86 }
87
88 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) {
89 this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
90 }
91
92 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) {
93 this(websocketPath, subprotocols, false, handshakeTimeoutMillis);
94 }
95
96 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) {
97 this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
98 }
99
100 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
101 long handshakeTimeoutMillis) {
102 this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis);
103 }
104
105 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
106 boolean allowExtensions, int maxFrameSize) {
107 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
108 }
109
110 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
111 boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) {
112 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis);
113 }
114
115 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
116 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) {
117 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch,
118 DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
119 }
120
121 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
122 int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) {
123 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false,
124 handshakeTimeoutMillis);
125 }
126
127 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
128 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) {
129 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
130 DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
131 }
132
133 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
134 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
135 boolean checkStartsWith, long handshakeTimeoutMillis) {
136 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true,
137 handshakeTimeoutMillis);
138 }
139
140 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols,
141 boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch,
142 boolean checkStartsWith, boolean dropPongFrames) {
143 this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith,
144 dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS);
145 }
146
147 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions,
148 int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith,
149 boolean dropPongFrames, long handshakeTimeoutMillis) {
150 this(websocketPath, subprotocols, checkStartsWith, dropPongFrames, handshakeTimeoutMillis,
151 WebSocketDecoderConfig.newBuilder()
152 .maxFramePayloadLength(maxFrameSize)
153 .allowMaskMismatch(allowMaskMismatch)
154 .allowExtensions(allowExtensions)
155 .build());
156 }
157
158 public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean checkStartsWith,
159 boolean dropPongFrames, long handshakeTimeoutMillis,
160 WebSocketDecoderConfig decoderConfig) {
161 this(WebSocketServerProtocolConfig.newBuilder()
162 .websocketPath(websocketPath)
163 .subprotocols(subprotocols)
164 .checkStartsWith(checkStartsWith)
165 .handshakeTimeoutMillis(handshakeTimeoutMillis)
166 .dropPongFrames(dropPongFrames)
167 .decoderConfig(decoderConfig)
168 .build());
169 }
170
171 @Override
172 public void handlerAdded(ChannelHandlerContext ctx) {
173 ChannelPipeline cp = ctx.pipeline();
174 if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
175
176 cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
177 new WebSocketServerProtocolHandshakeHandler(serverConfig));
178 }
179 if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
180
181 cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
182 new Utf8FrameValidator(serverConfig.decoderConfig().closeOnProtocolViolation()));
183 }
184 }
185
186 @Override
187 protected void decodeAndClose(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
188 if (serverConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
189 WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
190 if (handshaker != null) {
191 Promise<Void> promise = ctx.newPromise();
192 closeSent(promise);
193 handshaker.close(ctx, (CloseWebSocketFrame) frame).cascadeTo(promise);
194 } else {
195 frame.close();
196 ctx.writeAndFlush(ctx.bufferAllocator().allocate(0)).addListener(ctx, ChannelFutureListeners.CLOSE);
197 }
198 return;
199 }
200 super.decodeAndClose(ctx, frame);
201 }
202
203 @Override
204 protected WebSocketServerHandshakeException buildHandshakeException(String message) {
205 return new WebSocketServerHandshakeException(message);
206 }
207
208 @Override
209 public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
210 if (cause instanceof WebSocketHandshakeException) {
211 final byte[] bytes = cause.getMessage().getBytes();
212 FullHttpResponse response = new DefaultFullHttpResponse(
213 HTTP_1_1, HttpResponseStatus.BAD_REQUEST,
214 ctx.bufferAllocator().allocate(bytes.length).writeBytes(bytes));
215 ctx.channel().writeAndFlush(response).addListener(ctx, ChannelFutureListeners.CLOSE);
216 } else {
217 ctx.fireChannelExceptionCaught(cause);
218 ctx.close();
219 }
220 }
221
222 static WebSocketServerHandshaker getHandshaker(Channel channel) {
223 return channel.attr(HANDSHAKER_ATTR_KEY).get();
224 }
225
226 static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
227 channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
228 }
229 }