1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.cors;
17
18 import io.netty.buffer.Unpooled;
19 import io.netty.channel.ChannelDuplexHandler;
20 import io.netty.channel.ChannelFuture;
21 import io.netty.channel.ChannelFutureListener;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelPromise;
24 import io.netty.handler.codec.http.DefaultFullHttpResponse;
25 import io.netty.handler.codec.http.HttpHeaderNames;
26 import io.netty.handler.codec.http.HttpHeaderValues;
27 import io.netty.handler.codec.http.HttpHeaders;
28 import io.netty.handler.codec.http.DefaultHttpHeadersFactory;
29 import io.netty.handler.codec.http.HttpRequest;
30 import io.netty.handler.codec.http.HttpResponse;
31 import io.netty.handler.codec.http.HttpUtil;
32 import io.netty.util.internal.logging.InternalLogger;
33 import io.netty.util.internal.logging.InternalLoggerFactory;
34
35 import java.util.Collections;
36 import java.util.List;
37
38 import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
39 import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
40 import static io.netty.handler.codec.http.HttpResponseStatus.OK;
41 import static io.netty.util.ReferenceCountUtil.release;
42 import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
43 import static io.netty.util.internal.ObjectUtil.checkNotNull;
44
45
46
47
48
49
50
51 public class CorsHandler extends ChannelDuplexHandler {
52
53 private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
54 private static final String ANY_ORIGIN = "*";
55 private static final String NULL_ORIGIN = "null";
56 private CorsConfig config;
57
58 private HttpRequest request;
59 private final List<CorsConfig> configList;
60 private final boolean isShortCircuit;
61
62
63
64
65 public CorsHandler(final CorsConfig config) {
66 this(Collections.singletonList(checkNotNull(config, "config")), config.isShortCircuit());
67 }
68
69
70
71
72
73
74
75
76 public CorsHandler(final List<CorsConfig> configList, boolean isShortCircuit) {
77 checkNonEmpty(configList, "configList");
78 this.configList = configList;
79 this.isShortCircuit = isShortCircuit;
80 }
81
82 @Override
83 public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
84 if (msg instanceof HttpRequest) {
85 request = (HttpRequest) msg;
86 final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
87 config = getForOrigin(origin);
88 if (isPreflightRequest(request)) {
89 handlePreflight(ctx, request);
90 return;
91 }
92 if (isShortCircuit && !(origin == null || config != null)) {
93 forbidden(ctx, request);
94 return;
95 }
96 }
97 ctx.fireChannelRead(msg);
98 }
99
100 private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
101 final HttpResponse response = new DefaultFullHttpResponse(
102 request.protocolVersion(),
103 OK,
104 Unpooled.buffer(0),
105 DefaultHttpHeadersFactory.headersFactory().withCombiningHeaders(true),
106 DefaultHttpHeadersFactory.trailersFactory().withCombiningHeaders(true));
107 if (setOrigin(response)) {
108 setAllowMethods(response);
109 setAllowHeaders(response);
110 setAllowCredentials(response);
111 setMaxAge(response);
112 setPreflightHeaders(response);
113 setAllowPrivateNetwork(response);
114 }
115 if (!response.headers().contains(HttpHeaderNames.CONTENT_LENGTH)) {
116 response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
117 }
118 release(request);
119 respond(ctx, request, response);
120 }
121
122
123
124
125
126
127
128 private void setPreflightHeaders(final HttpResponse response) {
129 response.headers().add(config.preflightResponseHeaders());
130 }
131
132 private CorsConfig getForOrigin(String requestOrigin) {
133 for (CorsConfig corsConfig : configList) {
134 if (corsConfig.isAnyOriginSupported()) {
135 return corsConfig;
136 }
137 if (corsConfig.origins().contains(requestOrigin)) {
138 return corsConfig;
139 }
140 if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
141 return corsConfig;
142 }
143 }
144 return null;
145 }
146
147 private boolean setOrigin(final HttpResponse response) {
148 final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
149 if (origin != null && config != null) {
150 if (NULL_ORIGIN.equals(origin) && config.isNullOriginAllowed()) {
151 setNullOrigin(response);
152 return true;
153 }
154 if (config.isAnyOriginSupported()) {
155 if (config.isCredentialsAllowed()) {
156 echoRequestOrigin(response);
157 setVaryHeader(response);
158 } else {
159 setAnyOrigin(response);
160 }
161 return true;
162 }
163 if (config.origins().contains(origin)) {
164 setOrigin(response, origin);
165 setVaryHeader(response);
166 return true;
167 }
168 logger.debug("Request origin [{}]] was not among the configured origins [{}]", origin, config.origins());
169 }
170 return false;
171 }
172
173 private void echoRequestOrigin(final HttpResponse response) {
174 setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
175 }
176
177 private static void setVaryHeader(final HttpResponse response) {
178 response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
179 }
180
181 private static void setAnyOrigin(final HttpResponse response) {
182 setOrigin(response, ANY_ORIGIN);
183 }
184
185 private static void setNullOrigin(final HttpResponse response) {
186 setOrigin(response, NULL_ORIGIN);
187 }
188
189 private static void setOrigin(final HttpResponse response, final String origin) {
190 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
191 }
192
193 private void setAllowCredentials(final HttpResponse response) {
194 if (config.isCredentialsAllowed()
195 && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
196 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
197 }
198 }
199
200 private static boolean isPreflightRequest(final HttpRequest request) {
201 final HttpHeaders headers = request.headers();
202 return OPTIONS.equals(request.method()) &&
203 headers.contains(HttpHeaderNames.ORIGIN) &&
204 headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
205 }
206
207 private void setExposeHeaders(final HttpResponse response) {
208 if (!config.exposedHeaders().isEmpty()) {
209 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders());
210 }
211 }
212
213 private void setAllowMethods(final HttpResponse response) {
214 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods());
215 }
216
217 private void setAllowHeaders(final HttpResponse response) {
218 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
219 }
220
221 private void setMaxAge(final HttpResponse response) {
222 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
223 }
224
225 private void setAllowPrivateNetwork(final HttpResponse response) {
226 if (request.headers().contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK)) {
227 if (config.isPrivateNetworkAllowed()) {
228 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "true");
229 } else {
230 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "false");
231 }
232 }
233 }
234
235 @Override
236 public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
237 throws Exception {
238 if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
239 final HttpResponse response = (HttpResponse) msg;
240 if (setOrigin(response)) {
241 setAllowCredentials(response);
242 setExposeHeaders(response);
243 }
244 }
245 ctx.write(msg, promise);
246 }
247
248 private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
249 HttpResponse response = new DefaultFullHttpResponse(
250 request.protocolVersion(), FORBIDDEN, ctx.alloc().buffer(0));
251 response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
252 release(request);
253 respond(ctx, request, response);
254 }
255
256 private static void respond(
257 final ChannelHandlerContext ctx,
258 final HttpRequest request,
259 final HttpResponse response) {
260
261 final boolean keepAlive = HttpUtil.isKeepAlive(request);
262
263 HttpUtil.setKeepAlive(response, keepAlive);
264
265 final ChannelFuture future = ctx.writeAndFlush(response);
266 if (!keepAlive) {
267 future.addListener(ChannelFutureListener.CLOSE);
268 }
269 }
270 }