View Javadoc

1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  // (BSD License: http://www.opensource.org/licenses/bsd-license)
17  //
18  // Copyright (c) 2011, Joe Walnes and contributors
19  // All rights reserved.
20  //
21  // Redistribution and use in source and binary forms, with or
22  // without modification, are permitted provided that the
23  // following conditions are met:
24  //
25  // * Redistributions of source code must retain the above
26  // copyright notice, this list of conditions and the
27  // following disclaimer.
28  //
29  // * Redistributions in binary form must reproduce the above
30  // copyright notice, this list of conditions and the
31  // following disclaimer in the documentation and/or other
32  // materials provided with the distribution.
33  //
34  // * Neither the name of the Webbit nor the names of
35  // its contributors may be used to endorse or promote products
36  // derived from this software without specific prior written
37  // permission.
38  //
39  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
40  // CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
41  // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
42  // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
44  // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
45  // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
46  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
47  // GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48  // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
49  // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
51  // OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
52  // POSSIBILITY OF SUCH DAMAGE.
53  
54  package org.jboss.netty.handler.codec.http.websocketx;
55  
56  import org.jboss.netty.buffer.ChannelBuffer;
57  import org.jboss.netty.buffer.ChannelBuffers;
58  import org.jboss.netty.channel.Channel;
59  import org.jboss.netty.channel.ChannelFutureListener;
60  import org.jboss.netty.channel.ChannelHandlerContext;
61  import org.jboss.netty.handler.codec.frame.CorruptedFrameException;
62  import org.jboss.netty.handler.codec.frame.TooLongFrameException;
63  import org.jboss.netty.handler.codec.replay.ReplayingDecoder;
64  import org.jboss.netty.logging.InternalLogger;
65  import org.jboss.netty.logging.InternalLoggerFactory;
66  
67  /**
68   * Decodes a web socket frame from wire protocol version 8 format. This code was forked from <a
69   * href="https://github.com/joewalnes/webbit">webbit</a> and modified.
70   */
71  public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State> {
72  
73      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
74  
75      private static final byte OPCODE_CONT = 0x0;
76      private static final byte OPCODE_TEXT = 0x1;
77      private static final byte OPCODE_BINARY = 0x2;
78      private static final byte OPCODE_CLOSE = 0x8;
79      private static final byte OPCODE_PING = 0x9;
80      private static final byte OPCODE_PONG = 0xA;
81  
82      private Utf8Validator utf8Validator;
83      private int fragmentedFramesCount;
84  
85      private final long maxFramePayloadLength;
86      private boolean frameFinalFlag;
87      private int frameRsv;
88      private int frameOpcode;
89      private long framePayloadLength;
90      private ChannelBuffer framePayload;
91      private int framePayloadBytesRead;
92      private ChannelBuffer maskingKey;
93  
94      private final boolean allowExtensions;
95      private final boolean maskedPayload;
96      private boolean receivedClosingHandshake;
97  
98      public enum State {
99          FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
100     }
101 
102     /**
103      * Constructor with default values
104      *
105      * @param maskedPayload
106      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
107      *            must set this to false.
108      * @param allowExtensions
109      *            Flag to allow reserved extension bits to be used or not
110      */
111     public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions) {
112         this(maskedPayload, allowExtensions, Long.MAX_VALUE);
113     }
114 
115     /**
116      * Constructor
117      *
118      * @param maskedPayload
119      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
120      *            must set this to false.
121      * @param allowExtensions
122      *            Flag to allow reserved extension bits to be used or not
123      * @param maxFramePayloadLength
124      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
125      *            helps check for denial of services attacks.
126      */
127     public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
128         super(State.FRAME_START);
129         this.maskedPayload = maskedPayload;
130         this.allowExtensions = allowExtensions;
131         this.maxFramePayloadLength = maxFramePayloadLength;
132     }
133 
134     @Override
135     protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, State state)
136             throws Exception {
137 
138         // Discard all data received if closing handshake was received before.
139         if (receivedClosingHandshake) {
140             buffer.skipBytes(actualReadableBytes());
141             return null;
142         }
143 
144         switch (state) {
145         case FRAME_START:
146             framePayloadBytesRead = 0;
147             framePayloadLength = -1;
148             framePayload = null;
149 
150             // FIN, RSV, OPCODE
151             byte b = buffer.readByte();
152             frameFinalFlag = (b & 0x80) != 0;
153             frameRsv = (b & 0x70) >> 4;
154             frameOpcode = b & 0x0F;
155 
156             if (logger.isDebugEnabled()) {
157                 logger.debug("Decoding WebSocket Frame opCode=" + frameOpcode);
158             }
159 
160             // MASK, PAYLOAD LEN 1
161             b = buffer.readByte();
162             boolean frameMasked = (b & 0x80) != 0;
163             int framePayloadLen1 = b & 0x7F;
164 
165             if (frameRsv != 0 && !allowExtensions) {
166                 protocolViolation(channel, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
167                 return null;
168             }
169 
170             if (maskedPayload && !frameMasked) {
171                 protocolViolation(channel, "unmasked client to server frame");
172                 return null;
173             }
174             if (frameOpcode > 7) { // control frame (have MSB in opcode set)
175 
176                 // control frames MUST NOT be fragmented
177                 if (!frameFinalFlag) {
178                     protocolViolation(channel, "fragmented control frame");
179                     return null;
180                 }
181 
182                 // control frames MUST have payload 125 octets or less
183                 if (framePayloadLen1 > 125) {
184                     protocolViolation(channel, "control frame with payload length > 125 octets");
185                     return null;
186                 }
187 
188                 // check for reserved control frame opcodes
189                 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
190                     protocolViolation(channel, "control frame using reserved opcode " + frameOpcode);
191                     return null;
192                 }
193 
194                 // close frame : if there is a body, the first two bytes of the
195                 // body MUST be a 2-byte unsigned integer (in network byte
196                 // order) representing a status code
197                 if (frameOpcode == 8 && framePayloadLen1 == 1) {
198                     protocolViolation(channel, "received close control frame with payload len 1");
199                     return null;
200                 }
201             } else { // data frame
202                 // check for reserved data frame opcodes
203                 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
204                     protocolViolation(channel, "data frame using reserved opcode " + frameOpcode);
205                     return null;
206                 }
207 
208                 // check opcode vs message fragmentation state 1/2
209                 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
210                     protocolViolation(channel, "received continuation data frame outside fragmented message");
211                     return null;
212                 }
213 
214                 // check opcode vs message fragmentation state 2/2
215                 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
216                     protocolViolation(channel, "received non-continuation data frame while inside fragmented message");
217                     return null;
218                 }
219             }
220 
221             // Read frame payload length
222             if (framePayloadLen1 == 126) {
223                 framePayloadLength = buffer.readUnsignedShort();
224                 if (framePayloadLength < 126) {
225                     protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
226                     return null;
227                 }
228             } else if (framePayloadLen1 == 127) {
229                 framePayloadLength = buffer.readLong();
230                 // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
231                 // just check if it's negative?
232 
233                 if (framePayloadLength < 65536) {
234                     protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
235                     return null;
236                 }
237             } else {
238                 framePayloadLength = framePayloadLen1;
239             }
240 
241             if (framePayloadLength > maxFramePayloadLength) {
242                 protocolViolation(channel, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
243                 return null;
244             }
245             if (logger.isDebugEnabled()) {
246                 logger.debug("Decoding WebSocket Frame length=" + framePayloadLength);
247             }
248 
249             checkpoint(State.MASKING_KEY);
250         case MASKING_KEY:
251             if (maskedPayload) {
252                 maskingKey = buffer.readBytes(4);
253             }
254             checkpoint(State.PAYLOAD);
255         case PAYLOAD:
256             // Sometimes, the payload may not be delivered in 1 nice packet
257             // We need to accumulate the data until we have it all
258             int rbytes = actualReadableBytes();
259             ChannelBuffer payloadBuffer = null;
260 
261             long  willHaveReadByteCount = framePayloadBytesRead + rbytes;
262             // logger.debug("Frame rbytes=" + rbytes + " willHaveReadByteCount="
263             // + willHaveReadByteCount + " framePayloadLength=" +
264             // framePayloadLength);
265             if (willHaveReadByteCount == framePayloadLength) {
266                 // We have all our content so proceed to process
267                 payloadBuffer = buffer.readBytes(rbytes);
268             } else if (willHaveReadByteCount < framePayloadLength) {
269                 // We don't have all our content so accumulate payload.
270                 // Returning null means we will get called back
271                 payloadBuffer = buffer.readBytes(rbytes);
272                 if (framePayload == null) {
273                     framePayload = channel.getConfig().getBufferFactory().getBuffer(toFrameLength(framePayloadLength));
274                 }
275                 framePayload.writeBytes(payloadBuffer);
276                 framePayloadBytesRead += rbytes;
277 
278                 // Return null to wait for more bytes to arrive
279                 return null;
280             } else if (willHaveReadByteCount > framePayloadLength) {
281                 // We have more than what we need so read up to the end of frame
282                 // Leave the remainder in the buffer for next frame
283                 payloadBuffer = buffer.readBytes(toFrameLength(framePayloadLength - framePayloadBytesRead));
284             }
285 
286             // Now we have all the data, the next checkpoint must be the next
287             // frame
288             checkpoint(State.FRAME_START);
289 
290             // Take the data that we have in this packet
291             if (framePayload == null) {
292                 framePayload = payloadBuffer;
293             } else {
294                 framePayload.writeBytes(payloadBuffer);
295             }
296 
297             // Unmask data if needed
298             if (maskedPayload) {
299                 unmask(framePayload);
300             }
301 
302             // Processing ping/pong/close frames because they cannot be
303             // fragmented
304             if (frameOpcode == OPCODE_PING) {
305                 return new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
306             }
307             if (frameOpcode == OPCODE_PONG) {
308                 return new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
309             }
310             if (frameOpcode == OPCODE_CLOSE) {
311                 checkCloseFrameBody(channel, framePayload);
312                 receivedClosingHandshake = true;
313                 return new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
314             }
315 
316             // Processing for possible fragmented messages for text and binary
317             // frames
318             if (frameFinalFlag) {
319                 // Final frame of the sequence. Apparently ping frames are
320                 // allowed in the middle of a fragmented message
321                 if (frameOpcode != OPCODE_PING) {
322                     fragmentedFramesCount = 0;
323 
324                     // Check text for UTF8 correctness
325                     if (frameOpcode == OPCODE_TEXT || (utf8Validator != null && utf8Validator.isChecking())) {
326                         // Check UTF-8 correctness for this payload
327                         checkUTF8String(channel, framePayload.array());
328 
329                         // This does a second check to make sure UTF-8
330                         // correctness for entire text message
331                         utf8Validator.finish();
332                     }
333                 }
334             } else {
335                 // Not final frame so we can expect more frames in the
336                 // fragmented sequence
337                 if (fragmentedFramesCount == 0) {
338                     // First text or binary frame for a fragmented set
339                     if (frameOpcode == OPCODE_TEXT) {
340                         checkUTF8String(channel, framePayload.array());
341                     }
342                 } else {
343                     // Subsequent frames - only check if init frame is text
344                     if (utf8Validator != null && utf8Validator.isChecking()) {
345                         checkUTF8String(channel, framePayload.array());
346                     }
347                 }
348 
349                 // Increment counter
350                 fragmentedFramesCount++;
351             }
352 
353             // Return the frame
354             if (frameOpcode == OPCODE_TEXT) {
355                 return new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
356             } else if (frameOpcode == OPCODE_BINARY) {
357                 return new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
358             } else if (frameOpcode == OPCODE_CONT) {
359                 return new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
360             } else {
361                 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + frameOpcode);
362             }
363         case CORRUPT:
364             // If we don't keep reading Netty will throw an exception saying
365             // we can't return null if no bytes read and state not changed.
366             buffer.readByte();
367             return null;
368         default:
369             throw new Error("Shouldn't reach here.");
370         }
371     }
372 
373     private void unmask(ChannelBuffer frame) {
374         byte[] bytes = frame.array();
375         for (int i = 0; i < bytes.length; i++) {
376             frame.setByte(i, frame.getByte(i) ^ maskingKey.getByte(i % 4));
377         }
378     }
379 
380     private void protocolViolation(Channel channel, String reason) throws CorruptedFrameException {
381         protocolViolation(channel, new CorruptedFrameException(reason));
382     }
383 
384     private void protocolViolation(Channel channel, CorruptedFrameException ex) throws CorruptedFrameException {
385         checkpoint(State.CORRUPT);
386         if (channel.isConnected()) {
387             channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
388         }
389         throw ex;
390     }
391 
392     private static int toFrameLength(long l) throws TooLongFrameException {
393         if (l > Integer.MAX_VALUE) {
394             throw new TooLongFrameException("Length:" + l);
395         } else {
396             return (int) l;
397         }
398     }
399 
400     private void checkUTF8String(Channel channel, byte[] bytes) throws CorruptedFrameException {
401         try {
402             if (utf8Validator == null) {
403                 utf8Validator = new Utf8Validator();
404             }
405             utf8Validator.check(bytes);
406         } catch (CorruptedFrameException ex) {
407             protocolViolation(channel, ex);
408         }
409     }
410 
411     protected void checkCloseFrameBody(Channel channel, ChannelBuffer buffer) throws CorruptedFrameException {
412         if (buffer == null || buffer.capacity() == 0) {
413             return;
414         }
415         if (buffer.capacity() == 1) {
416             protocolViolation(channel, "Invalid close frame body");
417         }
418 
419         // Save reader index
420         int idx = buffer.readerIndex();
421         buffer.readerIndex(0);
422 
423         // Must have 2 byte integer within the valid range
424         int statusCode = buffer.readShort();
425         if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
426                 || statusCode >= 1012 && statusCode <= 2999) {
427             protocolViolation(channel, "Invalid close frame status code: " + statusCode);
428         }
429 
430         // May have UTF-8 message
431         if (buffer.readableBytes() > 0) {
432             byte[] b = new byte[buffer.readableBytes()];
433             buffer.readBytes(b);
434             try {
435                 Utf8Validator validator = new Utf8Validator();
436                 validator.check(b);
437             } catch (CorruptedFrameException ex) {
438                 protocolViolation(channel, ex);
439             }
440         }
441 
442         // Restore reader index
443         buffer.readerIndex(idx);
444     }
445 }