View Javadoc
1   /*
2    * Copyright 2019 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty5.handler.codec.http.websocketx;
17  
18  import io.netty5.buffer.api.BufferAllocator;
19  import io.netty5.channel.Channel;
20  import io.netty5.channel.ChannelFutureListeners;
21  import io.netty5.channel.ChannelHandler;
22  import io.netty5.channel.ChannelHandlerContext;
23  import io.netty5.channel.ChannelOutboundInvoker;
24  import io.netty5.channel.ChannelPipeline;
25  import io.netty5.channel.SimpleChannelInboundHandler;
26  import io.netty5.handler.codec.http.FullHttpRequest;
27  import io.netty5.handler.codec.http.FullHttpResponse;
28  import io.netty5.handler.codec.http.HttpContentCompressor;
29  import io.netty5.handler.codec.http.HttpHeaders;
30  import io.netty5.handler.codec.http.HttpObjectAggregator;
31  import io.netty5.handler.codec.http.HttpRequest;
32  import io.netty5.handler.codec.http.HttpRequestDecoder;
33  import io.netty5.handler.codec.http.HttpResponseEncoder;
34  import io.netty5.handler.codec.http.HttpServerCodec;
35  import io.netty5.util.ReferenceCountUtil;
36  import io.netty5.util.concurrent.Future;
37  import io.netty5.util.concurrent.Promise;
38  import io.netty5.util.internal.EmptyArrays;
39  import io.netty5.util.internal.logging.InternalLogger;
40  import io.netty5.util.internal.logging.InternalLoggerFactory;
41  
42  import java.nio.channels.ClosedChannelException;
43  import java.util.Collections;
44  import java.util.LinkedHashSet;
45  import java.util.Set;
46  
47  import static java.util.Objects.requireNonNull;
48  
49  /**
50   * Base class for server side web socket opening and closing handshakes
51   */
52  public abstract class WebSocketServerHandshaker {
53      protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class);
54  
55      private final String uri;
56  
57      private final String[] subprotocols;
58  
59      private final WebSocketVersion version;
60  
61      private final WebSocketDecoderConfig decoderConfig;
62  
63      private String selectedSubprotocol;
64  
65      /**
66       * Use this as wildcard to support all requested sub-protocols
67       */
68      public static final String SUB_PROTOCOL_WILDCARD = "*";
69  
70      /**
71       * Constructor specifying the destination web socket location
72       *
73       * @param version
74       *            the protocol version
75       * @param uri
76       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
77       *            sent to this URL.
78       * @param subprotocols
79       *            CSV of supported protocols. Null if sub protocols not supported.
80       * @param maxFramePayloadLength
81       *            Maximum length of a frame's payload
82       */
83      protected WebSocketServerHandshaker(
84              WebSocketVersion version, String uri, String subprotocols,
85              int maxFramePayloadLength) {
86          this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder()
87              .maxFramePayloadLength(maxFramePayloadLength)
88              .build());
89      }
90  
91      /**
92       * Constructor specifying the destination web socket location
93       *
94       * @param version
95       *            the protocol version
96       * @param uri
97       *            URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be
98       *            sent to this URL.
99       * @param subprotocols
100      *            CSV of supported protocols. Null if sub protocols not supported.
101      * @param decoderConfig
102      *            Frames decoder configuration.
103      */
104     protected WebSocketServerHandshaker(
105             WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
106         this.version = version;
107         this.uri = uri;
108         if (subprotocols != null) {
109             String[] subprotocolArray = subprotocols.split(",");
110             for (int i = 0; i < subprotocolArray.length; i++) {
111                 subprotocolArray[i] = subprotocolArray[i].trim();
112             }
113             this.subprotocols = subprotocolArray;
114         } else {
115             this.subprotocols = EmptyArrays.EMPTY_STRINGS;
116         }
117         this.decoderConfig = requireNonNull(decoderConfig, "decoderConfig");
118     }
119 
120     /**
121      * Returns the URL of the web socket
122      */
123     public String uri() {
124         return uri;
125     }
126 
127     /**
128      * Returns the CSV of supported sub protocols
129      */
130     public Set<String> subprotocols() {
131         Set<String> ret = new LinkedHashSet<>();
132         Collections.addAll(ret, subprotocols);
133         return ret;
134     }
135 
136     /**
137      * Returns the version of the specification being supported
138      */
139     public WebSocketVersion version() {
140         return version;
141     }
142 
143     /**
144      * Gets the maximum length for any frame's payload.
145      *
146      * @return The maximum length for a frame's payload
147      */
148     public int maxFramePayloadLength() {
149         return decoderConfig.maxFramePayloadLength();
150     }
151 
152     /**
153      * Gets this decoder configuration.
154      *
155      * @return This decoder configuration.
156      */
157     public WebSocketDecoderConfig decoderConfig() {
158         return decoderConfig;
159     }
160 
161     /**
162      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
163      * {@link FullHttpRequest} which is passed in.
164      *
165      * @param channel
166      *              Channel
167      * @param req
168      *              HTTP Request
169      * @return future
170      *              The {@link Future} which is notified once the opening handshake completes
171      */
172     public Future<Void> handshake(Channel channel, FullHttpRequest req) {
173         return handshake(channel, req, null);
174     }
175 
176     /**
177      * Performs the opening handshake
178      *
179      * When call this method you <strong>MUST NOT</strong> retain the {@link FullHttpRequest} which is passed in.
180      *
181      * @param channel
182      *            Channel
183      * @param req
184      *            HTTP Request
185      * @param responseHeaders
186      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
187      * @return future
188      *            the {@link Future} which is notified when the opening handshake is done
189      */
190     public final Future<Void> handshake(Channel channel, FullHttpRequest req, HttpHeaders responseHeaders) {
191 
192         if (logger.isDebugEnabled()) {
193             logger.debug("{} WebSocket version {} server handshake", channel, version());
194         }
195         FullHttpResponse response = newHandshakeResponse(channel.bufferAllocator(), req, responseHeaders);
196         ChannelPipeline p = channel.pipeline();
197         if (p.get(HttpObjectAggregator.class) != null) {
198             p.remove(HttpObjectAggregator.class);
199         }
200         if (p.get(HttpContentCompressor.class) != null) {
201             p.remove(HttpContentCompressor.class);
202         }
203         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
204         final String encoderName;
205         if (ctx == null) {
206             // this means the user use an HttpServerCodec
207             ctx = p.context(HttpServerCodec.class);
208             if (ctx == null) {
209                 return channel.newFailedFuture(
210                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
211             }
212             p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
213             p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
214             encoderName = ctx.name();
215         } else {
216             p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
217 
218             encoderName = p.context(HttpResponseEncoder.class).name();
219             p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
220         }
221         return channel.writeAndFlush(response).addListener(channel, (ch, future) -> {
222             if (future.isSuccess()) {
223                 ChannelPipeline p1 = ch.pipeline();
224                 p1.remove(encoderName);
225             }
226         });
227     }
228 
229     /**
230      * Performs the opening handshake. When call this method you <strong>MUST NOT</strong> retain the
231      * {@link FullHttpRequest} which is passed in.
232      *
233      * @param channel
234      *              Channel
235      * @param req
236      *              HTTP Request
237      * @return future
238      *              The {@link Future} which is notified once the opening handshake completes
239      */
240     public Future<Void> handshake(Channel channel, HttpRequest req) {
241         return handshake(channel, req, null);
242     }
243 
244     /**
245      * Performs the opening handshake
246      *
247      * When call this method you <strong>MUST NOT</strong> retain the {@link HttpRequest} which is passed in.
248      *
249      * @param channel
250      *            Channel
251      * @param req
252      *            HTTP Request
253      * @param responseHeaders
254      *            Extra headers to add to the handshake response or {@code null} if no extra headers should be added
255      * @return future
256      *            the {@link Future} which is notified when the opening handshake is done
257      */
258     public final Future<Void> handshake(final Channel channel, HttpRequest req,
259                                          final HttpHeaders responseHeaders) {
260 
261         if (req instanceof FullHttpRequest) {
262             return handshake(channel, (FullHttpRequest) req, responseHeaders);
263         }
264         if (logger.isDebugEnabled()) {
265             logger.debug("{} WebSocket version {} server handshake", channel, version());
266         }
267         ChannelPipeline p = channel.pipeline();
268         ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
269         if (ctx == null) {
270             // this means the user use an HttpServerCodec
271             ctx = p.context(HttpServerCodec.class);
272             if (ctx == null) {
273                 return channel.newFailedFuture(
274                         new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
275             }
276         }
277 
278         Promise<Void> promise = channel.newPromise();
279         // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit o 8192 should be more then
280         // enough for the websockets handshake payload.
281         //
282         // TODO: Make handshake work without HttpObjectAggregator at all.
283         String aggregatorName = "httpAggregator";
284         p.addAfter(ctx.name(), aggregatorName, new HttpObjectAggregator(8192));
285         p.addAfter(aggregatorName, "handshaker", new SimpleChannelInboundHandler<FullHttpRequest>() {
286             @Override
287             protected void messageReceived(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
288                 // Remove ourself and do the actual handshake
289                 ctx.pipeline().remove(this);
290                 handshake(channel, msg, responseHeaders);
291             }
292 
293             @Override
294             public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
295                 // Remove ourself and fail the handshake promise.
296                 promise.tryFailure(cause);
297                 ctx.fireChannelExceptionCaught(cause);
298                 ctx.pipeline().remove(this);
299             }
300 
301             @Override
302             public void channelInactive(ChannelHandlerContext ctx) throws Exception {
303                 // Fail promise if Channel was closed
304                 if (!promise.isDone()) {
305                     promise.tryFailure(new ClosedChannelException());
306                 }
307                 ctx.fireChannelInactive();
308             }
309         });
310         try {
311             ctx.fireChannelRead(ReferenceCountUtil.retain(req));
312         } catch (Throwable cause) {
313             promise.setFailure(cause);
314         }
315         return promise.asFuture();
316     }
317 
318     /**
319      * Returns a new {@link FullHttpResponse) which will be used for as response to the handshake request.
320      */
321     protected abstract FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,
322                                                              HttpHeaders responseHeaders);
323     /**
324      * Performs the closing handshake.
325      *
326      * When called from within a {@link ChannelHandler} you most likely want to use
327      * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}.
328      *
329      * @param channel
330      *            the {@link Channel} to use.
331      * @param frame
332      *            Closing Frame that was received.
333      */
334     public Future<Void> close(Channel channel, CloseWebSocketFrame frame) {
335         requireNonNull(channel, "channel");
336         return close0(channel, frame);
337     }
338 
339     /**
340      * Performs the closing handshake.
341      *
342      * @param ctx
343      *            the {@link ChannelHandlerContext} to use.
344      * @param frame
345      *            Closing Frame that was received.
346      */
347     public Future<Void> close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
348         requireNonNull(ctx, "ctx");
349         return close0(ctx, frame);
350     }
351 
352     private static Future<Void> close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame) {
353         return invoker.writeAndFlush(frame).addListener(invoker, ChannelFutureListeners.CLOSE);
354     }
355 
356     /**
357      * Selects the first matching supported sub protocol
358      *
359      * @param requestedSubprotocols
360      *            CSV of protocols to be supported. e.g. "chat, superchat"
361      * @return First matching supported sub protocol. Null if not found.
362      */
363     protected String selectSubprotocol(String requestedSubprotocols) {
364         if (requestedSubprotocols == null || subprotocols.length == 0) {
365             return null;
366         }
367 
368         String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
369         for (String p: requestedSubprotocolArray) {
370             String requestedSubprotocol = p.trim();
371 
372             for (String supportedSubprotocol: subprotocols) {
373                 if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol)
374                         || requestedSubprotocol.equals(supportedSubprotocol)) {
375                     selectedSubprotocol = requestedSubprotocol;
376                     return requestedSubprotocol;
377                 }
378             }
379         }
380 
381         // No match found
382         return null;
383     }
384 
385     /**
386      * Returns the selected subprotocol. Null if no subprotocol has been selected.
387      * <p>
388      * This is only available AFTER <tt>handshake()</tt> has been called.
389      * </p>
390      */
391     public String selectedSubprotocol() {
392         return selectedSubprotocol;
393     }
394 
395     /**
396      * Returns the decoder to use after handshake is complete.
397      */
398     protected abstract WebSocketFrameDecoder newWebsocketDecoder();
399 
400     /**
401      * Returns the encoder to use after the handshake is complete.
402      */
403     protected abstract WebSocketFrameEncoder newWebSocketEncoder();
404 
405 }