diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java
index 9e7756521e..9374839726 100644
--- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java
+++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java
@@ -30,15 +30,19 @@ import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
-import static io.netty.handler.codec.http.HttpMethod.*;
-import static io.netty.handler.codec.http.HttpResponseStatus.*;
-import static io.netty.util.ReferenceCountUtil.*;
-import static io.netty.util.internal.ObjectUtil.checkNotNull;
+import java.util.Collections;
+import java.util.List;
+
+import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
+import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
+import static io.netty.handler.codec.http.HttpResponseStatus.OK;
+import static io.netty.util.ReferenceCountUtil.release;
+import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
/**
* Handles Cross Origin Resource Sharing (CORS) requests.
*
- * This handler can be configured using a {@link CorsConfig}, please
+ * This handler can be configured using one or more {@link CorsConfig}, please
* refer to this class for details about the configuration options available.
*/
public class CorsHandler extends ChannelDuplexHandler {
@@ -46,26 +50,43 @@ public class CorsHandler extends ChannelDuplexHandler {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
private static final String ANY_ORIGIN = "*";
private static final String NULL_ORIGIN = "null";
- private final CorsConfig config;
+ private CorsConfig config;
private HttpRequest request;
+ private final List configList;
+ private boolean isShortCircuit;
/**
- * Creates a new instance with the specified {@link CorsConfig}.
+ * Creates a new instance with a single {@link CorsConfig}.
*/
public CorsHandler(final CorsConfig config) {
- this.config = checkNotNull(config, "config");
+ this(Collections.singletonList(config), config.isShortCircuit());
+ }
+
+ /**
+ * Creates a new instance with the specified config list. If more than one
+ * config matches a certain origin, the first in the List will be used.
+ *
+ * @param configList List of {@link CorsConfig}
+ * @param isShortCircuit Same as {@link CorsConfig#shortCircuit} but applicable to all supplied configs.
+ */
+ public CorsHandler(final List configList, boolean isShortCircuit) {
+ checkNonEmpty(configList, "configList");
+ this.configList = configList;
+ this.isShortCircuit = isShortCircuit;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
- if (config.isCorsSupportEnabled() && msg instanceof HttpRequest) {
+ if (msg instanceof HttpRequest) {
request = (HttpRequest) msg;
+ final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
+ config = getForOrigin(origin);
if (isPreflightRequest(request)) {
handlePreflight(ctx, request);
return;
}
- if (config.isShortCircuit() && !validateOrigin()) {
+ if (isShortCircuit && !(origin == null || config != null)) {
forbidden(ctx, request);
return;
}
@@ -99,6 +120,21 @@ public class CorsHandler extends ChannelDuplexHandler {
response.headers().add(config.preflightResponseHeaders());
}
+ private CorsConfig getForOrigin(String requestOrigin) {
+ for (CorsConfig corsConfig : configList) {
+ if (corsConfig.isAnyOriginSupported()) {
+ return corsConfig;
+ }
+ if (corsConfig.origins().contains(requestOrigin)) {
+ return corsConfig;
+ }
+ if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
+ return corsConfig;
+ }
+ }
+ return null;
+ }
+
private boolean setOrigin(final HttpResponse response) {
final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
if (origin != null) {
@@ -125,24 +161,6 @@ public class CorsHandler extends ChannelDuplexHandler {
return false;
}
- private boolean validateOrigin() {
- if (config.isAnyOriginSupported()) {
- return true;
- }
-
- final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
- if (origin == null) {
- // Not a CORS request so we cannot validate it. It may be a non CORS request.
- return true;
- }
-
- if ("null".equals(origin) && config.isNullOriginAllowed()) {
- return true;
- }
-
- return config.origins().contains(origin);
- }
-
private void echoRequestOrigin(final HttpResponse response) {
setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
}
@@ -198,7 +216,7 @@ public class CorsHandler extends ChannelDuplexHandler {
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
throws Exception {
- if (config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
+ if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
final HttpResponse response = (HttpResponse) msg;
if (setOrigin(response)) {
setAllowCredentials(response);
diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java
index f3fe97efd0..884c368796 100644
--- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java
+++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java
@@ -26,41 +26,26 @@ import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.AsciiString;
import io.netty.util.ReferenceCountUtil;
+import org.hamcrest.core.IsEqual;
import org.junit.Test;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
import java.util.concurrent.Callable;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS;
-import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD;
-import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
-import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
-import static io.netty.handler.codec.http.HttpHeaderNames.DATE;
-import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN;
-import static io.netty.handler.codec.http.HttpHeaderNames.VARY;
-import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
+import static io.netty.handler.codec.http.HttpHeaderNames.*;
import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE;
+import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
-import static io.netty.handler.codec.http.HttpMethod.DELETE;
-import static io.netty.handler.codec.http.HttpMethod.GET;
-import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
+import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
-import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forAnyOrigin;
-import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigin;
-import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigins;
-import static org.hamcrest.CoreMatchers.containsString;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.CoreMatchers.notNullValue;
-import static org.hamcrest.CoreMatchers.nullValue;
+import static io.netty.handler.codec.http.cors.CorsConfigBuilder.*;
+import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.core.IsEqual.equalTo;
public class CorsHandlerTest {
@@ -422,6 +407,46 @@ public class CorsHandlerTest {
assertThat(channel.finish(), is(false));
}
+ @Test
+ public void differentConfigsPerOrigin() {
+ String host1 = "http://host1:80";
+ String host2 = "http://host2";
+ CorsConfig rule1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).build();
+ CorsConfig rule2 = forOrigin(host2).allowedRequestMethods(HttpMethod.GET, HttpMethod.POST)
+ .allowCredentials().build();
+
+ List corsConfigs = Arrays.asList(rule1, rule2);
+
+ final HttpResponse preFlightHost1 = preflightRequest(corsConfigs, host1, "", false);
+ assertThat(preFlightHost1.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET"));
+ assertThat(preFlightHost1.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(nullValue()));
+
+ final HttpResponse preFlightHost2 = preflightRequest(corsConfigs, host2, "", false);
+ assertValues(preFlightHost2, ACCESS_CONTROL_ALLOW_METHODS.toString(), "GET", "POST");
+ assertThat(preFlightHost2.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), IsEqual.equalTo("true"));
+ }
+
+ @Test
+ public void specificConfigPrecedenceOverGeneric() {
+ String host1 = "http://host1";
+ String host2 = "http://host2";
+
+ CorsConfig forHost1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).maxAge(3600L).build();
+ CorsConfig allowAll = forAnyOrigin().allowedRequestMethods(HttpMethod.POST, HttpMethod.GET, HttpMethod.OPTIONS)
+ .maxAge(1800).build();
+
+ List rules = Arrays.asList(forHost1, allowAll);
+
+ final HttpResponse host1Response = preflightRequest(rules, host1, "", false);
+ assertThat(host1Response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET"));
+ assertThat(host1Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("3600"));
+
+ final HttpResponse host2Response = preflightRequest(rules, host2, "", false);
+ assertValues(host2Response, ACCESS_CONTROL_ALLOW_METHODS.toString(), "POST", "GET", "OPTIONS");
+ assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*"));
+ assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("1800"));
+ }
+
private static HttpResponse simpleRequest(final CorsConfig config, final String origin) {
return simpleRequest(config, origin, null);
}
@@ -451,7 +476,14 @@ public class CorsHandlerTest {
private static HttpResponse preflightRequest(final CorsConfig config,
final String origin,
final String requestHeaders) {
- final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
+ return preflightRequest(Collections.singletonList(config), origin, requestHeaders, config.isShortCircuit());
+ }
+
+ private static HttpResponse preflightRequest(final List configs,
+ final String origin,
+ final String requestHeaders,
+ final boolean isSHortCircuit) {
+ final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(configs, isSHortCircuit));
assertThat(channel.writeInbound(optionsRequest(origin, requestHeaders, null)), is(false));
HttpResponse response = channel.readOutbound();
assertThat(channel.finish(), is(false));
diff --git a/common/src/main/java/io/netty/util/internal/ObjectUtil.java b/common/src/main/java/io/netty/util/internal/ObjectUtil.java
index 28c3729fb9..cbac561f32 100644
--- a/common/src/main/java/io/netty/util/internal/ObjectUtil.java
+++ b/common/src/main/java/io/netty/util/internal/ObjectUtil.java
@@ -14,6 +14,8 @@
*/
package io.netty.util.internal;
+import java.util.Collection;
+
/**
* A grab-bag of useful utility methods.
*/
@@ -88,6 +90,17 @@ public final class ObjectUtil {
return array;
}
+ /**
+ * Checks that the given argument is neither null nor empty.
+ * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}.
+ * Otherwise, returns the argument.
+ */
+ public static > T checkNonEmpty(T collection, String name) {
+ checkNotNull(collection, name);
+ checkPositive(collection.size(), name + ".size");
+ return collection;
+ }
+
/**
* Resolves a possibly null Integer to a primitive int, using a default value.
* @param wrapper the wrapper