make the http2 codec void promise ready.

Motivation:

Void promises need special treatment, as they don't behave like normal promises. One
cannot add a listener to a void promise for example.

Modifications:

Inspected the code for addListener() calls, and added extra logic for void promises
where necessary. Essentially, when writing a frame with a void promise, any errors
are reported via the channel pipeline's exceptionCaught event.

Result:

The HTTP/2 codec has better support for void promises.
This commit is contained in:
buchgr 2016-09-12 19:40:26 +02:00 committed by Scott Mitchell
parent e3462a79c7
commit 908464f161
5 changed files with 169 additions and 38 deletions

View File

@ -151,8 +151,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
@Override
public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId,
final Http2Headers headers, final int streamDependency, final short weight,
final boolean exclusive, final int padding, final boolean endOfStream,
final ChannelPromise promise) {
final boolean exclusive, final int padding, final boolean endOfStream, ChannelPromise promise) {
try {
Http2Stream stream = connection.stream(streamId);
if (stream == null) {
@ -176,18 +175,18 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
// for this stream.
Http2RemoteFlowController flowController = flowController();
if (!endOfStream || !flowController.hasFlowControlled(stream)) {
ChannelFuture future = frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
if (endOfStream) {
final Http2Stream finalStream = stream;
future.addListener(new ChannelFutureListener() {
final ChannelFutureListener closeStreamLocalListener = new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
lifecycleManager.closeStreamLocal(finalStream, promise);
lifecycleManager.closeStreamLocal(finalStream, future);
}
});
};
promise = promise.unvoid().addListener(closeStreamLocalListener);
}
return future;
return frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
} else {
// Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames.
flowController.addFlowControlled(stream,

View File

@ -413,6 +413,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
promise = promise.unvoid();
// Avoid NotYetConnectedException
if (!ctx.channel().isActive()) {
ctx.close(promise);
@ -629,7 +630,8 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override
public ChannelFuture resetStream(final ChannelHandlerContext ctx, int streamId, long errorCode,
final ChannelPromise promise) {
ChannelPromise promise) {
promise = promise.unvoid();
final Http2Stream stream = connection().stream(streamId);
if (stream == null) {
return resetUnknownStream(ctx, streamId, errorCode, promise);
@ -670,12 +672,21 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
final ByteBuf debugData, ChannelPromise promise) {
boolean release = false;
try {
promise = promise.unvoid();
final Http2Connection connection = connection();
if (connection.goAwaySent() && lastStreamId > connection.remote().lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
connection.remote().lastStreamKnownByPeer(), lastStreamId);
if (connection().goAwaySent()) {
// Protect against re-entrancy. Could happen if writing the frame fails, and error handling
// treating this is a connection handler and doing a graceful shutdown...
if (lastStreamId == connection().remote().lastStreamKnownByPeer()) {
return promise;
}
if (lastStreamId > connection.remote().lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
connection.remote().lastStreamKnownByPeer(), lastStreamId);
}
}
connection.goAwaySent(lastStreamId, errorCode, debugData);
// Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and

View File

@ -22,6 +22,7 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_REMOTE;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL;
import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertEquals;
@ -50,11 +51,10 @@ import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import junit.framework.AssertionFailedError;
import org.junit.Before;
@ -84,6 +84,9 @@ public class DefaultHttp2ConnectionEncoderTest {
@Mock
private Channel channel;
@Mock
private ChannelPipeline pipeline;
@Mock
private Http2FrameListener listener;
@ -111,6 +114,7 @@ public class DefaultHttp2ConnectionEncoderTest {
MockitoAnnotations.initMocks(this);
when(channel.isActive()).thenReturn(true);
when(channel.pipeline()).thenReturn(pipeline);
when(writer.configuration()).thenReturn(writerConfig);
when(writerConfig.frameSizePolicy()).thenReturn(frameSizePolicy);
when(frameSizePolicy.maxFrameSize()).thenReturn(64);
@ -237,9 +241,9 @@ public class DefaultHttp2ConnectionEncoderTest {
createStream(STREAM_ID, false);
final ByteBuf data = dummyData().retain();
ChannelPromise promise1 = newVoidPromise();
ChannelPromise promise1 = newVoidPromise(channel);
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise1);
ChannelPromise promise2 = newVoidPromise();
ChannelPromise promise2 = newVoidPromise(channel);
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise2);
// Now merge the two payloads.
@ -279,6 +283,29 @@ public class DefaultHttp2ConnectionEncoderTest {
assertEquals(0, data.refCnt());
}
@Test
public void writeHeadersUsingVoidPromise() throws Exception {
final Throwable cause = new RuntimeException("fake exception");
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
anyInt(), anyBoolean(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = invocationOnMock.getArgumentAt(8, ChannelPromise.class);
assertFalse(promise.isVoid());
return promise.setFailure(cause);
}
});
createStream(STREAM_ID, false);
// END_STREAM flag, so that a listener is added to the future.
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, newVoidPromise(channel));
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
anyInt(), anyBoolean(), any(ChannelPromise.class));
// When using a void promise, the error should be propagated via the channel pipeline.
verify(pipeline).fireExceptionCaught(cause);
}
private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception {
createStream(STREAM_ID, false);
when(frameSizePolicy.maxFrameSize()).thenReturn(5);
@ -583,27 +610,6 @@ public class DefaultHttp2ConnectionEncoderTest {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
}
private ChannelPromise newVoidPromise() {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE) {
@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
throw new AssertionFailedError();
}
@Override
public ChannelPromise addListeners(
GenericFutureListener<? extends Future<? super Void>>... listeners) {
throw new AssertionFailedError();
}
@Override
public boolean isVoid() {
return true;
}
};
}
private ChannelFuture newSucceededFuture() {
return newPromise().setSuccess();
}

View File

@ -22,8 +22,10 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.GenericFutureListener;
@ -47,6 +49,7 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@ -104,6 +107,9 @@ public class Http2ConnectionHandlerTest {
@Mock
private Channel channel;
@Mock
private ChannelPipeline pipeline;
@Mock
private ChannelFuture future;
@ -154,6 +160,7 @@ public class Http2ConnectionHandlerTest {
when(future.cause()).thenReturn(fakeException);
when(future.channel()).thenReturn(channel);
when(channel.isActive()).thenReturn(true);
when(channel.pipeline()).thenReturn(pipeline);
when(connection.remote()).thenReturn(remote);
when(remote.flowController()).thenReturn(remoteFlowController);
when(connection.local()).thenReturn(local);
@ -439,6 +446,31 @@ public class Http2ConnectionHandlerTest {
}
}
@Test
public void canSendGoAwayUsingVoidPromise() throws Exception {
handler = newHandler();
ByteBuf data = dummyData();
long errorCode = Http2Error.INTERNAL_ERROR.code();
handler = newHandler();
final Throwable cause = new RuntimeException("fake exception");
doAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
ChannelPromise promise = invocation.getArgumentAt(4, ChannelPromise.class);
assertFalse(promise.isVoid());
// This is what DefaultHttp2FrameWriter does... I hate mocking :-(.
SimpleChannelPromiseAggregator aggregatedPromise =
new SimpleChannelPromiseAggregator(promise, channel, ImmediateEventExecutor.INSTANCE);
aggregatedPromise.newPromise();
aggregatedPromise.doneAllocatingPromises();
return aggregatedPromise.setFailure(cause);
}
}).when(frameWriter).writeGoAway(
any(ChannelHandlerContext.class), anyInt(), anyInt(), any(ByteBuf.class), any(ChannelPromise.class));
handler.goAway(ctx, STREAM_ID, errorCode, data, newVoidPromise(channel));
verify(pipeline).fireExceptionCaught(cause);
}
@Test
public void channelReadCompleteTriggersFlush() throws Exception {
handler = newHandler();
@ -458,6 +490,33 @@ public class Http2ConnectionHandlerTest {
any(ChannelPromise.class));
}
@Test
public void writeRstStreamForUnkownStreamUsingVoidPromise() throws Exception {
writeRstStreamUsingVoidPromise(NON_EXISTANT_STREAM_ID);
}
@Test
public void writeRstStreamForKnownStreamUsingVoidPromise() throws Exception {
writeRstStreamUsingVoidPromise(STREAM_ID);
}
private void writeRstStreamUsingVoidPromise(int streamId) throws Exception {
handler = newHandler();
final Throwable cause = new RuntimeException("fake exception");
when(frameWriter.writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = invocationOnMock.getArgumentAt(3, ChannelPromise.class);
assertFalse(promise.isVoid());
return promise.setFailure(cause);
}
});
handler.resetStream(ctx, streamId, STREAM_CLOSED.code(), newVoidPromise(channel));
verify(frameWriter).writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class));
verify(pipeline).fireExceptionCaught(cause);
}
private static ByteBuf dummyData() {
return Unpooled.buffer().writeBytes("abcdefgh".getBytes(UTF_8));
}

View File

@ -16,9 +16,17 @@ package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import junit.framework.AssertionFailedError;
import java.util.List;
import java.util.Random;
@ -382,4 +390,52 @@ final class Http2TestUtil {
messageLatch.countDown();
}
}
static ChannelPromise newVoidPromise(final Channel channel) {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE) {
@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
throw new AssertionFailedError();
}
@Override
public ChannelPromise addListeners(
GenericFutureListener<? extends Future<? super Void>>... listeners) {
throw new AssertionFailedError();
}
@Override
public boolean isVoid() {
return true;
}
@Override
public boolean tryFailure(Throwable cause) {
channel().pipeline().fireExceptionCaught(cause);
return true;
}
@Override
public ChannelPromise setFailure(Throwable cause) {
tryFailure(cause);
return this;
}
@Override
public ChannelPromise unvoid() {
ChannelPromise promise =
new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
channel().pipeline().fireExceptionCaught(future.cause());
}
}
});
return promise;
}
};
}
}