diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index 9e7756521e..9374839726 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -30,15 +30,19 @@ import io.netty.handler.codec.http.HttpUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; -import static io.netty.handler.codec.http.HttpMethod.*; -import static io.netty.handler.codec.http.HttpResponseStatus.*; -import static io.netty.util.ReferenceCountUtil.*; -import static io.netty.util.internal.ObjectUtil.checkNotNull; +import java.util.Collections; +import java.util.List; + +import static io.netty.handler.codec.http.HttpMethod.OPTIONS; +import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static io.netty.util.ReferenceCountUtil.release; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; /** * Handles Cross Origin Resource Sharing (CORS) requests. *

- * This handler can be configured using a {@link CorsConfig}, please + * This handler can be configured using one or more {@link CorsConfig}, please * refer to this class for details about the configuration options available. */ public class CorsHandler extends ChannelDuplexHandler { @@ -46,26 +50,43 @@ public class CorsHandler extends ChannelDuplexHandler { private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class); private static final String ANY_ORIGIN = "*"; private static final String NULL_ORIGIN = "null"; - private final CorsConfig config; + private CorsConfig config; private HttpRequest request; + private final List configList; + private boolean isShortCircuit; /** - * Creates a new instance with the specified {@link CorsConfig}. + * Creates a new instance with a single {@link CorsConfig}. */ public CorsHandler(final CorsConfig config) { - this.config = checkNotNull(config, "config"); + this(Collections.singletonList(config), config.isShortCircuit()); + } + + /** + * Creates a new instance with the specified config list. If more than one + * config matches a certain origin, the first in the List will be used. + * + * @param configList List of {@link CorsConfig} + * @param isShortCircuit Same as {@link CorsConfig#shortCircuit} but applicable to all supplied configs. + */ + public CorsHandler(final List configList, boolean isShortCircuit) { + checkNonEmpty(configList, "configList"); + this.configList = configList; + this.isShortCircuit = isShortCircuit; } @Override public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { - if (config.isCorsSupportEnabled() && msg instanceof HttpRequest) { + if (msg instanceof HttpRequest) { request = (HttpRequest) msg; + final String origin = request.headers().get(HttpHeaderNames.ORIGIN); + config = getForOrigin(origin); if (isPreflightRequest(request)) { handlePreflight(ctx, request); return; } - if (config.isShortCircuit() && !validateOrigin()) { + if (isShortCircuit && !(origin == null || config != null)) { forbidden(ctx, request); return; } @@ -99,6 +120,21 @@ public class CorsHandler extends ChannelDuplexHandler { response.headers().add(config.preflightResponseHeaders()); } + private CorsConfig getForOrigin(String requestOrigin) { + for (CorsConfig corsConfig : configList) { + if (corsConfig.isAnyOriginSupported()) { + return corsConfig; + } + if (corsConfig.origins().contains(requestOrigin)) { + return corsConfig; + } + if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) { + return corsConfig; + } + } + return null; + } + private boolean setOrigin(final HttpResponse response) { final String origin = request.headers().get(HttpHeaderNames.ORIGIN); if (origin != null) { @@ -125,24 +161,6 @@ public class CorsHandler extends ChannelDuplexHandler { return false; } - private boolean validateOrigin() { - if (config.isAnyOriginSupported()) { - return true; - } - - final String origin = request.headers().get(HttpHeaderNames.ORIGIN); - if (origin == null) { - // Not a CORS request so we cannot validate it. It may be a non CORS request. - return true; - } - - if ("null".equals(origin) && config.isNullOriginAllowed()) { - return true; - } - - return config.origins().contains(origin); - } - private void echoRequestOrigin(final HttpResponse response) { setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN)); } @@ -198,7 +216,7 @@ public class CorsHandler extends ChannelDuplexHandler { @Override public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) throws Exception { - if (config.isCorsSupportEnabled() && msg instanceof HttpResponse) { + if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) { final HttpResponse response = (HttpResponse) msg; if (setOrigin(response)) { setAllowCredentials(response); diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java index f3fe97efd0..884c368796 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -26,41 +26,26 @@ import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpUtil; import io.netty.util.AsciiString; import io.netty.util.ReferenceCountUtil; +import org.hamcrest.core.IsEqual; import org.junit.Test; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.concurrent.Callable; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS; -import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD; -import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; -import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpHeaderNames.DATE; -import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN; -import static io.netty.handler.codec.http.HttpHeaderNames.VARY; -import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE; +import static io.netty.handler.codec.http.HttpHeaderNames.*; import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE; +import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; -import static io.netty.handler.codec.http.HttpMethod.DELETE; -import static io.netty.handler.codec.http.HttpMethod.GET; -import static io.netty.handler.codec.http.HttpMethod.OPTIONS; +import static io.netty.handler.codec.http.HttpMethod.*; import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; -import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forAnyOrigin; -import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigin; -import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigins; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.notNullValue; -import static org.hamcrest.CoreMatchers.nullValue; +import static io.netty.handler.codec.http.cors.CorsConfigBuilder.*; +import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; public class CorsHandlerTest { @@ -422,6 +407,46 @@ public class CorsHandlerTest { assertThat(channel.finish(), is(false)); } + @Test + public void differentConfigsPerOrigin() { + String host1 = "http://host1:80"; + String host2 = "http://host2"; + CorsConfig rule1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).build(); + CorsConfig rule2 = forOrigin(host2).allowedRequestMethods(HttpMethod.GET, HttpMethod.POST) + .allowCredentials().build(); + + List corsConfigs = Arrays.asList(rule1, rule2); + + final HttpResponse preFlightHost1 = preflightRequest(corsConfigs, host1, "", false); + assertThat(preFlightHost1.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET")); + assertThat(preFlightHost1.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(nullValue())); + + final HttpResponse preFlightHost2 = preflightRequest(corsConfigs, host2, "", false); + assertValues(preFlightHost2, ACCESS_CONTROL_ALLOW_METHODS.toString(), "GET", "POST"); + assertThat(preFlightHost2.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), IsEqual.equalTo("true")); + } + + @Test + public void specificConfigPrecedenceOverGeneric() { + String host1 = "http://host1"; + String host2 = "http://host2"; + + CorsConfig forHost1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).maxAge(3600L).build(); + CorsConfig allowAll = forAnyOrigin().allowedRequestMethods(HttpMethod.POST, HttpMethod.GET, HttpMethod.OPTIONS) + .maxAge(1800).build(); + + List rules = Arrays.asList(forHost1, allowAll); + + final HttpResponse host1Response = preflightRequest(rules, host1, "", false); + assertThat(host1Response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET")); + assertThat(host1Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("3600")); + + final HttpResponse host2Response = preflightRequest(rules, host2, "", false); + assertValues(host2Response, ACCESS_CONTROL_ALLOW_METHODS.toString(), "POST", "GET", "OPTIONS"); + assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); + assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("1800")); + } + private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { return simpleRequest(config, origin, null); } @@ -451,7 +476,14 @@ public class CorsHandlerTest { private static HttpResponse preflightRequest(final CorsConfig config, final String origin, final String requestHeaders) { - final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + return preflightRequest(Collections.singletonList(config), origin, requestHeaders, config.isShortCircuit()); + } + + private static HttpResponse preflightRequest(final List configs, + final String origin, + final String requestHeaders, + final boolean isSHortCircuit) { + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(configs, isSHortCircuit)); assertThat(channel.writeInbound(optionsRequest(origin, requestHeaders, null)), is(false)); HttpResponse response = channel.readOutbound(); assertThat(channel.finish(), is(false)); diff --git a/common/src/main/java/io/netty/util/internal/ObjectUtil.java b/common/src/main/java/io/netty/util/internal/ObjectUtil.java index 28c3729fb9..cbac561f32 100644 --- a/common/src/main/java/io/netty/util/internal/ObjectUtil.java +++ b/common/src/main/java/io/netty/util/internal/ObjectUtil.java @@ -14,6 +14,8 @@ */ package io.netty.util.internal; +import java.util.Collection; + /** * A grab-bag of useful utility methods. */ @@ -88,6 +90,17 @@ public final class ObjectUtil { return array; } + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static > T checkNonEmpty(T collection, String name) { + checkNotNull(collection, name); + checkPositive(collection.size(), name + ".size"); + return collection; + } + /** * Resolves a possibly null Integer to a primitive int, using a default value. * @param wrapper the wrapper