make the http2 codec void promise ready.

Motivation:

Void promises need special treatment, as they don't behave like normal promises. One
cannot add a listener to a void promise for example.

Modifications:

Inspected the code for addListener() calls, and added extra logic for void promises
where necessary. Essentially, when writing a frame with a void promise, any errors
are reported via the channel pipeline's exceptionCaught event.

Result:

The HTTP/2 codec has better support for void promises.
This commit is contained in:
buchgr 2016-09-12 19:40:26 +02:00 committed by Scott Mitchell
parent e3462a79c7
commit 908464f161
5 changed files with 169 additions and 38 deletions

View File

@ -151,8 +151,7 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
@Override @Override
public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId, public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId,
final Http2Headers headers, final int streamDependency, final short weight, final Http2Headers headers, final int streamDependency, final short weight,
final boolean exclusive, final int padding, final boolean endOfStream, final boolean exclusive, final int padding, final boolean endOfStream, ChannelPromise promise) {
final ChannelPromise promise) {
try { try {
Http2Stream stream = connection.stream(streamId); Http2Stream stream = connection.stream(streamId);
if (stream == null) { if (stream == null) {
@ -176,18 +175,18 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
// for this stream. // for this stream.
Http2RemoteFlowController flowController = flowController(); Http2RemoteFlowController flowController = flowController();
if (!endOfStream || !flowController.hasFlowControlled(stream)) { if (!endOfStream || !flowController.hasFlowControlled(stream)) {
ChannelFuture future = frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
if (endOfStream) { if (endOfStream) {
final Http2Stream finalStream = stream; final Http2Stream finalStream = stream;
future.addListener(new ChannelFutureListener() { final ChannelFutureListener closeStreamLocalListener = new ChannelFutureListener() {
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
lifecycleManager.closeStreamLocal(finalStream, promise); lifecycleManager.closeStreamLocal(finalStream, future);
} }
}); };
promise = promise.unvoid().addListener(closeStreamLocalListener);
} }
return future; return frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
} else { } else {
// Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames. // Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames.
flowController.addFlowControlled(stream, flowController.addFlowControlled(stream,

View File

@ -413,6 +413,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override @Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
promise = promise.unvoid();
// Avoid NotYetConnectedException // Avoid NotYetConnectedException
if (!ctx.channel().isActive()) { if (!ctx.channel().isActive()) {
ctx.close(promise); ctx.close(promise);
@ -629,7 +630,8 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override @Override
public ChannelFuture resetStream(final ChannelHandlerContext ctx, int streamId, long errorCode, public ChannelFuture resetStream(final ChannelHandlerContext ctx, int streamId, long errorCode,
final ChannelPromise promise) { ChannelPromise promise) {
promise = promise.unvoid();
final Http2Stream stream = connection().stream(streamId); final Http2Stream stream = connection().stream(streamId);
if (stream == null) { if (stream == null) {
return resetUnknownStream(ctx, streamId, errorCode, promise); return resetUnknownStream(ctx, streamId, errorCode, promise);
@ -670,12 +672,21 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
final ByteBuf debugData, ChannelPromise promise) { final ByteBuf debugData, ChannelPromise promise) {
boolean release = false; boolean release = false;
try { try {
promise = promise.unvoid();
final Http2Connection connection = connection(); final Http2Connection connection = connection();
if (connection.goAwaySent() && lastStreamId > connection.remote().lastStreamKnownByPeer()) { if (connection().goAwaySent()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " + // Protect against re-entrancy. Could happen if writing the frame fails, and error handling
"sending multiple GOAWAY frames (was '%d', is '%d').", // treating this is a connection handler and doing a graceful shutdown...
connection.remote().lastStreamKnownByPeer(), lastStreamId); if (lastStreamId == connection().remote().lastStreamKnownByPeer()) {
return promise;
}
if (lastStreamId > connection.remote().lastStreamKnownByPeer()) {
throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " +
"sending multiple GOAWAY frames (was '%d', is '%d').",
connection.remote().lastStreamKnownByPeer(), lastStreamId);
}
} }
connection.goAwaySent(lastStreamId, errorCode, debugData); connection.goAwaySent(lastStreamId, errorCode, debugData);
// Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and // Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and

View File

@ -22,6 +22,7 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_REMOTE; import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_REMOTE;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL;
import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@ -50,11 +51,10 @@ import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled; import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor;
import junit.framework.AssertionFailedError; import junit.framework.AssertionFailedError;
import org.junit.Before; import org.junit.Before;
@ -84,6 +84,9 @@ public class DefaultHttp2ConnectionEncoderTest {
@Mock @Mock
private Channel channel; private Channel channel;
@Mock
private ChannelPipeline pipeline;
@Mock @Mock
private Http2FrameListener listener; private Http2FrameListener listener;
@ -111,6 +114,7 @@ public class DefaultHttp2ConnectionEncoderTest {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(channel.isActive()).thenReturn(true); when(channel.isActive()).thenReturn(true);
when(channel.pipeline()).thenReturn(pipeline);
when(writer.configuration()).thenReturn(writerConfig); when(writer.configuration()).thenReturn(writerConfig);
when(writerConfig.frameSizePolicy()).thenReturn(frameSizePolicy); when(writerConfig.frameSizePolicy()).thenReturn(frameSizePolicy);
when(frameSizePolicy.maxFrameSize()).thenReturn(64); when(frameSizePolicy.maxFrameSize()).thenReturn(64);
@ -237,9 +241,9 @@ public class DefaultHttp2ConnectionEncoderTest {
createStream(STREAM_ID, false); createStream(STREAM_ID, false);
final ByteBuf data = dummyData().retain(); final ByteBuf data = dummyData().retain();
ChannelPromise promise1 = newVoidPromise(); ChannelPromise promise1 = newVoidPromise(channel);
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise1); encoder.writeData(ctx, STREAM_ID, data, 0, true, promise1);
ChannelPromise promise2 = newVoidPromise(); ChannelPromise promise2 = newVoidPromise(channel);
encoder.writeData(ctx, STREAM_ID, data, 0, true, promise2); encoder.writeData(ctx, STREAM_ID, data, 0, true, promise2);
// Now merge the two payloads. // Now merge the two payloads.
@ -279,6 +283,29 @@ public class DefaultHttp2ConnectionEncoderTest {
assertEquals(0, data.refCnt()); assertEquals(0, data.refCnt());
} }
@Test
public void writeHeadersUsingVoidPromise() throws Exception {
final Throwable cause = new RuntimeException("fake exception");
when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
anyInt(), anyBoolean(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = invocationOnMock.getArgumentAt(8, ChannelPromise.class);
assertFalse(promise.isVoid());
return promise.setFailure(cause);
}
});
createStream(STREAM_ID, false);
// END_STREAM flag, so that a listener is added to the future.
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, newVoidPromise(channel));
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(),
anyInt(), anyBoolean(), any(ChannelPromise.class));
// When using a void promise, the error should be propagated via the channel pipeline.
verify(pipeline).fireExceptionCaught(cause);
}
private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception { private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception {
createStream(STREAM_ID, false); createStream(STREAM_ID, false);
when(frameSizePolicy.maxFrameSize()).thenReturn(5); when(frameSizePolicy.maxFrameSize()).thenReturn(5);
@ -583,27 +610,6 @@ public class DefaultHttp2ConnectionEncoderTest {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
} }
private ChannelPromise newVoidPromise() {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE) {
@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
throw new AssertionFailedError();
}
@Override
public ChannelPromise addListeners(
GenericFutureListener<? extends Future<? super Void>>... listeners) {
throw new AssertionFailedError();
}
@Override
public boolean isVoid() {
return true;
}
};
}
private ChannelFuture newSucceededFuture() { private ChannelFuture newSucceededFuture() {
return newPromise().setSuccess(); return newPromise().setSuccess();
} }

View File

@ -22,8 +22,10 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
@ -47,6 +49,7 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise;
import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -104,6 +107,9 @@ public class Http2ConnectionHandlerTest {
@Mock @Mock
private Channel channel; private Channel channel;
@Mock
private ChannelPipeline pipeline;
@Mock @Mock
private ChannelFuture future; private ChannelFuture future;
@ -154,6 +160,7 @@ public class Http2ConnectionHandlerTest {
when(future.cause()).thenReturn(fakeException); when(future.cause()).thenReturn(fakeException);
when(future.channel()).thenReturn(channel); when(future.channel()).thenReturn(channel);
when(channel.isActive()).thenReturn(true); when(channel.isActive()).thenReturn(true);
when(channel.pipeline()).thenReturn(pipeline);
when(connection.remote()).thenReturn(remote); when(connection.remote()).thenReturn(remote);
when(remote.flowController()).thenReturn(remoteFlowController); when(remote.flowController()).thenReturn(remoteFlowController);
when(connection.local()).thenReturn(local); when(connection.local()).thenReturn(local);
@ -439,6 +446,31 @@ public class Http2ConnectionHandlerTest {
} }
} }
@Test
public void canSendGoAwayUsingVoidPromise() throws Exception {
handler = newHandler();
ByteBuf data = dummyData();
long errorCode = Http2Error.INTERNAL_ERROR.code();
handler = newHandler();
final Throwable cause = new RuntimeException("fake exception");
doAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
ChannelPromise promise = invocation.getArgumentAt(4, ChannelPromise.class);
assertFalse(promise.isVoid());
// This is what DefaultHttp2FrameWriter does... I hate mocking :-(.
SimpleChannelPromiseAggregator aggregatedPromise =
new SimpleChannelPromiseAggregator(promise, channel, ImmediateEventExecutor.INSTANCE);
aggregatedPromise.newPromise();
aggregatedPromise.doneAllocatingPromises();
return aggregatedPromise.setFailure(cause);
}
}).when(frameWriter).writeGoAway(
any(ChannelHandlerContext.class), anyInt(), anyInt(), any(ByteBuf.class), any(ChannelPromise.class));
handler.goAway(ctx, STREAM_ID, errorCode, data, newVoidPromise(channel));
verify(pipeline).fireExceptionCaught(cause);
}
@Test @Test
public void channelReadCompleteTriggersFlush() throws Exception { public void channelReadCompleteTriggersFlush() throws Exception {
handler = newHandler(); handler = newHandler();
@ -458,6 +490,33 @@ public class Http2ConnectionHandlerTest {
any(ChannelPromise.class)); any(ChannelPromise.class));
} }
@Test
public void writeRstStreamForUnkownStreamUsingVoidPromise() throws Exception {
writeRstStreamUsingVoidPromise(NON_EXISTANT_STREAM_ID);
}
@Test
public void writeRstStreamForKnownStreamUsingVoidPromise() throws Exception {
writeRstStreamUsingVoidPromise(STREAM_ID);
}
private void writeRstStreamUsingVoidPromise(int streamId) throws Exception {
handler = newHandler();
final Throwable cause = new RuntimeException("fake exception");
when(frameWriter.writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class)))
.then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelPromise promise = invocationOnMock.getArgumentAt(3, ChannelPromise.class);
assertFalse(promise.isVoid());
return promise.setFailure(cause);
}
});
handler.resetStream(ctx, streamId, STREAM_CLOSED.code(), newVoidPromise(channel));
verify(frameWriter).writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class));
verify(pipeline).fireExceptionCaught(cause);
}
private static ByteBuf dummyData() { private static ByteBuf dummyData() {
return Unpooled.buffer().writeBytes("abcdefgh".getBytes(UTF_8)); return Unpooled.buffer().writeBytes("abcdefgh".getBytes(UTF_8));
} }

View File

@ -16,9 +16,17 @@ package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AsciiString; import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import junit.framework.AssertionFailedError;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
@ -382,4 +390,52 @@ final class Http2TestUtil {
messageLatch.countDown(); messageLatch.countDown();
} }
} }
static ChannelPromise newVoidPromise(final Channel channel) {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE) {
@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
throw new AssertionFailedError();
}
@Override
public ChannelPromise addListeners(
GenericFutureListener<? extends Future<? super Void>>... listeners) {
throw new AssertionFailedError();
}
@Override
public boolean isVoid() {
return true;
}
@Override
public boolean tryFailure(Throwable cause) {
channel().pipeline().fireExceptionCaught(cause);
return true;
}
@Override
public ChannelPromise setFailure(Throwable cause) {
tryFailure(cause);
return this;
}
@Override
public ChannelPromise unvoid() {
ChannelPromise promise =
new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
channel().pipeline().fireExceptionCaught(future.cause());
}
}
});
return promise;
}
};
}
} }