1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package io.netty.handler.codec.http;
16
17 import io.netty.channel.ChannelHandlerContext;
18 import io.netty.channel.ChannelOutboundHandler;
19 import io.netty.channel.ChannelPromise;
20 import io.netty.util.AsciiString;
21 import io.netty.util.ReferenceCountUtil;
22 import io.netty.util.internal.ObjectUtil;
23
24 import java.net.SocketAddress;
25 import java.util.Collection;
26 import java.util.LinkedHashSet;
27 import java.util.List;
28 import java.util.Set;
29
30 import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
31 import static io.netty.util.ReferenceCountUtil.release;
32
33
34
35
36
37
38
39
40 public class HttpClientUpgradeHandler extends HttpObjectAggregator implements ChannelOutboundHandler {
41
42
43
44
45 public enum UpgradeEvent {
46
47
48
49 UPGRADE_ISSUED,
50
51
52
53
54 UPGRADE_SUCCESSFUL,
55
56
57
58
59
60 UPGRADE_REJECTED
61 }
62
63
64
65
66 public interface SourceCodec {
67
68
69
70
71
72 void prepareUpgradeFrom(ChannelHandlerContext ctx);
73
74
75
76
77 void upgradeFrom(ChannelHandlerContext ctx);
78 }
79
80
81
82
83 public interface UpgradeCodec {
84
85
86
87 CharSequence protocol();
88
89
90
91
92
93 Collection<CharSequence> setUpgradeHeaders(ChannelHandlerContext ctx, HttpRequest upgradeRequest);
94
95
96
97
98
99
100
101
102
103 void upgradeTo(ChannelHandlerContext ctx, FullHttpResponse upgradeResponse) throws Exception;
104 }
105
106 private final SourceCodec sourceCodec;
107 private final UpgradeCodec upgradeCodec;
108 private UpgradeEvent currentUpgradeEvent;
109
110
111
112
113
114
115
116
117 public HttpClientUpgradeHandler(SourceCodec sourceCodec, UpgradeCodec upgradeCodec,
118 int maxContentLength) {
119 super(maxContentLength);
120 this.sourceCodec = ObjectUtil.checkNotNull(sourceCodec, "sourceCodec");
121 this.upgradeCodec = ObjectUtil.checkNotNull(upgradeCodec, "upgradeCodec");
122 }
123
124 @Override
125 public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
126 ctx.bind(localAddress, promise);
127 }
128
129 @Override
130 public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
131 ChannelPromise promise) throws Exception {
132 ctx.connect(remoteAddress, localAddress, promise);
133 }
134
135 @Override
136 public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
137 ctx.disconnect(promise);
138 }
139
140 @Override
141 public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
142 ctx.close(promise);
143 }
144
145 @Override
146 public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
147 ctx.deregister(promise);
148 }
149
150 @Override
151 public void read(ChannelHandlerContext ctx) throws Exception {
152 ctx.read();
153 }
154
155 @Override
156 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
157 throws Exception {
158 if (!(msg instanceof HttpRequest) || currentUpgradeEvent == UpgradeEvent.UPGRADE_SUCCESSFUL) {
159 ctx.write(msg, promise);
160 return;
161 }
162
163 if (currentUpgradeEvent == UpgradeEvent.UPGRADE_ISSUED) {
164
165 ReferenceCountUtil.release(msg);
166 promise.setFailure(new IllegalStateException(
167 "Attempting to write HTTP request with upgrade in progress"));
168 return;
169 }
170
171 currentUpgradeEvent = UpgradeEvent.UPGRADE_ISSUED;
172 setUpgradeRequestHeaders(ctx, (HttpRequest) msg);
173
174
175 ctx.write(msg, promise);
176
177
178 ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_ISSUED);
179
180 }
181
182 @Override
183 public void flush(ChannelHandlerContext ctx) throws Exception {
184 ctx.flush();
185 }
186
187 @Override
188 protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
189 throws Exception {
190 FullHttpResponse response = null;
191 try {
192 if (currentUpgradeEvent != UpgradeEvent.UPGRADE_ISSUED) {
193 throw new IllegalStateException("Read HTTP response without requesting protocol switch");
194 }
195
196 if (msg instanceof HttpResponse) {
197 HttpResponse rep = (HttpResponse) msg;
198 if (!SWITCHING_PROTOCOLS.equals(rep.status())) {
199
200
201
202
203 currentUpgradeEvent = null;
204 ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED);
205 removeThisHandler(ctx);
206 ctx.fireChannelRead(msg);
207 return;
208 }
209 }
210
211 if (msg instanceof FullHttpResponse) {
212 response = (FullHttpResponse) msg;
213
214 response.retain();
215 out.add(response);
216 } else {
217
218 super.decode(ctx, msg, out);
219 if (out.isEmpty()) {
220
221 return;
222 }
223
224 assert out.size() == 1;
225 response = (FullHttpResponse) out.get(0);
226 }
227
228 CharSequence upgradeHeader = response.headers().get(HttpHeaderNames.UPGRADE);
229 if (upgradeHeader != null && !AsciiString.contentEqualsIgnoreCase(upgradeCodec.protocol(), upgradeHeader)) {
230 throw new IllegalStateException(
231 "Switching Protocols response with unexpected UPGRADE protocol: " + upgradeHeader);
232 }
233
234
235 sourceCodec.prepareUpgradeFrom(ctx);
236 upgradeCodec.upgradeTo(ctx, response);
237
238
239
240 currentUpgradeEvent = UpgradeEvent.UPGRADE_SUCCESSFUL;
241
242
243 ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_SUCCESSFUL);
244
245
246
247 sourceCodec.upgradeFrom(ctx);
248
249
250
251 response.release();
252 out.clear();
253 removeThisHandler(ctx);
254 } catch (Throwable t) {
255 release(response);
256 ctx.fireExceptionCaught(t);
257 removeThisHandler(ctx);
258 }
259 }
260
261 private static void removeThisHandler(ChannelHandlerContext ctx) {
262 ctx.pipeline().remove(ctx.name());
263 }
264
265
266
267
268 private void setUpgradeRequestHeaders(ChannelHandlerContext ctx, HttpRequest request) {
269
270 request.headers().set(HttpHeaderNames.UPGRADE, upgradeCodec.protocol());
271
272
273 Set<CharSequence> connectionParts = new LinkedHashSet<CharSequence>(2);
274 connectionParts.addAll(upgradeCodec.setUpgradeHeaders(ctx, request));
275
276
277 StringBuilder builder = new StringBuilder();
278 for (CharSequence part : connectionParts) {
279 builder.append(part);
280 builder.append(',');
281 }
282 builder.append(HttpHeaderValues.UPGRADE);
283 request.headers().add(HttpHeaderNames.CONNECTION, builder.toString());
284 }
285 }