1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelOutboundHandler;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.handler.codec.MessageToMessageDecoder;
25 import io.netty.util.ReferenceCountUtil;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.PromiseNotifier;
28
29 import java.net.SocketAddress;
30 import java.nio.channels.ClosedChannelException;
31 import java.util.List;
32 import java.util.concurrent.TimeUnit;
33
34 abstract class WebSocketProtocolHandler extends MessageToMessageDecoder<WebSocketFrame>
35 implements ChannelOutboundHandler {
36
37 private final boolean dropPongFrames;
38 private final WebSocketCloseStatus closeStatus;
39 private final long forceCloseTimeoutMillis;
40 private ChannelPromise closeSent;
41
42
43
44
45 WebSocketProtocolHandler() {
46 this(true);
47 }
48
49
50
51
52
53
54
55
56 WebSocketProtocolHandler(boolean dropPongFrames) {
57 this(dropPongFrames, null, 0L);
58 }
59
60 WebSocketProtocolHandler(boolean dropPongFrames,
61 WebSocketCloseStatus closeStatus,
62 long forceCloseTimeoutMillis) {
63 this.dropPongFrames = dropPongFrames;
64 this.closeStatus = closeStatus;
65 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
66 }
67
68 @Override
69 protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
70 if (frame instanceof PingWebSocketFrame) {
71 frame.content().retain();
72 ctx.writeAndFlush(new PongWebSocketFrame(frame.content()));
73 readIfNeeded(ctx);
74 return;
75 }
76 if (frame instanceof PongWebSocketFrame && dropPongFrames) {
77 readIfNeeded(ctx);
78 return;
79 }
80
81 out.add(frame.retain());
82 }
83
84 private static void readIfNeeded(ChannelHandlerContext ctx) {
85 if (!ctx.channel().config().isAutoRead()) {
86 ctx.read();
87 }
88 }
89
90 @Override
91 public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception {
92 if (closeStatus == null || !ctx.channel().isActive()) {
93 ctx.close(promise);
94 } else {
95 if (closeSent == null) {
96 write(ctx, new CloseWebSocketFrame(closeStatus), ctx.newPromise());
97 }
98 flush(ctx);
99 applyCloseSentTimeout(ctx);
100 closeSent.addListener(new ChannelFutureListener() {
101 @Override
102 public void operationComplete(ChannelFuture future) {
103 ctx.close(promise);
104 }
105 });
106 }
107 }
108
109 @Override
110 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
111 if (closeSent != null) {
112 ReferenceCountUtil.release(msg);
113 promise.setFailure(new ClosedChannelException());
114 } else if (msg instanceof CloseWebSocketFrame) {
115 closeSent(promise.unvoid());
116 ctx.write(msg).addListener(new PromiseNotifier<Void, ChannelFuture>(false, closeSent));
117 } else {
118 ctx.write(msg, promise);
119 }
120 }
121
122 void closeSent(ChannelPromise promise) {
123 closeSent = promise;
124 }
125
126 private void applyCloseSentTimeout(ChannelHandlerContext ctx) {
127 if (closeSent.isDone() || forceCloseTimeoutMillis < 0) {
128 return;
129 }
130
131 final Future<?> timeoutTask = ctx.executor().schedule(new Runnable() {
132 @Override
133 public void run() {
134 if (!closeSent.isDone()) {
135 closeSent.tryFailure(buildHandshakeException("send close frame timed out"));
136 }
137 }
138 }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
139
140 closeSent.addListener(new ChannelFutureListener() {
141 @Override
142 public void operationComplete(ChannelFuture future) {
143 timeoutTask.cancel(false);
144 }
145 });
146 }
147
148
149
150
151
152 protected WebSocketHandshakeException buildHandshakeException(String message) {
153 return new WebSocketHandshakeException(message);
154 }
155
156 @Override
157 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress,
158 ChannelPromise promise) throws Exception {
159 ctx.bind(localAddress, promise);
160 }
161
162 @Override
163 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
164 SocketAddress localAddress, ChannelPromise promise) throws Exception {
165 ctx.connect(remoteAddress, localAddress, promise);
166 }
167
168 @Override
169 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise)
170 throws Exception {
171 ctx.disconnect(promise);
172 }
173
174 @Override
175 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
176 ctx.deregister(promise);
177 }
178
179 @Override
180 public void read(ChannelHandlerContext ctx) throws Exception {
181 ctx.read();
182 }
183
184 @Override
185 public void flush(ChannelHandlerContext ctx) throws Exception {
186 ctx.flush();
187 }
188
189 @Override
190 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
191 ctx.fireExceptionCaught(cause);
192 ctx.close();
193 }
194 }