1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package org.jboss.netty.handler.codec.http.websocketx;
55
56 import org.jboss.netty.buffer.ChannelBuffer;
57 import org.jboss.netty.buffer.ChannelBuffers;
58 import org.jboss.netty.channel.Channel;
59 import org.jboss.netty.channel.ChannelFutureListener;
60 import org.jboss.netty.channel.ChannelHandlerContext;
61 import org.jboss.netty.handler.codec.frame.CorruptedFrameException;
62 import org.jboss.netty.handler.codec.frame.TooLongFrameException;
63 import org.jboss.netty.handler.codec.replay.ReplayingDecoder;
64 import org.jboss.netty.logging.InternalLogger;
65 import org.jboss.netty.logging.InternalLoggerFactory;
66
67
68
69
70
71 public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State> {
72
73 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
74
75 private static final byte OPCODE_CONT = 0x0;
76 private static final byte OPCODE_TEXT = 0x1;
77 private static final byte OPCODE_BINARY = 0x2;
78 private static final byte OPCODE_CLOSE = 0x8;
79 private static final byte OPCODE_PING = 0x9;
80 private static final byte OPCODE_PONG = 0xA;
81
82 private UTF8Output fragmentedFramesText;
83 private int fragmentedFramesCount;
84
85 private final long maxFramePayloadLength;
86 private boolean frameFinalFlag;
87 private int frameRsv;
88 private int frameOpcode;
89 private long framePayloadLength;
90 private ChannelBuffer framePayload;
91 private int framePayloadBytesRead;
92 private ChannelBuffer maskingKey;
93
94 private final boolean allowExtensions;
95 private final boolean maskedPayload;
96 private boolean receivedClosingHandshake;
97
98 public enum State {
99 FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
100 }
101
102
103
104
105
106
107
108
109
110
111 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions) {
112 this(maskedPayload, allowExtensions, Long.MAX_VALUE);
113 }
114
115
116
117
118
119
120
121
122
123
124
125
126
127 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
128 super(State.FRAME_START);
129 this.maskedPayload = maskedPayload;
130 this.allowExtensions = allowExtensions;
131 this.maxFramePayloadLength = maxFramePayloadLength;
132 }
133
134 @Override
135 protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, State state)
136 throws Exception {
137
138
139 if (receivedClosingHandshake) {
140 buffer.skipBytes(actualReadableBytes());
141 return null;
142 }
143
144 switch (state) {
145 case FRAME_START:
146 framePayloadBytesRead = 0;
147 framePayloadLength = -1;
148 framePayload = null;
149
150
151 byte b = buffer.readByte();
152 frameFinalFlag = (b & 0x80) != 0;
153 frameRsv = (b & 0x70) >> 4;
154 frameOpcode = b & 0x0F;
155
156 if (logger.isDebugEnabled()) {
157 logger.debug("Decoding WebSocket Frame opCode=" + frameOpcode);
158 }
159
160
161 b = buffer.readByte();
162 boolean frameMasked = (b & 0x80) != 0;
163 int framePayloadLen1 = b & 0x7F;
164
165 if (frameRsv != 0 && !allowExtensions) {
166 protocolViolation(channel, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
167 return null;
168 }
169
170 if (maskedPayload && !frameMasked) {
171 protocolViolation(channel, "unmasked client to server frame");
172 return null;
173 }
174 if (frameOpcode > 7) {
175
176
177 if (!frameFinalFlag) {
178 protocolViolation(channel, "fragmented control frame");
179 return null;
180 }
181
182
183 if (framePayloadLen1 > 125) {
184 protocolViolation(channel, "control frame with payload length > 125 octets");
185 return null;
186 }
187
188
189 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
190 protocolViolation(channel, "control frame using reserved opcode " + frameOpcode);
191 return null;
192 }
193
194
195
196
197 if (frameOpcode == 8 && framePayloadLen1 == 1) {
198 protocolViolation(channel, "received close control frame with payload len 1");
199 return null;
200 }
201 } else {
202
203 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
204 protocolViolation(channel, "data frame using reserved opcode " + frameOpcode);
205 return null;
206 }
207
208
209 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
210 protocolViolation(channel, "received continuation data frame outside fragmented message");
211 return null;
212 }
213
214
215 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
216 protocolViolation(channel, "received non-continuation data frame while inside fragmented message");
217 return null;
218 }
219 }
220
221
222 if (framePayloadLen1 == 126) {
223 framePayloadLength = buffer.readUnsignedShort();
224 if (framePayloadLength < 126) {
225 protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
226 return null;
227 }
228 } else if (framePayloadLen1 == 127) {
229 framePayloadLength = buffer.readLong();
230
231
232
233 if (framePayloadLength < 65536) {
234 protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
235 return null;
236 }
237 } else {
238 framePayloadLength = framePayloadLen1;
239 }
240
241 if (framePayloadLength > maxFramePayloadLength) {
242 protocolViolation(channel, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
243 return null;
244 }
245 if (logger.isDebugEnabled()) {
246 logger.debug("Decoding WebSocket Frame length=" + framePayloadLength);
247 }
248
249 checkpoint(State.MASKING_KEY);
250 case MASKING_KEY:
251 if (maskedPayload) {
252 maskingKey = buffer.readBytes(4);
253 }
254 checkpoint(State.PAYLOAD);
255 case PAYLOAD:
256
257
258 int rbytes = actualReadableBytes();
259 ChannelBuffer payloadBuffer = null;
260
261 long willHaveReadByteCount = framePayloadBytesRead + rbytes;
262
263
264
265 if (willHaveReadByteCount == framePayloadLength) {
266
267 payloadBuffer = buffer.readBytes(rbytes);
268 } else if (willHaveReadByteCount < framePayloadLength) {
269
270
271 payloadBuffer = buffer.readBytes(rbytes);
272 if (framePayload == null) {
273 framePayload = channel.getConfig().getBufferFactory().getBuffer(toFrameLength(framePayloadLength));
274 }
275 framePayload.writeBytes(payloadBuffer);
276 framePayloadBytesRead += rbytes;
277
278
279 return null;
280 } else if (willHaveReadByteCount > framePayloadLength) {
281
282
283 payloadBuffer = buffer.readBytes(toFrameLength(framePayloadLength - framePayloadBytesRead));
284 }
285
286
287
288 checkpoint(State.FRAME_START);
289
290
291 if (framePayload == null) {
292 framePayload = payloadBuffer;
293 } else {
294 framePayload.writeBytes(payloadBuffer);
295 }
296
297
298 if (maskedPayload) {
299 unmask(framePayload);
300 }
301
302
303
304 if (frameOpcode == OPCODE_PING) {
305 return new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
306 } else if (frameOpcode == OPCODE_PONG) {
307 return new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
308 } else if (frameOpcode == OPCODE_CLOSE) {
309 checkCloseFrameBody(channel, framePayload);
310 receivedClosingHandshake = true;
311 return new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
312 }
313
314
315
316 String aggregatedText = null;
317 if (frameFinalFlag) {
318
319
320 if (frameOpcode != OPCODE_PING) {
321 fragmentedFramesCount = 0;
322
323
324 if (frameOpcode == OPCODE_TEXT || fragmentedFramesText != null) {
325
326 checkUTF8String(channel, framePayload.array());
327
328
329
330 aggregatedText = fragmentedFramesText.toString();
331
332 fragmentedFramesText = null;
333 }
334 }
335 } else {
336
337
338 if (fragmentedFramesCount == 0) {
339
340 fragmentedFramesText = null;
341 if (frameOpcode == OPCODE_TEXT) {
342 checkUTF8String(channel, framePayload.array());
343 }
344 } else {
345
346 if (fragmentedFramesText != null) {
347 checkUTF8String(channel, framePayload.array());
348 }
349 }
350
351
352 fragmentedFramesCount++;
353 }
354
355
356 if (frameOpcode == OPCODE_TEXT) {
357 return new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
358 } else if (frameOpcode == OPCODE_BINARY) {
359 return new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
360 } else if (frameOpcode == OPCODE_CONT) {
361 return new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload, aggregatedText);
362 } else {
363 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + frameOpcode);
364 }
365 case CORRUPT:
366
367
368 buffer.readByte();
369 return null;
370 default:
371 throw new Error("Shouldn't reach here.");
372 }
373 }
374
375 private void unmask(ChannelBuffer frame) {
376 byte[] bytes = frame.array();
377 for (int i = 0; i < bytes.length; i++) {
378 frame.setByte(i, frame.getByte(i) ^ maskingKey.getByte(i % 4));
379 }
380 }
381
382 private void protocolViolation(Channel channel, String reason) throws CorruptedFrameException {
383 checkpoint(State.CORRUPT);
384 if (channel.isConnected()) {
385 channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
386 }
387 throw new CorruptedFrameException(reason);
388 }
389
390 private static int toFrameLength(long l) throws TooLongFrameException {
391 if (l > Integer.MAX_VALUE) {
392 throw new TooLongFrameException("Length:" + l);
393 } else {
394 return (int) l;
395 }
396 }
397
398 private void checkUTF8String(Channel channel, byte[] bytes) throws CorruptedFrameException {
399 try {
400
401
402
403
404
405
406
407 if (fragmentedFramesText == null) {
408 fragmentedFramesText = new UTF8Output(bytes);
409 } else {
410 fragmentedFramesText.write(bytes);
411 }
412 } catch (UTF8Exception ex) {
413 protocolViolation(channel, "invalid UTF-8 bytes");
414 }
415 }
416
417 protected void checkCloseFrameBody(Channel channel, ChannelBuffer buffer) throws CorruptedFrameException {
418 if (buffer == null || buffer.capacity() == 0) {
419 return;
420 }
421 if (buffer.capacity() == 1) {
422 protocolViolation(channel, "Invalid close frame body");
423 }
424
425
426 int idx = buffer.readerIndex();
427 buffer.readerIndex(0);
428
429
430 int statusCode = buffer.readShort();
431 if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
432 || statusCode >= 1012 && statusCode <= 2999) {
433 protocolViolation(channel, "Invalid close frame status code: " + statusCode);
434 }
435
436
437 if (buffer.readableBytes() > 0) {
438 byte[] b = new byte[buffer.readableBytes()];
439 buffer.readBytes(b);
440 try {
441 new UTF8Output(b);
442 } catch (UTF8Exception ex) {
443 protocolViolation(channel, "Invalid close frame reason text. Invalid UTF-8 bytes");
444 }
445 }
446
447
448 buffer.readerIndex(idx);
449 }
450 }