diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java index b7b2b0dcc0..5f612dc704 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java @@ -47,6 +47,7 @@ public final class CorsConfig { private final Set allowedRequestHeaders; private final boolean allowNullOrigin; private final Map> preflightHeaders; + private final boolean shortCurcuit; private CorsConfig(final Builder builder) { origins = new LinkedHashSet(builder.origins); @@ -59,6 +60,7 @@ public final class CorsConfig { allowedRequestHeaders = builder.requestHeaders; allowNullOrigin = builder.allowNullOrigin; preflightHeaders = builder.preflightHeaders; + shortCurcuit = builder.shortCurcuit; } /** @@ -214,6 +216,20 @@ public final class CorsConfig { return preflightHeaders; } + /** + * Determines whether a CORS request should be rejected if it's invalid before being + * further processing. + * + * CORS headers are set after a request is processed. This may not always be desired + * and this setting will check that the Origin is valid and if it is not valid no + * further processing will take place, and a error will be returned to the calling client. + * + * @return {@code true} if a CORS request should short-curcuit upon receiving an invalid Origin header. + */ + public boolean isShortCurcuit() { + return shortCurcuit; + } + private static T getValue(final Callable callable) { try { return callable.call(); @@ -281,6 +297,7 @@ public final class CorsConfig { private final Set requestHeaders = new HashSet(); private final Map> preflightHeaders = new HashMap>(); private boolean noPreflightHeaders; + private boolean shortCurcuit; /** * Creates a new Builder instance with the origin passed in. @@ -498,6 +515,21 @@ public final class CorsConfig { } return new CorsConfig(this); } + + /** + * Specifies that a CORS request should be rejected if it's invalid before being + * further processing. + * + * CORS headers are set after a request is processed. This may not always be desired + * and this setting will check that the Origin is valid and if it is not valid no + * further processing will take place, and a error will be returned to the calling client. + * + * @return {@link Builder} to support method chaining. + */ + public Builder shortCurcuit() { + shortCurcuit = true; + return this; + } } /** diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index 9c5f716f2f..fa16ac0859 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -19,7 +19,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; @@ -33,7 +33,7 @@ import static io.netty.handler.codec.http.HttpResponseStatus.*; /** * Handles Cross Origin Resource Sharing (CORS) requests. *

- * This handler can be configured using a {@link io.netty.handler.codec.http.cors.CorsConfig}, please + * This handler can be configured using a {@link CorsConfig}, please * refer to this class for details about the configuration options available. */ public class CorsHandler extends ChannelDuplexHandler { @@ -55,12 +55,16 @@ public class CorsHandler extends ChannelDuplexHandler { handlePreflight(ctx, request); return; } + if (config.isShortCurcuit() && !validateOrigin()) { + forbidden(ctx, request); + return; + } } ctx.fireChannelRead(msg); } private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) { - final HttpResponse response = new DefaultHttpResponse(request.getProtocolVersion(), OK); + final HttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(), OK); if (setOrigin(response)) { setAllowMethods(response); setAllowHeaders(response); @@ -107,19 +111,37 @@ public class CorsHandler extends ChannelDuplexHandler { return false; } + private boolean validateOrigin() { + if (config.isAnyOriginSupported()) { + return true; + } + + final String origin = request.headers().get(ORIGIN); + if (origin == null) { + // Not a CORS request so we cannot validate it. It may be a non CORS request. + return true; + } + + if ("null".equals(origin) && config.isNullOriginAllowed()) { + return true; + } + + return config.origins().contains(origin); + } + private void echoRequestOrigin(final HttpResponse response) { setOrigin(response, request.headers().get(ORIGIN)); } - private void setVaryHeader(final HttpResponse response) { + private static void setVaryHeader(final HttpResponse response) { response.headers().set(VARY, ORIGIN); } - private void setAnyOrigin(final HttpResponse response) { + private static void setAnyOrigin(final HttpResponse response) { setOrigin(response, "*"); } - private void setOrigin(final HttpResponse response, final String origin) { + private static void setOrigin(final HttpResponse response, final String origin) { response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin); } @@ -173,5 +195,10 @@ public class CorsHandler extends ChannelDuplexHandler { logger.error("Caught error in CorsHandler", cause); ctx.fireExceptionCaught(cause); } + + private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { + ctx.writeAndFlush(new DefaultFullHttpResponse(request.getProtocolVersion(), FORBIDDEN)) + .addListener(ChannelFutureListener.CLOSE); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java index 3f091825ab..748bb7017f 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java @@ -123,4 +123,10 @@ public class CorsConfigTest { withOrigin("*").preflightResponseHeader("HeaderName", new Object[]{null}).build(); } + @Test + public void shortCurcuit() { + final CorsConfig cors = withOrigin("http://localhost:8080").shortCurcuit().build(); + assertThat(cors.isShortCurcuit(), is(true)); + } + } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java index 5f718c2122..3e20a9111e 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -19,11 +19,10 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; import org.junit.Test; import java.util.Arrays; @@ -31,6 +30,8 @@ import java.util.concurrent.Callable; import static io.netty.handler.codec.http.HttpHeaders.Names.*; import static io.netty.handler.codec.http.HttpMethod.*; +import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.*; import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.*; @@ -82,7 +83,7 @@ public class CorsHandlerTest { final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("GET", "DELETE")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -95,7 +96,7 @@ public class CorsHandlerTest { assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("OPTIONS", "GET")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_HEADERS), hasItems("content-type", "xheader1")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -104,7 +105,7 @@ public class CorsHandlerTest { final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(CONTENT_LENGTH), is("0")); assertThat(response.headers().get(DATE), is(notNullValue())); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -114,7 +115,7 @@ public class CorsHandlerTest { .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get("CustomHeader"), equalTo("somevalue")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -124,7 +125,7 @@ public class CorsHandlerTest { .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -134,7 +135,7 @@ public class CorsHandlerTest { .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -148,7 +149,7 @@ public class CorsHandlerTest { }).build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get("GenHeader"), equalTo("generatedValue")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -178,7 +179,7 @@ public class CorsHandlerTest { @Test public void simpleRequestCustomHeaders() { final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("custom1", "custom2").build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("custom1", "custom1")); } @@ -186,33 +187,56 @@ public class CorsHandlerTest { @Test public void simpleRequestAllowCredentials() { final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); } @Test public void simpleRequestDoNotAllowCredentials() { final CorsConfig config = CorsConfig.withAnyOrigin().build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); } @Test public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() { final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void simpleRequestExposeHeaders() { final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("one", "two").build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("one", "two")); } + @Test + public void simpleRequestShortCurcuit() { + final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").shortCurcuit().build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.getStatus(), is(FORBIDDEN)); + } + + @Test + public void simpleRequestNoShortCurcuit() { + final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.getStatus(), is(OK)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + } + + @Test + public void shortCurcuitNonCorsRequest() { + final CorsConfig config = CorsConfig.withOrigin("https://localhost").shortCurcuit().build(); + final HttpResponse response = simpleRequest(config, null); + assertThat(response.getStatus(), is(OK)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + } + private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { return simpleRequest(config, origin, null); } @@ -258,7 +282,7 @@ public class CorsHandlerTest { private static class EchoHandler extends SimpleChannelInboundHandler { @Override public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { - ctx.writeAndFlush(new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.OK)); + ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, OK)); } } }