1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.websocketx;
17
18
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelOutboundHandler;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.handler.codec.MessageToMessageDecoder;
25 import io.netty.util.ReferenceCountUtil;
26 import io.netty.util.concurrent.Future;
27 import io.netty.util.concurrent.PromiseNotifier;
28
29 import java.net.SocketAddress;
30 import java.nio.channels.ClosedChannelException;
31 import java.util.List;
32 import java.util.concurrent.TimeUnit;
33
34 abstract class WebSocketProtocolHandler extends MessageToMessageDecoder<WebSocketFrame>
35 implements ChannelOutboundHandler {
36
37 private final boolean dropPongFrames;
38 private final WebSocketCloseStatus closeStatus;
39 private final long forceCloseTimeoutMillis;
40 private ChannelPromise closeSent;
41
42
43
44
45 WebSocketProtocolHandler() {
46 this(true);
47 }
48
49
50
51
52
53
54
55
56 WebSocketProtocolHandler(boolean dropPongFrames) {
57 this(dropPongFrames, null, 0L);
58 }
59
60 WebSocketProtocolHandler(boolean dropPongFrames,
61 WebSocketCloseStatus closeStatus,
62 long forceCloseTimeoutMillis) {
63 super(WebSocketFrame.class);
64 this.dropPongFrames = dropPongFrames;
65 this.closeStatus = closeStatus;
66 this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
67 }
68
69 @Override
70 protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List<Object> out) throws Exception {
71 if (frame instanceof PingWebSocketFrame) {
72 frame.content().retain();
73 ctx.writeAndFlush(new PongWebSocketFrame(frame.content()));
74 readIfNeeded(ctx);
75 return;
76 }
77 if (frame instanceof PongWebSocketFrame && dropPongFrames) {
78 readIfNeeded(ctx);
79 return;
80 }
81
82 out.add(frame.retain());
83 }
84
85 private static void readIfNeeded(ChannelHandlerContext ctx) {
86 if (!ctx.channel().config().isAutoRead()) {
87 ctx.read();
88 }
89 }
90
91 @Override
92 public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception {
93 if (closeStatus == null || !ctx.channel().isActive()) {
94 ctx.close(promise);
95 } else {
96 if (closeSent == null) {
97 write(ctx, new CloseWebSocketFrame(closeStatus), ctx.newPromise());
98 }
99 flush(ctx);
100 applyCloseSentTimeout(ctx);
101 closeSent.addListener(future -> ctx.close(promise));
102 }
103 }
104
105 @Override
106 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
107 if (closeSent != null) {
108 ReferenceCountUtil.release(msg);
109 promise.setFailure(new ClosedChannelException());
110 } else if (msg instanceof CloseWebSocketFrame) {
111 closeSent(promise.unvoid());
112 ctx.write(msg).addListener(new PromiseNotifier<Void, ChannelFuture>(false, closeSent));
113 } else {
114 ctx.write(msg, promise);
115 }
116 }
117
118 void closeSent(ChannelPromise promise) {
119 closeSent = promise;
120 }
121
122 private void applyCloseSentTimeout(ChannelHandlerContext ctx) {
123 if (closeSent.isDone() || forceCloseTimeoutMillis < 0) {
124 return;
125 }
126
127 final Future<?> timeoutTask = ctx.executor().schedule(new Runnable() {
128 @Override
129 public void run() {
130 if (!closeSent.isDone()) {
131 closeSent.tryFailure(buildHandshakeException("send close frame timed out"));
132 }
133 }
134 }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
135
136 closeSent.addListener(future -> timeoutTask.cancel(false));
137 }
138
139
140
141
142
143 protected WebSocketHandshakeException buildHandshakeException(String message) {
144 return new WebSocketHandshakeException(message);
145 }
146
147 @Override
148 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress,
149 ChannelPromise promise) throws Exception {
150 ctx.bind(localAddress, promise);
151 }
152
153 @Override
154 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
155 SocketAddress localAddress, ChannelPromise promise) throws Exception {
156 ctx.connect(remoteAddress, localAddress, promise);
157 }
158
159 @Override
160 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise)
161 throws Exception {
162 ctx.disconnect(promise);
163 }
164
165 @Override
166 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
167 ctx.deregister(promise);
168 }
169
170 @Override
171 public void read(ChannelHandlerContext ctx) throws Exception {
172 ctx.read();
173 }
174
175 @Override
176 public void flush(ChannelHandlerContext ctx) throws Exception {
177 ctx.flush();
178 }
179
180 @Override
181 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
182 ctx.fireExceptionCaught(cause);
183 ctx.close();
184 }
185 }