View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx.extensions;
17  
18  import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.Unpooled;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelFuture;
24  import io.netty.channel.ChannelFutureListener;
25  import io.netty.channel.ChannelHandlerContext;
26  import io.netty.channel.ChannelPromise;
27  import io.netty.handler.codec.http.DefaultHttpRequest;
28  import io.netty.handler.codec.http.DefaultHttpResponse;
29  import io.netty.handler.codec.http.HttpHeaderNames;
30  import io.netty.handler.codec.http.HttpHeaders;
31  import io.netty.handler.codec.http.HttpRequest;
32  import io.netty.handler.codec.http.HttpResponse;
33  import io.netty.handler.codec.http.HttpResponseStatus;
34  import io.netty.handler.codec.http.LastHttpContent;
35  import io.netty.util.internal.UnstableApi;
36  
37  import java.util.ArrayDeque;
38  import java.util.ArrayList;
39  import java.util.Arrays;
40  import java.util.Collections;
41  import java.util.Iterator;
42  import java.util.List;
43  import java.util.Queue;
44  
45  /**
46   * This handler negotiates and initializes the WebSocket Extensions.
47   *
48   * It negotiates the extensions based on the client desired order,
49   * ensures that the successfully negotiated extensions are consistent between them,
50   * and initializes the channel pipeline with the extension decoder and encoder.
51   *
52   * Find a basic implementation for compression extensions at
53   * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler</tt>.
54   */
55  public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
56  
57      private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
58  
59      private final Queue<List<WebSocketServerExtension>> validExtensions =
60              new ArrayDeque<List<WebSocketServerExtension>>(4);
61  
62      /**
63       * Constructor
64       *
65       * @param extensionHandshakers
66       *      The extension handshaker in priority order. A handshaker could be repeated many times
67       *      with fallback configuration.
68       */
69      public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
70          this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
71      }
72  
73      @Override
74      public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
75          // JDK type checks vs non-implemented interfaces costs O(N), where
76          // N is the number of interfaces already implemented by the concrete type that's being tested.
77          // The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead
78          // and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met
79          // singleton and/or concrete types, to save performing such slow type checks.
80          if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
81              if (msg instanceof DefaultHttpRequest) {
82                  // fast-path
83                  onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg);
84              } else if (msg instanceof HttpRequest) {
85                  // slow path
86                  onHttpRequestChannelRead(ctx, (HttpRequest) msg);
87              } else {
88                  super.channelRead(ctx, msg);
89              }
90          } else {
91              super.channelRead(ctx, msg);
92          }
93      }
94  
95      /**
96       * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
97       * eg:<br>
98       * If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and
99       * {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg}
100      * types too, can override it like this:
101      * <pre>
102      *     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
103      *         if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
104      *             if (msg instanceof CustomHttpRequest) {
105      *                 onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
106      *             } else {
107      *                 // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
108      *                 // or have to delegate it to super.channelRead (that can perform redundant checks).
109      *                 // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
110      *                 // ...
111      *                 super.channelRead(ctx, msg);
112      *             }
113      *         } else {
114      *             // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
115      *             ctx.fireChannelRead(msg);
116      *         }
117      *     }
118      * </pre>
119      * <strong>IMPORTANT:</strong>
120      * It already call {@code super.channelRead(ctx, request)} before returning.
121      */
122     @UnstableApi
123     protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception {
124         List<WebSocketServerExtension> validExtensionsList = null;
125 
126         if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
127             String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
128 
129             if (extensionsHeader != null) {
130                 List<WebSocketExtensionData> extensions =
131                         WebSocketExtensionUtil.extractExtensions(extensionsHeader);
132                 int rsv = 0;
133 
134                 for (WebSocketExtensionData extensionData : extensions) {
135                     Iterator<WebSocketServerExtensionHandshaker> extensionHandshakersIterator =
136                             extensionHandshakers.iterator();
137                     WebSocketServerExtension validExtension = null;
138 
139                     while (validExtension == null && extensionHandshakersIterator.hasNext()) {
140                         WebSocketServerExtensionHandshaker extensionHandshaker =
141                                 extensionHandshakersIterator.next();
142                         validExtension = extensionHandshaker.handshakeExtension(extensionData);
143                     }
144 
145                     if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
146                         if (validExtensionsList == null) {
147                             validExtensionsList = new ArrayList<WebSocketServerExtension>(1);
148                         }
149                         rsv = rsv | validExtension.rsv();
150                         validExtensionsList.add(validExtension);
151                     }
152                 }
153             }
154         }
155 
156         if (validExtensionsList == null) {
157             validExtensionsList = Collections.emptyList();
158         }
159         validExtensions.offer(validExtensionsList);
160         super.channelRead(ctx, request);
161     }
162 
163     @Override
164     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
165         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
166             if (msg instanceof DefaultHttpResponse) {
167                 onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise);
168             } else if (msg instanceof HttpResponse) {
169                 onHttpResponseWrite(ctx, (HttpResponse) msg, promise);
170             } else {
171                 super.write(ctx, msg, promise);
172             }
173         } else {
174             super.write(ctx, msg, promise);
175         }
176     }
177 
178     /**
179      * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
180      * eg:<br>
181      * If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and
182      * {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this:
183      * <pre>
184      *     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
185      *         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
186      *             if (msg instanceof CustomHttpResponse) {
187      *                 onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
188      *             } else {
189      *                 // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
190      *                 // or have to delegate it to super.write (that can perform redundant checks).
191      *                 // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
192      *                 // ...
193      *                 super.write(ctx, msg, promise);
194      *             }
195      *         } else {
196      *             // given that msg isn't a HttpResponse type we can just skip calling super.write
197      *             ctx.write(msg, promise);
198      *         }
199      *     }
200      * </pre>
201      * <strong>IMPORTANT:</strong>
202      * It already call {@code super.write(ctx, response, promise)} before returning.
203      */
204     @UnstableApi
205     protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise)
206             throws Exception {
207         List<WebSocketServerExtension> validExtensionsList = validExtensions.poll();
208         // checking the status is faster than looking at headers so we do this first
209         if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(response.status())) {
210             handlePotentialUpgrade(ctx, promise, response, validExtensionsList);
211         }
212         super.write(ctx, response, promise);
213     }
214 
215     private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
216                                         ChannelPromise promise, HttpResponse httpResponse,
217                                         final List<WebSocketServerExtension> validExtensionsList) {
218         HttpHeaders headers = httpResponse.headers();
219 
220         if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
221             if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
222                 String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
223                 List<WebSocketExtensionData> extraExtensions =
224                   new ArrayList<WebSocketExtensionData>(extensionHandshakers.size());
225                 for (WebSocketServerExtension extension : validExtensionsList) {
226                     extraExtensions.add(extension.newReponseData());
227                 }
228                 String newHeaderValue = WebSocketExtensionUtil
229                   .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
230                 promise.addListener(new ChannelFutureListener() {
231                     @Override
232                     public void operationComplete(ChannelFuture future) {
233                         if (future.isSuccess()) {
234                             for (WebSocketServerExtension extension : validExtensionsList) {
235                                 WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
236                                 WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
237                                 String name = ctx.name();
238                                 ctx.pipeline()
239                                     .addAfter(name, decoder.getClass().getName(), decoder)
240                                     .addAfter(name, encoder.getClass().getName(), encoder);
241                             }
242                         }
243                     }
244                 });
245 
246                 if (newHeaderValue != null) {
247                     headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
248                 }
249             }
250 
251             promise.addListener(new ChannelFutureListener() {
252                 @Override
253                 public void operationComplete(ChannelFuture future) {
254                     if (future.isSuccess()) {
255                         ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
256                     }
257                 }
258             });
259         }
260     }
261 }