View Javadoc
1   /*
2    * Copyright 2014 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.handler.codec.http.websocketx.extensions;
17  
18  import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19  
20  import io.netty.buffer.ByteBuf;
21  import io.netty.buffer.Unpooled;
22  import io.netty.channel.ChannelDuplexHandler;
23  import io.netty.channel.ChannelHandlerContext;
24  import io.netty.channel.ChannelPromise;
25  import io.netty.handler.codec.http.DefaultHttpRequest;
26  import io.netty.handler.codec.http.DefaultHttpResponse;
27  import io.netty.handler.codec.http.HttpHeaderNames;
28  import io.netty.handler.codec.http.HttpHeaders;
29  import io.netty.handler.codec.http.HttpRequest;
30  import io.netty.handler.codec.http.HttpResponse;
31  import io.netty.handler.codec.http.HttpResponseStatus;
32  import io.netty.handler.codec.http.LastHttpContent;
33  
34  import java.util.ArrayDeque;
35  import java.util.ArrayList;
36  import java.util.Arrays;
37  import java.util.Collections;
38  import java.util.Iterator;
39  import java.util.List;
40  import java.util.Queue;
41  
42  /**
43   * This handler negotiates and initializes the WebSocket Extensions.
44   *
45   * It negotiates the extensions based on the client desired order,
46   * ensures that the successfully negotiated extensions are consistent between them,
47   * and initializes the channel pipeline with the extension decoder and encoder.
48   *
49   * Find a basic implementation for compression extensions at
50   * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler</tt>.
51   */
52  public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
53  
54      private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
55  
56      private final Queue<List<WebSocketServerExtension>> validExtensions = new ArrayDeque<>(4);
57  
58      /**
59       * Constructor
60       *
61       * @param extensionHandshakers
62       *      The extension handshaker in priority order. A handshaker could be repeated many times
63       *      with fallback configuration.
64       */
65      public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
66          this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
67      }
68  
69      @Override
70      public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
71          // JDK type checks vs non-implemented interfaces costs O(N), where
72          // N is the number of interfaces already implemented by the concrete type that's being tested.
73          // The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead
74          // and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met
75          // singleton and/or concrete types, to save performing such slow type checks.
76          if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
77              if (msg instanceof DefaultHttpRequest) {
78                  // fast-path
79                  onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg);
80              } else if (msg instanceof HttpRequest) {
81                  // slow path
82                  onHttpRequestChannelRead(ctx, (HttpRequest) msg);
83              } else {
84                  super.channelRead(ctx, msg);
85              }
86          } else {
87              super.channelRead(ctx, msg);
88          }
89      }
90  
91      /**
92       * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
93       * eg:<br>
94       * If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and
95       * {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg}
96       * types too, can override it like this:
97       * <pre>
98       *     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
99       *         if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
100      *             if (msg instanceof CustomHttpRequest) {
101      *                 onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
102      *             } else {
103      *                 // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
104      *                 // or have to delegate it to super.channelRead (that can perform redundant checks).
105      *                 // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
106      *                 // ...
107      *                 super.channelRead(ctx, msg);
108      *             }
109      *         } else {
110      *             // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
111      *             ctx.fireChannelRead(msg);
112      *         }
113      *     }
114      * </pre>
115      * <strong>IMPORTANT:</strong>
116      * It already call {@code super.channelRead(ctx, request)} before returning.
117      */
118     protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception {
119         List<WebSocketServerExtension> validExtensionsList = null;
120 
121         if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
122             String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
123 
124             if (extensionsHeader != null) {
125                 List<WebSocketExtensionData> extensions =
126                         WebSocketExtensionUtil.extractExtensions(extensionsHeader);
127                 int rsv = 0;
128 
129                 for (WebSocketExtensionData extensionData : extensions) {
130                     Iterator<WebSocketServerExtensionHandshaker> extensionHandshakersIterator =
131                             extensionHandshakers.iterator();
132                     WebSocketServerExtension validExtension = null;
133 
134                     while (validExtension == null && extensionHandshakersIterator.hasNext()) {
135                         WebSocketServerExtensionHandshaker extensionHandshaker =
136                                 extensionHandshakersIterator.next();
137                         validExtension = extensionHandshaker.handshakeExtension(extensionData);
138                     }
139 
140                     if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
141                         if (validExtensionsList == null) {
142                             validExtensionsList = new ArrayList<WebSocketServerExtension>(1);
143                         }
144                         rsv = rsv | validExtension.rsv();
145                         validExtensionsList.add(validExtension);
146                     }
147                 }
148             }
149         }
150 
151         if (validExtensionsList == null) {
152             validExtensionsList = Collections.emptyList();
153         }
154         validExtensions.offer(validExtensionsList);
155         super.channelRead(ctx, request);
156     }
157 
158     @Override
159     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
160         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
161             if (msg instanceof DefaultHttpResponse) {
162                 onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise);
163             } else if (msg instanceof HttpResponse) {
164                 onHttpResponseWrite(ctx, (HttpResponse) msg, promise);
165             } else {
166                 super.write(ctx, msg, promise);
167             }
168         } else {
169             super.write(ctx, msg, promise);
170         }
171     }
172 
173     /**
174      * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
175      * eg:<br>
176      * If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and
177      * {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this:
178      * <pre>
179      *     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
180      *         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
181      *             if (msg instanceof CustomHttpResponse) {
182      *                 onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
183      *             } else {
184      *                 // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
185      *                 // or have to delegate it to super.write (that can perform redundant checks).
186      *                 // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
187      *                 // ...
188      *                 super.write(ctx, msg, promise);
189      *             }
190      *         } else {
191      *             // given that msg isn't a HttpResponse type we can just skip calling super.write
192      *             ctx.write(msg, promise);
193      *         }
194      *     }
195      * </pre>
196      * <strong>IMPORTANT:</strong>
197      * It already call {@code super.write(ctx, response, promise)} before returning.
198      */
199     protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise)
200             throws Exception {
201         List<WebSocketServerExtension> validExtensionsList = validExtensions.poll();
202         // checking the status is faster than looking at headers so we do this first
203         if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(response.status())) {
204             handlePotentialUpgrade(ctx, promise, response, validExtensionsList);
205         }
206         super.write(ctx, response, promise);
207     }
208 
209     private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
210                                         ChannelPromise promise, HttpResponse httpResponse,
211                                         final List<WebSocketServerExtension> validExtensionsList) {
212         HttpHeaders headers = httpResponse.headers();
213 
214         if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
215             if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
216                 String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
217                 List<WebSocketExtensionData> extraExtensions =
218                   new ArrayList<>(extensionHandshakers.size());
219                 for (WebSocketServerExtension extension : validExtensionsList) {
220                     extraExtensions.add(extension.newReponseData());
221                 }
222                 String newHeaderValue = WebSocketExtensionUtil
223                   .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
224                 promise.addListener(future -> {
225                     if (future.isSuccess()) {
226                         for (WebSocketServerExtension extension : validExtensionsList) {
227                             WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
228                             WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
229                             String name = ctx.name();
230                             ctx.pipeline()
231                                 .addAfter(name, decoder.getClass().getName(), decoder)
232                                 .addAfter(name, encoder.getClass().getName(), encoder);
233                         }
234                     }
235                 });
236 
237                 headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
238             }
239 
240             promise.addListener(future -> {
241                 if (future.isSuccess()) {
242                     ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
243                 }
244             });
245         }
246     }
247 }