1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.cors;
17
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.ChannelDuplexHandler;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.handler.codec.http.DefaultFullHttpResponse;
25 import io.netty.handler.codec.http.HttpContent;
26 import io.netty.handler.codec.http.HttpHeaderNames;
27 import io.netty.handler.codec.http.HttpHeaderValues;
28 import io.netty.handler.codec.http.HttpHeaders;
29 import io.netty.handler.codec.http.DefaultHttpHeadersFactory;
30 import io.netty.handler.codec.http.HttpRequest;
31 import io.netty.handler.codec.http.HttpResponse;
32 import io.netty.handler.codec.http.HttpUtil;
33 import io.netty.util.ReferenceCountUtil;
34 import io.netty.util.internal.logging.InternalLogger;
35 import io.netty.util.internal.logging.InternalLoggerFactory;
36
37 import java.util.Collections;
38 import java.util.List;
39
40 import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
41 import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
42 import static io.netty.handler.codec.http.HttpResponseStatus.OK;
43 import static io.netty.util.ReferenceCountUtil.release;
44 import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
45 import static io.netty.util.internal.ObjectUtil.checkNotNull;
46
47
48
49
50
51
52
53 public class CorsHandler extends ChannelDuplexHandler {
54
55 private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
56 private static final String ANY_ORIGIN = "*";
57 private static final String NULL_ORIGIN = "null";
58 private CorsConfig config;
59
60 private HttpRequest request;
61 private final List<CorsConfig> configList;
62 private final boolean isShortCircuit;
63 private boolean consumeContent;
64
65
66
67
68 public CorsHandler(final CorsConfig config) {
69 this(Collections.singletonList(checkNotNull(config, "config")), config.isShortCircuit());
70 }
71
72
73
74
75
76
77
78
79 public CorsHandler(final List<CorsConfig> configList, boolean isShortCircuit) {
80 checkNonEmpty(configList, "configList");
81 this.configList = configList;
82 this.isShortCircuit = isShortCircuit;
83 }
84
85 @Override
86 public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
87 if (msg instanceof HttpRequest) {
88 request = (HttpRequest) msg;
89 final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
90 config = getForOrigin(origin);
91 if (isPreflightRequest(request)) {
92 handlePreflight(ctx, request);
93
94
95 consumeContent = true;
96 return;
97 }
98 if (isShortCircuit && !(origin == null || config != null)) {
99 forbidden(ctx, request);
100 consumeContent = true;
101 return;
102 }
103
104
105 consumeContent = false;
106 ctx.fireChannelRead(msg);
107 return;
108 }
109
110 if (consumeContent && (msg instanceof HttpContent)) {
111 ReferenceCountUtil.release(msg);
112 return;
113 }
114
115 ctx.fireChannelRead(msg);
116 }
117
118 private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
119 final HttpResponse response = new DefaultFullHttpResponse(
120 request.protocolVersion(),
121 OK,
122 Unpooled.buffer(0),
123 DefaultHttpHeadersFactory.headersFactory().withCombiningHeaders(true),
124 DefaultHttpHeadersFactory.trailersFactory().withCombiningHeaders(true));
125 if (setOrigin(response)) {
126 setAllowMethods(response);
127 setAllowHeaders(response);
128 setAllowCredentials(response);
129 setMaxAge(response);
130 setPreflightHeaders(response);
131 setAllowPrivateNetwork(response);
132 }
133 if (!response.headers().contains(HttpHeaderNames.CONTENT_LENGTH)) {
134 response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
135 }
136 release(request);
137 respond(ctx, request, response);
138 }
139
140
141
142
143
144
145
146 private void setPreflightHeaders(final HttpResponse response) {
147 response.headers().add(config.preflightResponseHeaders());
148 }
149
150 private CorsConfig getForOrigin(String requestOrigin) {
151 for (CorsConfig corsConfig : configList) {
152 if (corsConfig.isAnyOriginSupported()) {
153 return corsConfig;
154 }
155 if (corsConfig.origins().contains(requestOrigin)) {
156 return corsConfig;
157 }
158 if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
159 return corsConfig;
160 }
161 }
162 return null;
163 }
164
165 private boolean setOrigin(final HttpResponse response) {
166 final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
167 if (origin != null && config != null) {
168 if (NULL_ORIGIN.equals(origin) && config.isNullOriginAllowed()) {
169 setNullOrigin(response);
170 return true;
171 }
172 if (config.isAnyOriginSupported()) {
173 if (config.isCredentialsAllowed()) {
174 echoRequestOrigin(response);
175 setVaryHeader(response);
176 } else {
177 setAnyOrigin(response);
178 }
179 return true;
180 }
181 if (config.origins().contains(origin)) {
182 setOrigin(response, origin);
183 setVaryHeader(response);
184 return true;
185 }
186 logger.debug("Request origin [{}]] was not among the configured origins [{}]", origin, config.origins());
187 }
188 return false;
189 }
190
191 private void echoRequestOrigin(final HttpResponse response) {
192 setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
193 }
194
195 private static void setVaryHeader(final HttpResponse response) {
196 response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
197 }
198
199 private static void setAnyOrigin(final HttpResponse response) {
200 setOrigin(response, ANY_ORIGIN);
201 }
202
203 private static void setNullOrigin(final HttpResponse response) {
204 setOrigin(response, NULL_ORIGIN);
205 }
206
207 private static void setOrigin(final HttpResponse response, final String origin) {
208 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
209 }
210
211 private void setAllowCredentials(final HttpResponse response) {
212 if (config.isCredentialsAllowed()
213 && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
214 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
215 }
216 }
217
218 private static boolean isPreflightRequest(final HttpRequest request) {
219 final HttpHeaders headers = request.headers();
220 return OPTIONS.equals(request.method()) &&
221 headers.contains(HttpHeaderNames.ORIGIN) &&
222 headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
223 }
224
225 private void setExposeHeaders(final HttpResponse response) {
226 if (!config.exposedHeaders().isEmpty()) {
227 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders());
228 }
229 }
230
231 private void setAllowMethods(final HttpResponse response) {
232 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods());
233 }
234
235 private void setAllowHeaders(final HttpResponse response) {
236 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
237 }
238
239 private void setMaxAge(final HttpResponse response) {
240 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
241 }
242
243 private void setAllowPrivateNetwork(final HttpResponse response) {
244 if (request.headers().contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK)) {
245 if (config.isPrivateNetworkAllowed()) {
246 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "true");
247 } else {
248 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "false");
249 }
250 }
251 }
252
253 @Override
254 public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
255 throws Exception {
256 if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
257 final HttpResponse response = (HttpResponse) msg;
258 if (setOrigin(response)) {
259 setAllowCredentials(response);
260 setExposeHeaders(response);
261 }
262 }
263 ctx.write(msg, promise);
264 }
265
266 private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
267 HttpResponse response = new DefaultFullHttpResponse(
268 request.protocolVersion(), FORBIDDEN, ctx.alloc().buffer(0));
269 response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
270 release(request);
271 respond(ctx, request, response);
272 }
273
274 private static void respond(
275 final ChannelHandlerContext ctx,
276 final HttpRequest request,
277 final HttpResponse response) {
278
279 final boolean keepAlive = HttpUtil.isKeepAlive(request);
280
281 HttpUtil.setKeepAlive(response, keepAlive);
282
283 final ChannelFuture future = ctx.writeAndFlush(response);
284 if (!keepAlive) {
285 future.addListener(ChannelFutureListener.CLOSE);
286 }
287 }
288 }