1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.ChannelFutureListener;
20 import io.netty.channel.ChannelHandlerContext;
21 import io.netty.channel.ChannelInboundHandlerAdapter;
22
23
24
25
26 public class Utf8FrameValidator extends ChannelInboundHandlerAdapter {
27
28 private final boolean closeOnProtocolViolation;
29
30 private int fragmentedFramesCount;
31 private Utf8Validator utf8Validator;
32
33 public Utf8FrameValidator() {
34 this(true);
35 }
36
37 public Utf8FrameValidator(boolean closeOnProtocolViolation) {
38 this.closeOnProtocolViolation = closeOnProtocolViolation;
39 }
40
41
42 private static boolean isControlFrame(WebSocketFrame frame) {
43 return frame instanceof CloseWebSocketFrame ||
44 frame instanceof PingWebSocketFrame ||
45 frame instanceof PongWebSocketFrame;
46 }
47
48 @Override
49 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
50 if (msg instanceof WebSocketFrame) {
51 WebSocketFrame frame = (WebSocketFrame) msg;
52
53 try {
54
55
56 if (frame.isFinalFragment()) {
57
58
59 if (!isControlFrame(frame)) {
60
61
62 fragmentedFramesCount = 0;
63
64
65 if (frame instanceof TextWebSocketFrame ||
66 (utf8Validator != null && utf8Validator.isChecking())) {
67
68 checkUTF8String(frame.content());
69
70
71
72 utf8Validator.finish();
73 }
74 }
75 } else {
76
77
78 if (fragmentedFramesCount == 0) {
79
80 if (frame instanceof TextWebSocketFrame) {
81 checkUTF8String(frame.content());
82 }
83 } else {
84
85 if (utf8Validator != null && utf8Validator.isChecking()) {
86 checkUTF8String(frame.content());
87 }
88 }
89
90
91 fragmentedFramesCount++;
92 }
93 } catch (CorruptedWebSocketFrameException e) {
94 protocolViolation(ctx, frame, e);
95 }
96 }
97
98 super.channelRead(ctx, msg);
99 }
100
101 private void checkUTF8String(ByteBuf buffer) {
102 if (utf8Validator == null) {
103 utf8Validator = new Utf8Validator();
104 }
105 utf8Validator.check(buffer);
106 }
107
108 private void protocolViolation(ChannelHandlerContext ctx, WebSocketFrame frame,
109 CorruptedWebSocketFrameException ex) {
110 frame.release();
111 if (closeOnProtocolViolation && ctx.channel().isOpen()) {
112 WebSocketCloseStatus closeStatus = ex.closeStatus();
113 String reasonText = ex.getMessage();
114 if (reasonText == null) {
115 reasonText = closeStatus.reasonText();
116 }
117
118 CloseWebSocketFrame closeFrame = new CloseWebSocketFrame(closeStatus.code(), reasonText);
119 ctx.writeAndFlush(closeFrame).addListener(ChannelFutureListener.CLOSE);
120 }
121
122 throw ex;
123 }
124
125 @Override
126 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
127 super.exceptionCaught(ctx, cause);
128 }
129 }