diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java index b3bda3d004..63bffe15f0 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -16,6 +16,7 @@ package io.netty.handler.codec.http.cors; import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; @@ -24,6 +25,7 @@ import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -80,7 +82,7 @@ public class CorsHandler extends ChannelDuplexHandler { setPreflightHeaders(response); } release(request); - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + respond(ctx, request, response); } /** @@ -203,8 +205,22 @@ public class CorsHandler extends ChannelDuplexHandler { } private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { - ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN)) - .addListener(ChannelFutureListener.CLOSE); release(request); + respond(ctx, request, new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN)); + } + + private static void respond( + final ChannelHandlerContext ctx, + final HttpRequest request, + final HttpResponse response) { + + final boolean keepAlive = HttpUtil.isKeepAlive(request); + + HttpUtil.setKeepAlive(response, keepAlive); + + final ChannelFuture future = ctx.writeAndFlush(response); + if (!keepAlive) { + future.addListener(ChannelFutureListener.CLOSE); + } } } diff --git a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java index 9e170326b9..ef443728e1 100644 --- a/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java +++ b/codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -23,6 +23,9 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; import org.junit.Test; import java.util.Arrays; @@ -35,10 +38,13 @@ import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_O import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS; import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS; import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD; +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; import static io.netty.handler.codec.http.HttpHeaderNames.DATE; import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN; import static io.netty.handler.codec.http.HttpHeaderNames.VARY; +import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE; +import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static io.netty.handler.codec.http.HttpMethod.DELETE; import static io.netty.handler.codec.http.HttpMethod.GET; @@ -288,15 +294,118 @@ public class CorsHandlerTest { assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); } + @Test + public void shortCurcuitWithConnectionKeepAliveShouldStayOpen() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + request.headers().set(CONNECTION, KEEP_ALIVE); + + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void shortCurcuitWithoutConnectionShouldStayOpen() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void shortCurcuitWithConnectionCloseShouldClose() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + request.headers().set(CONNECTION, CLOSE); + + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(false)); + + assertThat(channel.isOpen(), is(false)); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + @Test public void preflightRequestShouldReleaseRequest() { final CorsConfig config = forOrigin("http://localhost:8888") .preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2")) .build(); final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); - final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1"); - channel.writeInbound(request); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1", null); + assertThat(channel.writeInbound(request), is(false)); assertThat(request.refCnt(), is(0)); + assertThat(ReferenceCountUtil.release(channel.readOutbound()), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestWithConnectionKeepAliveShouldStayOpen() throws Exception { + + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", KEEP_ALIVE); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(OK)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestWithoutConnectionShouldStayOpen() throws Exception { + + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(OK)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestWithConnectionCloseShouldClose() throws Exception { + + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", CLOSE); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(false)); + + assertThat(channel.isOpen(), is(false)); + assertThat(response.status(), is(OK)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); } @Test @@ -305,8 +414,10 @@ public class CorsHandlerTest { final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); final FullHttpRequest request = createHttpRequest(GET); request.headers().set(ORIGIN, "http://localhost:8888"); - channel.writeInbound(request); + assertThat(channel.writeInbound(request), is(false)); assertThat(request.refCnt(), is(0)); + assertThat(ReferenceCountUtil.release(channel.readOutbound()), is(true)); + assertThat(channel.finish(), is(false)); } private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { @@ -331,7 +442,7 @@ public class CorsHandlerTest { if (requestHeaders != null) { httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); } - channel.writeInbound(httpRequest); + assertThat(channel.writeInbound(httpRequest), is(false)); return (HttpResponse) channel.readOutbound(); } @@ -339,15 +450,23 @@ public class CorsHandlerTest { final String origin, final String requestHeaders) { final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); - channel.writeInbound(optionsRequest(origin, requestHeaders)); - return (HttpResponse) channel.readOutbound(); + assertThat(channel.writeInbound(optionsRequest(origin, requestHeaders, null)), is(false)); + HttpResponse response = channel.readOutbound(); + assertThat(channel.finish(), is(false)); + return response; } - private static FullHttpRequest optionsRequest(final String origin, final String requestHeaders) { + private static FullHttpRequest optionsRequest(final String origin, + final String requestHeaders, + final AsciiString connection) { final FullHttpRequest httpRequest = createHttpRequest(OPTIONS); httpRequest.headers().set(ORIGIN, origin); httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString()); httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); + if (connection != null) { + httpRequest.headers().set(CONNECTION, connection); + } + return httpRequest; }