1 /*
2 * Copyright 2014 The Netty Project
3 *
4 * The Netty Project licenses this file to you under the Apache License,
5 * version 2.0 (the "License"); you may not use this file except in compliance
6 * with the License. You may obtain a copy of the License at:
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13 * License for the specific language governing permissions and limitations
14 * under the License.
15 */
16 package io.netty.handler.codec.http.websocketx.extensions;
17
18 import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
19
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.Unpooled;
22 import io.netty.channel.ChannelDuplexHandler;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelPromise;
25 import io.netty.handler.codec.http.DefaultHttpRequest;
26 import io.netty.handler.codec.http.DefaultHttpResponse;
27 import io.netty.handler.codec.http.HttpHeaderNames;
28 import io.netty.handler.codec.http.HttpHeaders;
29 import io.netty.handler.codec.http.HttpRequest;
30 import io.netty.handler.codec.http.HttpResponse;
31 import io.netty.handler.codec.http.HttpResponseStatus;
32 import io.netty.handler.codec.http.LastHttpContent;
33
34 import java.util.ArrayDeque;
35 import java.util.ArrayList;
36 import java.util.Arrays;
37 import java.util.Collections;
38 import java.util.Iterator;
39 import java.util.List;
40 import java.util.Queue;
41
42 /**
43 * This handler negotiates and initializes the WebSocket Extensions.
44 *
45 * It negotiates the extensions based on the client desired order,
46 * ensures that the successfully negotiated extensions are consistent between them,
47 * and initializes the channel pipeline with the extension decoder and encoder.
48 *
49 * Find a basic implementation for compression extensions at
50 * <tt>io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler</tt>.
51 */
52 public class WebSocketServerExtensionHandler extends ChannelDuplexHandler {
53
54 private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
55
56 private final Queue<List<WebSocketServerExtension>> validExtensions = new ArrayDeque<>(4);
57
58 /**
59 * Constructor
60 *
61 * @param extensionHandshakers
62 * The extension handshaker in priority order. A handshaker could be repeated many times
63 * with fallback configuration.
64 */
65 public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
66 this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
67 }
68
69 @Override
70 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
71 // JDK type checks vs non-implemented interfaces costs O(N), where
72 // N is the number of interfaces already implemented by the concrete type that's being tested.
73 // The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead
74 // and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met
75 // singleton and/or concrete types, to save performing such slow type checks.
76 if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
77 if (msg instanceof DefaultHttpRequest) {
78 // fast-path
79 onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg);
80 } else if (msg instanceof HttpRequest) {
81 // slow path
82 onHttpRequestChannelRead(ctx, (HttpRequest) msg);
83 } else {
84 super.channelRead(ctx, msg);
85 }
86 } else {
87 super.channelRead(ctx, msg);
88 }
89 }
90
91 /**
92 * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
93 * eg:<br>
94 * If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and
95 * {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg}
96 * types too, can override it like this:
97 * <pre>
98 * public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
99 * if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
100 * if (msg instanceof CustomHttpRequest) {
101 * onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
102 * } else {
103 * // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
104 * // or have to delegate it to super.channelRead (that can perform redundant checks).
105 * // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
106 * // ...
107 * super.channelRead(ctx, msg);
108 * }
109 * } else {
110 * // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
111 * ctx.fireChannelRead(msg);
112 * }
113 * }
114 * </pre>
115 * <strong>IMPORTANT:</strong>
116 * It already call {@code super.channelRead(ctx, request)} before returning.
117 */
118 protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception {
119 List<WebSocketServerExtension> validExtensionsList = null;
120
121 if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
122 String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
123
124 if (extensionsHeader != null) {
125 List<WebSocketExtensionData> extensions =
126 WebSocketExtensionUtil.extractExtensions(extensionsHeader);
127 int rsv = 0;
128
129 for (WebSocketExtensionData extensionData : extensions) {
130 Iterator<WebSocketServerExtensionHandshaker> extensionHandshakersIterator =
131 extensionHandshakers.iterator();
132 WebSocketServerExtension validExtension = null;
133
134 while (validExtension == null && extensionHandshakersIterator.hasNext()) {
135 WebSocketServerExtensionHandshaker extensionHandshaker =
136 extensionHandshakersIterator.next();
137 validExtension = extensionHandshaker.handshakeExtension(extensionData);
138 }
139
140 if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
141 if (validExtensionsList == null) {
142 validExtensionsList = new ArrayList<WebSocketServerExtension>(1);
143 }
144 rsv = rsv | validExtension.rsv();
145 validExtensionsList.add(validExtension);
146 }
147 }
148 }
149 }
150
151 if (validExtensionsList == null) {
152 validExtensionsList = Collections.emptyList();
153 }
154 validExtensions.offer(validExtensionsList);
155 super.channelRead(ctx, request);
156 }
157
158 @Override
159 public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
160 if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
161 if (msg instanceof DefaultHttpResponse) {
162 onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise);
163 } else if (msg instanceof HttpResponse) {
164 onHttpResponseWrite(ctx, (HttpResponse) msg, promise);
165 } else {
166 super.write(ctx, msg, promise);
167 }
168 } else {
169 super.write(ctx, msg, promise);
170 }
171 }
172
173 /**
174 * This is a method exposed to perform fail-fast checks of user-defined http types.<p>
175 * eg:<br>
176 * If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and
177 * {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this:
178 * <pre>
179 * public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
180 * if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
181 * if (msg instanceof CustomHttpResponse) {
182 * onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
183 * } else {
184 * // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
185 * // or have to delegate it to super.write (that can perform redundant checks).
186 * // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
187 * // ...
188 * super.write(ctx, msg, promise);
189 * }
190 * } else {
191 * // given that msg isn't a HttpResponse type we can just skip calling super.write
192 * ctx.write(msg, promise);
193 * }
194 * }
195 * </pre>
196 * <strong>IMPORTANT:</strong>
197 * It already call {@code super.write(ctx, response, promise)} before returning.
198 */
199 protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise)
200 throws Exception {
201 List<WebSocketServerExtension> validExtensionsList = validExtensions.poll();
202 // checking the status is faster than looking at headers so we do this first
203 if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(response.status())) {
204 handlePotentialUpgrade(ctx, promise, response, validExtensionsList);
205 }
206 super.write(ctx, response, promise);
207 }
208
209 private void handlePotentialUpgrade(final ChannelHandlerContext ctx,
210 ChannelPromise promise, HttpResponse httpResponse,
211 final List<WebSocketServerExtension> validExtensionsList) {
212 HttpHeaders headers = httpResponse.headers();
213
214 if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
215 if (validExtensionsList != null && !validExtensionsList.isEmpty()) {
216 String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
217 List<WebSocketExtensionData> extraExtensions =
218 new ArrayList<>(extensionHandshakers.size());
219 for (WebSocketServerExtension extension : validExtensionsList) {
220 extraExtensions.add(extension.newReponseData());
221 }
222 String newHeaderValue = WebSocketExtensionUtil
223 .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
224 promise.addListener(future -> {
225 if (future.isSuccess()) {
226 for (WebSocketServerExtension extension : validExtensionsList) {
227 WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
228 WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
229 String name = ctx.name();
230 ctx.pipeline()
231 .addAfter(name, decoder.getClass().getName(), decoder)
232 .addAfter(name, encoder.getClass().getName(), encoder);
233 }
234 }
235 });
236
237 headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
238 }
239
240 promise.addListener(future -> {
241 if (future.isSuccess()) {
242 ctx.pipeline().remove(WebSocketServerExtensionHandler.this);
243 }
244 });
245 }
246 }
247 }