View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx.extensions;
17  
18  import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.Unpooled;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelPromise;
27  import io.netty.handler.codec.http.DefaultHttpRequest;
28  import io.netty.handler.codec.http.DefaultHttpResponse;
29  import io.netty.handler.codec.http.HttpHeaderNames;
30  import io.netty.handler.codec.http.HttpHeaders;
31  import io.netty.handler.codec.http.HttpRequest;
32  import io.netty.handler.codec.http.HttpResponse;
33  import io.netty.handler.codec.http.HttpResponseStatus;
34  import io.netty.handler.codec.http.LastHttpContent;
35  
36  import java.util.ArrayDeque;
37  import java.util.ArrayList;
38  import java.util.Arrays;
39  import java.util.Collections;
40  import java.util.Iterator;
41  import java.util.List;
42  import java.util.Queue;
43  
44  /**
45   * This handler negotiates and initializes the WebSocket Extensions.
46   *
47   * It negotiates the extensions based on the client desired order,
48   * ensures that the successfully negotiated extensions are consistent between them,
49   * and initializes the channel pipeline with the extension decoder and encoder.
50   *
51   * Find a basic implementation for compression extensions at
52   * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler</tt>.
53   */
54  public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
55  
56      private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
57  
58      private final Queue<List<WebSocketServerExtension>> validExtensions =
59              new ArrayDeque<List<WebSocketServerExtension>>(4);
60  
61      /**
62       * Constructor
63       *
64       * @param extensionHandshakers
65       *      The extension handshaker in priority order. A handshaker could be repeated many times
66       *      with fallback configuration.
67       */
68      public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
69          this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
70      }
71  
72      @Override
73      public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
74          // JDK type checks vs non-implemented interfaces costs O(N), where
75          // N is the number of interfaces already implemented by the concrete type that's being tested.
76          // The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead
77          // and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met
78          // singleton and/or concrete types, to save performing such slow type checks.
79          if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
80              if (msg instanceof DefaultHttpRequest) {
81                  // fast-path
82                  onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg);
83              } else if (msg instanceof HttpRequest) {
84                  // slow path
85                  onHttpRequestChannelRead(ctx, (HttpRequest) msg);
86              } else {
87                  super.channelRead(ctx, msg);
88              }
89          } else {
90              super.channelRead(ctx, msg);
91          }
92      }
93  
94      /**
95       * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
96       * eg:<br>
97       * If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and
98       * {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg}
99       * types too, can override it like this:
100      * <pre>
101      *     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
102      *         if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
103      *             if (msg instanceof CustomHttpRequest) {
104      *                 onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
105      *             } else {
106      *                 // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
107      *                 // or have to delegate it to super.channelRead (that can perform redundant checks).
108      *                 // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
109      *                 // ...
110      *                 super.channelRead(ctx, msg);
111      *             }
112      *         } else {
113      *             // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
114      *             ctx.fireChannelRead(msg);
115      *         }
116      *     }
117      * </pre>
118      * <strong>IMPORTANT:</strong>
119      * It already call {@code super.channelRead(ctx, request)} before returning.
120      */
121     protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception {
122         List<WebSocketServerExtension> validExtensionsList = null;
123 
124         if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
125             String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
126 
127             if (extensionsHeader != null) {
128                 List<WebSocketExtensionData> extensions =
129                         WebSocketExtensionUtil.extractExtensions(extensionsHeader);
130                 int rsv = 0;
131 
132                 for (WebSocketExtensionData extensionData : extensions) {
133                     Iterator<WebSocketServerExtensionHandshaker> extensionHandshakersIterator =
134                             extensionHandshakers.iterator();
135                     WebSocketServerExtension validExtension = null;
136 
137                     while (validExtension == null && extensionHandshakersIterator.hasNext()) {
138                         WebSocketServerExtensionHandshaker extensionHandshaker =
139                                 extensionHandshakersIterator.next();
140                         validExtension = extensionHandshaker.handshakeExtension(extensionData);
141                     }
142 
143                     if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
144                         if (validExtensionsList == null) {
145                             validExtensionsList = new ArrayList<WebSocketServerExtension>(1);
146                         }
147                         rsv = rsv | validExtension.rsv();
148                         validExtensionsList.add(validExtension);
149                     }
150                 }
151             }
152         }
153 
154         if (validExtensionsList == null) {
155             validExtensionsList = Collections.emptyList();
156         }
157         validExtensions.offer(validExtensionsList);
158         super.channelRead(ctx, request);
159     }
160 
161     @Override
162     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
163         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
164             if (msg instanceof DefaultHttpResponse) {
165                 onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise);
166             } else if (msg instanceof HttpResponse) {
167                 onHttpResponseWrite(ctx, (HttpResponse) msg, promise);
168             } else {
169                 super.write(ctx, msg, promise);
170             }
171         } else {
172             super.write(ctx, msg, promise);
173         }
174     }
175 
176     /**
177      * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
178      * eg:<br>
179      * If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and
180      * {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this:
181      * <pre>
182      *     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
183      *         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
184      *             if (msg instanceof CustomHttpResponse) {
185      *                 onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
186      *             } else {
187      *                 // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
188      *                 // or have to delegate it to super.write (that can perform redundant checks).
189      *                 // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
190      *                 // ...
191      *                 super.write(ctx, msg, promise);
192      *             }
193      *         } else {
194      *             // given that msg isn't a HttpResponse type we can just skip calling super.write
195      *             ctx.write(msg, promise);
196      *         }
197      *     }
198      * </pre>
199      * <strong>IMPORTANT:</strong>
200      * It already call {@code super.write(ctx, response, promise)} before returning.
201      */
202     protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise)
203             throws Exception {
204         List<WebSocketServerExtension> validExtensionsList = validExtensions.poll();
205         // checking the status is faster than looking at headers so we do this first
206         if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(response.status())) {
207             handlePotentialUpgrade(ctx, promise, response, validExtensionsList);
208         }
209         super.write(ctx, response, promise);
210     }
211 
212     private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
213                                         ChannelPromise promise, HttpResponse httpResponse,
214                                         final List<WebSocketServerExtension> validExtensionsList) {
215         HttpHeaders headers = httpResponse.headers();
216 
217         if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
218             if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
219                 String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
220                 List<WebSocketExtensionData> extraExtensions =
221                   new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
222                 for (WebSocketServerExtension extension : validExtensionsList) {
223                     extraExtensions.add(extension.newReponseData());
224                 }
225                 String newHeaderValue = WebSocketExtensionUtil
226                   .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
227                 promise.addListener(new ChannelFutureListener() {
228                     @Override
229                     public void operationComplete(ChannelFuture future) {
230                         if (future.isSuccess()) {
231                             for (WebSocketServerExtension extension : validExtensionsList) {
232                                 WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
233                                 WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
234                                 String name = ctx.name();
235                                 ctx.pipeline()
236                                     .addAfter(name, decoder.getClass().getName(), decoder)
237                                     .addAfter(name, encoder.getClass().getName(), encoder);
238                             }
239                         }
240                     }
241                 });
242 
243                 if (newHeaderValue != null) {
244                     headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
245                 }
246             }
247 
248             promise.addListener(new ChannelFutureListener() {
249                 @Override
250                 public void operationComplete(ChannelFuture future) {
251                     if (future.isSuccess()) {
252                         ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
253                     }
254                 }
255             });
256         }
257     }
258 }