View Javadoc
1   /*
2    * Copyright 2013 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License, version
5    * 2.0 (the "License"); you may not use this file except in compliance with the
6    * License. You may obtain a copy of the License at:
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations under
14   * the License.
15   */
16  package io.netty.handler.codec.http.cors;
17  
18  import io.netty.channel.ChannelDuplexHandler;
19  import io.netty.channel.ChannelFuture;
20  import io.netty.channel.ChannelFutureListener;
21  import io.netty.channel.ChannelHandlerContext;
22  import io.netty.channel.ChannelPromise;
23  import io.netty.handler.codec.http.DefaultFullHttpResponse;
24  import io.netty.handler.codec.http.HttpHeaderNames;
25  import io.netty.handler.codec.http.HttpHeaderValues;
26  import io.netty.handler.codec.http.HttpHeaders;
27  import io.netty.handler.codec.http.HttpRequest;
28  import io.netty.handler.codec.http.HttpResponse;
29  import io.netty.handler.codec.http.HttpUtil;
30  import io.netty.util.internal.logging.InternalLogger;
31  import io.netty.util.internal.logging.InternalLoggerFactory;
32  
33  import java.util.Collections;
34  import java.util.List;
35  
36  import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
37  import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
38  import static io.netty.handler.codec.http.HttpResponseStatus.OK;
39  import static io.netty.util.ReferenceCountUtil.release;
40  import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
41  import static io.netty.util.internal.ObjectUtil.checkNotNull;
42  
43  /**
44   * Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
45   * <p>
46   * This handler can be configured using one or more {@link CorsConfig}, please
47   * refer to this class for details about the configuration options available.
48   */
49  public class CorsHandler extends ChannelDuplexHandler {
50  
51      private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
52      private static final String ANY_ORIGIN = "*";
53      private static final String NULL_ORIGIN = "null";
54      private CorsConfig config;
55  
56      private HttpRequest request;
57      private final List<CorsConfig> configList;
58      private boolean isShortCircuit;
59  
60      /**
61       * Creates a new instance with a single {@link CorsConfig}.
62       */
63      public CorsHandler(final CorsConfig config) {
64          this(Collections.singletonList(checkNotNull(config, "config")), config.isShortCircuit());
65      }
66  
67      /**
68       * Creates a new instance with the specified config list. If more than one
69       * config matches a certain origin, the first in the List will be used.
70       *
71       * @param configList     List of {@link CorsConfig}
72       * @param isShortCircuit Same as {@link CorsConfig#shortCircuit} but applicable to all supplied configs.
73       */
74      public CorsHandler(final List<CorsConfig> configList, boolean isShortCircuit) {
75          checkNonEmpty(configList, "configList");
76          this.configList = configList;
77          this.isShortCircuit = isShortCircuit;
78      }
79  
80      @Override
81      public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
82          if (msg instanceof HttpRequest) {
83              request = (HttpRequest) msg;
84              final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
85              config = getForOrigin(origin);
86              if (isPreflightRequest(request)) {
87                  handlePreflight(ctx, request);
88                  return;
89              }
90              if (isShortCircuit && !(origin == null || config != null)) {
91                  forbidden(ctx, request);
92                  return;
93              }
94          }
95          ctx.fireChannelRead(msg);
96      }
97  
98      private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
99          final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), OK, true, true);
100         if (setOrigin(response)) {
101             setAllowMethods(response);
102             setAllowHeaders(response);
103             setAllowCredentials(response);
104             setMaxAge(response);
105             setPreflightHeaders(response);
106         }
107         if (!response.headers().contains(HttpHeaderNames.CONTENT_LENGTH)) {
108             response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
109         }
110         release(request);
111         respond(ctx, request, response);
112     }
113 
114     /**
115      * This is a non CORS specification feature which enables the setting of preflight
116      * response headers that might be required by intermediaries.
117      *
118      * @param response the HttpResponse to which the preflight response headers should be added.
119      */
120     private void setPreflightHeaders(final HttpResponse response) {
121         response.headers().add(config.preflightResponseHeaders());
122     }
123 
124     private CorsConfig getForOrigin(String requestOrigin) {
125         for (CorsConfig corsConfig : configList) {
126             if (corsConfig.isAnyOriginSupported()) {
127                 return corsConfig;
128             }
129             if (corsConfig.origins().contains(requestOrigin)) {
130                 return corsConfig;
131             }
132             if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
133                 return corsConfig;
134             }
135         }
136         return null;
137     }
138 
139     private boolean setOrigin(final HttpResponse response) {
140         final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
141         if (origin != null && config != null) {
142             if (NULL_ORIGIN.equals(origin) && config.isNullOriginAllowed()) {
143                 setNullOrigin(response);
144                 return true;
145             }
146             if (config.isAnyOriginSupported()) {
147                 if (config.isCredentialsAllowed()) {
148                     echoRequestOrigin(response);
149                     setVaryHeader(response);
150                 } else {
151                     setAnyOrigin(response);
152                 }
153                 return true;
154             }
155             if (config.origins().contains(origin)) {
156                 setOrigin(response, origin);
157                 setVaryHeader(response);
158                 return true;
159             }
160             logger.debug("Request origin [{}]] was not among the configured origins [{}]", origin, config.origins());
161         }
162         return false;
163     }
164 
165     private void echoRequestOrigin(final HttpResponse response) {
166         setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
167     }
168 
169     private static void setVaryHeader(final HttpResponse response) {
170         response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
171     }
172 
173     private static void setAnyOrigin(final HttpResponse response) {
174         setOrigin(response, ANY_ORIGIN);
175     }
176 
177     private static void setNullOrigin(final HttpResponse response) {
178         setOrigin(response, NULL_ORIGIN);
179     }
180 
181     private static void setOrigin(final HttpResponse response, final String origin) {
182         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
183     }
184 
185     private void setAllowCredentials(final HttpResponse response) {
186         if (config.isCredentialsAllowed()
187                 && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
188             response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
189         }
190     }
191 
192     private static boolean isPreflightRequest(final HttpRequest request) {
193         final HttpHeaders headers = request.headers();
194         return request.method().equals(OPTIONS) &&
195                 headers.contains(HttpHeaderNames.ORIGIN) &&
196                 headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
197     }
198 
199     private void setExposeHeaders(final HttpResponse response) {
200         if (!config.exposedHeaders().isEmpty()) {
201             response.headers().set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders());
202         }
203     }
204 
205     private void setAllowMethods(final HttpResponse response) {
206         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods());
207     }
208 
209     private void setAllowHeaders(final HttpResponse response) {
210         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
211     }
212 
213     private void setMaxAge(final HttpResponse response) {
214         response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
215     }
216 
217     @Override
218     public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
219             throws Exception {
220         if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
221             final HttpResponse response = (HttpResponse) msg;
222             if (setOrigin(response)) {
223                 setAllowCredentials(response);
224                 setExposeHeaders(response);
225             }
226         }
227         ctx.write(msg, promise);
228     }
229 
230     private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
231         HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN);
232         response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO);
233         release(request);
234         respond(ctx, request, response);
235     }
236 
237     private static void respond(
238             final ChannelHandlerContext ctx,
239             final HttpRequest request,
240             final HttpResponse response) {
241 
242         final boolean keepAlive = HttpUtil.isKeepAlive(request);
243 
244         HttpUtil.setKeepAlive(response, keepAlive);
245 
246         final ChannelFuture future = ctx.writeAndFlush(response);
247         if (!keepAlive) {
248             future.addListener(ChannelFutureListener.CLOSE);
249         }
250     }
251 }