View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  // (BSD License: http://www.opensource.org/licenses/bsd-license)
17  //
18  // Copyright (c) 2011, Joe Walnes and contributors
19  // All rights reserved.
20  //
21  // Redistribution and use in source and binary forms, with or
22  // without modification, are permitted provided that the
23  // following conditions are met:
24  //
25  // * Redistributions of source code must retain the above
26  // copyright notice, this list of conditions and the
27  // following disclaimer.
28  //
29  // * Redistributions in binary form must reproduce the above
30  // copyright notice, this list of conditions and the
31  // following disclaimer in the documentation and/or other
32  // materials provided with the distribution.
33  //
34  // * Neither the name of the Webbit nor the names of
35  // its contributors may be used to endorse or promote products
36  // derived from this software without specific prior written
37  // permission.
38  //
39  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
40  // CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
41  // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
42  // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
44  // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
45  // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
46  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
47  // GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48  // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
49  // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
51  // OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
52  // POSSIBILITY OF SUCH DAMAGE.
53  
54  package org.jboss.netty.handler.codec.http.websocketx;
55  
56  import org.jboss.netty.buffer.ChannelBuffer;
57  import org.jboss.netty.buffer.ChannelBuffers;
58  import org.jboss.netty.channel.Channel;
59  import org.jboss.netty.channel.ChannelFutureListener;
60  import org.jboss.netty.channel.ChannelHandlerContext;
61  import org.jboss.netty.handler.codec.frame.CorruptedFrameException;
62  import org.jboss.netty.handler.codec.frame.TooLongFrameException;
63  import org.jboss.netty.handler.codec.replay.ReplayingDecoder;
64  import org.jboss.netty.logging.InternalLogger;
65  import org.jboss.netty.logging.InternalLoggerFactory;
66  
67  /**
68   * Decodes a web socket frame from wire protocol version 8 format. This code was forked from <a
69   * href="https://github.com/joewalnes/webbit">webbit</a> and modified.
70   */
71  public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State> {
72  
73      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
74  
75      private static final byte OPCODE_CONT = 0x0;
76      private static final byte OPCODE_TEXT = 0x1;
77      private static final byte OPCODE_BINARY = 0x2;
78      private static final byte OPCODE_CLOSE = 0x8;
79      private static final byte OPCODE_PING = 0x9;
80      private static final byte OPCODE_PONG = 0xA;
81  
82      private UTF8Output fragmentedFramesText;
83      private int fragmentedFramesCount;
84  
85      private final long maxFramePayloadLength;
86      private boolean frameFinalFlag;
87      private int frameRsv;
88      private int frameOpcode;
89      private long framePayloadLength;
90      private ChannelBuffer framePayload;
91      private int framePayloadBytesRead;
92      private ChannelBuffer maskingKey;
93  
94      private final boolean allowExtensions;
95      private final boolean maskedPayload;
96      private boolean receivedClosingHandshake;
97  
98      public enum State {
99          FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
100     }
101 
102     /**
103      * Constructor with default values
104      *
105      * @param maskedPayload
106      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
107      *            must set this to false.
108      * @param allowExtensions
109      *            Flag to allow reserved extension bits to be used or not
110      */
111     public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions) {
112         this(maskedPayload, allowExtensions, Long.MAX_VALUE);
113     }
114 
115     /**
116      * Constructor
117      *
118      * @param maskedPayload
119      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
120      *            must set this to false.
121      * @param allowExtensions
122      *            Flag to allow reserved extension bits to be used or not
123      * @param maxFramePayloadLength
124      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
125      *            helps check for denial of services attacks.
126      */
127     public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
128         super(State.FRAME_START);
129         this.maskedPayload = maskedPayload;
130         this.allowExtensions = allowExtensions;
131         this.maxFramePayloadLength = maxFramePayloadLength;
132     }
133 
134     @Override
135     protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, State state)
136             throws Exception {
137 
138         // Discard all data received if closing handshake was received before.
139         if (receivedClosingHandshake) {
140             buffer.skipBytes(actualReadableBytes());
141             return null;
142         }
143 
144         switch (state) {
145         case FRAME_START:
146             framePayloadBytesRead = 0;
147             framePayloadLength = -1;
148             framePayload = null;
149 
150             // FIN, RSV, OPCODE
151             byte b = buffer.readByte();
152             frameFinalFlag = (b & 0x80) != 0;
153             frameRsv = (b & 0x70) >> 4;
154             frameOpcode = b & 0x0F;
155 
156             if (logger.isDebugEnabled()) {
157                 logger.debug("Decoding WebSocket Frame opCode=" + frameOpcode);
158             }
159 
160             // MASK, PAYLOAD LEN 1
161             b = buffer.readByte();
162             boolean frameMasked = (b & 0x80) != 0;
163             int framePayloadLen1 = b & 0x7F;
164 
165             if (frameRsv != 0 && !allowExtensions) {
166                 protocolViolation(channel, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
167                 return null;
168             }
169 
170             if (maskedPayload && !frameMasked) {
171                 protocolViolation(channel, "unmasked client to server frame");
172                 return null;
173             }
174             if (frameOpcode > 7) { // control frame (have MSB in opcode set)
175 
176                 // control frames MUST NOT be fragmented
177                 if (!frameFinalFlag) {
178                     protocolViolation(channel, "fragmented control frame");
179                     return null;
180                 }
181 
182                 // control frames MUST have payload 125 octets or less
183                 if (framePayloadLen1 > 125) {
184                     protocolViolation(channel, "control frame with payload length > 125 octets");
185                     return null;
186                 }
187 
188                 // check for reserved control frame opcodes
189                 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
190                     protocolViolation(channel, "control frame using reserved opcode " + frameOpcode);
191                     return null;
192                 }
193 
194                 // close frame : if there is a body, the first two bytes of the
195                 // body MUST be a 2-byte unsigned integer (in network byte
196                 // order) representing a status code
197                 if (frameOpcode == 8 && framePayloadLen1 == 1) {
198                     protocolViolation(channel, "received close control frame with payload len 1");
199                     return null;
200                 }
201             } else { // data frame
202                 // check for reserved data frame opcodes
203                 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
204                     protocolViolation(channel, "data frame using reserved opcode " + frameOpcode);
205                     return null;
206                 }
207 
208                 // check opcode vs message fragmentation state 1/2
209                 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
210                     protocolViolation(channel, "received continuation data frame outside fragmented message");
211                     return null;
212                 }
213 
214                 // check opcode vs message fragmentation state 2/2
215                 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
216                     protocolViolation(channel, "received non-continuation data frame while inside fragmented message");
217                     return null;
218                 }
219             }
220 
221             // Read frame payload length
222             if (framePayloadLen1 == 126) {
223                 framePayloadLength = buffer.readUnsignedShort();
224                 if (framePayloadLength < 126) {
225                     protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
226                     return null;
227                 }
228             } else if (framePayloadLen1 == 127) {
229                 framePayloadLength = buffer.readLong();
230                 // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
231                 // just check if it's negative?
232 
233                 if (framePayloadLength < 65536) {
234                     protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
235                     return null;
236                 }
237             } else {
238                 framePayloadLength = framePayloadLen1;
239             }
240 
241             if (framePayloadLength > maxFramePayloadLength) {
242                 protocolViolation(channel, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
243                 return null;
244             }
245             if (logger.isDebugEnabled()) {
246                 logger.debug("Decoding WebSocket Frame length=" + framePayloadLength);
247             }
248 
249             checkpoint(State.MASKING_KEY);
250         case MASKING_KEY:
251             if (maskedPayload) {
252                 maskingKey = buffer.readBytes(4);
253             }
254             checkpoint(State.PAYLOAD);
255         case PAYLOAD:
256             // Sometimes, the payload may not be delivered in 1 nice packet
257             // We need to accumulate the data until we have it all
258             int rbytes = actualReadableBytes();
259             ChannelBuffer payloadBuffer = null;
260 
261             long  willHaveReadByteCount = framePayloadBytesRead + rbytes;
262             // logger.debug("Frame rbytes=" + rbytes + " willHaveReadByteCount="
263             // + willHaveReadByteCount + " framePayloadLength=" +
264             // framePayloadLength);
265             if (willHaveReadByteCount == framePayloadLength) {
266                 // We have all our content so proceed to process
267                 payloadBuffer = buffer.readBytes(rbytes);
268             } else if (willHaveReadByteCount < framePayloadLength) {
269                 // We don't have all our content so accumulate payload.
270                 // Returning null means we will get called back
271                 payloadBuffer = buffer.readBytes(rbytes);
272                 if (framePayload == null) {
273                     framePayload = channel.getConfig().getBufferFactory().getBuffer(toFrameLength(framePayloadLength));
274                 }
275                 framePayload.writeBytes(payloadBuffer);
276                 framePayloadBytesRead += rbytes;
277 
278                 // Return null to wait for more bytes to arrive
279                 return null;
280             } else if (willHaveReadByteCount > framePayloadLength) {
281                 // We have more than what we need so read up to the end of frame
282                 // Leave the remainder in the buffer for next frame
283                 payloadBuffer = buffer.readBytes(toFrameLength(framePayloadLength - framePayloadBytesRead));
284             }
285 
286             // Now we have all the data, the next checkpoint must be the next
287             // frame
288             checkpoint(State.FRAME_START);
289 
290             // Take the data that we have in this packet
291             if (framePayload == null) {
292                 framePayload = payloadBuffer;
293             } else {
294                 framePayload.writeBytes(payloadBuffer);
295             }
296 
297             // Unmask data if needed
298             if (maskedPayload) {
299                 unmask(framePayload);
300             }
301 
302             // Processing ping/pong/close frames because they cannot be
303             // fragmented
304             if (frameOpcode == OPCODE_PING) {
305                 return new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
306             } else if (frameOpcode == OPCODE_PONG) {
307                 return new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
308             } else if (frameOpcode == OPCODE_CLOSE) {
309                 checkCloseFrameBody(channel, framePayload);
310                 receivedClosingHandshake = true;
311                 return new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
312             }
313 
314             // Processing for possible fragmented messages for text and binary
315             // frames
316             String aggregatedText = null;
317             if (frameFinalFlag) {
318                 // Final frame of the sequence. Apparently ping frames are
319                 // allowed in the middle of a fragmented message
320                 if (frameOpcode != OPCODE_PING) {
321                     fragmentedFramesCount = 0;
322 
323                     // Check text for UTF8 correctness
324                     if (frameOpcode == OPCODE_TEXT || fragmentedFramesText != null) {
325                         // Check UTF-8 correctness for this payload
326                         checkUTF8String(channel, framePayload.array());
327 
328                         // This does a second check to make sure UTF-8
329                         // correctness for entire text message
330                         aggregatedText = fragmentedFramesText.toString();
331 
332                         fragmentedFramesText = null;
333                     }
334                 }
335             } else {
336                 // Not final frame so we can expect more frames in the
337                 // fragmented sequence
338                 if (fragmentedFramesCount == 0) {
339                     // First text or binary frame for a fragmented set
340                     fragmentedFramesText = null;
341                     if (frameOpcode == OPCODE_TEXT) {
342                         checkUTF8String(channel, framePayload.array());
343                     }
344                 } else {
345                     // Subsequent frames - only check if init frame is text
346                     if (fragmentedFramesText != null) {
347                         checkUTF8String(channel, framePayload.array());
348                     }
349                 }
350 
351                 // Increment counter
352                 fragmentedFramesCount++;
353             }
354 
355             // Return the frame
356             if (frameOpcode == OPCODE_TEXT) {
357                 return new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
358             } else if (frameOpcode == OPCODE_BINARY) {
359                 return new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
360             } else if (frameOpcode == OPCODE_CONT) {
361                 return new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload, aggregatedText);
362             } else {
363                 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + frameOpcode);
364             }
365         case CORRUPT:
366             // If we don't keep reading Netty will throw an exception saying
367             // we can't return null if no bytes read and state not changed.
368             buffer.readByte();
369             return null;
370         default:
371             throw new Error("Shouldn't reach here.");
372         }
373     }
374 
375     private void unmask(ChannelBuffer frame) {
376         byte[] bytes = frame.array();
377         for (int i = 0; i < bytes.length; i++) {
378             frame.setByte(i, frame.getByte(i) ^ maskingKey.getByte(i % 4));
379         }
380     }
381 
382     private void protocolViolation(Channel channel, String reason) throws CorruptedFrameException {
383         checkpoint(State.CORRUPT);
384         if (channel.isConnected()) {
385             channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
386         }
387         throw new CorruptedFrameException(reason);
388     }
389 
390     private static int toFrameLength(long l) throws TooLongFrameException {
391         if (l > Integer.MAX_VALUE) {
392             throw new TooLongFrameException("Length:" + l);
393         } else {
394             return (int) l;
395         }
396     }
397 
398     private void checkUTF8String(Channel channel, byte[] bytes) throws CorruptedFrameException {
399         try {
400             // StringBuilder sb = new StringBuilder("UTF8 " + bytes.length +
401             // " bytes: ");
402             // for (byte b : bytes) {
403             // sb.append(Integer.toHexString(b)).append(" ");
404             // }
405             // logger.debug(sb.toString());
406 
407             if (fragmentedFramesText == null) {
408                 fragmentedFramesText = new UTF8Output(bytes);
409             } else {
410                 fragmentedFramesText.write(bytes);
411             }
412         } catch (UTF8Exception ex) {
413             protocolViolation(channel, "invalid UTF-8 bytes");
414         }
415     }
416 
417     protected void checkCloseFrameBody(Channel channel, ChannelBuffer buffer) throws CorruptedFrameException {
418         if (buffer == null || buffer.capacity() == 0) {
419             return;
420         }
421         if (buffer.capacity() == 1) {
422             protocolViolation(channel, "Invalid close frame body");
423         }
424 
425         // Save reader index
426         int idx = buffer.readerIndex();
427         buffer.readerIndex(0);
428 
429         // Must have 2 byte integer within the valid range
430         int statusCode = buffer.readShort();
431         if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
432                 || statusCode >= 1012 && statusCode <= 2999) {
433             protocolViolation(channel, "Invalid close frame status code: " + statusCode);
434         }
435 
436         // May have UTF-8 message
437         if (buffer.readableBytes() > 0) {
438             byte[] b = new byte[buffer.readableBytes()];
439             buffer.readBytes(b);
440             try {
441                 new UTF8Output(b);
442             } catch (UTF8Exception ex) {
443                 protocolViolation(channel, "Invalid close frame reason text. Invalid UTF-8 bytes");
444             }
445         }
446 
447         // Restore reader index
448         buffer.readerIndex(idx);
449     }
450 }