1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty5.handler.codec.http.websocketx.extensions;
17
18 import io.netty5.channel.ChannelHandler;
19 import io.netty5.channel.ChannelHandlerContext;
20 import io.netty5.handler.codec.http.HttpHeaderNames;
21 import io.netty5.handler.codec.http.HttpHeaders;
22 import io.netty5.handler.codec.http.HttpRequest;
23 import io.netty5.handler.codec.http.HttpResponse;
24 import io.netty5.handler.codec.http.HttpResponseStatus;
25 import io.netty5.util.concurrent.Future;
26 import io.netty5.util.concurrent.FutureListener;
27
28 import java.util.ArrayList;
29 import java.util.Arrays;
30 import java.util.Iterator;
31 import java.util.List;
32
33 import static io.netty5.util.internal.ObjectUtil.checkNonEmpty;
34
35
36
37
38
39
40
41
42
43
44
45 public class WebSocketServerExtensionHandler implements ChannelHandler {
46
47 private final List<WebSocketServerExtensionHandshaker> extensionHandshakers;
48
49 private List<WebSocketServerExtension> validExtensions;
50
51
52
53
54
55
56
57
58 public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) {
59 this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers"));
60 }
61
62 @Override
63 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
64 if (msg instanceof HttpRequest) {
65 HttpRequest request = (HttpRequest) msg;
66
67 if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) {
68 String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
69
70 if (extensionsHeader != null) {
71 List<WebSocketExtensionData> extensions =
72 WebSocketExtensionUtil.extractExtensions(extensionsHeader);
73 int rsv = 0;
74
75 for (WebSocketExtensionData extensionData : extensions) {
76 Iterator<WebSocketServerExtensionHandshaker> extensionHandshakersIterator =
77 extensionHandshakers.iterator();
78 WebSocketServerExtension validExtension = null;
79
80 while (validExtension == null && extensionHandshakersIterator.hasNext()) {
81 WebSocketServerExtensionHandshaker extensionHandshaker =
82 extensionHandshakersIterator.next();
83 validExtension = extensionHandshaker.handshakeExtension(extensionData);
84 }
85
86 if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) {
87 if (validExtensions == null) {
88 validExtensions = new ArrayList<>(1);
89 }
90 rsv = rsv | validExtension.rsv();
91 validExtensions.add(validExtension);
92 }
93 }
94 }
95 }
96 }
97
98 ctx.fireChannelRead(msg);
99 }
100
101 @Override
102 public Future<Void> write(final ChannelHandlerContext ctx, Object msg) {
103 if (msg instanceof HttpResponse) {
104 HttpResponse httpResponse = (HttpResponse) msg;
105
106
107 if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(httpResponse.status())) {
108 HttpHeaders headers = httpResponse.headers();
109
110 FutureListener<Void> listener = null;
111 if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) {
112 if (validExtensions != null) {
113 String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS);
114 List<WebSocketExtensionData> extraExtensions =
115 new ArrayList<>(extensionHandshakers.size());
116 for (WebSocketServerExtension extension : validExtensions) {
117 extraExtensions.add(extension.newResponseData());
118 }
119 String newHeaderValue = WebSocketExtensionUtil
120 .computeMergeExtensionsHeaderValue(headerValue, extraExtensions);
121 listener = future -> {
122 if (future.isSuccess()) {
123 for (WebSocketServerExtension extension : validExtensions) {
124 WebSocketExtensionDecoder decoder = extension.newExtensionDecoder();
125 WebSocketExtensionEncoder encoder = extension.newExtensionEncoder();
126 String name = ctx.name();
127 ctx.pipeline()
128
129 .addAfter(name, decoder.getClass().getName(), decoder)
130 .addAfter(name, encoder.getClass().getName(), encoder);
131 }
132 }
133 };
134
135 if (newHeaderValue != null) {
136 headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue);
137 }
138 }
139 Future<Void> f = ctx.write(httpResponse);
140 if (listener != null) {
141 f.addListener(listener);
142 }
143 f.addListener(future -> {
144 if (future.isSuccess()) {
145 ctx.pipeline().remove(this);
146 }
147 });
148 return f;
149 }
150 }
151 }
152 return ctx.write(msg);
153 }
154 }