1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package io.netty.handler.codec.http.websocketx;
55
56 import io.netty.buffer.ByteBuf;
57 import io.netty.buffer.Unpooled;
58 import io.netty.channel.ChannelFutureListener;
59 import io.netty.channel.ChannelHandlerContext;
60 import io.netty.handler.codec.CorruptedFrameException;
61 import io.netty.handler.codec.ReplayingDecoder;
62 import io.netty.handler.codec.TooLongFrameException;
63 import io.netty.util.internal.logging.InternalLogger;
64 import io.netty.util.internal.logging.InternalLoggerFactory;
65
66 import java.nio.ByteOrder;
67 import java.util.List;
68
69
70
71
72
73 public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State>
74 implements WebSocketFrameDecoder {
75
76 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
77
78 private static final byte OPCODE_CONT = 0x0;
79 private static final byte OPCODE_TEXT = 0x1;
80 private static final byte OPCODE_BINARY = 0x2;
81 private static final byte OPCODE_CLOSE = 0x8;
82 private static final byte OPCODE_PING = 0x9;
83 private static final byte OPCODE_PONG = 0xA;
84
85 private int fragmentedFramesCount;
86 private final long maxFramePayloadLength;
87 private boolean frameFinalFlag;
88 private int frameRsv;
89 private int frameOpcode;
90 private long framePayloadLength;
91 private ByteBuf framePayload;
92 private int framePayloadBytesRead;
93 private byte[] maskingKey;
94 private ByteBuf payloadBuffer;
95 private final boolean allowExtensions;
96 private final boolean maskedPayload;
97 private boolean receivedClosingHandshake;
98 private Utf8Validator utf8Validator;
99
100 enum State {
101 FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
102 }
103
104
105
106
107
108
109
110
111
112
113
114
115
116 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, int maxFramePayloadLength) {
117 super(State.FRAME_START);
118 this.maskedPayload = maskedPayload;
119 this.allowExtensions = allowExtensions;
120 this.maxFramePayloadLength = maxFramePayloadLength;
121 }
122
123 @Override
124 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
125
126
127 if (receivedClosingHandshake) {
128 in.skipBytes(actualReadableBytes());
129 return;
130 }
131
132 try {
133 switch (state()) {
134 case FRAME_START:
135 framePayloadBytesRead = 0;
136 framePayloadLength = -1;
137 framePayload = null;
138 payloadBuffer = null;
139
140
141 byte b = in.readByte();
142 frameFinalFlag = (b & 0x80) != 0;
143 frameRsv = (b & 0x70) >> 4;
144 frameOpcode = b & 0x0F;
145
146 if (logger.isDebugEnabled()) {
147 logger.debug("Decoding WebSocket Frame opCode={}", frameOpcode);
148 }
149
150
151 b = in.readByte();
152 boolean frameMasked = (b & 0x80) != 0;
153 int framePayloadLen1 = b & 0x7F;
154
155 if (frameRsv != 0 && !allowExtensions) {
156 protocolViolation(ctx, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
157 return;
158 }
159
160 if (maskedPayload && !frameMasked) {
161 protocolViolation(ctx, "unmasked client to server frame");
162 return;
163 }
164 if (frameOpcode > 7) {
165
166
167 if (!frameFinalFlag) {
168 protocolViolation(ctx, "fragmented control frame");
169 return;
170 }
171
172
173 if (framePayloadLen1 > 125) {
174 protocolViolation(ctx, "control frame with payload length > 125 octets");
175 return;
176 }
177
178
179 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
180 || frameOpcode == OPCODE_PONG)) {
181 protocolViolation(ctx, "control frame using reserved opcode " + frameOpcode);
182 return;
183 }
184
185
186
187
188 if (frameOpcode == 8 && framePayloadLen1 == 1) {
189 protocolViolation(ctx, "received close control frame with payload len 1");
190 return;
191 }
192 } else {
193
194 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
195 || frameOpcode == OPCODE_BINARY)) {
196 protocolViolation(ctx, "data frame using reserved opcode " + frameOpcode);
197 return;
198 }
199
200
201 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
202 protocolViolation(ctx, "received continuation data frame outside fragmented message");
203 return;
204 }
205
206
207 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
208 protocolViolation(ctx,
209 "received non-continuation data frame while inside fragmented message");
210 return;
211 }
212 }
213
214
215 if (framePayloadLen1 == 126) {
216 framePayloadLength = in.readUnsignedShort();
217 if (framePayloadLength < 126) {
218 protocolViolation(ctx, "invalid data frame length (not using minimal length encoding)");
219 return;
220 }
221 } else if (framePayloadLen1 == 127) {
222 framePayloadLength = in.readLong();
223
224
225
226 if (framePayloadLength < 65536) {
227 protocolViolation(ctx, "invalid data frame length (not using minimal length encoding)");
228 return;
229 }
230 } else {
231 framePayloadLength = framePayloadLen1;
232 }
233
234 if (framePayloadLength > maxFramePayloadLength) {
235 protocolViolation(ctx, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
236 return;
237 }
238
239 if (logger.isDebugEnabled()) {
240 logger.debug("Decoding WebSocket Frame length={}", framePayloadLength);
241 }
242
243 checkpoint(State.MASKING_KEY);
244 case MASKING_KEY:
245 if (maskedPayload) {
246 if (maskingKey == null) {
247 maskingKey = new byte[4];
248 }
249 in.readBytes(maskingKey);
250 }
251 checkpoint(State.PAYLOAD);
252 case PAYLOAD:
253
254
255 int rbytes = actualReadableBytes();
256
257 long willHaveReadByteCount = framePayloadBytesRead + rbytes;
258
259
260
261 if (willHaveReadByteCount == framePayloadLength) {
262
263 payloadBuffer = ctx.alloc().buffer(rbytes);
264 payloadBuffer.writeBytes(in, rbytes);
265 } else if (willHaveReadByteCount < framePayloadLength) {
266
267
268
269 if (framePayload == null) {
270 framePayload = ctx.alloc().buffer(toFrameLength(framePayloadLength));
271 }
272 framePayload.writeBytes(in, rbytes);
273 framePayloadBytesRead += rbytes;
274
275
276 return;
277 } else if (willHaveReadByteCount > framePayloadLength) {
278
279
280 if (framePayload == null) {
281 framePayload = ctx.alloc().buffer(toFrameLength(framePayloadLength));
282 }
283 framePayload.writeBytes(in, toFrameLength(framePayloadLength - framePayloadBytesRead));
284 }
285
286
287
288 checkpoint(State.FRAME_START);
289
290
291 if (framePayload == null) {
292 framePayload = payloadBuffer;
293 payloadBuffer = null;
294 } else if (payloadBuffer != null) {
295 framePayload.writeBytes(payloadBuffer);
296 payloadBuffer.release();
297 payloadBuffer = null;
298 }
299
300
301 if (maskedPayload) {
302 unmask(framePayload);
303 }
304
305
306
307 if (frameOpcode == OPCODE_PING) {
308 out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
309 framePayload = null;
310 return;
311 }
312 if (frameOpcode == OPCODE_PONG) {
313 out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
314 framePayload = null;
315 return;
316 }
317 if (frameOpcode == OPCODE_CLOSE) {
318 receivedClosingHandshake = true;
319 checkCloseFrameBody(ctx, framePayload);
320 out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
321 framePayload = null;
322 return;
323 }
324
325
326
327 if (frameFinalFlag) {
328
329
330 if (frameOpcode != OPCODE_PING) {
331 fragmentedFramesCount = 0;
332
333
334 if (frameOpcode == OPCODE_TEXT ||
335 utf8Validator != null && utf8Validator.isChecking()) {
336
337 checkUTF8String(ctx, framePayload);
338
339
340
341 utf8Validator.finish();
342 }
343 }
344 } else {
345
346
347 if (fragmentedFramesCount == 0) {
348
349 if (frameOpcode == OPCODE_TEXT) {
350 checkUTF8String(ctx, framePayload);
351 }
352 } else {
353
354 if (utf8Validator != null && utf8Validator.isChecking()) {
355 checkUTF8String(ctx, framePayload);
356 }
357 }
358
359
360 fragmentedFramesCount++;
361 }
362
363
364 if (frameOpcode == OPCODE_TEXT) {
365 out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
366 framePayload = null;
367 return;
368 } else if (frameOpcode == OPCODE_BINARY) {
369 out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
370 framePayload = null;
371 return;
372 } else if (frameOpcode == OPCODE_CONT) {
373 out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
374 framePayload = null;
375 return;
376 } else {
377 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
378 + frameOpcode);
379 }
380 case CORRUPT:
381
382
383 in.readByte();
384 return;
385 default:
386 throw new Error("Shouldn't reach here.");
387 }
388 } catch (Exception e) {
389 if (payloadBuffer != null) {
390 if (payloadBuffer.refCnt() > 0) {
391 payloadBuffer.release();
392 }
393 payloadBuffer = null;
394 }
395 if (framePayload != null) {
396 if (framePayload.refCnt() > 0) {
397 framePayload.release();
398 }
399 framePayload = null;
400 }
401 throw e;
402 }
403 }
404
405 private void unmask(ByteBuf frame) {
406 int i = frame.readerIndex();
407 int end = frame.writerIndex();
408
409 ByteOrder order = frame.order();
410
411
412
413 int intMask = ((maskingKey[0] & 0xFF) << 24)
414 | ((maskingKey[1] & 0xFF) << 16)
415 | ((maskingKey[2] & 0xFF) << 8)
416 | (maskingKey[3] & 0xFF);
417
418
419
420 if (order == ByteOrder.LITTLE_ENDIAN) {
421 intMask = Integer.reverseBytes(intMask);
422 }
423
424 for (; i + 3 < end; i += 4) {
425 int unmasked = frame.getInt(i) ^ intMask;
426 frame.setInt(i, unmasked);
427 }
428 for (; i < end; i++) {
429 frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]);
430 }
431 }
432
433 private void protocolViolation(ChannelHandlerContext ctx, String reason) {
434 protocolViolation(ctx, new CorruptedFrameException(reason));
435 }
436
437 private void protocolViolation(ChannelHandlerContext ctx, CorruptedFrameException ex) {
438 checkpoint(State.CORRUPT);
439 if (ctx.channel().isActive()) {
440 Object closeMessage;
441 if (receivedClosingHandshake) {
442 closeMessage = Unpooled.EMPTY_BUFFER;
443 } else {
444 closeMessage = new CloseWebSocketFrame(1002, null);
445 }
446 ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
447 }
448 throw ex;
449 }
450
451 private static int toFrameLength(long l) {
452 if (l > Integer.MAX_VALUE) {
453 throw new TooLongFrameException("Length:" + l);
454 } else {
455 return (int) l;
456 }
457 }
458
459 private void checkUTF8String(ChannelHandlerContext ctx, ByteBuf buffer) {
460 try {
461 if (utf8Validator == null) {
462 utf8Validator = new Utf8Validator();
463 }
464 utf8Validator.check(buffer);
465 } catch (CorruptedFrameException ex) {
466 protocolViolation(ctx, ex);
467 }
468 }
469
470
471 protected void checkCloseFrameBody(
472 ChannelHandlerContext ctx, ByteBuf buffer) {
473 if (buffer == null || !buffer.isReadable()) {
474 return;
475 }
476 if (buffer.readableBytes() == 1) {
477 protocolViolation(ctx, "Invalid close frame body");
478 }
479
480
481 int idx = buffer.readerIndex();
482 buffer.readerIndex(0);
483
484
485 int statusCode = buffer.readShort();
486 if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
487 || statusCode >= 1012 && statusCode <= 2999) {
488 protocolViolation(ctx, "Invalid close frame getStatus code: " + statusCode);
489 }
490
491
492 if (buffer.isReadable()) {
493 try {
494 new Utf8Validator().check(buffer);
495 } catch (CorruptedFrameException ex) {
496 protocolViolation(ctx, ex);
497 }
498 }
499
500
501 buffer.readerIndex(idx);
502 }
503
504 @Override
505 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
506 super.channelInactive(ctx);
507
508
509
510 if (framePayload != null) {
511 framePayload.release();
512 }
513 if (payloadBuffer != null) {
514 payloadBuffer.release();
515 }
516 }
517 }