1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54 package org.jboss.netty.handler.codec.http.websocketx;
55
56 import org.jboss.netty.buffer.ChannelBuffer;
57 import org.jboss.netty.buffer.ChannelBuffers;
58 import org.jboss.netty.channel.Channel;
59 import org.jboss.netty.channel.ChannelFutureListener;
60 import org.jboss.netty.channel.ChannelHandlerContext;
61 import org.jboss.netty.handler.codec.frame.CorruptedFrameException;
62 import org.jboss.netty.handler.codec.frame.TooLongFrameException;
63 import org.jboss.netty.handler.codec.replay.ReplayingDecoder;
64 import org.jboss.netty.logging.InternalLogger;
65 import org.jboss.netty.logging.InternalLoggerFactory;
66
67
68
69
70
71 public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDecoder.State> {
72
73 private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class);
74
75 private static final byte OPCODE_CONT = 0x0;
76 private static final byte OPCODE_TEXT = 0x1;
77 private static final byte OPCODE_BINARY = 0x2;
78 private static final byte OPCODE_CLOSE = 0x8;
79 private static final byte OPCODE_PING = 0x9;
80 private static final byte OPCODE_PONG = 0xA;
81
82 private Utf8Validator utf8Validator;
83 private int fragmentedFramesCount;
84
85 private final long maxFramePayloadLength;
86 private boolean frameFinalFlag;
87 private int frameRsv;
88 private int frameOpcode;
89 private long framePayloadLength;
90 private ChannelBuffer framePayload;
91 private int framePayloadBytesRead;
92 private ChannelBuffer maskingKey;
93
94 private final boolean allowExtensions;
95 private final boolean maskedPayload;
96 private boolean receivedClosingHandshake;
97
98 public enum State {
99 FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
100 }
101
102
103
104
105
106
107
108
109
110
111 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions) {
112 this(maskedPayload, allowExtensions, Long.MAX_VALUE);
113 }
114
115
116
117
118
119
120
121
122
123
124
125
126
127 public WebSocket08FrameDecoder(boolean maskedPayload, boolean allowExtensions, long maxFramePayloadLength) {
128 super(State.FRAME_START);
129 this.maskedPayload = maskedPayload;
130 this.allowExtensions = allowExtensions;
131 this.maxFramePayloadLength = maxFramePayloadLength;
132 }
133
134 @Override
135 protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, State state)
136 throws Exception {
137
138
139 if (receivedClosingHandshake) {
140 buffer.skipBytes(actualReadableBytes());
141 return null;
142 }
143
144 switch (state) {
145 case FRAME_START:
146 framePayloadBytesRead = 0;
147 framePayloadLength = -1;
148 framePayload = null;
149
150
151 byte b = buffer.readByte();
152 frameFinalFlag = (b & 0x80) != 0;
153 frameRsv = (b & 0x70) >> 4;
154 frameOpcode = b & 0x0F;
155
156 if (logger.isDebugEnabled()) {
157 logger.debug("Decoding WebSocket Frame opCode=" + frameOpcode);
158 }
159
160
161 b = buffer.readByte();
162 boolean frameMasked = (b & 0x80) != 0;
163 int framePayloadLen1 = b & 0x7F;
164
165 if (frameRsv != 0 && !allowExtensions) {
166 protocolViolation(channel, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
167 return null;
168 }
169
170 if (maskedPayload && !frameMasked) {
171 protocolViolation(channel, "unmasked client to server frame");
172 return null;
173 }
174 if (frameOpcode > 7) {
175
176
177 if (!frameFinalFlag) {
178 protocolViolation(channel, "fragmented control frame");
179 return null;
180 }
181
182
183 if (framePayloadLen1 > 125) {
184 protocolViolation(channel, "control frame with payload length > 125 octets");
185 return null;
186 }
187
188
189 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) {
190 protocolViolation(channel, "control frame using reserved opcode " + frameOpcode);
191 return null;
192 }
193
194
195
196
197 if (frameOpcode == 8 && framePayloadLen1 == 1) {
198 protocolViolation(channel, "received close control frame with payload len 1");
199 return null;
200 }
201 } else {
202
203 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) {
204 protocolViolation(channel, "data frame using reserved opcode " + frameOpcode);
205 return null;
206 }
207
208
209 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
210 protocolViolation(channel, "received continuation data frame outside fragmented message");
211 return null;
212 }
213
214
215 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) {
216 protocolViolation(channel, "received non-continuation data frame while inside fragmented message");
217 return null;
218 }
219 }
220
221
222 if (framePayloadLen1 == 126) {
223 framePayloadLength = buffer.readUnsignedShort();
224 if (framePayloadLength < 126) {
225 protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
226 return null;
227 }
228 } else if (framePayloadLen1 == 127) {
229 framePayloadLength = buffer.readLong();
230
231
232
233 if (framePayloadLength < 65536) {
234 protocolViolation(channel, "invalid data frame length (not using minimal length encoding)");
235 return null;
236 }
237 } else {
238 framePayloadLength = framePayloadLen1;
239 }
240
241 if (framePayloadLength > maxFramePayloadLength) {
242 protocolViolation(channel, "Max frame length of " + maxFramePayloadLength + " has been exceeded.");
243 return null;
244 }
245 if (logger.isDebugEnabled()) {
246 logger.debug("Decoding WebSocket Frame length=" + framePayloadLength);
247 }
248
249 checkpoint(State.MASKING_KEY);
250 case MASKING_KEY:
251 if (maskedPayload) {
252 maskingKey = buffer.readBytes(4);
253 }
254 checkpoint(State.PAYLOAD);
255 case PAYLOAD:
256
257
258 int rbytes = actualReadableBytes();
259 ChannelBuffer payloadBuffer = null;
260
261 long willHaveReadByteCount = framePayloadBytesRead + rbytes;
262
263
264
265 if (willHaveReadByteCount == framePayloadLength) {
266
267 payloadBuffer = buffer.readBytes(rbytes);
268 } else if (willHaveReadByteCount < framePayloadLength) {
269
270
271 payloadBuffer = buffer.readBytes(rbytes);
272 if (framePayload == null) {
273 framePayload = channel.getConfig().getBufferFactory().getBuffer(toFrameLength(framePayloadLength));
274 }
275 framePayload.writeBytes(payloadBuffer);
276 framePayloadBytesRead += rbytes;
277
278
279 return null;
280 } else if (willHaveReadByteCount > framePayloadLength) {
281
282
283 payloadBuffer = buffer.readBytes(toFrameLength(framePayloadLength - framePayloadBytesRead));
284 }
285
286
287
288 checkpoint(State.FRAME_START);
289
290
291 if (framePayload == null) {
292 framePayload = payloadBuffer;
293 } else {
294 framePayload.writeBytes(payloadBuffer);
295 }
296
297
298 if (maskedPayload) {
299 unmask(framePayload);
300 }
301
302
303
304 if (frameOpcode == OPCODE_PING) {
305 return new PingWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
306 }
307 if (frameOpcode == OPCODE_PONG) {
308 return new PongWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
309 }
310 if (frameOpcode == OPCODE_CLOSE) {
311 checkCloseFrameBody(channel, framePayload);
312 receivedClosingHandshake = true;
313 return new CloseWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
314 }
315
316
317
318 if (frameFinalFlag) {
319
320
321 if (frameOpcode != OPCODE_PING) {
322 fragmentedFramesCount = 0;
323
324
325 if (frameOpcode == OPCODE_TEXT || (utf8Validator != null && utf8Validator.isChecking())) {
326
327 checkUTF8String(channel, framePayload.array());
328
329
330
331 utf8Validator.finish();
332 }
333 }
334 } else {
335
336
337 if (fragmentedFramesCount == 0) {
338
339 if (frameOpcode == OPCODE_TEXT) {
340 checkUTF8String(channel, framePayload.array());
341 }
342 } else {
343
344 if (utf8Validator != null && utf8Validator.isChecking()) {
345 checkUTF8String(channel, framePayload.array());
346 }
347 }
348
349
350 fragmentedFramesCount++;
351 }
352
353
354 if (frameOpcode == OPCODE_TEXT) {
355 return new TextWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
356 } else if (frameOpcode == OPCODE_BINARY) {
357 return new BinaryWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
358 } else if (frameOpcode == OPCODE_CONT) {
359 return new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload);
360 } else {
361 throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + frameOpcode);
362 }
363 case CORRUPT:
364
365
366 buffer.readByte();
367 return null;
368 default:
369 throw new Error("Shouldn't reach here.");
370 }
371 }
372
373 private void unmask(ChannelBuffer frame) {
374 byte[] bytes = frame.array();
375 for (int i = 0; i < bytes.length; i++) {
376 frame.setByte(i, frame.getByte(i) ^ maskingKey.getByte(i % 4));
377 }
378 }
379
380 private void protocolViolation(Channel channel, String reason) throws CorruptedFrameException {
381 protocolViolation(channel, new CorruptedFrameException(reason));
382 }
383
384 private void protocolViolation(Channel channel, CorruptedFrameException ex) throws CorruptedFrameException {
385 checkpoint(State.CORRUPT);
386 if (channel.isConnected()) {
387 channel.write(ChannelBuffers.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
388 }
389 throw ex;
390 }
391
392 private static int toFrameLength(long l) throws TooLongFrameException {
393 if (l > Integer.MAX_VALUE) {
394 throw new TooLongFrameException("Length:" + l);
395 } else {
396 return (int) l;
397 }
398 }
399
400 private void checkUTF8String(Channel channel, byte[] bytes) throws CorruptedFrameException {
401 try {
402 if (utf8Validator == null) {
403 utf8Validator = new Utf8Validator();
404 }
405 utf8Validator.check(bytes);
406 } catch (CorruptedFrameException ex) {
407 protocolViolation(channel, ex);
408 }
409 }
410
411 protected void checkCloseFrameBody(Channel channel, ChannelBuffer buffer) throws CorruptedFrameException {
412 if (buffer == null || buffer.capacity() == 0) {
413 return;
414 }
415 if (buffer.capacity() == 1) {
416 protocolViolation(channel, "Invalid close frame body");
417 }
418
419
420 int idx = buffer.readerIndex();
421 buffer.readerIndex(0);
422
423
424 int statusCode = buffer.readShort();
425 if (statusCode >= 0 && statusCode <= 999 || statusCode >= 1004 && statusCode <= 1006
426 || statusCode >= 1012 && statusCode <= 2999) {
427 protocolViolation(channel, "Invalid close frame status code: " + statusCode);
428 }
429
430
431 if (buffer.readableBytes() > 0) {
432 byte[] b = new byte[buffer.readableBytes()];
433 buffer.readBytes(b);
434 try {
435 Utf8Validator validator = new Utf8Validator();
436 validator.check(b);
437 } catch (CorruptedFrameException ex) {
438 protocolViolation(channel, ex);
439 }
440 }
441
442
443 buffer.readerIndex(idx);
444 }
445 }