1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.http.cors;
17
18 import io.netty.channel.ChannelDuplexHandler;
19 import io.netty.channel.ChannelFuture;
20 import io.netty.channel.ChannelFutureListener;
21 import io.netty.channel.ChannelHandlerContext;
22 import io.netty.channel.ChannelPromise;
23 import io.netty.handler.codec.http.DefaultFullHttpResponse;
24 import io.netty.handler.codec.http.HttpHeaders;
25 import io.netty.handler.codec.http.HttpRequest;
26 import io.netty.handler.codec.http.HttpResponse;
27 import io.netty.util.internal.logging.InternalLogger;
28 import io.netty.util.internal.logging.InternalLoggerFactory;
29
30 import static io.netty.handler.codec.http.HttpHeaders.Names.*;
31 import static io.netty.handler.codec.http.HttpMethod.*;
32 import static io.netty.handler.codec.http.HttpResponseStatus.*;
33 import static io.netty.util.ReferenceCountUtil.release;
34
35
36
37
38
39
40
41 public class CorsHandler extends ChannelDuplexHandler {
42
43 private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
44 private static final String ANY_ORIGIN = "*";
45 private final CorsConfig config;
46
47 private HttpRequest request;
48
49 public CorsHandler(final CorsConfig config) {
50 this.config = config;
51 }
52
53 @Override
54 public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
55 if (config.isCorsSupportEnabled() && msg instanceof HttpRequest) {
56 request = (HttpRequest) msg;
57 if (isPreflightRequest(request)) {
58 handlePreflight(ctx, request);
59 return;
60 }
61 if (config.isShortCurcuit() && !validateOrigin()) {
62 forbidden(ctx, request);
63 return;
64 }
65 }
66 ctx.fireChannelRead(msg);
67 }
68
69 private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
70 final HttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(), OK);
71 if (setOrigin(response)) {
72 setAllowMethods(response);
73 setAllowHeaders(response);
74 setAllowCredentials(response);
75 setMaxAge(response);
76 setPreflightHeaders(response);
77 }
78 if (!response.headers().contains(CONTENT_LENGTH)) {
79 response.headers().set(CONTENT_LENGTH, "0");
80 }
81 release(request);
82 respond(ctx, request, response);
83 }
84
85
86
87
88
89
90
91 private void setPreflightHeaders(final HttpResponse response) {
92 response.headers().add(config.preflightResponseHeaders());
93 }
94
95 private boolean setOrigin(final HttpResponse response) {
96 final String origin = request.headers().get(ORIGIN);
97 if (origin != null) {
98 if ("null".equals(origin) && config.isNullOriginAllowed()) {
99 setAnyOrigin(response);
100 return true;
101 }
102 if (config.isAnyOriginSupported()) {
103 if (config.isCredentialsAllowed()) {
104 echoRequestOrigin(response);
105 setVaryHeader(response);
106 } else {
107 setAnyOrigin(response);
108 }
109 return true;
110 }
111 if (config.origins().contains(origin)) {
112 setOrigin(response, origin);
113 setVaryHeader(response);
114 return true;
115 }
116 logger.debug("Request origin [" + origin + "] was not among the configured origins " + config.origins());
117 }
118 return false;
119 }
120
121 private boolean validateOrigin() {
122 if (config.isAnyOriginSupported()) {
123 return true;
124 }
125
126 final String origin = request.headers().get(ORIGIN);
127 if (origin == null) {
128
129 return true;
130 }
131
132 if ("null".equals(origin) && config.isNullOriginAllowed()) {
133 return true;
134 }
135
136 return config.origins().contains(origin);
137 }
138
139 private void echoRequestOrigin(final HttpResponse response) {
140 setOrigin(response, request.headers().get(ORIGIN));
141 }
142
143 private static void setVaryHeader(final HttpResponse response) {
144 response.headers().set(VARY, ORIGIN);
145 }
146
147 private static void setAnyOrigin(final HttpResponse response) {
148 setOrigin(response, ANY_ORIGIN);
149 }
150
151 private static void setOrigin(final HttpResponse response, final String origin) {
152 response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
153 }
154
155 private void setAllowCredentials(final HttpResponse response) {
156 if (config.isCredentialsAllowed()
157 && !response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
158 response.headers().set(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
159 }
160 }
161
162 private static boolean isPreflightRequest(final HttpRequest request) {
163 final HttpHeaders headers = request.headers();
164 return request.getMethod().equals(OPTIONS) &&
165 headers.contains(ORIGIN) &&
166 headers.contains(ACCESS_CONTROL_REQUEST_METHOD);
167 }
168
169 private void setExposeHeaders(final HttpResponse response) {
170 if (!config.exposedHeaders().isEmpty()) {
171 response.headers().set(ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders());
172 }
173 }
174
175 private void setAllowMethods(final HttpResponse response) {
176 response.headers().set(ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods());
177 }
178
179 private void setAllowHeaders(final HttpResponse response) {
180 response.headers().set(ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
181 }
182
183 private void setMaxAge(final HttpResponse response) {
184 response.headers().set(ACCESS_CONTROL_MAX_AGE, config.maxAge());
185 }
186
187 @Override
188 public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
189 throws Exception {
190 if (config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
191 final HttpResponse response = (HttpResponse) msg;
192 if (setOrigin(response)) {
193 setAllowCredentials(response);
194 setExposeHeaders(response);
195 }
196 }
197 ctx.writeAndFlush(msg, promise);
198 }
199
200 @Override
201 public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception {
202 logger.error("Caught error in CorsHandler", cause);
203 ctx.fireExceptionCaught(cause);
204 }
205
206 private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
207 HttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(), FORBIDDEN);
208 response.headers().set(CONTENT_LENGTH, "0");
209 release(request);
210 respond(ctx, request, response);
211 }
212
213 private static void respond(
214 final ChannelHandlerContext ctx,
215 final HttpRequest request,
216 final HttpResponse response) {
217
218 final boolean keepAlive = HttpHeaders.isKeepAlive(request);
219
220 HttpHeaders.setKeepAlive(response, keepAlive);
221
222 final ChannelFuture future = ctx.writeAndFlush(response);
223 if (!keepAlive) {
224 future.addListener(ChannelFutureListener.CLOSE);
225 }
226 }
227 }
228