1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package io.netty.handler.codec.http.websocketx;
55
56 import io.netty.buffer.ByteBuf;
57 import io.netty.buffer.Unpooled;
58 import io.netty.channel.ChannelFutureListener;
59 import io.netty.channel.ChannelHandlerContext;
60 import io.netty.handler.codec.ByteToMessageDecoder;
61 import io.netty.handler.codec.TooLongFrameException;
62 import io.netty.util.internal.ObjectUtil;
63 import io.netty.util.internal.logging.InternalLogger;
64 import io.netty.util.internal.logging.InternalLoggerFactory;
65
66 import java.nio.ByteOrder;
67 import java.util.List;
68
69 import static io.netty.buffer.ByteBufUtil.readBytes;
70
71
72
73
74
75 public class WebSocket08FrameDecoder extends ByteToMessageDecoder
76 implements WebSocketFrameDecoder {
77
78 enum State {
79 READING_FIRST,
80 READING_SECOND,
81 READING_SIZE,
82 MASKING_KEY,
83 PAYLOAD,
84 CORRUPT
85 }
86
87 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
88
89 private static final byte OPCODE_CONT = 0x0;
90 private static final byte OPCODE_TEXT = 0x1;
91 private static final byte OPCODE_BINARY = 0x2;
92 private static final byte OPCODE_CLOSE = 0x8;
93 private static final byte OPCODE_PING = 0x9;
94 private static final byte OPCODE_PONG = 0xA;
95
96 private final WebSocketDecoderConfig config;
97
98 private int fragmentedFramesCount;
99 private boolean frameFinalFlag;
100 private boolean frameMasked;
101 private int frameRsv;
102 private int frameOpcode;
103 private long framePayloadLength;
104 private byte[] maskingKey;
105 private int framePayloadLen1;
106 private boolean receivedClosingHandshake;
107 private State state = State.READING_FIRST;
108
109
110
111
112
113
114
115
116
117
118
119
120
121 public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
122 this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
123 }
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140 public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
141 boolean allowMaskMismatch) {
142 this(WebSocketDecoderConfig.newBuilder()
143 .expectMaskedFrames(expectMaskedFrames)
144 .allowExtensions(allowExtensions)
145 .maxFramePayloadLength(maxFramePayloadLength)
146 .allowMaskMismatch(allowMaskMismatch)
147 .build());
148 }
149
150
151
152
153
154
155
156 public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) {
157 this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
158 }
159
160 @Override
161 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
162
163 if (receivedClosingHandshake) {
164 in.skipBytes(actualReadableBytes());
165 return;
166 }
167
168 switch (state) {
169 case READING_FIRST:
170 if (!in.isReadable()) {
171 return;
172 }
173
174 framePayloadLength = 0;
175
176
177 byte b = in.readByte();
178 frameFinalFlag = (b & 0x80) != 0;
179 frameRsv = (b & 0x70) >> 4;
180 frameOpcode = b & 0x0F;
181
182 if (logger.isTraceEnabled()) {
183 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
184 }
185
186 state = State.READING_SECOND;
187 case READING_SECOND:
188 if (!in.isReadable()) {
189 return;
190 }
191
192 b = in.readByte();
193 frameMasked = (b & 0x80) != 0;
194 framePayloadLen1 = b & 0x7F;
195
196 if (frameRsv != 0 && !config.allowExtensions()) {
197 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
198 return;
199 }
200
201 if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
202 protocolViolation(ctx, in, "received a frame that is not masked as expected");
203 return;
204 }
205
206 if (frameOpcode > 7) {
207
208
209 if (!frameFinalFlag) {
210 protocolViolation(ctx, in, "fragmented control frame");
211 return;
212 }
213
214
215 if (framePayloadLen1 > 125) {
216 protocolViolation(ctx, in, "control frame with payload length > 125 octets");
217 return;
218 }
219
220
221 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
222 || frameOpcode == OPCODE_PONG)) {
223 protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
224 return;
225 }
226
227
228
229
230 if (frameOpcode == 8 && framePayloadLen1 == 1) {
231 protocolViolation(ctx, in, "received close control frame with payload len 1");
232 return;
233 }
234 } else {
235
236 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
237 || frameOpcode == OPCODE_BINARY)) {
238 protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
239 return;
240 }
241
242
243 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
244 protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
245 return;
246 }
247
248
249 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
250 protocolViolation(ctx, in,
251 "received non-continuation data frame while inside fragmented message");
252 return;
253 }
254 }
255
256 state = State.READING_SIZE;
257 case READING_SIZE:
258
259
260 if (framePayloadLen1 == 126) {
261 if (in.readableBytes() < 2) {
262 return;
263 }
264 framePayloadLength = in.readUnsignedShort();
265 if (framePayloadLength < 126) {
266 protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
267 return;
268 }
269 } else if (framePayloadLen1 == 127) {
270 if (in.readableBytes() < 8) {
271 return;
272 }
273 framePayloadLength = in.readLong();
274
275
276
277 if (framePayloadLength < 65536) {
278 protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
279 return;
280 }
281 } else {
282 framePayloadLength = framePayloadLen1;
283 }
284
285 if (framePayloadLength > config.maxFramePayloadLength()) {
286 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
287 "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
288 return;
289 }
290
291 if (logger.isTraceEnabled()) {
292 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
293 }
294
295 state = State.MASKING_KEY;
296 case MASKING_KEY:
297 if (frameMasked) {
298 if (in.readableBytes() < 4) {
299 return;
300 }
301 if (maskingKey == null) {
302 maskingKey = new byte[4];
303 }
304 in.readBytes(maskingKey);
305 }
306 state = State.PAYLOAD;
307 case PAYLOAD:
308 if (in.readableBytes() < framePayloadLength) {
309 return;
310 }
311
312 ByteBuf payloadBuffer = null;
313 try {
314 payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength));
315
316
317
318 state = State.READING_FIRST;
319
320
321 if (frameMasked) {
322 unmask(payloadBuffer);
323 }
324
325
326
327 if (frameOpcode == OPCODE_PING) {
328 out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
329 payloadBuffer = null;
330 return;
331 }
332 if (frameOpcode == OPCODE_PONG) {
333 out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
334 payloadBuffer = null;
335 return;
336 }
337 if (frameOpcode == OPCODE_CLOSE) {
338 receivedClosingHandshake = true;
339 checkCloseFrameBody(ctx, payloadBuffer);
340 out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
341 payloadBuffer = null;
342 return;
343 }
344
345
346
347 if (frameFinalFlag) {
348
349
350 fragmentedFramesCount = 0;
351 } else {
352
353 fragmentedFramesCount++;
354 }
355
356
357 if (frameOpcode == OPCODE_TEXT) {
358 out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
359 payloadBuffer = null;
360 return;
361 } else if (frameOpcode == OPCODE_BINARY) {
362 out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
363 payloadBuffer = null;
364 return;
365 } else if (frameOpcode == OPCODE_CONT) {
366 out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv,
367 payloadBuffer));
368 payloadBuffer = null;
369 return;
370 } else {
371 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
372 + frameOpcode);
373 }
374 } finally {
375 if (payloadBuffer != null) {
376 payloadBuffer.release();
377 }
378 }
379 case CORRUPT:
380 if (in.isReadable()) {
381
382
383 in.readByte();
384 }
385 return;
386 default:
387 throw new Error("Shouldn't reach here.");
388 }
389 }
390
391 private void unmask(ByteBuf frame) {
392 int i = frame.readerIndex();
393 int end = frame.writerIndex();
394
395 ByteOrder order = frame.order();
396
397
398
399 int intMask = ((maskingKey[0] & 0xFF) << 24)
400 | ((maskingKey[1] & 0xFF) << 16)
401 | ((maskingKey[2] & 0xFF) << 8)
402 | (maskingKey[3] & 0xFF);
403
404
405
406 if (order == ByteOrder.LITTLE_ENDIAN) {
407 intMask = Integer.reverseBytes(intMask);
408 }
409
410 for (; i + 3 < end; i += 4) {
411 int unmasked = frame.getInt(i) ^ intMask;
412 frame.setInt(i, unmasked);
413 }
414 for (; i < end; i++) {
415 frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]);
416 }
417 }
418
419 private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) {
420 protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
421 }
422
423 private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) {
424 protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
425 }
426
427 private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) {
428 state = State.CORRUPT;
429 int readableBytes = in.readableBytes();
430 if (readableBytes > 0) {
431
432
433 in.skipBytes(readableBytes);
434 }
435 if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
436 Object closeMessage;
437 if (receivedClosingHandshake) {
438 closeMessage = Unpooled.EMPTY_BUFFER;
439 } else {
440 WebSocketCloseStatus closeStatus = ex.closeStatus();
441 String reasonText = ex.getMessage();
442 if (reasonText == null) {
443 reasonText = closeStatus.reasonText();
444 }
445 closeMessage = new CloseWebSocketFrame(closeStatus, reasonText);
446 }
447 ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
448 }
449 throw ex;
450 }
451
452 private static int toFrameLength(long l) {
453 if (l > Integer.MAX_VALUE) {
454 throw new TooLongFrameException("Length:" + l);
455 } else {
456 return (int) l;
457 }
458 }
459
460
461 protected void checkCloseFrameBody(
462 ChannelHandlerContext ctx, ByteBuf buffer) {
463 if (buffer == null || !buffer.isReadable()) {
464 return;
465 }
466 if (buffer.readableBytes() < 2) {
467 protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
468 }
469
470
471 int statusCode = buffer.getShort(buffer.readerIndex());
472 if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
473 protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
474 }
475
476
477 if (buffer.readableBytes() > 2) {
478 try {
479 new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2);
480 } catch (CorruptedWebSocketFrameException ex) {
481 protocolViolation(ctx, buffer, ex);
482 }
483 }
484 }
485 }