View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  // (BSD License: http://www.opensource.org/licenses/bsd-license)
17  //
18  // Copyright (c) 2011, Joe Walnes and contributors
19  // All rights reserved.
20  //
21  // Redistribution and use in source and binary forms, with or
22  // without modification, are permitted provided that the
23  // following conditions are met:
24  //
25  // * Redistributions of source code must retain the above
26  // copyright notice, this list of conditions and the
27  // following disclaimer.
28  //
29  // * Redistributions in binary form must reproduce the above
30  // copyright notice, this list of conditions and the
31  // following disclaimer in the documentation and/or other
32  // materials provided with the distribution.
33  //
34  // * Neither the name of the Webbit nor the names of
35  // its contributors may be used to endorse or promote products
36  // derived from this software without specific prior written
37  // permission.
38  //
39  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
40  // CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
41  // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
42  // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
44  // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
45  // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
46  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
47  // GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48  // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
49  // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
51  // OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
52  // POSSIBILITY OF SUCH DAMAGE.
53  
54  package io.netty.handler.codec.http.websocketx;
55  
56  import io.netty.buffer.ByteBuf;
57  import io.netty.buffer.Unpooled;
58  import io.netty.channel.ChannelFutureListener;
59  import io.netty.channel.ChannelHandlerContext;
60  import io.netty.handler.codec.ByteToMessageDecoder;
61  import io.netty.handler.codec.TooLongFrameException;
62  import io.netty.util.internal.ObjectUtil;
63  import io.netty.util.internal.logging.InternalLogger;
64  import io.netty.util.internal.logging.InternalLoggerFactory;
65  
66  import java.nio.ByteOrder;
67  import java.util.List;
68  
69  import static io.netty.buffer.ByteBufUtil.readBytes;
70  
71  /**
72   * Decodes a web socket frame from wire protocol version 8 format. This code was forked from <a
73   * href="https://github.com/joewalnes/webbit">webbit</a> and modified.
74   */
75  public class WebSocket08FrameDecoder extends ByteToMessageDecoder
76          implements WebSocketFrameDecoder {
77  
78      enum State {
79          READING_FIRST,
80          READING_SECOND,
81          READING_SIZE,
82          MASKING_KEY,
83          PAYLOAD,
84          CORRUPT
85      }
86  
87      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
88  
89      private static final byte OPCODE_CONT = 0x0;
90      private static final byte OPCODE_TEXT = 0x1;
91      private static final byte OPCODE_BINARY = 0x2;
92      private static final byte OPCODE_CLOSE = 0x8;
93      private static final byte OPCODE_PING = 0x9;
94      private static final byte OPCODE_PONG = 0xA;
95  
96      private final WebSocketDecoderConfig config;
97  
98      private int fragmentedFramesCount;
99      private boolean frameFinalFlag;
100     private boolean frameMasked;
101     private int frameRsv;
102     private int frameOpcode;
103     private long framePayloadLength;
104     private byte[] maskingKey;
105     private int framePayloadLen1;
106     private boolean receivedClosingHandshake;
107     private State state = State.READING_FIRST;
108 
109     /**
110      * Constructor
111      *
112      * @param expectMaskedFrames
113      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
114      *            must set this to false.
115      * @param allowExtensions
116      *            Flag to allow reserved extension bits to be used or not
117      * @param maxFramePayloadLength
118      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
119      *            helps check for denial of services attacks.
120      */
121     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
122         this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
123     }
124 
125     /**
126      * Constructor
127      *
128      * @param expectMaskedFrames
129      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
130      *            must set this to false.
131      * @param allowExtensions
132      *            Flag to allow reserved extension bits to be used or not
133      * @param maxFramePayloadLength
134      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
135      *            helps check for denial of services attacks.
136      * @param allowMaskMismatch
137      *            When set to true, frames which are not masked properly according to the standard will still be
138      *            accepted.
139      */
140     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
141                                    boolean allowMaskMismatch) {
142         this(WebSocketDecoderConfig.newBuilder()
143             .expectMaskedFrames(expectMaskedFrames)
144             .allowExtensions(allowExtensions)
145             .maxFramePayloadLength(maxFramePayloadLength)
146             .allowMaskMismatch(allowMaskMismatch)
147             .build());
148     }
149 
150     /**
151      * Constructor
152      *
153      * @param decoderConfig
154      *            Frames decoder configuration.
155      */
156     public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) {
157         this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
158     }
159 
160     @Override
161     protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
162         // Discard all data received if closing handshake was received before.
163         if (receivedClosingHandshake) {
164             in.skipBytes(actualReadableBytes());
165             return;
166         }
167 
168         switch (state) {
169         case READING_FIRST:
170             if (!in.isReadable()) {
171                 return;
172             }
173 
174             framePayloadLength = 0;
175 
176             // FIN, RSV, OPCODE
177             byte b = in.readByte();
178             frameFinalFlag = (b & 0x80) != 0;
179             frameRsv = (b & 0x70) >> 4;
180             frameOpcode = b & 0x0F;
181 
182             if (logger.isTraceEnabled()) {
183                 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
184             }
185 
186             state = State.READING_SECOND;
187         case READING_SECOND:
188             if (!in.isReadable()) {
189                 return;
190             }
191             // MASK, PAYLOAD LEN 1
192             b = in.readByte();
193             frameMasked = (b & 0x80) != 0;
194             framePayloadLen1 = b & 0x7F;
195 
196             if (frameRsv != 0 && !config.allowExtensions()) {
197                 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
198                 return;
199             }
200 
201             if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
202                 protocolViolation(ctx, in, "received a frame that is not masked as expected");
203                 return;
204             }
205 
206             if (frameOpcode > 7) { // control frame (have MSB in opcode set)
207 
208                 // control frames MUST NOT be fragmented
209                 if (!frameFinalFlag) {
210                     protocolViolation(ctx, in, "fragmented control frame");
211                     return;
212                 }
213 
214                 // control frames MUST have payload 125 octets or less
215                 if (framePayloadLen1 > 125) {
216                     protocolViolation(ctx, in, "control frame with payload length > 125 octets");
217                     return;
218                 }
219 
220                 // check for reserved control frame opcodes
221                 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
222                       || frameOpcode == OPCODE_PONG)) {
223                     protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
224                     return;
225                 }
226 
227                 // close frame : if there is a body, the first two bytes of the
228                 // body MUST be a 2-byte unsigned integer (in network byte
229                 // order) representing a getStatus code
230                 if (frameOpcode == 8 && framePayloadLen1 == 1) {
231                     protocolViolation(ctx, in, "received close control frame with payload len 1");
232                     return;
233                 }
234             } else { // data frame
235                 // check for reserved data frame opcodes
236                 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
237                       || frameOpcode == OPCODE_BINARY)) {
238                     protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
239                     return;
240                 }
241 
242                 // check opcode vs message fragmentation state 1/2
243                 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
244                     protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
245                     return;
246                 }
247 
248                 // check opcode vs message fragmentation state 2/2
249                 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
250                     protocolViolation(ctx, in,
251                                       "received non-continuation data frame while inside fragmented message");
252                     return;
253                 }
254             }
255 
256             state = State.READING_SIZE;
257         case READING_SIZE:
258 
259             // Read frame payload length
260             if (framePayloadLen1 == 126) {
261                 if (in.readableBytes() < 2) {
262                     return;
263                 }
264                 framePayloadLength = in.readUnsignedShort();
265                 if (framePayloadLength < 126) {
266                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
267                     return;
268                 }
269             } else if (framePayloadLen1 == 127) {
270                 if (in.readableBytes() < 8) {
271                     return;
272                 }
273                 framePayloadLength = in.readLong();
274                 // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
275                 // just check if it's negative?
276 
277                 if (framePayloadLength < 65536) {
278                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
279                     return;
280                 }
281             } else {
282                 framePayloadLength = framePayloadLen1;
283             }
284 
285             if (framePayloadLength > config.maxFramePayloadLength()) {
286                 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
287                     "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
288                 return;
289             }
290 
291             if (logger.isTraceEnabled()) {
292                 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
293             }
294 
295             state = State.MASKING_KEY;
296         case MASKING_KEY:
297             if (frameMasked) {
298                 if (in.readableBytes() < 4) {
299                     return;
300                 }
301                 if (maskingKey == null) {
302                     maskingKey = new byte[4];
303                 }
304                 in.readBytes(maskingKey);
305             }
306             state = State.PAYLOAD;
307         case PAYLOAD:
308             if (in.readableBytes() < framePayloadLength) {
309                 return;
310             }
311 
312             ByteBuf payloadBuffer = null;
313             try {
314                 payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength));
315 
316                 // Now we have all the data, the next checkpoint must be the next
317                 // frame
318                 state = State.READING_FIRST;
319 
320                 // Unmask data if needed
321                 if (frameMasked) {
322                     unmask(payloadBuffer);
323                 }
324 
325                 // Processing ping/pong/close frames because they cannot be
326                 // fragmented
327                 if (frameOpcode == OPCODE_PING) {
328                     out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
329                     payloadBuffer = null;
330                     return;
331                 }
332                 if (frameOpcode == OPCODE_PONG) {
333                     out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
334                     payloadBuffer = null;
335                     return;
336                 }
337                 if (frameOpcode == OPCODE_CLOSE) {
338                     receivedClosingHandshake = true;
339                     checkCloseFrameBody(ctx, payloadBuffer);
340                     out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
341                     payloadBuffer = null;
342                     return;
343                 }
344 
345                 // Processing for possible fragmented messages for text and binary
346                 // frames
347                 if (frameFinalFlag) {
348                     // Final frame of the sequence. Apparently ping frames are
349                     // allowed in the middle of a fragmented message
350                     if (frameOpcode != OPCODE_PING) {
351                         fragmentedFramesCount = 0;
352                     }
353                 } else {
354                     // Increment counter
355                     fragmentedFramesCount++;
356                 }
357 
358                 // Return the frame
359                 if (frameOpcode == OPCODE_TEXT) {
360                     out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
361                     payloadBuffer = null;
362                     return;
363                 } else if (frameOpcode == OPCODE_BINARY) {
364                     out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
365                     payloadBuffer = null;
366                     return;
367                 } else if (frameOpcode == OPCODE_CONT) {
368                     out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv,
369                                                            payloadBuffer));
370                     payloadBuffer = null;
371                     return;
372                 } else {
373                     throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
374                                                             + frameOpcode);
375                 }
376             } finally {
377                 if (payloadBuffer != null) {
378                     payloadBuffer.release();
379                 }
380             }
381         case CORRUPT:
382             if (in.isReadable()) {
383                 // If we don't keep reading Netty will throw an exception saying
384                 // we can't return null if no bytes read and state not changed.
385                 in.readByte();
386             }
387             return;
388         default:
389             throw new Error("Shouldn't reach here.");
390         }
391     }
392 
393     private void unmask(ByteBuf frame) {
394         int i = frame.readerIndex();
395         int end = frame.writerIndex();
396 
397         ByteOrder order = frame.order();
398 
399         // Remark: & 0xFF is necessary because Java will do signed expansion from
400         // byte to int which we don't want.
401         int intMask = ((maskingKey[0] & 0xFF) << 24)
402                     | ((maskingKey[1] & 0xFF) << 16)
403                     | ((maskingKey[2] & 0xFF) << 8)
404                     | (maskingKey[3] & 0xFF);
405 
406         // If the byte order of our buffers it little endian we have to bring our mask
407         // into the same format, because getInt() and writeInt() will use a reversed byte order
408         if (order == ByteOrder.LITTLE_ENDIAN) {
409             intMask = Integer.reverseBytes(intMask);
410         }
411 
412         for (; i + 3 < end; i += 4) {
413             int unmasked = frame.getInt(i) ^ intMask;
414             frame.setInt(i, unmasked);
415         }
416         for (; i < end; i++) {
417             frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]);
418         }
419     }
420 
421     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) {
422         protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
423     }
424 
425     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) {
426         protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
427     }
428 
429     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) {
430         state = State.CORRUPT;
431         int readableBytes = in.readableBytes();
432         if (readableBytes > 0) {
433             // Fix for memory leak, caused by ByteToMessageDecoder#channelRead:
434             // buffer 'cumulation' is released ONLY when no more readable bytes available.
435             in.skipBytes(readableBytes);
436         }
437         if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
438             Object closeMessage;
439             if (receivedClosingHandshake) {
440                 closeMessage = Unpooled.EMPTY_BUFFER;
441             } else {
442                 WebSocketCloseStatus closeStatus = ex.closeStatus();
443                 String reasonText = ex.getMessage();
444                 if (reasonText == null) {
445                     reasonText = closeStatus.reasonText();
446                 }
447                 closeMessage = new CloseWebSocketFrame(closeStatus, reasonText);
448             }
449             ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
450         }
451         throw ex;
452     }
453 
454     private static int toFrameLength(long l) {
455         if (l > Integer.MAX_VALUE) {
456             throw new TooLongFrameException("Length:" + l);
457         } else {
458             return (int) l;
459         }
460     }
461 
462     /** */
463     protected void checkCloseFrameBody(
464             ChannelHandlerContext ctx, ByteBuf buffer) {
465         if (buffer == null || !buffer.isReadable()) {
466             return;
467         }
468         if (buffer.readableBytes() == 1) {
469             protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
470         }
471 
472         // Save reader index
473         int idx = buffer.readerIndex();
474         buffer.readerIndex(0);
475 
476         // Must have 2 byte integer within the valid range
477         int statusCode = buffer.readShort();
478         if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
479             protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
480         }
481 
482         // May have UTF-8 message
483         if (buffer.isReadable()) {
484             try {
485                 new Utf8Validator().check(buffer);
486             } catch (CorruptedWebSocketFrameException ex) {
487                 protocolViolation(ctx, buffer, ex);
488             }
489         }
490 
491         // Restore reader index
492         buffer.readerIndex(idx);
493     }
494 }