Enable per origin Cors configuration (#7800)

Motivation:

Finer granularity when configuring CorsHandler, enabling different policies for different origins.

Modifications:

The CorsHandler has an extra constructor that accepts a List<CorsConfig> that are evaluated sequentially when processing a Cors request

Result:

The changes don't break backwards compatibility. The extra ctor can be used to provide more than one CorsConfig object.
This commit is contained in:
Gustavo Fernandes 2018-04-11 09:06:13 +01:00 committed by Norman Maurer
parent f8ff834f03
commit 76c5f6cd03
3 changed files with 117 additions and 54 deletions

View File

@ -30,15 +30,19 @@ import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.*;
import static io.netty.util.ReferenceCountUtil.*;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import java.util.Collections;
import java.util.List;
import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.util.ReferenceCountUtil.release;
import static io.netty.util.internal.ObjectUtil.checkNonEmpty;
/**
* Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
* <p>
* This handler can be configured using a {@link CorsConfig}, please
* This handler can be configured using one or more {@link CorsConfig}, please
* refer to this class for details about the configuration options available.
*/
public class CorsHandler extends ChannelDuplexHandler {
@ -46,26 +50,43 @@ public class CorsHandler extends ChannelDuplexHandler {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class);
private static final String ANY_ORIGIN = "*";
private static final String NULL_ORIGIN = "null";
private final CorsConfig config;
private CorsConfig config;
private HttpRequest request;
private final List<CorsConfig> configList;
private boolean isShortCircuit;
/**
* Creates a new instance with the specified {@link CorsConfig}.
* Creates a new instance with a single {@link CorsConfig}.
*/
public CorsHandler(final CorsConfig config) {
this.config = checkNotNull(config, "config");
this(Collections.singletonList(config), config.isShortCircuit());
}
/**
* Creates a new instance with the specified config list. If more than one
* config matches a certain origin, the first in the List will be used.
*
* @param configList List of {@link CorsConfig}
* @param isShortCircuit Same as {@link CorsConfig#shortCircuit} but applicable to all supplied configs.
*/
public CorsHandler(final List<CorsConfig> configList, boolean isShortCircuit) {
checkNonEmpty(configList, "configList");
this.configList = configList;
this.isShortCircuit = isShortCircuit;
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (config.isCorsSupportEnabled() && msg instanceof HttpRequest) {
if (msg instanceof HttpRequest) {
request = (HttpRequest) msg;
final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
config = getForOrigin(origin);
if (isPreflightRequest(request)) {
handlePreflight(ctx, request);
return;
}
if (config.isShortCircuit() && !validateOrigin()) {
if (isShortCircuit && !(origin == null || config != null)) {
forbidden(ctx, request);
return;
}
@ -99,6 +120,21 @@ public class CorsHandler extends ChannelDuplexHandler {
response.headers().add(config.preflightResponseHeaders());
}
private CorsConfig getForOrigin(String requestOrigin) {
for (CorsConfig corsConfig : configList) {
if (corsConfig.isAnyOriginSupported()) {
return corsConfig;
}
if (corsConfig.origins().contains(requestOrigin)) {
return corsConfig;
}
if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) {
return corsConfig;
}
}
return null;
}
private boolean setOrigin(final HttpResponse response) {
final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
if (origin != null) {
@ -125,24 +161,6 @@ public class CorsHandler extends ChannelDuplexHandler {
return false;
}
private boolean validateOrigin() {
if (config.isAnyOriginSupported()) {
return true;
}
final String origin = request.headers().get(HttpHeaderNames.ORIGIN);
if (origin == null) {
// Not a CORS request so we cannot validate it. It may be a non CORS request.
return true;
}
if ("null".equals(origin) && config.isNullOriginAllowed()) {
return true;
}
return config.origins().contains(origin);
}
private void echoRequestOrigin(final HttpResponse response) {
setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN));
}
@ -198,7 +216,7 @@ public class CorsHandler extends ChannelDuplexHandler {
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise)
throws Exception {
if (config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) {
final HttpResponse response = (HttpResponse) msg;
if (setOrigin(response)) {
setAllowCredentials(response);

View File

@ -26,41 +26,26 @@ import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.AsciiString;
import io.netty.util.ReferenceCountUtil;
import org.hamcrest.core.IsEqual;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD;
import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.DATE;
import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN;
import static io.netty.handler.codec.http.HttpHeaderNames.VARY;
import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpHeaderNames.*;
import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE;
import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
import static io.netty.handler.codec.http.HttpMethod.DELETE;
import static io.netty.handler.codec.http.HttpMethod.GET;
import static io.netty.handler.codec.http.HttpMethod.OPTIONS;
import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forAnyOrigin;
import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigin;
import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigins;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static io.netty.handler.codec.http.cors.CorsConfigBuilder.*;
import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.IsEqual.equalTo;
public class CorsHandlerTest {
@ -422,6 +407,46 @@ public class CorsHandlerTest {
assertThat(channel.finish(), is(false));
}
@Test
public void differentConfigsPerOrigin() {
String host1 = "http://host1:80";
String host2 = "http://host2";
CorsConfig rule1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).build();
CorsConfig rule2 = forOrigin(host2).allowedRequestMethods(HttpMethod.GET, HttpMethod.POST)
.allowCredentials().build();
List<CorsConfig> corsConfigs = Arrays.asList(rule1, rule2);
final HttpResponse preFlightHost1 = preflightRequest(corsConfigs, host1, "", false);
assertThat(preFlightHost1.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET"));
assertThat(preFlightHost1.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(nullValue()));
final HttpResponse preFlightHost2 = preflightRequest(corsConfigs, host2, "", false);
assertValues(preFlightHost2, ACCESS_CONTROL_ALLOW_METHODS.toString(), "GET", "POST");
assertThat(preFlightHost2.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), IsEqual.equalTo("true"));
}
@Test
public void specificConfigPrecedenceOverGeneric() {
String host1 = "http://host1";
String host2 = "http://host2";
CorsConfig forHost1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).maxAge(3600L).build();
CorsConfig allowAll = forAnyOrigin().allowedRequestMethods(HttpMethod.POST, HttpMethod.GET, HttpMethod.OPTIONS)
.maxAge(1800).build();
List<CorsConfig> rules = Arrays.asList(forHost1, allowAll);
final HttpResponse host1Response = preflightRequest(rules, host1, "", false);
assertThat(host1Response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET"));
assertThat(host1Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("3600"));
final HttpResponse host2Response = preflightRequest(rules, host2, "", false);
assertValues(host2Response, ACCESS_CONTROL_ALLOW_METHODS.toString(), "POST", "GET", "OPTIONS");
assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*"));
assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("1800"));
}
private static HttpResponse simpleRequest(final CorsConfig config, final String origin) {
return simpleRequest(config, origin, null);
}
@ -451,7 +476,14 @@ public class CorsHandlerTest {
private static HttpResponse preflightRequest(final CorsConfig config,
final String origin,
final String requestHeaders) {
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
return preflightRequest(Collections.singletonList(config), origin, requestHeaders, config.isShortCircuit());
}
private static HttpResponse preflightRequest(final List<CorsConfig> configs,
final String origin,
final String requestHeaders,
final boolean isSHortCircuit) {
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(configs, isSHortCircuit));
assertThat(channel.writeInbound(optionsRequest(origin, requestHeaders, null)), is(false));
HttpResponse response = channel.readOutbound();
assertThat(channel.finish(), is(false));

View File

@ -14,6 +14,8 @@
*/
package io.netty.util.internal;
import java.util.Collection;
/**
* A grab-bag of useful utility methods.
*/
@ -88,6 +90,17 @@ public final class ObjectUtil {
return array;
}
/**
* Checks that the given argument is neither null nor empty.
* If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}.
* Otherwise, returns the argument.
*/
public static <T extends Collection<?>> T checkNonEmpty(T collection, String name) {
checkNotNull(collection, name);
checkPositive(collection.size(), name + ".size");
return collection;
}
/**
* Resolves a possibly null Integer to a primitive int, using a default value.
* @param wrapper the wrapper