View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx.extensions;
17  
18  import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19  
20  import io.netty.channel.ChannelDuplexHandler;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.CodecException;
24  import io.netty.handler.codec.http.HttpHeaderNames;
25  import io.netty.handler.codec.http.HttpRequest;
26  import io.netty.handler.codec.http.HttpResponse;
27  
28  import java.util.ArrayList;
29  import java.util.Arrays;
30  import java.util.Iterator;
31  import java.util.List;
32  
33  /**
34   * This handler negotiates and initializes the WebSocket Extensions.
35   *
36   * This implementation negotiates the extension with the server in a defined order,
37   * ensures that the successfully negotiated extensions are consistent between them,
38   * and initializes the channel pipeline with the extension decoder and encoder.
39   *
40   * Find a basic implementation for compression extensions at
41   * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler</tt>.
42   */
43  public class WebSocketClientExtensionHandler extends ChannelDuplexHandler {
44  
45      private final List<WebSocketClientExtensionHandshaker> extensionHandshakers;
46  
47      /**
48       * Constructor
49       *
50       * @param extensionHandshakers
51       *      The extension handshaker in priority order. A handshaker could be repeated many times
52       *      with fallback configuration.
53       */
54      public WebSocketClientExtensionHandler(WebSocketClientExtensionHandshaker... extensionHandshakers) {
55          this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
56      }
57  
58      @Override
59      public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
60          if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) {
61              HttpRequest request = (HttpRequest) msg;
62              String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
63              List<WebSocketExtensionData> extraExtensions =
64                new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
65              for (WebSocketClientExtensionHandshaker extensionHandshaker : extensionHandshakers) {
66                  extraExtensions.add(extensionHandshaker.newRequestData());
67              }
68              String newHeaderValue = WebSocketExtensionUtil
69                .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
70  
71              request.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
72          }
73  
74          super.write(ctx, msg, promise);
75      }
76  
77      @Override
78      public void channelRead(ChannelHandlerContext ctx, Object msg)
79              throws Exception {
80          if (msg instanceof HttpResponse) {
81              HttpResponse response = (HttpResponse) msg;
82  
83              if (WebSocketExtensionUtil.isWebsocketUpgrade(response.headers())) {
84                  String extensionsHeader = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
85  
86                  if (extensionsHeader != null) {
87                      List<WebSocketExtensionData> extensions =
88                              WebSocketExtensionUtil.extractExtensions(extensionsHeader);
89                      List<WebSocketClientExtension> validExtensions =
90                              new ArrayList<WebSocketClientExtension>(extensions.size());
91                      int rsv = 0;
92  
93                      for (WebSocketExtensionData extensionData : extensions) {
94                          Iterator<WebSocketClientExtensionHandshaker> extensionHandshakersIterator =
95                                  extensionHandshakers.iterator();
96                          WebSocketClientExtension validExtension = null;
97  
98                          while (validExtension == null && extensionHandshakersIterator.hasNext()) {
99                              WebSocketClientExtensionHandshaker extensionHandshaker =
100                                     extensionHandshakersIterator.next();
101                             validExtension = extensionHandshaker.handshakeExtension(extensionData);
102                         }
103 
104                         if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
105                             rsv = rsv | validExtension.rsv();
106                             validExtensions.add(validExtension);
107                         } else {
108                             throw new CodecException(
109                                     "invalid WebSocket Extension handshake for \"" + extensionsHeader + '"');
110                         }
111                     }
112 
113                     for (WebSocketClientExtension validExtension : validExtensions) {
114                         WebSocketExtensionDecoder decoder = validExtension.newExtensionDecoder();
115                         WebSocketExtensionEncoder encoder = validExtension.newExtensionEncoder();
116                         ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder);
117                         ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder);
118                     }
119                 }
120 
121                 ctx.pipeline().remove(ctx.name());
122             }
123         }
124 
125         super.channelRead(ctx, msg);
126     }
127 }
128