CorsHandler to respect http connection (keep-alive) header.

Motivation:

The CorsHandler currently closes the channel when it responds to a preflight (OPTIONS)
request or in the event of a short circuit due to failed validation.

Especially in an environment where there's a proxy in front of the service this causes
unnecessary connection churn.

Modifications:

CorsHandler now uses HttpUtil to determine if the connection should be closed
after responding

Result:

Channel will stay open when the CorsHandler responds unless the client specifies otherwise
or the protocol version is HTTP/1.0
This commit is contained in:
William Blackie 2016-08-22 12:40:34 -04:00 committed by Norman Maurer
parent 5e148d5670
commit ecd6e5ce6d
2 changed files with 112 additions and 6 deletions

View File

@ -24,6 +24,7 @@ import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -80,7 +81,12 @@ public class CorsHandler extends ChannelDuplexHandler {
setPreflightHeaders(response); setPreflightHeaders(response);
} }
release(request); release(request);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
if (HttpUtil.isKeepAlive(request)) {
ctx.writeAndFlush(response);
} else {
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
}
} }
/** /**
@ -203,8 +209,14 @@ public class CorsHandler extends ChannelDuplexHandler {
} }
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN))
.addListener(ChannelFutureListener.CLOSE); if (HttpUtil.isKeepAlive(request)) {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN));
} else {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), FORBIDDEN))
.addListener(ChannelFutureListener.CLOSE);
}
release(request); release(request);
} }
} }

View File

@ -23,6 +23,7 @@ import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.util.AsciiString;
import org.junit.Test; import org.junit.Test;
import java.util.Arrays; import java.util.Arrays;
@ -35,10 +36,13 @@ import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_ALLOW_O
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS; import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS; import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS;
import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD; import static io.netty.handler.codec.http.HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD;
import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpHeaderNames.DATE; import static io.netty.handler.codec.http.HttpHeaderNames.DATE;
import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN; import static io.netty.handler.codec.http.HttpHeaderNames.ORIGIN;
import static io.netty.handler.codec.http.HttpHeaderNames.VARY; import static io.netty.handler.codec.http.HttpHeaderNames.VARY;
import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE;
import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE;
import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; import static io.netty.handler.codec.http.HttpHeadersTestUtils.of;
import static io.netty.handler.codec.http.HttpMethod.DELETE; import static io.netty.handler.codec.http.HttpMethod.DELETE;
import static io.netty.handler.codec.http.HttpMethod.GET; import static io.netty.handler.codec.http.HttpMethod.GET;
@ -288,17 +292,100 @@ public class CorsHandlerTest {
assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue()));
} }
@Test
public void shortCurcuitWithConnectionKeepAliveShouldStayOpen() {
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");
request.headers().set(CONNECTION, KEEP_ALIVE);
channel.writeInbound(request);
final HttpResponse response = channel.readOutbound();
assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(FORBIDDEN));
}
@Test
public void shortCurcuitWithoutConnectionShouldStayOpen() {
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");
channel.writeInbound(request);
final HttpResponse response = channel.readOutbound();
assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(FORBIDDEN));
}
@Test
public void shortCurcuitWithConnectionCloseShouldClose() {
final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = createHttpRequest(GET);
request.headers().set(ORIGIN, "http://localhost:8888");
request.headers().set(CONNECTION, CLOSE);
channel.writeInbound(request);
final HttpResponse response = channel.readOutbound();
assertThat(channel.isOpen(), is(false));
assertThat(response.status(), is(FORBIDDEN));
}
@Test @Test
public void preflightRequestShouldReleaseRequest() { public void preflightRequestShouldReleaseRequest() {
final CorsConfig config = forOrigin("http://localhost:8888") final CorsConfig config = forOrigin("http://localhost:8888")
.preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2")) .preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2"))
.build(); .build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1"); final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1", null);
channel.writeInbound(request); channel.writeInbound(request);
assertThat(request.refCnt(), is(0)); assertThat(request.refCnt(), is(0));
} }
@Test
public void preflightRequestWithConnectionKeepAliveShouldStayOpen() throws Exception {
final CorsConfig config = forOrigin("http://localhost:8888").build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", KEEP_ALIVE);
channel.writeInbound(request);
final HttpResponse response = channel.readOutbound();
assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(OK));
}
@Test
public void preflightRequestWithoutConnectionShouldStayOpen() throws Exception {
final CorsConfig config = forOrigin("http://localhost:8888").build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null);
channel.writeInbound(request);
final HttpResponse response = channel.readOutbound();
assertThat(channel.isOpen(), is(true));
assertThat(response.status(), is(OK));
}
@Test
public void preflightRequestWithConnectionCloseShouldClose() throws Exception {
final CorsConfig config = forOrigin("http://localhost:8888").build();
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
final FullHttpRequest request = optionsRequest("http://localhost:8888", "", CLOSE);
channel.writeInbound(request);
final HttpResponse response = channel.readOutbound();
assertThat(channel.isOpen(), is(false));
assertThat(response.status(), is(OK));
}
@Test @Test
public void forbiddenShouldReleaseRequest() { public void forbiddenShouldReleaseRequest() {
final CorsConfig config = forOrigin("https://localhost").shortCircuit().build(); final CorsConfig config = forOrigin("https://localhost").shortCircuit().build();
@ -339,15 +426,22 @@ public class CorsHandlerTest {
final String origin, final String origin,
final String requestHeaders) { final String requestHeaders) {
final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config));
channel.writeInbound(optionsRequest(origin, requestHeaders)); channel.writeInbound(optionsRequest(origin, requestHeaders, null));
return (HttpResponse) channel.readOutbound(); return (HttpResponse) channel.readOutbound();
} }
private static FullHttpRequest optionsRequest(final String origin, final String requestHeaders) { private static FullHttpRequest optionsRequest(final String origin,
final String requestHeaders,
final AsciiString connection) {
final FullHttpRequest httpRequest = createHttpRequest(OPTIONS); final FullHttpRequest httpRequest = createHttpRequest(OPTIONS);
httpRequest.headers().set(ORIGIN, origin); httpRequest.headers().set(ORIGIN, origin);
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString()); httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString());
httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders);
if (connection != null) {
httpRequest.headers().set(CONNECTION, connection);
}
return httpRequest; return httpRequest;
} }