View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  // (BSD License: https://www.opensource.org/licenses/bsd-license)
17  //
18  // Copyright (c) 2011, Joe Walnes and contributors
19  // All rights reserved.
20  //
21  // Redistribution and use in source and binary forms, with or
22  // without modification, are permitted provided that the
23  // following conditions are met:
24  //
25  // * Redistributions of source code must retain the above
26  // copyright notice, this list of conditions and the
27  // following disclaimer.
28  //
29  // * Redistributions in binary form must reproduce the above
30  // copyright notice, this list of conditions and the
31  // following disclaimer in the documentation and/or other
32  // materials provided with the distribution.
33  //
34  // * Neither the name of the Webbit nor the names of
35  // its contributors may be used to endorse or promote products
36  // derived from this software without specific prior written
37  // permission.
38  //
39  // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
40  // CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
41  // INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
42  // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
43  // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
44  // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
45  // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
46  // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
47  // GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48  // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
49  // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
50  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
51  // OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
52  // POSSIBILITY OF SUCH DAMAGE.
53  
54  package io.netty.handler.codec.http.websocketx;
55  
56  import io.netty.buffer.ByteBuf;
57  import io.netty.buffer.Unpooled;
58  import io.netty.channel.ChannelFutureListener;
59  import io.netty.channel.ChannelHandlerContext;
60  import io.netty.handler.codec.ByteToMessageDecoder;
61  import io.netty.handler.codec.TooLongFrameException;
62  import io.netty.util.internal.ObjectUtil;
63  import io.netty.util.internal.logging.InternalLogger;
64  import io.netty.util.internal.logging.InternalLoggerFactory;
65  
66  import java.nio.ByteOrder;
67  import java.util.List;
68  
69  import static io.netty.buffer.ByteBufUtil.readBytes;
70  
71  /**
72   * Decodes a web socket frame from wire protocol version 8 format. This code was forked from <a
73   * href="https://github.com/joewalnes/webbit">webbit</a> and modified.
74   */
75  public class WebSocket08FrameDecoder extends ByteToMessageDecoder
76          implements WebSocketFrameDecoder {
77  
78      enum State {
79          READING_FIRST,
80          READING_SECOND,
81          READING_SIZE,
82          MASKING_KEY,
83          PAYLOAD,
84          CORRUPT
85      }
86  
87      private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
88  
89      private static final byte OPCODE_CONT = 0x0;
90      private static final byte OPCODE_TEXT = 0x1;
91      private static final byte OPCODE_BINARY = 0x2;
92      private static final byte OPCODE_CLOSE = 0x8;
93      private static final byte OPCODE_PING = 0x9;
94      private static final byte OPCODE_PONG = 0xA;
95  
96      private final WebSocketDecoderConfig config;
97  
98      private int fragmentedFramesCount;
99      private boolean frameFinalFlag;
100     private boolean frameMasked;
101     private int frameRsv;
102     private int frameOpcode;
103     private long framePayloadLength;
104     private int mask;
105     private int framePayloadLen1;
106     private boolean receivedClosingHandshake;
107     private State state = State.READING_FIRST;
108 
109     /**
110      * Constructor
111      *
112      * @param expectMaskedFrames
113      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
114      *            must set this to false.
115      * @param allowExtensions
116      *            Flag to allow reserved extension bits to be used or not
117      * @param maxFramePayloadLength
118      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
119      *            helps check for denial of services attacks.
120      */
121     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
122         this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
123     }
124 
125     /**
126      * Constructor
127      *
128      * @param expectMaskedFrames
129      *            Web socket servers must set this to true processed incoming masked payload. Client implementations
130      *            must set this to false.
131      * @param allowExtensions
132      *            Flag to allow reserved extension bits to be used or not
133      * @param maxFramePayloadLength
134      *            Maximum length of a frame's payload. Setting this to an appropriate value for you application
135      *            helps check for denial of services attacks.
136      * @param allowMaskMismatch
137      *            When set to true, frames which are not masked properly according to the standard will still be
138      *            accepted.
139      */
140     public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength,
141                                    boolean allowMaskMismatch) {
142         this(WebSocketDecoderConfig.newBuilder()
143             .expectMaskedFrames(expectMaskedFrames)
144             .allowExtensions(allowExtensions)
145             .maxFramePayloadLength(maxFramePayloadLength)
146             .allowMaskMismatch(allowMaskMismatch)
147             .build());
148     }
149 
150     /**
151      * Constructor
152      *
153      * @param decoderConfig
154      *            Frames decoder configuration.
155      */
156     public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) {
157         this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig");
158     }
159 
160     @Override
161     protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
162         // Discard all data received if closing handshake was received before.
163         if (receivedClosingHandshake) {
164             in.skipBytes(actualReadableBytes());
165             return;
166         }
167 
168         switch (state) {
169         case READING_FIRST:
170             if (!in.isReadable()) {
171                 return;
172             }
173 
174             framePayloadLength = 0;
175 
176             // FIN, RSV, OPCODE
177             byte b = in.readByte();
178             frameFinalFlag = (b & 0x80) != 0;
179             frameRsv = (b & 0x70) >> 4;
180             frameOpcode = b & 0x0F;
181 
182             if (logger.isTraceEnabled()) {
183                 logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
184             }
185 
186             state = State.READING_SECOND;
187         case READING_SECOND:
188             if (!in.isReadable()) {
189                 return;
190             }
191             // MASK, PAYLOAD LEN 1
192             b = in.readByte();
193             frameMasked = (b & 0x80) != 0;
194             framePayloadLen1 = b & 0x7F;
195 
196             if (frameRsv != 0 && !config.allowExtensions()) {
197                 protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
198                 return;
199             }
200 
201             if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
202                 protocolViolation(ctx, in, "received a frame that is not masked as expected");
203                 return;
204             }
205 
206             if (frameOpcode > 7) { // control frame (have MSB in opcode set)
207 
208                 // control frames MUST NOT be fragmented
209                 if (!frameFinalFlag) {
210                     protocolViolation(ctx, in, "fragmented control frame");
211                     return;
212                 }
213 
214                 // control frames MUST have payload 125 octets or less
215                 if (framePayloadLen1 > 125) {
216                     protocolViolation(ctx, in, "control frame with payload length > 125 octets");
217                     return;
218                 }
219 
220                 // check for reserved control frame opcodes
221                 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
222                       || frameOpcode == OPCODE_PONG)) {
223                     protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
224                     return;
225                 }
226 
227                 // close frame : if there is a body, the first two bytes of the
228                 // body MUST be a 2-byte unsigned integer (in network byte
229                 // order) representing a getStatus code
230                 if (frameOpcode == 8 && framePayloadLen1 == 1) {
231                     protocolViolation(ctx, in, "received close control frame with payload len 1");
232                     return;
233                 }
234             } else { // data frame
235                 // check for reserved data frame opcodes
236                 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
237                       || frameOpcode == OPCODE_BINARY)) {
238                     protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
239                     return;
240                 }
241 
242                 // check opcode vs message fragmentation state 1/2
243                 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
244                     protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
245                     return;
246                 }
247 
248                 // check opcode vs message fragmentation state 2/2
249                 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
250                     protocolViolation(ctx, in,
251                                       "received non-continuation data frame while inside fragmented message");
252                     return;
253                 }
254             }
255 
256             state = State.READING_SIZE;
257         case READING_SIZE:
258 
259             // Read frame payload length
260             if (framePayloadLen1 == 126) {
261                 if (in.readableBytes() < 2) {
262                     return;
263                 }
264                 framePayloadLength = in.readUnsignedShort();
265                 if (framePayloadLength < 126) {
266                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
267                     return;
268                 }
269             } else if (framePayloadLen1 == 127) {
270                 if (in.readableBytes() < 8) {
271                     return;
272                 }
273                 framePayloadLength = in.readLong();
274                 if (framePayloadLength < 0) {
275                     protocolViolation(ctx, in, "invalid data frame length (negative length)");
276                     return;
277                 }
278 
279                 if (framePayloadLength < 65536) {
280                     protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
281                     return;
282                 }
283             } else {
284                 framePayloadLength = framePayloadLen1;
285             }
286 
287             if (framePayloadLength > config.maxFramePayloadLength()) {
288                 protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
289                     "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
290                 return;
291             }
292 
293             if (logger.isTraceEnabled()) {
294                 logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
295             }
296 
297             state = State.MASKING_KEY;
298         case MASKING_KEY:
299             if (frameMasked) {
300                 if (in.readableBytes() < 4) {
301                     return;
302                 }
303                 mask = in.readInt();
304             }
305             state = State.PAYLOAD;
306         case PAYLOAD:
307             if (in.readableBytes() < framePayloadLength) {
308                 return;
309             }
310 
311             ByteBuf payloadBuffer = Unpooled.EMPTY_BUFFER;
312             try {
313                 if (framePayloadLength > 0) {
314                     payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength));
315                 }
316 
317                 // Now we have all the data, the next checkpoint must be the next
318                 // frame
319                 state = State.READING_FIRST;
320 
321                 // Unmask data if needed
322                 if (frameMasked & framePayloadLength > 0) {
323                     unmask(payloadBuffer);
324                 }
325 
326                 // Processing ping/pong/close frames because they cannot be
327                 // fragmented
328                 if (frameOpcode == OPCODE_PING) {
329                     out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
330                     payloadBuffer = null;
331                     return;
332                 }
333                 if (frameOpcode == OPCODE_PONG) {
334                     out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
335                     payloadBuffer = null;
336                     return;
337                 }
338                 if (frameOpcode == OPCODE_CLOSE) {
339                     receivedClosingHandshake = true;
340                     checkCloseFrameBody(ctx, payloadBuffer);
341                     out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
342                     payloadBuffer = null;
343                     return;
344                 }
345 
346                 // Processing for possible fragmented messages for text and binary
347                 // frames
348                 if (frameFinalFlag) {
349                     // Final frame of the sequence. Apparently ping frames are
350                     // allowed in the middle of a fragmented message
351                     fragmentedFramesCount = 0;
352                 } else {
353                     // Increment counter
354                     fragmentedFramesCount++;
355                 }
356 
357                 // Return the frame
358                 if (frameOpcode == OPCODE_TEXT) {
359                     out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
360                     payloadBuffer = null;
361                     return;
362                 } else if (frameOpcode == OPCODE_BINARY) {
363                     out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer));
364                     payloadBuffer = null;
365                     return;
366                 } else if (frameOpcode == OPCODE_CONT) {
367                     out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv,
368                                                            payloadBuffer));
369                     payloadBuffer = null;
370                     return;
371                 } else {
372                     throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
373                                                             + frameOpcode);
374                 }
375             } finally {
376                 if (payloadBuffer != null) {
377                     payloadBuffer.release();
378                 }
379             }
380         case CORRUPT:
381             if (in.isReadable()) {
382                 // If we don't keep reading Netty will throw an exception saying
383                 // we can't return null if no bytes read and state not changed.
384                 in.readByte();
385             }
386             return;
387         default:
388             throw new Error("Shouldn't reach here.");
389         }
390     }
391 
392     private void unmask(ByteBuf frame) {
393         int i = frame.readerIndex();
394         int end = frame.writerIndex();
395 
396         ByteOrder order = frame.order();
397 
398         int intMask = mask;
399         // Avoid sign extension on widening primitive conversion
400         long longMask = intMask & 0xFFFFFFFFL;
401         longMask |= longMask << 32;
402 
403         for (int lim = end - 7; i < lim; i += 8) {
404             frame.setLong(i, frame.getLong(i) ^ longMask);
405         }
406 
407         if (i < end - 3) {
408             frame.setInt(i, frame.getInt(i) ^ (int) longMask);
409             i += 4;
410         }
411 
412         if (order == ByteOrder.LITTLE_ENDIAN) {
413             intMask = Integer.reverseBytes(intMask);
414         }
415 
416         int maskOffset = 0;
417         for (; i < end; i++) {
418             frame.setByte(i, frame.getByte(i) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3));
419         }
420     }
421 
422     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) {
423         protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
424     }
425 
426     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) {
427         protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
428     }
429 
430     private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) {
431         state = State.CORRUPT;
432         int readableBytes = in.readableBytes();
433         if (readableBytes > 0) {
434             // Fix for memory leak, caused by ByteToMessageDecoder#channelRead:
435             // buffer 'cumulation' is released ONLY when no more readable bytes available.
436             in.skipBytes(readableBytes);
437         }
438         if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
439             Object closeMessage;
440             if (receivedClosingHandshake) {
441                 closeMessage = Unpooled.EMPTY_BUFFER;
442             } else {
443                 WebSocketCloseStatus closeStatus = ex.closeStatus();
444                 String reasonText = ex.getMessage();
445                 if (reasonText == null) {
446                     reasonText = closeStatus.reasonText();
447                 }
448                 closeMessage = new CloseWebSocketFrame(closeStatus, reasonText);
449             }
450             ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE);
451         }
452         throw ex;
453     }
454 
455     private static int toFrameLength(long l) {
456         if (l > Integer.MAX_VALUE) {
457             throw new TooLongFrameException("Length:" + l);
458         } else {
459             return (int) l;
460         }
461     }
462 
463     /** */
464     protected void checkCloseFrameBody(
465             ChannelHandlerContext ctx, ByteBuf buffer) {
466         if (buffer == null || !buffer.isReadable()) {
467             return;
468         }
469         if (buffer.readableBytes() < 2) {
470             protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
471         }
472 
473         // Must have 2 byte integer within the valid range
474         int statusCode = buffer.getShort(buffer.readerIndex());
475         if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
476             protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
477         }
478 
479         // May have UTF-8 message
480         if (buffer.readableBytes() > 2) {
481             try {
482                 new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2);
483             } catch (CorruptedWebSocketFrameException ex) {
484                 protocolViolation(ctx, buffer, ex);
485             }
486         }
487     }
488 }