View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx.extensions;
17  
18  import io.netty.channel.ChannelDuplexHandler;
19  import io.netty.channel.ChannelHandlerContext;
20  import io.netty.channel.ChannelPromise;
21  import io.netty.handler.codec.CodecException;
22  import io.netty.handler.codec.http.HttpHeaderNames;
23  import io.netty.handler.codec.http.HttpRequest;
24  import io.netty.handler.codec.http.HttpResponse;
25  
26  import java.util.ArrayList;
27  import java.util.Arrays;
28  import java.util.Iterator;
29  import java.util.List;
30  
31  /**
32   * This handler negotiates and initializes the WebSocket Extensions.
33   *
34   * This implementation negotiates the extension with the server in a defined order,
35   * ensures that the successfully negotiated extensions are consistent between them,
36   * and initializes the channel pipeline with the extension decoder and encoder.
37   *
38   * Find a basic implementation for compression extensions at
39   * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler</tt>.
40   */
41  public class WebSocketClientExtensionHandler extends ChannelDuplexHandler {
42  
43      private final List<WebSocketClientExtensionHandshaker> extensionHandshakers;
44  
45      /**
46       * Constructor
47       *
48       * @param extensionHandshakers
49       *      The extension handshaker in priority order. A handshaker could be repeated many times
50       *      with fallback configuration.
51       */
52      public WebSocketClientExtensionHandler(WebSocketClientExtensionHandshaker... extensionHandshakers) {
53          if (extensionHandshakers == null) {
54              throw new NullPointerException("extensionHandshakers");
55          }
56          if (extensionHandshakers.length == 0) {
57              throw new IllegalArgumentException("extensionHandshakers must contains at least one handshaker");
58          }
59          this.extensionHandshakers = Arrays.asList(extensionHandshakers);
60      }
61  
62      @Override
63      public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
64          if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) {
65              HttpRequest request = (HttpRequest) msg;
66              String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
67  
68              for (WebSocketClientExtensionHandshaker extensionHandshaker : extensionHandshakers) {
69                  WebSocketExtensionData extensionData = extensionHandshaker.newRequestData();
70                  headerValue = WebSocketExtensionUtil.appendExtension(headerValue,
71                          extensionData.name(), extensionData.parameters());
72              }
73  
74              request.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, headerValue);
75          }
76  
77          super.write(ctx, msg, promise);
78      }
79  
80      @Override
81      public void channelRead(ChannelHandlerContext ctx, Object msg)
82              throws Exception {
83          if (msg instanceof HttpResponse) {
84              HttpResponse response = (HttpResponse) msg;
85  
86              if (WebSocketExtensionUtil.isWebsocketUpgrade(response.headers())) {
87                  String extensionsHeader = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
88  
89                  if (extensionsHeader != null) {
90                      List<WebSocketExtensionData> extensions =
91                              WebSocketExtensionUtil.extractExtensions(extensionsHeader);
92                      List<WebSocketClientExtension> validExtensions =
93                              new ArrayList<WebSocketClientExtension>(extensions.size());
94                      int rsv = 0;
95  
96                      for (WebSocketExtensionData extensionData : extensions) {
97                          Iterator<WebSocketClientExtensionHandshaker> extensionHandshakersIterator =
98                                  extensionHandshakers.iterator();
99                          WebSocketClientExtension validExtension = null;
100 
101                         while (validExtension == null && extensionHandshakersIterator.hasNext()) {
102                             WebSocketClientExtensionHandshaker extensionHandshaker =
103                                     extensionHandshakersIterator.next();
104                             validExtension = extensionHandshaker.handshakeExtension(extensionData);
105                         }
106 
107                         if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
108                             rsv = rsv | validExtension.rsv();
109                             validExtensions.add(validExtension);
110                         } else {
111                             throw new CodecException(
112                                     "invalid WebSocket Extension handshake for \"" + extensionsHeader + '"');
113                         }
114                     }
115 
116                     for (WebSocketClientExtension validExtension : validExtensions) {
117                         WebSocketExtensionDecoder decoder = validExtension.newExtensionDecoder();
118                         WebSocketExtensionEncoder encoder = validExtension.newExtensionEncoder();
119                         ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder);
120                         ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
121                     }
122                 }
123 
124                 ctx.pipeline().remove(ctx.name());
125             }
126         }
127 
128         super.channelRead(ctx, msg);
129     }
130 }
131