1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  package io.netty.handler.codec.http.websocketx;
55  
56  import io.netty.buffer.ByteBuf;
57  import io.netty.buffer.Unpooled;
58  import io.netty.channel.ChannelFutureListener;
59  import io.netty.channel.ChannelHandlerContext;
60  import io.netty.handler.codec.ByteToMessageDecoder;
61  import io.netty.handler.codec.TooLongFrameException;
62  import io.netty.util.internal.ObjectUtil;
63  import io.netty.util.internal.logging.InternalLogger;
64  import io.netty.util.internal.logging.InternalLoggerFactory;
65  
66  import java.nio.ByteOrder;
67  import java.util.List;
68  
69  import static io.netty.buffer.ByteBufUtil.readBytes;
70  
71  
72  
73  
74  
75  public class WebSocket08FrameDecoder extends ByteToMessageDecoder
76          implements WebSocketFrameDecoder {
77  
78      enum State {
79          READING_FIRST,
80          READING_SECOND,
81          READING_SIZE,
82          MASKING_KEY,
83          PAYLOAD,
84          CORRUPT
85      }
86  
87      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
88  
89      private static final byte OPCODE_CONT = 0x0;
90      private static final byte OPCODE_TEXT = 0x1;
91      private static final byte OPCODE_BINARY = 0x2;
92      private static final byte OPCODE_CLOSE = 0x8;
93      private static final byte OPCODE_PING = 0x9;
94      private static final byte OPCODE_PONG = 0xA;
95  
96      private final WebSocketDecoderConfig config;
97  
98      private int fragmentedFramesCount;
99      private boolean frameFinalFlag;
100     private boolean frameMasked;
101     private int frameRsv;
102     private int frameOpcode;
103     private long framePayloadLength;
104     private int mask;
105     private int framePayloadLen1;
106     private boolean receivedClosingHandshake;
107     private State state = State.READING_FIRST;
108 
109     
110 
111 
112 
113 
114 
115 
116 
117 
118 
119 
120 
121     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
122         this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
123     }
124 
125     
126 
127 
128 
129 
130 
131 
132 
133 
134 
135 
136 
137 
138 
139 
140     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
141                                    boolean allowMaskMismatch) {
142         this(WebSocketDecoderConfig.newBuilder()
143             .expectMaskedFrames(expectMaskedFrames)
144             .allowExtensions(allowExtensions)
145             .maxFramePayloadLength(maxFramePayloadLength)
146             .allowMaskMismatch(allowMaskMismatch)
147             .build());
148     }
149 
150     
151 
152 
153 
154 
155 
156     public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) {
157         this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
158     }
159 
160     @Override
161     protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
162         
163         if (receivedClosingHandshake) {
164             in.skipBytes(actualReadableBytes());
165             return;
166         }
167 
168         switch (state) {
169         case READING_FIRST:
170             if (!in.isReadable()) {
171                 return;
172             }
173 
174             framePayloadLength = 0;
175 
176             
177             byte b = in.readByte();
178             frameFinalFlag = (b & 0x80) != 0;
179             frameRsv = (b & 0x70) >> 4;
180             frameOpcode = b & 0x0F;
181 
182             if (logger.isTraceEnabled()) {
183                 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
184             }
185 
186             state = State.READING_SECOND;
187         case READING_SECOND:
188             if (!in.isReadable()) {
189                 return;
190             }
191             
192             b = in.readByte();
193             frameMasked = (b & 0x80) != 0;
194             framePayloadLen1 = b & 0x7F;
195 
196             if (frameRsv != 0 && !config.allowExtensions()) {
197                 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
198                 return;
199             }
200 
201             if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
202                 protocolViolation(ctx, in, "received a frame that is not masked as expected");
203                 return;
204             }
205 
206             if (frameOpcode > 7) { 
207 
208                 
209                 if (!frameFinalFlag) {
210                     protocolViolation(ctx, in, "fragmented control frame");
211                     return;
212                 }
213 
214                 
215                 if (framePayloadLen1 > 125) {
216                     protocolViolation(ctx, in, "control frame with payload length > 125 octets");
217                     return;
218                 }
219 
220                 
221                 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
222                       || frameOpcode == OPCODE_PONG)) {
223                     protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
224                     return;
225                 }
226 
227                 
228                 
229                 
230                 if (frameOpcode == 8 && framePayloadLen1 == 1) {
231                     protocolViolation(ctx, in, "received close control frame with payload len 1");
232                     return;
233                 }
234             } else { 
235                 
236                 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
237                       || frameOpcode == OPCODE_BINARY)) {
238                     protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
239                     return;
240                 }
241 
242                 
243                 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
244                     protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
245                     return;
246                 }
247 
248                 
249                 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
250                     protocolViolation(ctx, in,
251                                       "received non-continuation data frame while inside fragmented message");
252                     return;
253                 }
254             }
255 
256             state = State.READING_SIZE;
257         case READING_SIZE:
258 
259             
260             if (framePayloadLen1 == 126) {
261                 if (in.readableBytes() < 2) {
262                     return;
263                 }
264                 framePayloadLength = in.readUnsignedShort();
265                 if (framePayloadLength < 126) {
266                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
267                     return;
268                 }
269             } else if (framePayloadLen1 == 127) {
270                 if (in.readableBytes() < 8) {
271                     return;
272                 }
273                 framePayloadLength = in.readLong();
274                 if (framePayloadLength < 0) {
275                     protocolViolation(ctx, in, "invalid data frame length (negative length)");
276                     return;
277                 }
278 
279                 if (framePayloadLength < 65536) {
280                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
281                     return;
282                 }
283             } else {
284                 framePayloadLength = framePayloadLen1;
285             }
286 
287             if (framePayloadLength > config.maxFramePayloadLength()) {
288                 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
289                     "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
290                 return;
291             }
292 
293             if (logger.isTraceEnabled()) {
294                 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
295             }
296 
297             state = State.MASKING_KEY;
298         case MASKING_KEY:
299             if (frameMasked) {
300                 if (in.readableBytes() < 4) {
301                     return;
302                 }
303                 mask = in.readInt();
304             }
305             state = State.PAYLOAD;
306         case PAYLOAD:
307             if (in.readableBytes() < framePayloadLength) {
308                 return;
309             }
310 
311             ByteBuf payloadBuffer = Unpooled.EMPTY_BUFFER;
312             try {
313                 if (framePayloadLength > 0) {
314                     payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength));
315                 }
316 
317                 
318                 
319                 state = State.READING_FIRST;
320 
321                 
322                 if (frameMasked & framePayloadLength > 0) {
323                     unmask(payloadBuffer);
324                 }
325 
326                 
327                 
328                 if (frameOpcode == OPCODE_PING) {
329                     out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
330                     payloadBuffer = null;
331                     return;
332                 }
333                 if (frameOpcode == OPCODE_PONG) {
334                     out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
335                     payloadBuffer = null;
336                     return;
337                 }
338                 if (frameOpcode == OPCODE_CLOSE) {
339                     receivedClosingHandshake = true;
340                     checkCloseFrameBody(ctx, payloadBuffer);
341                     out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
342                     payloadBuffer = null;
343                     return;
344                 }
345 
346                 
347                 
348                 if (frameFinalFlag) {
349                     
350                     
351                     fragmentedFramesCount = 0;
352                 } else {
353                     
354                     fragmentedFramesCount++;
355                 }
356 
357                 
358                 if (frameOpcode == OPCODE_TEXT) {
359                     out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
360                     payloadBuffer = null;
361                     return;
362                 } else if (frameOpcode == OPCODE_BINARY) {
363                     out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
364                     payloadBuffer = null;
365                     return;
366                 } else if (frameOpcode == OPCODE_CONT) {
367                     out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv,
368                                                            payloadBuffer));
369                     payloadBuffer = null;
370                     return;
371                 } else {
372                     throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
373                                                             + frameOpcode);
374                 }
375             } finally {
376                 if (payloadBuffer != null) {
377                     payloadBuffer.release();
378                 }
379             }
380         case CORRUPT:
381             if (in.isReadable()) {
382                 
383                 
384                 in.readByte();
385             }
386             return;
387         default:
388             throw new Error("Shouldn't reach here (state: " + state + ")");
389         }
390     }
391 
392     private void unmask(ByteBuf frame) {
393         int i = frame.readerIndex();
394         int end = frame.writerIndex();
395 
396         ByteOrder order = frame.order();
397 
398         int intMask = mask;
399         if (intMask == 0) {
400             
401             return;
402         }
403         
404         long longMask = intMask & 0xFFFFFFFFL;
405         longMask |= longMask << 32;
406 
407         for (int lim = end - 7; i < lim; i += 8) {
408             frame.setLong(i, frame.getLong(i) ^ longMask);
409         }
410 
411         if (i < end - 3) {
412             frame.setInt(i, frame.getInt(i) ^ (int) longMask);
413             i += 4;
414         }
415 
416         if (order == ByteOrder.LITTLE_ENDIAN) {
417             intMask = Integer.reverseBytes(intMask);
418         }
419 
420         int maskOffset = 0;
421         for (; i < end; i++) {
422             frame.setByte(i, frame.getByte(i) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3));
423         }
424     }
425 
426     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) {
427         protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
428     }
429 
430     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) {
431         protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
432     }
433 
434     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) {
435         state = State.CORRUPT;
436         int readableBytes = in.readableBytes();
437         if (readableBytes > 0) {
438             
439             
440             in.skipBytes(readableBytes);
441         }
442         if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
443             Object closeMessage;
444             if (receivedClosingHandshake) {
445                 closeMessage = Unpooled.EMPTY_BUFFER;
446             } else {
447                 WebSocketCloseStatus closeStatus = ex.closeStatus();
448                 String reasonText = ex.getMessage();
449                 if (reasonText == null) {
450                     reasonText = closeStatus.reasonText();
451                 }
452                 closeMessage = new CloseWebSocketFrame(closeStatus, reasonText);
453             }
454             ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
455         }
456         throw ex;
457     }
458 
459     private static int toFrameLength(long l) {
460         if (l > Integer.MAX_VALUE) {
461             throw new TooLongFrameException("frame length exceeds " + Integer.MAX_VALUE + ": " + l);
462         } else {
463             return (int) l;
464         }
465     }
466 
467     
468     protected void checkCloseFrameBody(
469             ChannelHandlerContext ctx, ByteBuf buffer) {
470         if (buffer == null || !buffer.isReadable()) {
471             return;
472         }
473         if (buffer.readableBytes() < 2) {
474             protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
475         }
476 
477         
478         int statusCode = buffer.getShort(buffer.readerIndex());
479         if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
480             protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
481         }
482 
483         
484         if (buffer.readableBytes() > 2) {
485             try {
486                 new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2);
487             } catch (CorruptedWebSocketFrameException ex) {
488                 protocolViolation(ctx, buffer, ex);
489             }
490         }
491     }
492 }