1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package io.netty5.handler.codec.http.websocketx;
55
56 import io.netty5.buffer.api.Buffer;
57 import io.netty5.channel.ChannelFutureListeners;
58 import io.netty5.channel.ChannelHandlerContext;
59 import io.netty5.handler.codec.ByteToMessageDecoder;
60 import io.netty5.handler.codec.TooLongFrameException;
61 import io.netty5.util.internal.logging.InternalLogger;
62 import io.netty5.util.internal.logging.InternalLoggerFactory;
63
64 import java.util.Objects;
65
66
67
68
69 public class WebSocket13FrameDecoder extends ByteToMessageDecoder implements WebSocketFrameDecoder {
70
71 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket13FrameDecoder.class);
72 private static final byte OPCODE_CONT = 0x0;
73 private static final byte OPCODE_TEXT = 0x1;
74 private static final byte OPCODE_BINARY = 0x2;
75 private static final byte OPCODE_CLOSE = 0x8;
76 private static final byte OPCODE_PING = 0x9;
77 private static final byte OPCODE_PONG = 0xA;
78 private final WebSocketDecoderConfig config;
79 private int fragmentedFramesCount;
80 private boolean frameFinalFlag;
81 private boolean frameMasked;
82 private int frameRsv;
83 private int frameOpcode;
84 private long framePayloadLength;
85 private byte[] maskingKey;
86 private int framePayloadLen1;
87 private boolean receivedClosingHandshake;
88 private State state = State.READING_FIRST;
89
90
91
92
93
94
95
96
97
98
99
100
101
102 public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
103 this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
104 }
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121 public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
122 boolean allowMaskMismatch) {
123 this(WebSocketDecoderConfig.newBuilder()
124 .expectMaskedFrames(expectMaskedFrames)
125 .allowExtensions(allowExtensions)
126 .maxFramePayloadLength(maxFramePayloadLength)
127 .allowMaskMismatch(allowMaskMismatch)
128 .build());
129 }
130
131
132
133
134
135
136
137 public WebSocket13FrameDecoder(WebSocketDecoderConfig decoderConfig) {
138 config = Objects.requireNonNull(decoderConfig, "decoderConfig");
139 }
140
141 private static int toFrameLength(long length) {
142 if (length > Integer.MAX_VALUE) {
143 throw new TooLongFrameException("Length:" + length);
144 } else {
145 return (int) length;
146 }
147 }
148
149 @Override
150 protected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
151
152 if (receivedClosingHandshake) {
153 in.skipReadableBytes(actualReadableBytes());
154 return;
155 }
156
157 switch (state) {
158 case READING_FIRST: {
159 if (in.readableBytes() == 0) {
160 return;
161 }
162
163 framePayloadLength = 0;
164
165
166 byte b = in.readByte();
167 frameFinalFlag = (b & 0x80) != 0;
168 frameRsv = (b & 0x70) >> 4;
169 frameOpcode = b & 0x0F;
170
171 if (logger.isTraceEnabled()) {
172 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
173 }
174
175 state = State.READING_SECOND;
176 }
177 case READING_SECOND: {
178 if (in.readableBytes() == 0) {
179 return;
180 }
181
182 byte b = in.readByte();
183 frameMasked = (b & 0x80) != 0;
184 framePayloadLen1 = b & 0x7F;
185
186 if (frameRsv != 0 && !config.allowExtensions()) {
187 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
188 return;
189 }
190
191 if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
192 protocolViolation(ctx, in, "received a frame that is not masked as expected");
193 return;
194 }
195
196 if (frameOpcode > 7) {
197
198
199 if (!frameFinalFlag) {
200 protocolViolation(ctx, in, "fragmented control frame");
201 return;
202 }
203
204
205 if (framePayloadLen1 > 125) {
206 protocolViolation(ctx, in, "control frame with payload length > 125 octets");
207 return;
208 }
209
210
211 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
212 || frameOpcode == OPCODE_PONG)) {
213 protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
214 return;
215 }
216
217
218
219
220 if (frameOpcode == 8 && framePayloadLen1 == 1) {
221 protocolViolation(ctx, in, "received close control frame with payload len 1");
222 return;
223 }
224 } else {
225
226 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
227 || frameOpcode == OPCODE_BINARY)) {
228 protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
229 return;
230 }
231
232
233 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
234 protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
235 return;
236 }
237
238
239 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
240 protocolViolation(ctx, in,
241 "received non-continuation data frame while inside fragmented message");
242 return;
243 }
244 }
245
246 state = State.READING_SIZE;
247 }
248 case READING_SIZE: {
249
250 if (framePayloadLen1 == 126) {
251 if (in.readableBytes() < 2) {
252 return;
253 }
254 framePayloadLength = in.readUnsignedShort();
255 if (framePayloadLength < 126) {
256 protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
257 return;
258 }
259 } else if (framePayloadLen1 == 127) {
260 if (in.readableBytes() < 8) {
261 return;
262 }
263 framePayloadLength = in.readLong();
264
265
266
267 if (framePayloadLength < 65536) {
268 protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
269 return;
270 }
271 } else {
272 framePayloadLength = framePayloadLen1;
273 }
274
275 if (framePayloadLength > config.maxFramePayloadLength()) {
276 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
277 "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
278 return;
279 }
280
281 if (logger.isTraceEnabled()) {
282 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
283 }
284
285 state = State.MASKING_KEY;
286 }
287 case MASKING_KEY: {
288 if (frameMasked) {
289 if (in.readableBytes() < 4) {
290 return;
291 }
292 if (maskingKey == null) {
293 maskingKey = new byte[4];
294 }
295 in.readBytes(maskingKey, 0, maskingKey.length);
296 }
297 state = State.PAYLOAD;
298 }
299 case PAYLOAD: {
300 if (in.readableBytes() < framePayloadLength) {
301 return;
302 }
303
304 Buffer payloadBuffer = null;
305 try {
306 payloadBuffer = in.readSplit(toFrameLength(framePayloadLength));
307
308
309
310 state = State.READING_FIRST;
311
312
313 if (frameMasked) {
314 unmask(payloadBuffer);
315 }
316
317
318
319 if (frameOpcode == OPCODE_PING) {
320 WebSocketFrame frame = new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
321 payloadBuffer = null;
322 ctx.fireChannelRead(frame);
323 return;
324 }
325 if (frameOpcode == OPCODE_PONG) {
326 WebSocketFrame frame = new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
327 payloadBuffer = null;
328 ctx.fireChannelRead(frame);
329 return;
330 }
331 if (frameOpcode == OPCODE_CLOSE) {
332 receivedClosingHandshake = true;
333 checkCloseFrameBody(ctx, payloadBuffer);
334 WebSocketFrame frame = new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
335 payloadBuffer = null;
336 ctx.fireChannelRead(frame);
337 return;
338 }
339
340
341
342 if (frameFinalFlag) {
343
344
345 fragmentedFramesCount = 0;
346 } else {
347
348 fragmentedFramesCount++;
349 }
350
351
352 if (frameOpcode == OPCODE_TEXT) {
353 WebSocketFrame frame = new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
354 payloadBuffer = null;
355 ctx.fireChannelRead(frame);
356 return;
357 } else if (frameOpcode == OPCODE_BINARY) {
358 WebSocketFrame frame = new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
359 payloadBuffer = null;
360 ctx.fireChannelRead(frame);
361 return;
362 } else if (frameOpcode == OPCODE_CONT) {
363 WebSocketFrame frame = new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
364 payloadBuffer = null;
365 ctx.fireChannelRead(frame);
366 return;
367 } else {
368 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
369 + frameOpcode);
370 }
371 } finally {
372 if (payloadBuffer != null) {
373 payloadBuffer.close();
374 }
375 }
376 }
377 case CORRUPT: {
378 if (in.readableBytes() > 0) {
379
380
381 in.readByte();
382 }
383 return;
384 }
385 default:
386 throw new Error("Shouldn't reach here.");
387 }
388 }
389
390 private void unmask(Buffer frame) {
391 int base = frame.readerOffset();
392 int len = frame.readableBytes();
393 int index = 0;
394
395
396
397 int intMask = (maskingKey[0] & 0xFF) << 24
398 | (maskingKey[1] & 0xFF) << 16
399 | (maskingKey[2] & 0xFF) << 8
400 | maskingKey[3] & 0xFF;
401
402 for (; index + 3 < len; index += Integer.BYTES) {
403 int off = base + index;
404 frame.setInt(off, frame.getInt(off) ^ intMask);
405 }
406 for (; index < len; index++) {
407 int off = base + index;
408 frame.setByte(off, (byte) (frame.getByte(off) ^ maskingKey[index % 4]));
409 }
410 }
411
412 private void protocolViolation(ChannelHandlerContext ctx, Buffer in, String reason) {
413 protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
414 }
415
416 private void protocolViolation(ChannelHandlerContext ctx, Buffer in, WebSocketCloseStatus status, String reason) {
417 protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
418 }
419
420 private void protocolViolation(ChannelHandlerContext ctx, Buffer in, CorruptedWebSocketFrameException ex) {
421 state = State.CORRUPT;
422 int readableBytes = in.readableBytes();
423 if (readableBytes > 0) {
424
425
426 in.skipReadableBytes(readableBytes);
427 }
428 if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
429 Object closeMessage;
430 if (receivedClosingHandshake) {
431 closeMessage = ctx.bufferAllocator().allocate(0);
432 } else {
433 WebSocketCloseStatus closeStatus = ex.closeStatus();
434 String reasonText = ex.getMessage();
435 if (reasonText == null) {
436 reasonText = closeStatus.reasonText();
437 }
438 closeMessage = new CloseWebSocketFrame(ctx.bufferAllocator(), closeStatus, reasonText);
439 }
440 ctx.writeAndFlush(closeMessage).addListener(ctx, ChannelFutureListeners.CLOSE);
441 }
442 throw ex;
443 }
444
445
446 protected void checkCloseFrameBody(ChannelHandlerContext ctx, Buffer buffer) {
447 if (buffer == null || buffer.readableBytes() <= 0) {
448 return;
449 }
450 if (buffer.readableBytes() == 1) {
451 protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
452 }
453
454
455 int offset = buffer.readerOffset();
456 try {
457
458 int statusCode = buffer.readShort();
459 if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
460 protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
461 }
462
463
464 if (buffer.readableBytes() > 0) {
465 try {
466 new Utf8Validator().check(buffer);
467 } catch (CorruptedWebSocketFrameException ex) {
468 protocolViolation(ctx, buffer, ex);
469 }
470 }
471 } finally {
472
473 buffer.readerOffset(offset);
474 }
475 }
476
477 enum State {
478 READING_FIRST,
479 READING_SECOND,
480 READING_SIZE,
481 MASKING_KEY,
482 PAYLOAD,
483 CORRUPT
484 }
485 }