1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  package io.netty.handler.codec.http;
16  
17  import io.netty.channel.ChannelHandlerContext;
18  import io.netty.channel.ChannelOutboundHandler;
19  import io.netty.channel.ChannelPromise;
20  import io.netty.util.AsciiString;
21  import io.netty.util.ReferenceCountUtil;
22  import io.netty.util.internal.ObjectUtil;
23  
24  import java.net.SocketAddress;
25  import java.util.Collection;
26  import java.util.LinkedHashSet;
27  import java.util.List;
28  import java.util.Set;
29  
30  import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS;
31  import static io.netty.util.ReferenceCountUtil.release;
32  
33  
34  
35  
36  
37  
38  
39  
40  public class HttpClientUpgradeHandler extends HttpObjectAggregator implements ChannelOutboundHandler {
41  
42      
43  
44  
45      public enum UpgradeEvent {
46          
47  
48  
49          UPGRADE_ISSUED,
50  
51          
52  
53  
54          UPGRADE_SUCCESSFUL,
55  
56          
57  
58  
59  
60          UPGRADE_REJECTED
61      }
62  
63      
64  
65  
66      public interface SourceCodec {
67  
68          
69  
70  
71  
72          void prepareUpgradeFrom(ChannelHandlerContext ctx);
73  
74          
75  
76  
77          void upgradeFrom(ChannelHandlerContext ctx);
78      }
79  
80      
81  
82  
83      public interface UpgradeCodec {
84          
85  
86  
87          CharSequence protocol();
88  
89          
90  
91  
92  
93          Collection<CharSequence> setUpgradeHeaders(ChannelHandlerContext ctx, HttpRequest upgradeRequest);
94  
95          
96  
97  
98  
99  
100 
101 
102 
103         void upgradeTo(ChannelHandlerContext ctx, FullHttpResponse upgradeResponse) throws Exception;
104     }
105 
106     private final SourceCodec sourceCodec;
107     private final UpgradeCodec upgradeCodec;
108     private UpgradeEvent currentUpgradeEvent;
109 
110     
111 
112 
113 
114 
115 
116 
117     public HttpClientUpgradeHandler(SourceCodec sourceCodec, UpgradeCodec upgradeCodec,
118                                     int maxContentLength) {
119         super(maxContentLength);
120         this.sourceCodec = ObjectUtil.checkNotNull(sourceCodec, "sourceCodec");
121         this.upgradeCodec = ObjectUtil.checkNotNull(upgradeCodec, "upgradeCodec");
122     }
123 
124     @Override
125     public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception {
126         ctx.bind(localAddress, promise);
127     }
128 
129     @Override
130     public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
131                         ChannelPromise promise) throws Exception {
132         ctx.connect(remoteAddress, localAddress, promise);
133     }
134 
135     @Override
136     public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
137         ctx.disconnect(promise);
138     }
139 
140     @Override
141     public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
142         ctx.close(promise);
143     }
144 
145     @Override
146     public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
147         ctx.deregister(promise);
148     }
149 
150     @Override
151     public void read(ChannelHandlerContext ctx) throws Exception {
152         ctx.read();
153     }
154 
155     @Override
156     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
157             throws Exception {
158         if (!(msg instanceof HttpRequest) || currentUpgradeEvent == UpgradeEvent.UPGRADE_SUCCESSFUL) {
159             ctx.write(msg, promise);
160             return;
161         }
162 
163         if (currentUpgradeEvent == UpgradeEvent.UPGRADE_ISSUED) {
164             
165             ReferenceCountUtil.release(msg);
166             promise.setFailure(new IllegalStateException(
167                     "Attempting to write HTTP request with upgrade in progress"));
168             return;
169         }
170 
171         currentUpgradeEvent = UpgradeEvent.UPGRADE_ISSUED;
172         setUpgradeRequestHeaders(ctx, (HttpRequest) msg);
173 
174         
175         ctx.write(msg, promise);
176 
177         
178         ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_ISSUED);
179         
180     }
181 
182     @Override
183     public void flush(ChannelHandlerContext ctx) throws Exception {
184         ctx.flush();
185     }
186 
187     @Override
188     protected void decode(ChannelHandlerContext ctx, HttpObject msg, List<Object> out)
189             throws Exception {
190         FullHttpResponse response = null;
191         try {
192             if (currentUpgradeEvent != UpgradeEvent.UPGRADE_ISSUED) {
193                 throw new IllegalStateException("Read HTTP response without requesting protocol switch");
194             }
195 
196             if (msg instanceof HttpResponse) {
197                 HttpResponse rep = (HttpResponse) msg;
198                 if (!SWITCHING_PROTOCOLS.equals(rep.status())) {
199                     
200                     
201                     
202                     
203                     currentUpgradeEvent = null;
204                     ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED);
205                     removeThisHandler(ctx);
206                     ctx.fireChannelRead(msg);
207                     return;
208                 }
209             }
210 
211             if (msg instanceof FullHttpResponse) {
212                 response = (FullHttpResponse) msg;
213                 
214                 response.retain();
215                 out.add(response);
216             } else {
217                 
218                 super.decode(ctx, msg, out);
219                 if (out.isEmpty()) {
220                     
221                     return;
222                 }
223 
224                 assert out.size() == 1;
225                 response = (FullHttpResponse) out.get(0);
226             }
227 
228             CharSequence upgradeHeader = response.headers().get(HttpHeaderNames.UPGRADE);
229             if (upgradeHeader != null && !AsciiString.contentEqualsIgnoreCase(upgradeCodec.protocol(), upgradeHeader)) {
230                 throw new IllegalStateException(
231                         "Switching Protocols response with unexpected UPGRADE protocol: " + upgradeHeader);
232             }
233 
234             
235             sourceCodec.prepareUpgradeFrom(ctx);
236             upgradeCodec.upgradeTo(ctx, response);
237 
238             
239             
240             currentUpgradeEvent = UpgradeEvent.UPGRADE_SUCCESSFUL;
241 
242             
243             ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_SUCCESSFUL);
244 
245             
246             
247             sourceCodec.upgradeFrom(ctx);
248 
249             
250             
251             response.release();
252             out.clear();
253             removeThisHandler(ctx);
254         } catch (Throwable t) {
255             ctx.fireExceptionCaught(t);
256             removeThisHandler(ctx);
257         }
258     }
259 
260     private static void removeThisHandler(ChannelHandlerContext ctx) {
261         ctx.pipeline().remove(ctx.name());
262     }
263 
264     
265 
266 
267     private void setUpgradeRequestHeaders(ChannelHandlerContext ctx, HttpRequest request) {
268         
269         request.headers().set(HttpHeaderNames.UPGRADE, upgradeCodec.protocol());
270 
271         
272         Set<CharSequence> connectionParts = new LinkedHashSet<CharSequence>(2);
273         connectionParts.addAll(upgradeCodec.setUpgradeHeaders(ctx, request));
274 
275         
276         StringBuilder builder = new StringBuilder();
277         for (CharSequence part : connectionParts) {
278             builder.append(part);
279             builder.append(',');
280         }
281         builder.append(HttpHeaderValues.UPGRADE);
282         request.headers().add(HttpHeaderNames.CONNECTION, builder.toString());
283     }
284 }