1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.cors;
17
18 import io.netty.channel.ChannelDuplexHandler;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.http.DefaultFullHttpResponse;
24 import io.netty.handler.codec.http.HttpHeaderNames;
25 import io.netty.handler.codec.http.HttpHeaderValues;
26 import io.netty.handler.codec.http.HttpHeaders;
27 import io.netty.handler.codec.http.HttpRequest;
28 import io.netty.handler.codec.http.HttpResponse;
29 import io.netty.handler.codec.http.HttpUtil;
30 import io.netty.util.internal.logging.InternalLogger;
31 import io.netty.util.internal.logging.InternalLoggerFactory;
32
33 import java.util.Collections;
34 import java.util.List;
35
36 import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
37 import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
38 import static io.netty.handler.codec.http.HttpResponseStatus.OK;
39 import static io.netty.util.ReferenceCountUtil.release;
40 import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
41 import static io.netty.util.internal.ObjectUtil.checkNotNull;
42
43
44
45
46
47
48
49 public class CorsHandler extends ChannelDuplexHandler {
50
51 private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
52 private static final String ANY_ORIGIN = "*";
53 private static final String NULL_ORIGIN = "null";
54 private CorsConfig config;
55
56 private HttpRequest request;
57 private final List<CorsConfig> configList;
58 private final boolean isShortCircuit;
59
60
61
62
63 public CorsHandler(final CorsConfig config) {
64 this(Collections.singletonList(checkNotNull(config, "config")), config.isShortCircuit());
65 }
66
67
68
69
70
71
72
73
74 public CorsHandler(final List<CorsConfig> configList, boolean isShortCircuit) {
75 checkNonEmpty(configList, "configList");
76 this.configList = configList;
77 this.isShortCircuit = isShortCircuit;
78 }
79
80 @Override
81 public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
82 if (msg instanceof HttpRequest) {
83 request = (HttpRequest) msg;
84 final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
85 config = getForOrigin(origin);
86 if (isPreflightRequest(request)) {
87 handlePreflight(ctx, request);
88 return;
89 }
90 if (isShortCircuit && !(origin == null || config != null)) {
91 forbidden(ctx, request);
92 return;
93 }
94 }
95 ctx.fireChannelRead(msg);
96 }
97
98 private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
99 final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), OK, true, true);
100 if (setOrigin(response)) {
101 setAllowMethods(response);
102 setAllowHeaders(response);
103 setAllowCredentials(response);
104 setMaxAge(response);
105 setPreflightHeaders(response);
106 setAllowPrivateNetwork(response);
107 }
108 if (!response.headers().contains(HttpHeaderNames.CONTENT_LENGTH)) {
109 response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
110 }
111 release(request);
112 respond(ctx, request, response);
113 }
114
115
116
117
118
119
120
121 private void setPreflightHeaders(final HttpResponse response) {
122 response.headers().add(config.preflightResponseHeaders());
123 }
124
125 private CorsConfig getForOrigin(String requestOrigin) {
126 for (CorsConfig corsConfig : configList) {
127 if (corsConfig.isAnyOriginSupported()) {
128 return corsConfig;
129 }
130 if (corsConfig.origins().contains(requestOrigin)) {
131 return corsConfig;
132 }
133 if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
134 return corsConfig;
135 }
136 }
137 return null;
138 }
139
140 private boolean setOrigin(final HttpResponse response) {
141 final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
142 if (origin != null && config != null) {
143 if (NULL_ORIGIN.equals(origin) && config.isNullOriginAllowed()) {
144 setNullOrigin(response);
145 return true;
146 }
147 if (config.isAnyOriginSupported()) {
148 if (config.isCredentialsAllowed()) {
149 echoRequestOrigin(response);
150 setVaryHeader(response);
151 } else {
152 setAnyOrigin(response);
153 }
154 return true;
155 }
156 if (config.origins().contains(origin)) {
157 setOrigin(response, origin);
158 setVaryHeader(response);
159 return true;
160 }
161 logger.debug("Request origin [{}]] was not among the configured origins [{}]", origin, config.origins());
162 }
163 return false;
164 }
165
166 private void echoRequestOrigin(final HttpResponse response) {
167 setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
168 }
169
170 private static void setVaryHeader(final HttpResponse response) {
171 response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
172 }
173
174 private static void setAnyOrigin(final HttpResponse response) {
175 setOrigin(response, ANY_ORIGIN);
176 }
177
178 private static void setNullOrigin(final HttpResponse response) {
179 setOrigin(response, NULL_ORIGIN);
180 }
181
182 private static void setOrigin(final HttpResponse response, final String origin) {
183 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
184 }
185
186 private void setAllowCredentials(final HttpResponse response) {
187 if (config.isCredentialsAllowed()
188 && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
189 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
190 }
191 }
192
193 private static boolean isPreflightRequest(final HttpRequest request) {
194 final HttpHeaders headers = request.headers();
195 return OPTIONS.equals(request.method()) &&
196 headers.contains(HttpHeaderNames.ORIGIN) &&
197 headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
198 }
199
200 private void setExposeHeaders(final HttpResponse response) {
201 if (!config.exposedHeaders().isEmpty()) {
202 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders());
203 }
204 }
205
206 private void setAllowMethods(final HttpResponse response) {
207 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods());
208 }
209
210 private void setAllowHeaders(final HttpResponse response) {
211 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
212 }
213
214 private void setMaxAge(final HttpResponse response) {
215 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
216 }
217
218 private void setAllowPrivateNetwork(final HttpResponse response) {
219 if (request.headers().contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK)) {
220 if (config.isPrivateNetworkAllowed()) {
221 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "true");
222 } else {
223 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "false");
224 }
225 }
226 }
227
228 @Override
229 public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
230 throws Exception {
231 if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
232 final HttpResponse response = (HttpResponse) msg;
233 if (setOrigin(response)) {
234 setAllowCredentials(response);
235 setExposeHeaders(response);
236 }
237 }
238 ctx.write(msg, promise);
239 }
240
241 private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
242 HttpResponse response = new DefaultFullHttpResponse(
243 request.protocolVersion(), FORBIDDEN, ctx.alloc().buffer(0));
244 response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
245 release(request);
246 respond(ctx, request, response);
247 }
248
249 private static void respond(
250 final ChannelHandlerContext ctx,
251 final HttpRequest request,
252 final HttpResponse response) {
253
254 final boolean keepAlive = HttpUtil.isKeepAlive(request);
255
256 HttpUtil.setKeepAlive(response, keepAlive);
257
258 final ChannelFuture future = ctx.writeAndFlush(response);
259 if (!keepAlive) {
260 future.addListener(ChannelFutureListener.CLOSE);
261 }
262 }
263 }