Adding short-curcuit option for CORS

Motivation:
CORS request are currently processed, and potentially failed, after the
target ChannelHandler(s) have been invoked. This might not be desired, for
example a HTTP PUT or POST might have been performed.

Modifications:
Added a shortCurcuit option to CorsConfig which when set will
cause a validation of the HTTP request's 'Origin' header and verify that
it is valid according to the configuration. If found invalid an 403
"Forbidden" response will be returned and not further processing will
take place.

This is indeed no help for non browser request, like using curl, which
can set the 'Origin' header.

Result:
Users can now configure if the 'Origin' header should be validated
upfront and have the request rejected before any further processing
takes place.
This commit is contained in:
Daniel Bevenius 2014-05-06 08:24:21 +02:00 committed by Norman Maurer
parent 286e0c7e87
commit 4a1d739e0f
4 changed files with 111 additions and 22 deletions

View File

@ -47,6 +47,7 @@ public final class CorsConfig {
private final Set<String> allowedRequestHeaders; private final Set<String> allowedRequestHeaders;
private final boolean allowNullOrigin; private final boolean allowNullOrigin;
private final Map<CharSequence, Callable<?>> preflightHeaders; private final Map<CharSequence, Callable<?>> preflightHeaders;
private final boolean shortCurcuit;
private CorsConfig(final Builder builder) { private CorsConfig(final Builder builder) {
origins = new LinkedHashSet<String>(builder.origins); origins = new LinkedHashSet<String>(builder.origins);
@ -59,6 +60,7 @@ public final class CorsConfig {
allowedRequestHeaders = builder.requestHeaders; allowedRequestHeaders = builder.requestHeaders;
allowNullOrigin = builder.allowNullOrigin; allowNullOrigin = builder.allowNullOrigin;
preflightHeaders = builder.preflightHeaders; preflightHeaders = builder.preflightHeaders;
shortCurcuit = builder.shortCurcuit;
} }
/** /**
@ -214,6 +216,20 @@ public final class CorsConfig {
return preflightHeaders; return preflightHeaders;
} }
/**
* Determines whether a CORS request should be rejected if it's invalid before being
* further processing.
*
* CORS headers are set after a request is processed. This may not always be desired
* and this setting will check that the Origin is valid and if it is not valid no
* further processing will take place, and a error will be returned to the calling client.
*
* @return {@code true} if a CORS request should short-curcuit upon receiving an invalid Origin header.
*/
public boolean isShortCurcuit() {
return shortCurcuit;
}
private static <T> T getValue(final Callable<T> callable) { private static <T> T getValue(final Callable<T> callable) {
try { try {
return callable.call(); return callable.call();
@ -281,6 +297,7 @@ public final class CorsConfig {
private final Set<String> requestHeaders = new HashSet<String>(); private final Set<String> requestHeaders = new HashSet<String>();
private final Map<CharSequence, Callable<?>> preflightHeaders = new HashMap<CharSequence, Callable<?>>(); private final Map<CharSequence, Callable<?>> preflightHeaders = new HashMap<CharSequence, Callable<?>>();
private boolean noPreflightHeaders; private boolean noPreflightHeaders;
private boolean shortCurcuit;
/** /**
* Creates a new Builder instance with the origin passed in. * Creates a new Builder instance with the origin passed in.
@ -498,6 +515,21 @@ public final class CorsConfig {
} }
return new CorsConfig(this); return new CorsConfig(this);
} }
/**
* Specifies that a CORS request should be rejected if it's invalid before being
* further processing.
*
* CORS headers are set after a request is processed. This may not always be desired
* and this setting will check that the Origin is valid and if it is not valid no
* further processing will take place, and a error will be returned to the calling client.
*
* @return {@link Builder} to support method chaining.
*/
public Builder shortCurcuit() {
shortCurcuit = true;
return this;
}
} }
/** /**

View File

@ -19,7 +19,7 @@ import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
@ -33,7 +33,7 @@ import static io.netty.handler.codec.http.HttpResponseStatus.*;
/** /**
* Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests. * Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
* <p> * <p>
* This handler can be configured using a {@link io.netty.handler.codec.http.cors.CorsConfig}, please * This handler can be configured using a {@link CorsConfig}, please
* refer to this class for details about the configuration options available. * refer to this class for details about the configuration options available.
*/ */
public class CorsHandler extends ChannelDuplexHandler { public class CorsHandler extends ChannelDuplexHandler {
@ -55,12 +55,16 @@ public class CorsHandler extends ChannelDuplexHandler {
handlePreflight(ctx, request); handlePreflight(ctx, request);
return; return;
} }
if (config.isShortCurcuit() && !validateOrigin()) {
forbidden(ctx, request);
return;
}
} }
ctx.fireChannelRead(msg); ctx.fireChannelRead(msg);
} }
private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) { private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
final HttpResponse response = new DefaultHttpResponse(request.getProtocolVersion(), OK); final HttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(), OK);
if (setOrigin(response)) { if (setOrigin(response)) {
setAllowMethods(response); setAllowMethods(response);
setAllowHeaders(response); setAllowHeaders(response);
@ -107,19 +111,37 @@ public class CorsHandler extends ChannelDuplexHandler {
return false; return false;
} }
private boolean validateOrigin() {
if (config.isAnyOriginSupported()) {
return true;
}
final String origin = request.headers().get(ORIGIN);
if (origin == null) {
// Not a CORS request so we cannot validate it. It may be a non CORS request.
return true;
}
if ("null".equals(origin) && config.isNullOriginAllowed()) {
return true;
}
return config.origins().contains(origin);
}
private void echoRequestOrigin(final HttpResponse response) { private void echoRequestOrigin(final HttpResponse response) {
setOrigin(response, request.headers().get(ORIGIN)); setOrigin(response, request.headers().get(ORIGIN));
} }
private void setVaryHeader(final HttpResponse response) { private static void setVaryHeader(final HttpResponse response) {
response.headers().set(VARY, ORIGIN); response.headers().set(VARY, ORIGIN);
} }
private void setAnyOrigin(final HttpResponse response) { private static void setAnyOrigin(final HttpResponse response) {
setOrigin(response, "*"); setOrigin(response, "*");
} }
private void setOrigin(final HttpResponse response, final String origin) { private static void setOrigin(final HttpResponse response, final String origin) {
response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin); response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
} }
@ -173,5 +195,10 @@ public class CorsHandler extends ChannelDuplexHandler {
logger.error("Caught error in CorsHandler", cause); logger.error("Caught error in CorsHandler", cause);
ctx.fireExceptionCaught(cause); ctx.fireExceptionCaught(cause);
} }
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.getProtocolVersion(), FORBIDDEN))
.addListener(ChannelFutureListener.CLOSE);
}
} }

View File

@ -123,4 +123,10 @@ public class CorsConfigTest {
withOrigin("*").preflightResponseHeader("HeaderName", new Object[]{null}).build(); withOrigin("*").preflightResponseHeader("HeaderName", new Object[]{null}).build();
} }
@Test
public void shortCurcuit() {
final CorsConfig cors = withOrigin("http://localhost:8080").shortCurcuit().build();
assertThat(cors.isShortCurcuit(), is(true));
}
} }

View File

@ -19,11 +19,10 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.junit.Test; import org.junit.Test;
import java.util.Arrays; import java.util.Arrays;
@ -31,6 +30,8 @@ import java.util.concurrent.Callable;
import static io.netty.handler.codec.http.HttpHeaders.Names.*; import static io.netty.handler.codec.http.HttpHeaders.Names.*;
import static io.netty.handler.codec.http.HttpMethod.*; import static io.netty.handler.codec.http.HttpMethod.*;
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.*; import static io.netty.handler.codec.http.HttpVersion.*;
import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.*; import static org.hamcrest.MatcherAssert.*;
@ -82,7 +83,7 @@ public class CorsHandlerTest {
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888"));
assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("GET", "DELETE")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("GET", "DELETE"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -95,7 +96,7 @@ public class CorsHandlerTest {
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888"));
assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("OPTIONS", "GET")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("OPTIONS", "GET"));
assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_HEADERS), hasItems("content-type", "xheader1")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_HEADERS), hasItems("content-type", "xheader1"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -104,7 +105,7 @@ public class CorsHandlerTest {
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
assertThat(response.headers().get(CONTENT_LENGTH), is("0")); assertThat(response.headers().get(CONTENT_LENGTH), is("0"));
assertThat(response.headers().get(DATE), is(notNullValue())); assertThat(response.headers().get(DATE), is(notNullValue()));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -114,7 +115,7 @@ public class CorsHandlerTest {
.build(); .build();
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
assertThat(response.headers().get("CustomHeader"), equalTo("somevalue")); assertThat(response.headers().get("CustomHeader"), equalTo("somevalue"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -124,7 +125,7 @@ public class CorsHandlerTest {
.build(); .build();
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -134,7 +135,7 @@ public class CorsHandlerTest {
.build(); .build();
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -148,7 +149,7 @@ public class CorsHandlerTest {
}).build(); }).build();
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
assertThat(response.headers().get("GenHeader"), equalTo("generatedValue")); assertThat(response.headers().get("GenHeader"), equalTo("generatedValue"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
@ -178,7 +179,7 @@ public class CorsHandlerTest {
@Test @Test
public void simpleRequestCustomHeaders() { public void simpleRequestCustomHeaders() {
final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("custom1", "custom2").build(); final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("custom1", "custom2").build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*"));
assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("custom1", "custom1")); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("custom1", "custom1"));
} }
@ -186,33 +187,56 @@ public class CorsHandlerTest {
@Test @Test
public void simpleRequestAllowCredentials() { public void simpleRequestAllowCredentials() {
final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
} }
@Test @Test
public void simpleRequestDoNotAllowCredentials() { public void simpleRequestDoNotAllowCredentials() {
final CorsConfig config = CorsConfig.withAnyOrigin().build(); final CorsConfig config = CorsConfig.withAnyOrigin().build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false));
} }
@Test @Test
public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() { public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() {
final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777"));
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); assertThat(response.headers().get(VARY), equalTo(ORIGIN));
} }
@Test @Test
public void simpleRequestExposeHeaders() { public void simpleRequestExposeHeaders() {
final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("one", "two").build(); final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("one", "two").build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("one", "two")); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("one", "two"));
} }
@Test
public void simpleRequestShortCurcuit() {
final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").shortCurcuit().build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.getStatus(), is(FORBIDDEN));
}
@Test
public void simpleRequestNoShortCurcuit() {
final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").build();
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
assertThat(response.getStatus(), is(OK));
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
}
@Test
public void shortCurcuitNonCorsRequest() {
final CorsConfig config = CorsConfig.withOrigin("https://localhost").shortCurcuit().build();
final HttpResponse response = simpleRequest(config, null);
assertThat(response.getStatus(), is(OK));
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
}
private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { private static HttpResponse simpleRequest(final CorsConfig config, final String origin) {
return simpleRequest(config, origin, null); return simpleRequest(config, origin, null);
} }
@ -259,7 +283,7 @@ public class CorsHandlerTest {
@Override @Override
public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
ctx.writeAndFlush(new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.OK)); ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, OK));
} }
} }