View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  // (BSD License: https://www.opensource.org/licenses/bsd-license)
17  //
18  // Copyright (c) 2011, Joe Walnes and contributors
19  // All rights reserved.
20  //
21  // Redistribution and use in source and binary forms, with or
22  // without modification, are permitted provided that the
23  // following conditions are met:
24  //
25  // * Redistributions of source code must retain the above
26  // copyright notice, this list of conditions and the
27  // following disclaimer.
28  //
29  // * Redistributions in binary form must reproduce the above
30  // copyright notice, this list of conditions and the
31  // following disclaimer in the documentation and/or other
32  // materials provided with the distribution.
33  //
34  // * Neither the name of the Webbit nor the names of
35  // its contributors may be used to endorse or promote products
36  // derived from this software without specific prior written
37  // permission.
38  //
39  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
40  // CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
41  // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
42  // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
44  // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
45  // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
46  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
47  // GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48  // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
49  // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
51  // OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
52  // POSSIBILITY OF SUCH DAMAGE.
53  
54  package io.netty.handler.codec.http.websocketx;
55  
56  import io.netty.buffer.ByteBuf;
57  import io.netty.buffer.Unpooled;
58  import io.netty.channel.ChannelFutureListener;
59  import io.netty.channel.ChannelHandlerContext;
60  import io.netty.handler.codec.ByteToMessageDecoder;
61  import io.netty.handler.codec.TooLongFrameException;
62  import io.netty.util.internal.ObjectUtil;
63  import io.netty.util.internal.logging.InternalLogger;
64  import io.netty.util.internal.logging.InternalLoggerFactory;
65  
66  import java.nio.ByteOrder;
67  import java.util.List;
68  
69  import static io.netty.buffer.ByteBufUtil.readBytes;
70  
71  /**
72   * Decodes a web socket frame from wire protocol version 8 format. This code was forked from <a
73   * href="https://github.com/joewalnes/webbit">webbit</a> and modified.
74   */
75  public class WebSocket08FrameDecoder extends ByteToMessageDecoder
76          implements WebSocketFrameDecoder {
77  
78      enum State {
79          READING_FIRST,
80          READING_SECOND,
81          READING_SIZE,
82          MASKING_KEY,
83          PAYLOAD,
84          CORRUPT
85      }
86  
87      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
88  
89      private static final byte OPCODE_CONT = 0x0;
90      private static final byte OPCODE_TEXT = 0x1;
91      private static final byte OPCODE_BINARY = 0x2;
92      private static final byte OPCODE_CLOSE = 0x8;
93      private static final byte OPCODE_PING = 0x9;
94      private static final byte OPCODE_PONG = 0xA;
95  
96      private final WebSocketDecoderConfig config;
97  
98      private int fragmentedFramesCount;
99      private boolean frameFinalFlag;
100     private boolean frameMasked;
101     private int frameRsv;
102     private int frameOpcode;
103     private long framePayloadLength;
104     private byte[] maskingKey;
105     private int framePayloadLen1;
106     private boolean receivedClosingHandshake;
107     private State state = State.READING_FIRST;
108 
109     /**
110      * Constructor
111      *
112      * @param expectMaskedFrames
113      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
114      *            must set this to false.
115      * @param allowExtensions
116      *            Flag to allow reserved extension bits to be used or not
117      * @param maxFramePayloadLength
118      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
119      *            helps check for denial of services attacks.
120      */
121     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
122         this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
123     }
124 
125     /**
126      * Constructor
127      *
128      * @param expectMaskedFrames
129      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
130      *            must set this to false.
131      * @param allowExtensions
132      *            Flag to allow reserved extension bits to be used or not
133      * @param maxFramePayloadLength
134      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
135      *            helps check for denial of services attacks.
136      * @param allowMaskMismatch
137      *            When set to true, frames which are not masked properly according to the standard will still be
138      *            accepted.
139      */
140     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
141                                    boolean allowMaskMismatch) {
142         this(WebSocketDecoderConfig.newBuilder()
143             .expectMaskedFrames(expectMaskedFrames)
144             .allowExtensions(allowExtensions)
145             .maxFramePayloadLength(maxFramePayloadLength)
146             .allowMaskMismatch(allowMaskMismatch)
147             .build());
148     }
149 
150     /**
151      * Constructor
152      *
153      * @param decoderConfig
154      *            Frames decoder configuration.
155      */
156     public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) {
157         this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
158     }
159 
160     @Override
161     protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
162         // Discard all data received if closing handshake was received before.
163         if (receivedClosingHandshake) {
164             in.skipBytes(actualReadableBytes());
165             return;
166         }
167 
168         switch (state) {
169         case READING_FIRST:
170             if (!in.isReadable()) {
171                 return;
172             }
173 
174             framePayloadLength = 0;
175 
176             // FIN, RSV, OPCODE
177             byte b = in.readByte();
178             frameFinalFlag = (b & 0x80) != 0;
179             frameRsv = (b & 0x70) >> 4;
180             frameOpcode = b & 0x0F;
181 
182             if (logger.isTraceEnabled()) {
183                 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
184             }
185 
186             state = State.READING_SECOND;
187         case READING_SECOND:
188             if (!in.isReadable()) {
189                 return;
190             }
191             // MASK, PAYLOAD LEN 1
192             b = in.readByte();
193             frameMasked = (b & 0x80) != 0;
194             framePayloadLen1 = b & 0x7F;
195 
196             if (frameRsv != 0 && !config.allowExtensions()) {
197                 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
198                 return;
199             }
200 
201             if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
202                 protocolViolation(ctx, in, "received a frame that is not masked as expected");
203                 return;
204             }
205 
206             if (frameOpcode > 7) { // control frame (have MSB in opcode set)
207 
208                 // control frames MUST NOT be fragmented
209                 if (!frameFinalFlag) {
210                     protocolViolation(ctx, in, "fragmented control frame");
211                     return;
212                 }
213 
214                 // control frames MUST have payload 125 octets or less
215                 if (framePayloadLen1 > 125) {
216                     protocolViolation(ctx, in, "control frame with payload length > 125 octets");
217                     return;
218                 }
219 
220                 // check for reserved control frame opcodes
221                 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
222                       || frameOpcode == OPCODE_PONG)) {
223                     protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
224                     return;
225                 }
226 
227                 // close frame : if there is a body, the first two bytes of the
228                 // body MUST be a 2-byte unsigned integer (in network byte
229                 // order) representing a getStatus code
230                 if (frameOpcode == 8 && framePayloadLen1 == 1) {
231                     protocolViolation(ctx, in, "received close control frame with payload len 1");
232                     return;
233                 }
234             } else { // data frame
235                 // check for reserved data frame opcodes
236                 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
237                       || frameOpcode == OPCODE_BINARY)) {
238                     protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
239                     return;
240                 }
241 
242                 // check opcode vs message fragmentation state 1/2
243                 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
244                     protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
245                     return;
246                 }
247 
248                 // check opcode vs message fragmentation state 2/2
249                 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
250                     protocolViolation(ctx, in,
251                                       "received non-continuation data frame while inside fragmented message");
252                     return;
253                 }
254             }
255 
256             state = State.READING_SIZE;
257         case READING_SIZE:
258 
259             // Read frame payload length
260             if (framePayloadLen1 == 126) {
261                 if (in.readableBytes() < 2) {
262                     return;
263                 }
264                 framePayloadLength = in.readUnsignedShort();
265                 if (framePayloadLength < 126) {
266                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
267                     return;
268                 }
269             } else if (framePayloadLen1 == 127) {
270                 if (in.readableBytes() < 8) {
271                     return;
272                 }
273                 framePayloadLength = in.readLong();
274                 // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
275                 // just check if it's negative?
276 
277                 if (framePayloadLength < 65536) {
278                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
279                     return;
280                 }
281             } else {
282                 framePayloadLength = framePayloadLen1;
283             }
284 
285             if (framePayloadLength > config.maxFramePayloadLength()) {
286                 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
287                     "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
288                 return;
289             }
290 
291             if (logger.isTraceEnabled()) {
292                 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
293             }
294 
295             state = State.MASKING_KEY;
296         case MASKING_KEY:
297             if (frameMasked) {
298                 if (in.readableBytes() < 4) {
299                     return;
300                 }
301                 if (maskingKey == null) {
302                     maskingKey = new byte[4];
303                 }
304                 in.readBytes(maskingKey);
305             }
306             state = State.PAYLOAD;
307         case PAYLOAD:
308             if (in.readableBytes() < framePayloadLength) {
309                 return;
310             }
311 
312             ByteBuf payloadBuffer = null;
313             try {
314                 payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength));
315 
316                 // Now we have all the data, the next checkpoint must be the next
317                 // frame
318                 state = State.READING_FIRST;
319 
320                 // Unmask data if needed
321                 if (frameMasked) {
322                     unmask(payloadBuffer);
323                 }
324 
325                 // Processing ping/pong/close frames because they cannot be
326                 // fragmented
327                 if (frameOpcode == OPCODE_PING) {
328                     out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
329                     payloadBuffer = null;
330                     return;
331                 }
332                 if (frameOpcode == OPCODE_PONG) {
333                     out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
334                     payloadBuffer = null;
335                     return;
336                 }
337                 if (frameOpcode == OPCODE_CLOSE) {
338                     receivedClosingHandshake = true;
339                     checkCloseFrameBody(ctx, payloadBuffer);
340                     out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
341                     payloadBuffer = null;
342                     return;
343                 }
344 
345                 // Processing for possible fragmented messages for text and binary
346                 // frames
347                 if (frameFinalFlag) {
348                     // Final frame of the sequence. Apparently ping frames are
349                     // allowed in the middle of a fragmented message
350                     fragmentedFramesCount = 0;
351                 } else {
352                     // Increment counter
353                     fragmentedFramesCount++;
354                 }
355 
356                 // Return the frame
357                 if (frameOpcode == OPCODE_TEXT) {
358                     out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
359                     payloadBuffer = null;
360                     return;
361                 } else if (frameOpcode == OPCODE_BINARY) {
362                     out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
363                     payloadBuffer = null;
364                     return;
365                 } else if (frameOpcode == OPCODE_CONT) {
366                     out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv,
367                                                            payloadBuffer));
368                     payloadBuffer = null;
369                     return;
370                 } else {
371                     throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
372                                                             + frameOpcode);
373                 }
374             } finally {
375                 if (payloadBuffer != null) {
376                     payloadBuffer.release();
377                 }
378             }
379         case CORRUPT:
380             if (in.isReadable()) {
381                 // If we don't keep reading Netty will throw an exception saying
382                 // we can't return null if no bytes read and state not changed.
383                 in.readByte();
384             }
385             return;
386         default:
387             throw new Error("Shouldn't reach here.");
388         }
389     }
390 
391     private void unmask(ByteBuf frame) {
392         int i = frame.readerIndex();
393         int end = frame.writerIndex();
394 
395         ByteOrder order = frame.order();
396 
397         // Remark: & 0xFF is necessary because Java will do signed expansion from
398         // byte to int which we don't want.
399         int intMask = ((maskingKey[0] & 0xFF) << 24)
400                     | ((maskingKey[1] & 0xFF) << 16)
401                     | ((maskingKey[2] & 0xFF) << 8)
402                     | (maskingKey[3] & 0xFF);
403 
404         // If the byte order of our buffers it little endian we have to bring our mask
405         // into the same format, because getInt() and writeInt() will use a reversed byte order
406         if (order == ByteOrder.LITTLE_ENDIAN) {
407             intMask = Integer.reverseBytes(intMask);
408         }
409 
410         for (; i + 3 < end; i += 4) {
411             int unmasked = frame.getInt(i) ^ intMask;
412             frame.setInt(i, unmasked);
413         }
414         for (; i < end; i++) {
415             frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]);
416         }
417     }
418 
419     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) {
420         protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
421     }
422 
423     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) {
424         protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
425     }
426 
427     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) {
428         state = State.CORRUPT;
429         int readableBytes = in.readableBytes();
430         if (readableBytes > 0) {
431             // Fix for memory leak, caused by ByteToMessageDecoder#channelRead:
432             // buffer 'cumulation' is released ONLY when no more readable bytes available.
433             in.skipBytes(readableBytes);
434         }
435         if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
436             Object closeMessage;
437             if (receivedClosingHandshake) {
438                 closeMessage = Unpooled.EMPTY_BUFFER;
439             } else {
440                 WebSocketCloseStatus closeStatus = ex.closeStatus();
441                 String reasonText = ex.getMessage();
442                 if (reasonText == null) {
443                     reasonText = closeStatus.reasonText();
444                 }
445                 closeMessage = new CloseWebSocketFrame(closeStatus, reasonText);
446             }
447             ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
448         }
449         throw ex;
450     }
451 
452     private static int toFrameLength(long l) {
453         if (l > Integer.MAX_VALUE) {
454             throw new TooLongFrameException("Length:" + l);
455         } else {
456             return (int) l;
457         }
458     }
459 
460     /** */
461     protected void checkCloseFrameBody(
462             ChannelHandlerContext ctx, ByteBuf buffer) {
463         if (buffer == null || !buffer.isReadable()) {
464             return;
465         }
466         if (buffer.readableBytes() < 2) {
467             protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
468         }
469 
470         // Must have 2 byte integer within the valid range
471         int statusCode = buffer.getShort(buffer.readerIndex());
472         if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
473             protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
474         }
475 
476         // May have UTF-8 message
477         if (buffer.readableBytes() > 2) {
478             try {
479                 new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2);
480             } catch (CorruptedWebSocketFrameException ex) {
481                 protocolViolation(ctx, buffer, ex);
482             }
483         }
484     }
485 }