CorsHandler to respect http connection (keep-alive) header.
Motivation: The CorsHandler currently closes the channel when it responds to a preflight (OPTIONS) request or in the event of a short circuit due to failed validation. Especially in an environment where there's a proxy in front of the service this causes unnecessary connection churn. Modifications: CorsHandler now uses HttpUtil to determine if the connection should be closed after responding Result: Channel will stay open when the CorsHandler responds unless the client specifies otherwise or the protocol version is HTTP/1.0
This commit is contained in:
parent
5e148d5670
commit
ecd6e5ce6d
@ -24,6 +24,7 @@ import io.netty.handler.codec.http.HttpHeaderNames;
|
|||||||
import io.netty.handler.codec.http.HttpHeaders;
|
import io.netty.handler.codec.http.HttpHeaders;
|
||||||
import io.netty.handler.codec.http.HttpRequest;
|
import io.netty.handler.codec.http.HttpRequest;
|
||||||
import io.netty.handler.codec.http.HttpResponse;
|
import io.netty.handler.codec.http.HttpResponse;
|
||||||
|
import io.netty.handler.codec.http.HttpUtil;
|
||||||
import io.netty.util.internal.logging.InternalLogger;
|
import io.netty.util.internal.logging.InternalLogger;
|
||||||
import io.netty.util.internal.logging.InternalLoggerFactory;
|
import io.netty.util.internal.logging.InternalLoggerFactory;
|
||||||
|
|
||||||
@ -80,7 +81,12 @@ public class CorsHandler extends ChannelDuplexHandler {
|
|||||||
setPreflightHeaders(response);
|
setPreflightHeaders(response);
|
||||||
}
|
}
|
||||||
release(request);
|
release(request);
|
||||||
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
|
|
||||||
|
if (HttpUtil.isKeepAlive(request)) {
|
||||||
|
ctx.writeAndFlush(response);
|
||||||
|
} else {
|
||||||
|
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -203,8 +209,14 @@ public class CorsHandler extends ChannelDuplexHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
|
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
|
||||||
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN))
|
|
||||||
.addListener(ChannelFutureListener.CLOSE);
|
if (HttpUtil.isKeepAlive(request)) {
|
||||||
|
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN));
|
||||||
|
} else {
|
||||||
|
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN))
|
||||||
|
.addListener(ChannelFutureListener.CLOSE);
|
||||||
|
}
|
||||||
|
|
||||||
release(request);
|
release(request);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse;
|
|||||||
import io.netty.handler.codec.http.FullHttpRequest;
|
import io.netty.handler.codec.http.FullHttpRequest;
|
||||||
import io.netty.handler.codec.http.HttpMethod;
|
import io.netty.handler.codec.http.HttpMethod;
|
||||||
import io.netty.handler.codec.http.HttpResponse;
|
import io.netty.handler.codec.http.HttpResponse;
|
||||||
|
import io.netty.util.AsciiString;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
@ -35,10 +36,13 @@ import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_O
|
|||||||
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS;
|
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS;
|
||||||
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS;
|
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS;
|
||||||
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD;
|
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD;
|
||||||
|
import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
|
||||||
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
|
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
|
||||||
import static io.netty.handler.codec.http.HttpHeaderNames.DATE;
|
import static io.netty.handler.codec.http.HttpHeaderNames.DATE;
|
||||||
import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN;
|
import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN;
|
||||||
import static io.netty.handler.codec.http.HttpHeaderNames.VARY;
|
import static io.netty.handler.codec.http.HttpHeaderNames.VARY;
|
||||||
|
import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
|
||||||
|
import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE;
|
||||||
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
|
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
|
||||||
import static io.netty.handler.codec.http.HttpMethod.DELETE;
|
import static io.netty.handler.codec.http.HttpMethod.DELETE;
|
||||||
import static io.netty.handler.codec.http.HttpMethod.GET;
|
import static io.netty.handler.codec.http.HttpMethod.GET;
|
||||||
@ -288,17 +292,100 @@ public class CorsHandlerTest {
|
|||||||
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
|
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shortCurcuitWithConnectionKeepAliveShouldStayOpen() {
|
||||||
|
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
|
final FullHttpRequest request = createHttpRequest(GET);
|
||||||
|
request.headers().set(ORIGIN, "http://localhost:8888");
|
||||||
|
request.headers().set(CONNECTION, KEEP_ALIVE);
|
||||||
|
|
||||||
|
channel.writeInbound(request);
|
||||||
|
final HttpResponse response = channel.readOutbound();
|
||||||
|
|
||||||
|
assertThat(channel.isOpen(), is(true));
|
||||||
|
assertThat(response.status(), is(FORBIDDEN));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shortCurcuitWithoutConnectionShouldStayOpen() {
|
||||||
|
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
|
final FullHttpRequest request = createHttpRequest(GET);
|
||||||
|
request.headers().set(ORIGIN, "http://localhost:8888");
|
||||||
|
|
||||||
|
channel.writeInbound(request);
|
||||||
|
final HttpResponse response = channel.readOutbound();
|
||||||
|
|
||||||
|
assertThat(channel.isOpen(), is(true));
|
||||||
|
assertThat(response.status(), is(FORBIDDEN));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shortCurcuitWithConnectionCloseShouldClose() {
|
||||||
|
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
|
final FullHttpRequest request = createHttpRequest(GET);
|
||||||
|
request.headers().set(ORIGIN, "http://localhost:8888");
|
||||||
|
request.headers().set(CONNECTION, CLOSE);
|
||||||
|
|
||||||
|
channel.writeInbound(request);
|
||||||
|
final HttpResponse response = channel.readOutbound();
|
||||||
|
|
||||||
|
assertThat(channel.isOpen(), is(false));
|
||||||
|
assertThat(response.status(), is(FORBIDDEN));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void preflightRequestShouldReleaseRequest() {
|
public void preflightRequestShouldReleaseRequest() {
|
||||||
final CorsConfig config = forOrigin("http://localhost:8888")
|
final CorsConfig config = forOrigin("http://localhost:8888")
|
||||||
.preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2"))
|
.preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2"))
|
||||||
.build();
|
.build();
|
||||||
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1");
|
final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1", null);
|
||||||
channel.writeInbound(request);
|
channel.writeInbound(request);
|
||||||
assertThat(request.refCnt(), is(0));
|
assertThat(request.refCnt(), is(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preflightRequestWithConnectionKeepAliveShouldStayOpen() throws Exception {
|
||||||
|
|
||||||
|
final CorsConfig config = forOrigin("http://localhost:8888").build();
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
|
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", KEEP_ALIVE);
|
||||||
|
channel.writeInbound(request);
|
||||||
|
final HttpResponse response = channel.readOutbound();
|
||||||
|
|
||||||
|
assertThat(channel.isOpen(), is(true));
|
||||||
|
assertThat(response.status(), is(OK));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preflightRequestWithoutConnectionShouldStayOpen() throws Exception {
|
||||||
|
|
||||||
|
final CorsConfig config = forOrigin("http://localhost:8888").build();
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
|
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null);
|
||||||
|
channel.writeInbound(request);
|
||||||
|
final HttpResponse response = channel.readOutbound();
|
||||||
|
|
||||||
|
assertThat(channel.isOpen(), is(true));
|
||||||
|
assertThat(response.status(), is(OK));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void preflightRequestWithConnectionCloseShouldClose() throws Exception {
|
||||||
|
|
||||||
|
final CorsConfig config = forOrigin("http://localhost:8888").build();
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
|
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", CLOSE);
|
||||||
|
channel.writeInbound(request);
|
||||||
|
final HttpResponse response = channel.readOutbound();
|
||||||
|
|
||||||
|
assertThat(channel.isOpen(), is(false));
|
||||||
|
assertThat(response.status(), is(OK));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void forbiddenShouldReleaseRequest() {
|
public void forbiddenShouldReleaseRequest() {
|
||||||
final CorsConfig config = forOrigin("https://localhost").shortCircuit().build();
|
final CorsConfig config = forOrigin("https://localhost").shortCircuit().build();
|
||||||
@ -339,15 +426,22 @@ public class CorsHandlerTest {
|
|||||||
final String origin,
|
final String origin,
|
||||||
final String requestHeaders) {
|
final String requestHeaders) {
|
||||||
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
|
||||||
channel.writeInbound(optionsRequest(origin, requestHeaders));
|
channel.writeInbound(optionsRequest(origin, requestHeaders, null));
|
||||||
return (HttpResponse) channel.readOutbound();
|
return (HttpResponse) channel.readOutbound();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static FullHttpRequest optionsRequest(final String origin, final String requestHeaders) {
|
private static FullHttpRequest optionsRequest(final String origin,
|
||||||
|
final String requestHeaders,
|
||||||
|
final AsciiString connection) {
|
||||||
final FullHttpRequest httpRequest = createHttpRequest(OPTIONS);
|
final FullHttpRequest httpRequest = createHttpRequest(OPTIONS);
|
||||||
httpRequest.headers().set(ORIGIN, origin);
|
httpRequest.headers().set(ORIGIN, origin);
|
||||||
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString());
|
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString());
|
||||||
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders);
|
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders);
|
||||||
|
|
||||||
|
if (connection != null) {
|
||||||
|
httpRequest.headers().set(CONNECTION, connection);
|
||||||
|
}
|
||||||
|
|
||||||
return httpRequest;
|
return httpRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user