Adding short-curcuit option for CORS
Motivation: CORS request are currently processed, and potentially failed, after the target ChannelHandler(s) have been invoked. This might not be desired, for example a HTTP PUT or POST might have been performed. Modifications: Added a shortCurcuit option to CorsConfig which when set will cause a validation of the HTTP request's 'Origin' header and verify that it is valid according to the configuration. If found invalid an 403 "Forbidden" response will be returned and not further processing will take place. This is indeed no help for non browser request, like using curl, which can set the 'Origin' header. Result: Users can now configure if the 'Origin' header should be validated upfront and have the request rejected before any further processing takes place.
This commit is contained in:
parent
286e0c7e87
commit
4a1d739e0f
@ -47,6 +47,7 @@ public final class CorsConfig {
|
||||
private final Set<String> allowedRequestHeaders;
|
||||
private final boolean allowNullOrigin;
|
||||
private final Map<CharSequence, Callable<?>> preflightHeaders;
|
||||
private final boolean shortCurcuit;
|
||||
|
||||
private CorsConfig(final Builder builder) {
|
||||
origins = new LinkedHashSet<String>(builder.origins);
|
||||
@ -59,6 +60,7 @@ public final class CorsConfig {
|
||||
allowedRequestHeaders = builder.requestHeaders;
|
||||
allowNullOrigin = builder.allowNullOrigin;
|
||||
preflightHeaders = builder.preflightHeaders;
|
||||
shortCurcuit = builder.shortCurcuit;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -214,6 +216,20 @@ public final class CorsConfig {
|
||||
return preflightHeaders;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines whether a CORS request should be rejected if it's invalid before being
|
||||
* further processing.
|
||||
*
|
||||
* CORS headers are set after a request is processed. This may not always be desired
|
||||
* and this setting will check that the Origin is valid and if it is not valid no
|
||||
* further processing will take place, and a error will be returned to the calling client.
|
||||
*
|
||||
* @return {@code true} if a CORS request should short-curcuit upon receiving an invalid Origin header.
|
||||
*/
|
||||
public boolean isShortCurcuit() {
|
||||
return shortCurcuit;
|
||||
}
|
||||
|
||||
private static <T> T getValue(final Callable<T> callable) {
|
||||
try {
|
||||
return callable.call();
|
||||
@ -281,6 +297,7 @@ public final class CorsConfig {
|
||||
private final Set<String> requestHeaders = new HashSet<String>();
|
||||
private final Map<CharSequence, Callable<?>> preflightHeaders = new HashMap<CharSequence, Callable<?>>();
|
||||
private boolean noPreflightHeaders;
|
||||
private boolean shortCurcuit;
|
||||
|
||||
/**
|
||||
* Creates a new Builder instance with the origin passed in.
|
||||
@ -498,6 +515,21 @@ public final class CorsConfig {
|
||||
}
|
||||
return new CorsConfig(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Specifies that a CORS request should be rejected if it's invalid before being
|
||||
* further processing.
|
||||
*
|
||||
* CORS headers are set after a request is processed. This may not always be desired
|
||||
* and this setting will check that the Origin is valid and if it is not valid no
|
||||
* further processing will take place, and a error will be returned to the calling client.
|
||||
*
|
||||
* @return {@link Builder} to support method chaining.
|
||||
*/
|
||||
public Builder shortCurcuit() {
|
||||
shortCurcuit = true;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -19,7 +19,7 @@ import io.netty.channel.ChannelDuplexHandler;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.http.DefaultHttpResponse;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||
import io.netty.handler.codec.http.HttpHeaders;
|
||||
import io.netty.handler.codec.http.HttpRequest;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
@ -33,7 +33,7 @@ import static io.netty.handler.codec.http.HttpResponseStatus.*;
|
||||
/**
|
||||
* Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
|
||||
* <p>
|
||||
* This handler can be configured using a {@link io.netty.handler.codec.http.cors.CorsConfig}, please
|
||||
* This handler can be configured using a {@link CorsConfig}, please
|
||||
* refer to this class for details about the configuration options available.
|
||||
*/
|
||||
public class CorsHandler extends ChannelDuplexHandler {
|
||||
@ -55,12 +55,16 @@ public class CorsHandler extends ChannelDuplexHandler {
|
||||
handlePreflight(ctx, request);
|
||||
return;
|
||||
}
|
||||
if (config.isShortCurcuit() && !validateOrigin()) {
|
||||
forbidden(ctx, request);
|
||||
return;
|
||||
}
|
||||
}
|
||||
ctx.fireChannelRead(msg);
|
||||
}
|
||||
|
||||
private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
|
||||
final HttpResponse response = new DefaultHttpResponse(request.getProtocolVersion(), OK);
|
||||
final HttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(), OK);
|
||||
if (setOrigin(response)) {
|
||||
setAllowMethods(response);
|
||||
setAllowHeaders(response);
|
||||
@ -107,19 +111,37 @@ public class CorsHandler extends ChannelDuplexHandler {
|
||||
return false;
|
||||
}
|
||||
|
||||
private boolean validateOrigin() {
|
||||
if (config.isAnyOriginSupported()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
final String origin = request.headers().get(ORIGIN);
|
||||
if (origin == null) {
|
||||
// Not a CORS request so we cannot validate it. It may be a non CORS request.
|
||||
return true;
|
||||
}
|
||||
|
||||
if ("null".equals(origin) && config.isNullOriginAllowed()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return config.origins().contains(origin);
|
||||
}
|
||||
|
||||
private void echoRequestOrigin(final HttpResponse response) {
|
||||
setOrigin(response, request.headers().get(ORIGIN));
|
||||
}
|
||||
|
||||
private void setVaryHeader(final HttpResponse response) {
|
||||
private static void setVaryHeader(final HttpResponse response) {
|
||||
response.headers().set(VARY, ORIGIN);
|
||||
}
|
||||
|
||||
private void setAnyOrigin(final HttpResponse response) {
|
||||
private static void setAnyOrigin(final HttpResponse response) {
|
||||
setOrigin(response, "*");
|
||||
}
|
||||
|
||||
private void setOrigin(final HttpResponse response, final String origin) {
|
||||
private static void setOrigin(final HttpResponse response, final String origin) {
|
||||
response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
|
||||
}
|
||||
|
||||
@ -173,5 +195,10 @@ public class CorsHandler extends ChannelDuplexHandler {
|
||||
logger.error("Caught error in CorsHandler", cause);
|
||||
ctx.fireExceptionCaught(cause);
|
||||
}
|
||||
|
||||
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
|
||||
ctx.writeAndFlush(new DefaultFullHttpResponse(request.getProtocolVersion(), FORBIDDEN))
|
||||
.addListener(ChannelFutureListener.CLOSE);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,4 +123,10 @@ public class CorsConfigTest {
|
||||
withOrigin("*").preflightResponseHeader("HeaderName", new Object[]{null}).build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shortCurcuit() {
|
||||
final CorsConfig cors = withOrigin("http://localhost:8080").shortCurcuit().build();
|
||||
assertThat(cors.isShortCurcuit(), is(true));
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -19,11 +19,10 @@ import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.channel.embedded.EmbeddedChannel;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpRequest;
|
||||
import io.netty.handler.codec.http.DefaultHttpResponse;
|
||||
import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
||||
import io.netty.handler.codec.http.FullHttpRequest;
|
||||
import io.netty.handler.codec.http.HttpMethod;
|
||||
import io.netty.handler.codec.http.HttpResponse;
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -31,6 +30,8 @@ import java.util.concurrent.Callable;
|
||||
|
||||
import static io.netty.handler.codec.http.HttpHeaders.Names.*;
|
||||
import static io.netty.handler.codec.http.HttpMethod.*;
|
||||
import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN;
|
||||
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
|
||||
import static io.netty.handler.codec.http.HttpVersion.*;
|
||||
import static org.hamcrest.CoreMatchers.*;
|
||||
import static org.hamcrest.MatcherAssert.*;
|
||||
@ -82,7 +83,7 @@ public class CorsHandlerTest {
|
||||
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888"));
|
||||
assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("GET", "DELETE"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -95,7 +96,7 @@ public class CorsHandlerTest {
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888"));
|
||||
assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("OPTIONS", "GET"));
|
||||
assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_HEADERS), hasItems("content-type", "xheader1"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -104,7 +105,7 @@ public class CorsHandlerTest {
|
||||
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
|
||||
assertThat(response.headers().get(CONTENT_LENGTH), is("0"));
|
||||
assertThat(response.headers().get(DATE), is(notNullValue()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -114,7 +115,7 @@ public class CorsHandlerTest {
|
||||
.build();
|
||||
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
|
||||
assertThat(response.headers().get("CustomHeader"), equalTo("somevalue"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -124,7 +125,7 @@ public class CorsHandlerTest {
|
||||
.build();
|
||||
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
|
||||
assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -134,7 +135,7 @@ public class CorsHandlerTest {
|
||||
.build();
|
||||
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
|
||||
assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -148,7 +149,7 @@ public class CorsHandlerTest {
|
||||
}).build();
|
||||
final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1");
|
||||
assertThat(response.headers().get("GenHeader"), equalTo("generatedValue"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -178,7 +179,7 @@ public class CorsHandlerTest {
|
||||
@Test
|
||||
public void simpleRequestCustomHeaders() {
|
||||
final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("custom1", "custom2").build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777", "");
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*"));
|
||||
assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("custom1", "custom1"));
|
||||
}
|
||||
@ -186,33 +187,56 @@ public class CorsHandlerTest {
|
||||
@Test
|
||||
public void simpleRequestAllowCredentials() {
|
||||
final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777", "");
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simpleRequestDoNotAllowCredentials() {
|
||||
final CorsConfig config = CorsConfig.withAnyOrigin().build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777", "");
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() {
|
||||
final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777", "");
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777"));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString()));
|
||||
assertThat(response.headers().get(VARY), equalTo(ORIGIN));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simpleRequestExposeHeaders() {
|
||||
final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("one", "two").build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777", "");
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("one", "two"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simpleRequestShortCurcuit() {
|
||||
final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").shortCurcuit().build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.getStatus(), is(FORBIDDEN));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simpleRequestNoShortCurcuit() {
|
||||
final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").build();
|
||||
final HttpResponse response = simpleRequest(config, "http://localhost:7777");
|
||||
assertThat(response.getStatus(), is(OK));
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shortCurcuitNonCorsRequest() {
|
||||
final CorsConfig config = CorsConfig.withOrigin("https://localhost").shortCurcuit().build();
|
||||
final HttpResponse response = simpleRequest(config, null);
|
||||
assertThat(response.getStatus(), is(OK));
|
||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
|
||||
}
|
||||
|
||||
private static HttpResponse simpleRequest(final CorsConfig config, final String origin) {
|
||||
return simpleRequest(config, origin, null);
|
||||
}
|
||||
@ -259,7 +283,7 @@ public class CorsHandlerTest {
|
||||
|
||||
@Override
|
||||
public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
ctx.writeAndFlush(new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.OK));
|
||||
ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, OK));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user