View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  // (BSD License: http://www.opensource.org/licenses/bsd-license)
17  //
18  // Copyright (c) 2011, Joe Walnes and contributors
19  // All rights reserved.
20  //
21  // Redistribution and use in source and binary forms, with or
22  // without modification, are permitted provided that the
23  // following conditions are met:
24  //
25  // * Redistributions of source code must retain the above
26  // copyright notice, this list of conditions and the
27  // following disclaimer.
28  //
29  // * Redistributions in binary form must reproduce the above
30  // copyright notice, this list of conditions and the
31  // following disclaimer in the documentation and/or other
32  // materials provided with the distribution.
33  //
34  // * Neither the name of the Webbit nor the names of
35  // its contributors may be used to endorse or promote products
36  // derived from this software without specific prior written
37  // permission.
38  //
39  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
40  // CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
41  // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
42  // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
44  // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
45  // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
46  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
47  // GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48  // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
49  // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
51  // OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
52  // POSSIBILITY OF SUCH DAMAGE.
53  
54  package io.netty.handler.codec.http.websocketx;
55  
56  import io.netty.buffer.ByteBuf;
57  import io.netty.buffer.Unpooled;
58  import io.netty.channel.ChannelFutureListener;
59  import io.netty.channel.ChannelHandlerContext;
60  import io.netty.handler.codec.CorruptedFrameException;
61  import io.netty.handler.codec.ReplayingDecoder;
62  import io.netty.handler.codec.TooLongFrameException;
63  import io.netty.util.internal.logging.InternalLogger;
64  import io.netty.util.internal.logging.InternalLoggerFactory;
65  
66  import java.nio.ByteOrder;
67  import java.util.List;
68  
69  /**
70   * Decodes a web socket frame from wire protocol version 8 format. This code was forked from <a
71   * href="https://github.com/joewalnes/webbit">webbit</a> and modified.
72   */
73  public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State>
74          implements WebSocketFrameDecoder {
75  
76      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
77  
78      private static final byte OPCODE_CONT = 0x0;
79      private static final byte OPCODE_TEXT = 0x1;
80      private static final byte OPCODE_BINARY = 0x2;
81      private static final byte OPCODE_CLOSE = 0x8;
82      private static final byte OPCODE_PING = 0x9;
83      private static final byte OPCODE_PONG = 0xA;
84  
85      private int fragmentedFramesCount;
86      private final long maxFramePayloadLength;
87      private boolean frameFinalFlag;
88      private int frameRsv;
89      private int frameOpcode;
90      private long framePayloadLength;
91      private ByteBuf framePayload;
92      private int framePayloadBytesRead;
93      private byte[] maskingKey;
94      private ByteBuf payloadBuffer;
95      private final boolean allowExtensions;
96      private final boolean maskedPayload;
97      private boolean receivedClosingHandshake;
98      private Utf8Validator utf8Validator;
99  
100     enum State {
101         FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
102     }
103 
104     /**
105      * Constructor
106      *
107      * @param maskedPayload
108      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
109      *            must set this to false.
110      * @param allowExtensions
111      *            Flag to allow reserved extension bits to be used or not
112      * @param maxFramePayloadLength
113      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
114      *            helps check for denial of services attacks.
115      */
116     public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, int maxFramePayloadLength) {
117         super(State.FRAME_START);
118         this.maskedPayload = maskedPayload;
119         this.allowExtensions = allowExtensions;
120         this.maxFramePayloadLength = maxFramePayloadLength;
121     }
122 
123     @Override
124     protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
125 
126         // Discard all data received if closing handshake was received before.
127         if (receivedClosingHandshake) {
128             in.skipBytes(actualReadableBytes());
129             return;
130         }
131 
132         try {
133             switch (state()) {
134                 case FRAME_START:
135                     framePayloadBytesRead = 0;
136                     framePayloadLength = -1;
137                     framePayload = null;
138                     payloadBuffer = null;
139 
140                     // FIN, RSV, OPCODE
141                     byte b = in.readByte();
142                     frameFinalFlag = (b & 0x80) != 0;
143                     frameRsv = (b & 0x70) >> 4;
144                     frameOpcode = b & 0x0F;
145 
146                     if (logger.isDebugEnabled()) {
147                         logger.debug("Decoding WebSocket Frame opCode={}", frameOpcode);
148                     }
149 
150                     // MASK, PAYLOAD LEN 1
151                     b = in.readByte();
152                     boolean frameMasked = (b & 0x80) != 0;
153                     int framePayloadLen1 = b & 0x7F;
154 
155                     if (frameRsv != 0 && !allowExtensions) {
156                         protocolViolation(ctx, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
157                         return;
158                     }
159 
160                     if (maskedPayload && !frameMasked) {
161                         protocolViolation(ctx, "unmasked client to server frame");
162                         return;
163                     }
164                     if (frameOpcode > 7) { // control frame (have MSB in opcode set)
165 
166                         // control frames MUST NOT be fragmented
167                         if (!frameFinalFlag) {
168                             protocolViolation(ctx, "fragmented control frame");
169                             return;
170                         }
171 
172                         // control frames MUST have payload 125 octets or less
173                         if (framePayloadLen1 > 125) {
174                             protocolViolation(ctx, "control frame with payload length > 125 octets");
175                             return;
176                         }
177 
178                         // check for reserved control frame opcodes
179                         if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
180                                 || frameOpcode == OPCODE_PONG)) {
181                             protocolViolation(ctx, "control frame using reserved opcode " + frameOpcode);
182                             return;
183                         }
184 
185                         // close frame : if there is a body, the first two bytes of the
186                         // body MUST be a 2-byte unsigned integer (in network byte
187                         // order) representing a getStatus code
188                         if (frameOpcode == 8 && framePayloadLen1 == 1) {
189                             protocolViolation(ctx, "received close control frame with payload len 1");
190                             return;
191                         }
192                     } else { // data frame
193                         // check for reserved data frame opcodes
194                         if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
195                                 || frameOpcode == OPCODE_BINARY)) {
196                             protocolViolation(ctx, "data frame using reserved opcode " + frameOpcode);
197                             return;
198                         }
199 
200                         // check opcode vs message fragmentation state 1/2
201                         if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
202                             protocolViolation(ctx, "received continuation data frame outside fragmented message");
203                             return;
204                         }
205 
206                         // check opcode vs message fragmentation state 2/2
207                         if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
208                             protocolViolation(ctx,
209                                     "received non-continuation data frame while inside fragmented message");
210                             return;
211                         }
212                     }
213 
214                     // Read frame payload length
215                     if (framePayloadLen1 == 126) {
216                         framePayloadLength = in.readUnsignedShort();
217                         if (framePayloadLength < 126) {
218                             protocolViolation(ctx, "invalid data frame length (not using minimal length encoding)");
219                             return;
220                         }
221                     } else if (framePayloadLen1 == 127) {
222                         framePayloadLength = in.readLong();
223                         // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
224                         // just check if it's negative?
225 
226                         if (framePayloadLength < 65536) {
227                             protocolViolation(ctx, "invalid data frame length (not using minimal length encoding)");
228                             return;
229                         }
230                     } else {
231                         framePayloadLength = framePayloadLen1;
232                     }
233 
234                     if (framePayloadLength > maxFramePayloadLength) {
235                         protocolViolation(ctx, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
236                         return;
237                     }
238 
239                     if (logger.isDebugEnabled()) {
240                         logger.debug("Decoding WebSocket Frame length={}", framePayloadLength);
241                     }
242 
243                     checkpoint(State.MASKING_KEY);
244                 case MASKING_KEY:
245                     if (maskedPayload) {
246                         if (maskingKey == null) {
247                             maskingKey = new byte[4];
248                         }
249                         in.readBytes(maskingKey);
250                     }
251                     checkpoint(State.PAYLOAD);
252                 case PAYLOAD:
253                     // Sometimes, the payload may not be delivered in 1 nice packet
254                     // We need to accumulate the data until we have it all
255                     int rbytes = actualReadableBytes();
256 
257                     long willHaveReadByteCount = framePayloadBytesRead + rbytes;
258                     // logger.debug("Frame rbytes=" + rbytes + " willHaveReadByteCount="
259                     // + willHaveReadByteCount + " framePayloadLength=" +
260                     // framePayloadLength);
261                     if (willHaveReadByteCount == framePayloadLength) {
262                         // We have all our content so proceed to process
263                         payloadBuffer = ctx.alloc().buffer(rbytes);
264                         payloadBuffer.writeBytes(in, rbytes);
265                     } else if (willHaveReadByteCount < framePayloadLength) {
266 
267                         // We don't have all our content so accumulate payload.
268                         // Returning null means we will get called back
269                         if (framePayload == null) {
270                             framePayload = ctx.alloc().buffer(toFrameLength(framePayloadLength));
271                         }
272                         framePayload.writeBytes(in, rbytes);
273                         framePayloadBytesRead += rbytes;
274 
275                         // Return null to wait for more bytes to arrive
276                         return;
277                     } else if (willHaveReadByteCount > framePayloadLength) {
278                         // We have more than what we need so read up to the end of frame
279                         // Leave the remainder in the buffer for next frame
280                         if (framePayload == null) {
281                             framePayload = ctx.alloc().buffer(toFrameLength(framePayloadLength));
282                         }
283                         framePayload.writeBytes(in, toFrameLength(framePayloadLength - framePayloadBytesRead));
284                     }
285 
286                     // Now we have all the data, the next checkpoint must be the next
287                     // frame
288                     checkpoint(State.FRAME_START);
289 
290                     // Take the data that we have in this packet
291                     if (framePayload == null) {
292                         framePayload = payloadBuffer;
293                         payloadBuffer = null;
294                     } else if (payloadBuffer != null) {
295                         framePayload.writeBytes(payloadBuffer);
296                         payloadBuffer.release();
297                         payloadBuffer = null;
298                     }
299 
300                     // Unmask data if needed
301                     if (maskedPayload) {
302                         unmask(framePayload);
303                     }
304 
305                     // Processing ping/pong/close frames because they cannot be
306                     // fragmented
307                     if (frameOpcode == OPCODE_PING) {
308                         out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
309                         framePayload = null;
310                         return;
311                     }
312                     if (frameOpcode == OPCODE_PONG) {
313                         out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
314                         framePayload = null;
315                         return;
316                     }
317                     if (frameOpcode == OPCODE_CLOSE) {
318                         receivedClosingHandshake = true;
319                         checkCloseFrameBody(ctx, framePayload);
320                         out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
321                         framePayload = null;
322                         return;
323                     }
324 
325                     // Processing for possible fragmented messages for text and binary
326                     // frames
327                     if (frameFinalFlag) {
328                         // Final frame of the sequence. Apparently ping frames are
329                         // allowed in the middle of a fragmented message
330                         if (frameOpcode != OPCODE_PING) {
331                             fragmentedFramesCount = 0;
332 
333                             // Check text for UTF8 correctness
334                             if (frameOpcode == OPCODE_TEXT ||
335                                     utf8Validator != null && utf8Validator.isChecking()) {
336                                 // Check UTF-8 correctness for this payload
337                                 checkUTF8String(ctx, framePayload);
338 
339                                 // This does a second check to make sure UTF-8
340                                 // correctness for entire text message
341                                 utf8Validator.finish();
342                             }
343                         }
344                     } else {
345                         // Not final frame so we can expect more frames in the
346                         // fragmented sequence
347                         if (fragmentedFramesCount == 0) {
348                             // First text or binary frame for a fragmented set
349                             if (frameOpcode == OPCODE_TEXT) {
350                                 checkUTF8String(ctx, framePayload);
351                             }
352                         } else {
353                             // Subsequent frames - only check if init frame is text
354                             if (utf8Validator != null && utf8Validator.isChecking()) {
355                                 checkUTF8String(ctx, framePayload);
356                             }
357                         }
358 
359                         // Increment counter
360                         fragmentedFramesCount++;
361                     }
362 
363                     // Return the frame
364                     if (frameOpcode == OPCODE_TEXT) {
365                         out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
366                         framePayload = null;
367                         return;
368                     } else if (frameOpcode == OPCODE_BINARY) {
369                         out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
370                         framePayload = null;
371                         return;
372                     } else if (frameOpcode == OPCODE_CONT) {
373                         out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
374                         framePayload = null;
375                         return;
376                     } else {
377                         throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
378                                 + frameOpcode);
379                     }
380                 case CORRUPT:
381                     // If we don't keep reading Netty will throw an exception saying
382                     // we can't return null if no bytes read and state not changed.
383                     in.readByte();
384                     return;
385                 default:
386                     throw new Error("Shouldn't reach here.");
387             }
388         } catch (Exception e) {
389             if (payloadBuffer != null) {
390                 if (payloadBuffer.refCnt() > 0) {
391                     payloadBuffer.release();
392                 }
393                 payloadBuffer = null;
394             }
395             if (framePayload != null) {
396                 if (framePayload.refCnt() > 0) {
397                     framePayload.release();
398                 }
399                 framePayload = null;
400             }
401             throw e;
402         }
403     }
404 
405     private void unmask(ByteBuf frame) {
406         int i = frame.readerIndex();
407         int end = frame.writerIndex();
408 
409         ByteOrder order = frame.order();
410 
411         // Remark: & 0xFF is necessary because Java will do signed expansion from
412         // byte to int which we don't want.
413         int intMask = ((maskingKey[0] & 0xFF) << 24)
414                     | ((maskingKey[1] & 0xFF) << 16)
415                     | ((maskingKey[2] & 0xFF) << 8)
416                     | (maskingKey[3] & 0xFF);
417 
418         // If the byte order of our buffers it little endian we have to bring our mask
419         // into the same format, because getInt() and writeInt() will use a reversed byte order
420         if (order == ByteOrder.LITTLE_ENDIAN) {
421             intMask = Integer.reverseBytes(intMask);
422         }
423 
424         for (; i + 3 < end; i += 4) {
425             int unmasked = frame.getInt(i) ^ intMask;
426             frame.setInt(i, unmasked);
427         }
428         for (; i < end; i++) {
429             frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]);
430         }
431     }
432 
433     private void protocolViolation(ChannelHandlerContext ctx, String reason) {
434         protocolViolation(ctx, new CorruptedFrameException(reason));
435     }
436 
437     private void protocolViolation(ChannelHandlerContext ctx, CorruptedFrameException ex) {
438         checkpoint(State.CORRUPT);
439         if (ctx.channel().isActive()) {
440             Object closeMessage;
441             if (receivedClosingHandshake) {
442                 closeMessage = Unpooled.EMPTY_BUFFER;
443             } else {
444                 closeMessage = new CloseWebSocketFrame(1002, null);
445             }
446             ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
447         }
448         throw ex;
449     }
450 
451     private static int toFrameLength(long l) {
452         if (l > Integer.MAX_VALUE) {
453             throw new TooLongFrameException("Length:" + l);
454         } else {
455             return (int) l;
456         }
457     }
458 
459     private void checkUTF8String(ChannelHandlerContext ctx, ByteBuf buffer) {
460         try {
461             if (utf8Validator == null) {
462                 utf8Validator = new Utf8Validator();
463             }
464             utf8Validator.check(buffer);
465         } catch (CorruptedFrameException ex) {
466             protocolViolation(ctx, ex);
467         }
468     }
469 
470     /** */
471     protected void checkCloseFrameBody(
472             ChannelHandlerContext ctx, ByteBuf buffer) {
473         if (buffer == null || !buffer.isReadable()) {
474             return;
475         }
476         if (buffer.readableBytes() == 1) {
477             protocolViolation(ctx, "Invalid close frame body");
478         }
479 
480         // Save reader index
481         int idx = buffer.readerIndex();
482         buffer.readerIndex(0);
483 
484         // Must have 2 byte integer within the valid range
485         int statusCode = buffer.readShort();
486         if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
487                 || statusCode >= 1012 && statusCode <= 2999) {
488             protocolViolation(ctx, "Invalid close frame getStatus code: " + statusCode);
489         }
490 
491         // May have UTF-8 message
492         if (buffer.isReadable()) {
493             try {
494                 new Utf8Validator().check(buffer);
495             } catch (CorruptedFrameException ex) {
496                 protocolViolation(ctx, ex);
497             }
498         }
499 
500         // Restore reader index
501         buffer.readerIndex(idx);
502     }
503 
504     @Override
505     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
506         super.channelInactive(ctx);
507 
508         // release all not complete frames data to prevent leaks.
509         // https://github.com/netty/netty/issues/1874
510         if (framePayload != null) {
511             framePayload.release();
512         }
513         if (payloadBuffer != null) {
514             payloadBuffer.release();
515         }
516     }
517 }