View Javadoc
1   /*
2    * Copyright 2013 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License, version
5    * 2.0 (the "License"); you may not use this file except in compliance with the
6    * License. You may obtain a copy of the License at:
7    *
8    * https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations under
14   * the License.
15   */
16  package io.netty.handler.codec.http.cors;
17  
18  import io.netty.buffer.Unpooled;
19  import io.netty.channel.ChannelDuplexHandler;
20  import io.netty.channel.ChannelFuture;
21  import io.netty.channel.ChannelFutureListener;
22  import io.netty.channel.ChannelHandlerContext;
23  import io.netty.channel.ChannelPromise;
24  import io.netty.handler.codec.http.DefaultFullHttpResponse;
25  import io.netty.handler.codec.http.HttpHeaderNames;
26  import io.netty.handler.codec.http.HttpHeaderValues;
27  import io.netty.handler.codec.http.HttpHeaders;
28  import io.netty.handler.codec.http.DefaultHttpHeadersFactory;
29  import io.netty.handler.codec.http.HttpRequest;
30  import io.netty.handler.codec.http.HttpResponse;
31  import io.netty.handler.codec.http.HttpUtil;
32  import io.netty.util.internal.logging.InternalLogger;
33  import io.netty.util.internal.logging.InternalLoggerFactory;
34  
35  import java.util.Collections;
36  import java.util.List;
37  
38  import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
39  import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
40  import static io.netty.handler.codec.http.HttpResponseStatus.OK;
41  import static io.netty.util.ReferenceCountUtil.release;
42  import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
43  import static io.netty.util.internal.ObjectUtil.checkNotNull;
44  
45  /**
46   * Handles <a href="https://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
47   * <p>
48   * This handler can be configured using one or more {@link CorsConfig}, please
49   * refer to this class for details about the configuration options available.
50   */
51  public class CorsHandler extends ChannelDuplexHandler {
52  
53      private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
54      private static final String ANY_ORIGIN = "*";
55      private static final String NULL_ORIGIN = "null";
56      private CorsConfig config;
57  
58      private HttpRequest request;
59      private final List<CorsConfig> configList;
60      private final boolean isShortCircuit;
61  
62      /**
63       * Creates a new instance with a single {@link CorsConfig}.
64       */
65      public CorsHandler(final CorsConfig config) {
66          this(Collections.singletonList(checkNotNull(config, "config")), config.isShortCircuit());
67      }
68  
69      /**
70       * Creates a new instance with the specified config list. If more than one
71       * config matches a certain origin, the first in the List will be used.
72       *
73       * @param configList     List of {@link CorsConfig}
74       * @param isShortCircuit Same as {@link CorsConfig#isShortCircuit} but applicable to all supplied configs.
75       */
76      public CorsHandler(final List<CorsConfig> configList, boolean isShortCircuit) {
77          checkNonEmpty(configList, "configList");
78          this.configList = configList;
79          this.isShortCircuit = isShortCircuit;
80      }
81  
82      @Override
83      public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
84          if (msg instanceof HttpRequest) {
85              request = (HttpRequest) msg;
86              final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
87              config = getForOrigin(origin);
88              if (isPreflightRequest(request)) {
89                  handlePreflight(ctx, request);
90                  return;
91              }
92              if (isShortCircuit && !(origin == null || config != null)) {
93                  forbidden(ctx, request);
94                  return;
95              }
96          }
97          ctx.fireChannelRead(msg);
98      }
99  
100     private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
101         final HttpResponse response = new DefaultFullHttpResponse(
102                 request.protocolVersion(),
103                 OK,
104                 Unpooled.buffer(0),
105                 DefaultHttpHeadersFactory.headersFactory().withCombiningHeaders(true),
106                 DefaultHttpHeadersFactory.trailersFactory().withCombiningHeaders(true));
107         if (setOrigin(response)) {
108             setAllowMethods(response);
109             setAllowHeaders(response);
110             setAllowCredentials(response);
111             setMaxAge(response);
112             setPreflightHeaders(response);
113             setAllowPrivateNetwork(response);
114         }
115         if (!response.headers().contains(HttpHeaderNames.CONTENT_LENGTH)) {
116             response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
117         }
118         release(request);
119         respond(ctx, request, response);
120     }
121 
122     /**
123      * This is a non CORS specification feature which enables the setting of preflight
124      * response headers that might be required by intermediaries.
125      *
126      * @param response the HttpResponse to which the preflight response headers should be added.
127      */
128     private void setPreflightHeaders(final HttpResponse response) {
129         response.headers().add(config.preflightResponseHeaders());
130     }
131 
132     private CorsConfig getForOrigin(String requestOrigin) {
133         for (CorsConfig corsConfig : configList) {
134             if (corsConfig.isAnyOriginSupported()) {
135                 return corsConfig;
136             }
137             if (corsConfig.origins().contains(requestOrigin)) {
138                 return corsConfig;
139             }
140             if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
141                 return corsConfig;
142             }
143         }
144         return null;
145     }
146 
147     private boolean setOrigin(final HttpResponse response) {
148         final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
149         if (origin != null && config != null) {
150             if (NULL_ORIGIN.equals(origin) && config.isNullOriginAllowed()) {
151                 setNullOrigin(response);
152                 return true;
153             }
154             if (config.isAnyOriginSupported()) {
155                 if (config.isCredentialsAllowed()) {
156                     echoRequestOrigin(response);
157                     setVaryHeader(response);
158                 } else {
159                     setAnyOrigin(response);
160                 }
161                 return true;
162             }
163             if (config.origins().contains(origin)) {
164                 setOrigin(response, origin);
165                 setVaryHeader(response);
166                 return true;
167             }
168             logger.debug("Request origin [{}]] was not among the configured origins [{}]", origin, config.origins());
169         }
170         return false;
171     }
172 
173     private void echoRequestOrigin(final HttpResponse response) {
174         setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
175     }
176 
177     private static void setVaryHeader(final HttpResponse response) {
178         response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
179     }
180 
181     private static void setAnyOrigin(final HttpResponse response) {
182         setOrigin(response, ANY_ORIGIN);
183     }
184 
185     private static void setNullOrigin(final HttpResponse response) {
186         setOrigin(response, NULL_ORIGIN);
187     }
188 
189     private static void setOrigin(final HttpResponse response, final String origin) {
190         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
191     }
192 
193     private void setAllowCredentials(final HttpResponse response) {
194         if (config.isCredentialsAllowed()
195                 && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
196             response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
197         }
198     }
199 
200     private static boolean isPreflightRequest(final HttpRequest request) {
201         final HttpHeaders headers = request.headers();
202         return OPTIONS.equals(request.method()) &&
203                 headers.contains(HttpHeaderNames.ORIGIN) &&
204                 headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
205     }
206 
207     private void setExposeHeaders(final HttpResponse response) {
208         if (!config.exposedHeaders().isEmpty()) {
209             response.headers().set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders());
210         }
211     }
212 
213     private void setAllowMethods(final HttpResponse response) {
214         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods());
215     }
216 
217     private void setAllowHeaders(final HttpResponse response) {
218         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
219     }
220 
221     private void setMaxAge(final HttpResponse response) {
222         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
223     }
224 
225     private void setAllowPrivateNetwork(final HttpResponse response) {
226         if (request.headers().contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK)) {
227             if (config.isPrivateNetworkAllowed()) {
228                 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "true");
229             } else {
230                 response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "false");
231             }
232         }
233     }
234 
235     @Override
236     public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
237             throws Exception {
238         if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
239             final HttpResponse response = (HttpResponse) msg;
240             if (setOrigin(response)) {
241                 setAllowCredentials(response);
242                 setExposeHeaders(response);
243             }
244         }
245         ctx.write(msg, promise);
246     }
247 
248     private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
249         HttpResponse response = new DefaultFullHttpResponse(
250                 request.protocolVersion(), FORBIDDEN, ctx.alloc().buffer(0));
251         response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
252         release(request);
253         respond(ctx, request, response);
254     }
255 
256     private static void respond(
257             final ChannelHandlerContext ctx,
258             final HttpRequest request,
259             final HttpResponse response) {
260 
261         final boolean keepAlive = HttpUtil.isKeepAlive(request);
262 
263         HttpUtil.setKeepAlive(response, keepAlive);
264 
265         final ChannelFuture future = ctx.writeAndFlush(response);
266         if (!keepAlive) {
267             future.addListener(ChannelFutureListener.CLOSE);
268         }
269     }
270 }