Selective Message Aggregation (#8793)

Motivation

Implementations of MessageAggregator (HttpObjectAggregator in particular) may wish to
selectively aggrerage requests and responses on a case-by-case basis such as for example
only POST requests or only responses of a certain content-type.

Modifications

Adding a flag to MessageAggregator that toggles between true/false depending on if aggregation
is desired for the current message or not.

Result

Fixes #8772
This commit is contained in:
Roger Kapsi 2019-02-04 03:57:54 -05:00 committed by Norman Maurer
parent 95bc819513
commit 32563bfcc1
2 changed files with 150 additions and 3 deletions

View File

@ -23,7 +23,10 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.DecoderResultProvider;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.AsciiString;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;
import org.mockito.Mockito;
@ -40,6 +43,7 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assert.assertSame;
public class HttpObjectAggregatorTest {
@ -517,4 +521,123 @@ public class HttpObjectAggregatorTest {
aggregatedRep.release();
replacedRep.release();
}
@Test
public void testSelectiveRequestAggregation() {
HttpObjectAggregator myPostAggregator = new HttpObjectAggregator(1024 * 1024) {
@Override
protected boolean isStartMessage(HttpObject msg) throws Exception {
if (msg instanceof HttpRequest) {
HttpRequest request = (HttpRequest) msg;
HttpMethod method = request.method();
if (method.equals(HttpMethod.POST)) {
return true;
}
}
return false;
}
};
EmbeddedChannel channel = new EmbeddedChannel(myPostAggregator);
try {
// Aggregate: POST
HttpRequest request1 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/");
HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8));
request1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
assertTrue(channel.writeInbound(request1, content1, LastHttpContent.EMPTY_LAST_CONTENT));
// Getting an aggregated response out
Object msg1 = channel.readInbound();
try {
assertTrue(msg1 instanceof FullHttpRequest);
} finally {
ReferenceCountUtil.release(msg1);
}
// Don't aggregate: non-POST
HttpRequest request2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/");
HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8));
request2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
try {
assertTrue(channel.writeInbound(request2, content2, LastHttpContent.EMPTY_LAST_CONTENT));
// Getting the same response objects out
assertSame(request2, channel.readInbound());
assertSame(content2, channel.readInbound());
assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound());
} finally {
ReferenceCountUtil.release(request2);
ReferenceCountUtil.release(content2);
}
assertFalse(channel.finish());
} finally {
channel.close();
}
}
@Test
public void testSelectiveResponseAggregation() {
HttpObjectAggregator myTextAggregator = new HttpObjectAggregator(1024 * 1024) {
@Override
protected boolean isStartMessage(HttpObject msg) throws Exception {
if (msg instanceof HttpResponse) {
HttpResponse response = (HttpResponse) msg;
HttpHeaders headers = response.headers();
String contentType = headers.get(HttpHeaderNames.CONTENT_TYPE);
if (AsciiString.contentEqualsIgnoreCase(contentType, HttpHeaderValues.TEXT_PLAIN)) {
return true;
}
}
return false;
}
};
EmbeddedChannel channel = new EmbeddedChannel(myTextAggregator);
try {
// Aggregate: text/plain
HttpResponse response1 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8));
response1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
assertTrue(channel.writeInbound(response1, content1, LastHttpContent.EMPTY_LAST_CONTENT));
// Getting an aggregated response out
Object msg1 = channel.readInbound();
try {
assertTrue(msg1 instanceof FullHttpResponse);
} finally {
ReferenceCountUtil.release(msg1);
}
// Don't aggregate: application/json
HttpResponse response2 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("{key: 'value'}", CharsetUtil.UTF_8));
response2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
try {
assertTrue(channel.writeInbound(response2, content2, LastHttpContent.EMPTY_LAST_CONTENT));
// Getting the same response objects out
assertSame(response2, channel.readInbound());
assertSame(content2, channel.readInbound());
assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound());
} finally {
ReferenceCountUtil.release(response2);
ReferenceCountUtil.release(content2);
}
assertFalse(channel.finish());
} finally {
channel.close();
}
}
}

View File

@ -62,6 +62,8 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
private ChannelHandlerContext ctx;
private ChannelFutureListener continueResponseWriteListener;
private boolean aggregating;
/**
* Creates a new instance.
*
@ -95,7 +97,20 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
@SuppressWarnings("unchecked")
I in = (I) msg;
return (isContentMessage(in) || isStartMessage(in)) && !isAggregated(in);
if (isAggregated(in)) {
return false;
}
// NOTE: It's tempting to make this check only if aggregating is false. There are however
// side conditions in decode(...) in respect to large messages.
if (isStartMessage(in)) {
aggregating = true;
return true;
} else if (aggregating && isContentMessage(in)) {
return true;
}
return false;
}
/**
@ -191,6 +206,8 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
@Override
protected void decode(final ChannelHandlerContext ctx, I msg, List<Object> out) throws Exception {
assert aggregating;
if (isStartMessage(msg)) {
handlingOversizedMessage = false;
if (currentMessage != null) {
@ -245,7 +262,7 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
} else {
aggregated = beginAggregation(m, EMPTY_BUFFER);
}
finishAggregation(aggregated);
finishAggregation0(aggregated);
out.add(aggregated);
return;
}
@ -300,7 +317,7 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
}
if (last) {
finishAggregation(currentMessage);
finishAggregation0(currentMessage);
// All done
out.add(currentMessage);
@ -370,6 +387,11 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
*/
protected void aggregate(O aggregated, C content) throws Exception { }
private void finishAggregation0(O aggregated) throws Exception {
aggregating = false;
finishAggregation(aggregated);
}
/**
* Invoked when the specified {@code aggregated} message is about to be passed to the next handler in the pipeline.
*/
@ -377,6 +399,7 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
handlingOversizedMessage = true;
aggregating = false;
currentMessage = null;
try {
handleOversizedMessage(ctx, oversized);
@ -440,6 +463,7 @@ public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends
currentMessage.release();
currentMessage = null;
handlingOversizedMessage = false;
aggregating = false;
}
}
}