From 1a22186645e7d638b729e8acc32bc1ccdc197724 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 6 May 2014 08:24:21 +0200 Subject: [PATCH] Adding short-curcuit option for CORS Motivation: CORS request are currently processed, and potentially failed, after the target ChannelHandler(s) have been invoked. This might not be desired, for example a HTTP PUT or POST might have been performed. Modifications: Added a shortCurcuit option to CorsConfig which when set will cause a validation of the HTTP request's 'Origin' header and verify that it is valid according to the configuration. If found invalid an 403 "Forbidden" response will be returned and not further processing will take place. This is indeed no help for non browser request, like using curl, which can set the 'Origin' header. Result: Users can now configure if the 'Origin' header should be validated upfront and have the request rejected before any further processing takes place. --- .../handler/codec/http/cors/CorsConfig.java | 32 +++++++++++ .../handler/codec/http/cors/CorsHandler.java | 39 +++++++++++-- .../codec/http/cors/CorsConfigTest.java | 6 ++ .../codec/http/cors/CorsHandlerTest.java | 56 +++++++++++++------ 4 files changed, 111 insertions(+), 22 deletions(-) diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java index b7b2b0dcc0..5f612dc704 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java @@ -47,6 +47,7 @@ public final class CorsConfig { private final Set allowedRequestHeaders; private final boolean allowNullOrigin; private final Map> preflightHeaders; + private final boolean shortCurcuit; private CorsConfig(final Builder builder) { origins = new LinkedHashSet(builder.origins); @@ -59,6 +60,7 @@ public final class CorsConfig { allowedRequestHeaders = builder.requestHeaders; allowNullOrigin = builder.allowNullOrigin; preflightHeaders = builder.preflightHeaders; + shortCurcuit = builder.shortCurcuit; } /** @@ -214,6 +216,20 @@ public final class CorsConfig { return preflightHeaders; } + /** + * Determines whether a CORS request should be rejected if it's invalid before being + * further processing. + * + * CORS headers are set after a request is processed. This may not always be desired + * and this setting will check that the Origin is valid and if it is not valid no + * further processing will take place, and a error will be returned to the calling client. + * + * @return {@code true} if a CORS request should short-curcuit upon receiving an invalid Origin header. + */ + public boolean isShortCurcuit() { + return shortCurcuit; + } + private static T getValue(final Callable callable) { try { return callable.call(); @@ -281,6 +297,7 @@ public final class CorsConfig { private final Set requestHeaders = new HashSet(); private final Map> preflightHeaders = new HashMap>(); private boolean noPreflightHeaders; + private boolean shortCurcuit; /** * Creates a new Builder instance with the origin passed in. @@ -498,6 +515,21 @@ public final class CorsConfig { } return new CorsConfig(this); } + + /** + * Specifies that a CORS request should be rejected if it's invalid before being + * further processing. + * + * CORS headers are set after a request is processed. This may not always be desired + * and this setting will check that the Origin is valid and if it is not valid no + * further processing will take place, and a error will be returned to the calling client. + * + * @return {@link Builder} to support method chaining. + */ + public Builder shortCurcuit() { + shortCurcuit = true; + return this; + } } /** diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index 9c5f716f2f..fa16ac0859 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -19,7 +19,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; @@ -33,7 +33,7 @@ import static io.netty.handler.codec.http.HttpResponseStatus.*; /** * Handles Cross Origin Resource Sharing (CORS) requests. *

- * This handler can be configured using a {@link io.netty.handler.codec.http.cors.CorsConfig}, please + * This handler can be configured using a {@link CorsConfig}, please * refer to this class for details about the configuration options available. */ public class CorsHandler extends ChannelDuplexHandler { @@ -55,12 +55,16 @@ public class CorsHandler extends ChannelDuplexHandler { handlePreflight(ctx, request); return; } + if (config.isShortCurcuit() && !validateOrigin()) { + forbidden(ctx, request); + return; + } } ctx.fireChannelRead(msg); } private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) { - final HttpResponse response = new DefaultHttpResponse(request.getProtocolVersion(), OK); + final HttpResponse response = new DefaultFullHttpResponse(request.getProtocolVersion(), OK); if (setOrigin(response)) { setAllowMethods(response); setAllowHeaders(response); @@ -107,19 +111,37 @@ public class CorsHandler extends ChannelDuplexHandler { return false; } + private boolean validateOrigin() { + if (config.isAnyOriginSupported()) { + return true; + } + + final String origin = request.headers().get(ORIGIN); + if (origin == null) { + // Not a CORS request so we cannot validate it. It may be a non CORS request. + return true; + } + + if ("null".equals(origin) && config.isNullOriginAllowed()) { + return true; + } + + return config.origins().contains(origin); + } + private void echoRequestOrigin(final HttpResponse response) { setOrigin(response, request.headers().get(ORIGIN)); } - private void setVaryHeader(final HttpResponse response) { + private static void setVaryHeader(final HttpResponse response) { response.headers().set(VARY, ORIGIN); } - private void setAnyOrigin(final HttpResponse response) { + private static void setAnyOrigin(final HttpResponse response) { setOrigin(response, "*"); } - private void setOrigin(final HttpResponse response, final String origin) { + private static void setOrigin(final HttpResponse response, final String origin) { response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin); } @@ -173,5 +195,10 @@ public class CorsHandler extends ChannelDuplexHandler { logger.error("Caught error in CorsHandler", cause); ctx.fireExceptionCaught(cause); } + + private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { + ctx.writeAndFlush(new DefaultFullHttpResponse(request.getProtocolVersion(), FORBIDDEN)) + .addListener(ChannelFutureListener.CLOSE); + } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java index 3f091825ab..748bb7017f 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java @@ -123,4 +123,10 @@ public class CorsConfigTest { withOrigin("*").preflightResponseHeader("HeaderName", new Object[]{null}).build(); } + @Test + public void shortCurcuit() { + final CorsConfig cors = withOrigin("http://localhost:8080").shortCurcuit().build(); + assertThat(cors.isShortCurcuit(), is(true)); + } + } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java index 5f718c2122..3e20a9111e 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -19,11 +19,10 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; import org.junit.Test; import java.util.Arrays; @@ -31,6 +30,8 @@ import java.util.concurrent.Callable; import static io.netty.handler.codec.http.HttpHeaders.Names.*; import static io.netty.handler.codec.http.HttpMethod.*; +import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.*; import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.*; @@ -82,7 +83,7 @@ public class CorsHandlerTest { final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("GET", "DELETE")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -95,7 +96,7 @@ public class CorsHandlerTest { assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_METHODS), hasItems("OPTIONS", "GET")); assertThat(response.headers().getAll(ACCESS_CONTROL_ALLOW_HEADERS), hasItems("content-type", "xheader1")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -104,7 +105,7 @@ public class CorsHandlerTest { final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get(CONTENT_LENGTH), is("0")); assertThat(response.headers().get(DATE), is(notNullValue())); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -114,7 +115,7 @@ public class CorsHandlerTest { .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get("CustomHeader"), equalTo("somevalue")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -124,7 +125,7 @@ public class CorsHandlerTest { .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -134,7 +135,7 @@ public class CorsHandlerTest { .build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().getAll("CustomHeader"), hasItems("value1", "value2")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -148,7 +149,7 @@ public class CorsHandlerTest { }).build(); final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); assertThat(response.headers().get("GenHeader"), equalTo("generatedValue")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test @@ -178,7 +179,7 @@ public class CorsHandlerTest { @Test public void simpleRequestCustomHeaders() { final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("custom1", "custom2").build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("custom1", "custom1")); } @@ -186,33 +187,56 @@ public class CorsHandlerTest { @Test public void simpleRequestAllowCredentials() { final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); } @Test public void simpleRequestDoNotAllowCredentials() { final CorsConfig config = CorsConfig.withAnyOrigin().build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); } @Test public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() { final CorsConfig config = CorsConfig.withAnyOrigin().allowCredentials().build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777")); - assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN)); } @Test public void simpleRequestExposeHeaders() { final CorsConfig config = CorsConfig.withAnyOrigin().exposeHeaders("one", "two").build(); - final HttpResponse response = simpleRequest(config, "http://localhost:7777", ""); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); assertThat(response.headers().getAll(ACCESS_CONTROL_EXPOSE_HEADERS), hasItems("one", "two")); } + @Test + public void simpleRequestShortCurcuit() { + final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").shortCurcuit().build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.getStatus(), is(FORBIDDEN)); + } + + @Test + public void simpleRequestNoShortCurcuit() { + final CorsConfig config = CorsConfig.withOrigin("http://localhost:8080").build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.getStatus(), is(OK)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + } + + @Test + public void shortCurcuitNonCorsRequest() { + final CorsConfig config = CorsConfig.withOrigin("https://localhost").shortCurcuit().build(); + final HttpResponse response = simpleRequest(config, null); + assertThat(response.getStatus(), is(OK)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + } + private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { return simpleRequest(config, origin, null); } @@ -258,7 +282,7 @@ public class CorsHandlerTest { private static class EchoHandler extends SimpleChannelInboundHandler { @Override public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { - ctx.writeAndFlush(new DefaultHttpResponse(HTTP_1_1, HttpResponseStatus.OK)); + ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, OK)); } } }