HTTP/2 write of released buffer should not write and should fail the promise

Motivation:
HTTP/2 allows writes of 0 length data frames. However in some cases EMPTY_BUFFER is used instead of the actual buffer that was written. This may mask writes of released buffers or otherwise invalid buffer objects. It is also possible that if the buffer is invalid AbstractCoalescingBufferQueue will not release the aggregated buffer nor fail the associated promise.

Modifications:
- DefaultHttp2FrameCodec should take care to fail the promise, even if releasing the data throws
- AbstractCoalescingBufferQueue should release any aggregated data and fail the associated promise if something goes wrong during aggregation

Result:
More correct handling of invalid buffers in HTTP/2 code.
This commit is contained in:
Scott Mitchell 2017-11-03 17:09:42 -07:00
parent bcad9dbf97
commit 35b0cd58fb
6 changed files with 194 additions and 66 deletions

View File

@ -15,7 +15,6 @@
package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
@ -386,10 +385,10 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
if (!endOfStream) {
if (queuedData == 0) {
// There's no need to write any data frames because there are only empty data frames in the queue
// and it is not end of stream yet. Just complete their promises by writing an empty buffer.
// and it is not end of stream yet. Just complete their promises by getting the buffer corresponding
// to 0 bytes and writing it to the channel (to preserve notification order).
ChannelPromise writePromise = ctx.newPromise().addListener(this);
queue.remove(0, writePromise).release();
ctx.write(Unpooled.EMPTY_BUFFER, writePromise);
ctx.write(queue.remove(0, writePromise), writePromise);
return;
}

View File

@ -173,13 +173,18 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
}
} while (!lastFrame);
} catch (Throwable t) {
if (needToReleaseHeaders) {
header.release();
try {
if (needToReleaseHeaders) {
header.release();
}
if (needToReleaseData) {
data.release();
}
} finally {
promiseAggregator.setFailure(t);
promiseAggregator.doneAllocatingPromises();
}
if (needToReleaseData) {
data.release();
}
promiseAggregator.setFailure(t);
return promiseAggregator;
}
return promiseAggregator.doneAllocatingPromises();
}
@ -264,27 +269,35 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
}
@Override
public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data,
ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data, ChannelPromise promise) {
final SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try {
verifyPingPayload(data);
Http2Flags flags = ack ? new Http2Flags().ack(true) : new Http2Flags();
int payloadLength = data.readableBytes();
ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, data.readableBytes(), PING, flags, 0);
// Assume nothing below will throw until buf is written. That way we don't have to take care of ownership
// in the catch block.
writeFrameHeaderInternal(buf, payloadLength, PING, flags, 0);
ctx.write(buf, promiseAggregator.newPromise());
} catch (Throwable t) {
try {
data.release();
} finally {
promiseAggregator.setFailure(t);
promiseAggregator.doneAllocatingPromises();
}
return promiseAggregator;
}
try {
// Write the debug data.
releaseData = false;
ctx.write(data, promiseAggregator.newPromise());
} catch (Throwable t) {
if (releaseData) {
data.release();
}
promiseAggregator.setFailure(t);
}
return promiseAggregator.doneAllocatingPromises();
}
@ -349,7 +362,6 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
@Override
public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode,
ByteBuf debugData, ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try {
@ -358,17 +370,25 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
int payloadLength = 8 + debugData.readableBytes();
ByteBuf buf = ctx.alloc().buffer(GO_AWAY_FRAME_HEADER_LENGTH);
// Assume nothing below will throw until buf is written. That way we don't have to take care of ownership
// in the catch block.
writeFrameHeaderInternal(buf, payloadLength, GO_AWAY, new Http2Flags(), 0);
buf.writeInt(lastStreamId);
buf.writeInt((int) errorCode);
ctx.write(buf, promiseAggregator.newPromise());
} catch (Throwable t) {
try {
debugData.release();
} finally {
promiseAggregator.setFailure(t);
promiseAggregator.doneAllocatingPromises();
}
return promiseAggregator;
}
releaseData = false;
try {
ctx.write(debugData, promiseAggregator.newPromise());
} catch (Throwable t) {
if (releaseData) {
debugData.release();
}
promiseAggregator.setFailure(t);
}
return promiseAggregator.doneAllocatingPromises();
@ -393,21 +413,27 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
@Override
public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId,
Http2Flags flags, ByteBuf payload, ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try {
verifyStreamOrConnectionId(streamId, STREAM_ID);
ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH);
// Assume nothing below will throw until buf is written. That way we don't have to take care of ownership
// in the catch block.
writeFrameHeaderInternal(buf, payload.readableBytes(), frameType, flags, streamId);
ctx.write(buf, promiseAggregator.newPromise());
releaseData = false;
} catch (Throwable t) {
try {
payload.release();
} finally {
promiseAggregator.setFailure(t);
promiseAggregator.doneAllocatingPromises();
}
return promiseAggregator;
}
try {
ctx.write(payload, promiseAggregator.newPromise());
} catch (Throwable t) {
if (releaseData) {
payload.release();
}
promiseAggregator.setFailure(t);
}
return promiseAggregator.doneAllocatingPromises();

View File

@ -36,6 +36,7 @@ import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown;
import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable;
import io.netty.util.AsciiString;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.concurrent.Future;
import org.junit.After;
import org.junit.Before;
@ -48,6 +49,7 @@ import org.mockito.stubbing.Answer;
import java.io.ByteArrayOutputStream;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID;
@ -60,12 +62,14 @@ import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt;
@ -649,6 +653,79 @@ public class Http2ConnectionRoundtripTest {
assertFalse(clientChannel.isOpen());
}
private enum WriteEmptyBufferMode {
SINGLE_END_OF_STREAM,
SECOND_END_OF_STREAM,
SINGLE_WITH_TRAILERS,
SECOND_WITH_TRAILERS
}
@Test
public void writeOfEmptyReleasedBufferSingleBufferQueuedInFlowControllerShouldFail() throws Exception {
writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SINGLE_END_OF_STREAM);
}
@Test
public void writeOfEmptyReleasedBufferSingleBufferTrailersQueuedInFlowControllerShouldFail() throws Exception {
writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SINGLE_WITH_TRAILERS);
}
@Test
public void writeOfEmptyReleasedBufferMultipleBuffersQueuedInFlowControllerShouldFail() throws Exception {
writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SECOND_END_OF_STREAM);
}
@Test
public void writeOfEmptyReleasedBufferMultipleBuffersTrailersQueuedInFlowControllerShouldFail() throws Exception {
writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SECOND_WITH_TRAILERS);
}
public void writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(final WriteEmptyBufferMode mode)
throws Exception {
bootstrapEnv(1, 1, 2, 1);
final ChannelPromise emptyDataPromise = newPromise();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0, (short) 16, false, 0, false,
newPromise());
ByteBuf emptyBuf = Unpooled.buffer();
emptyBuf.release();
switch (mode) {
case SINGLE_END_OF_STREAM:
http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, true, emptyDataPromise);
break;
case SECOND_END_OF_STREAM:
http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, false, emptyDataPromise);
http2Client.encoder().writeData(ctx(), 3, randomBytes(8), 0, true, newPromise());
break;
case SINGLE_WITH_TRAILERS:
http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, false, emptyDataPromise);
http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0,
(short) 16, false, 0, true, newPromise());
break;
case SECOND_WITH_TRAILERS:
http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, false, emptyDataPromise);
http2Client.encoder().writeData(ctx(), 3, randomBytes(8), 0, false, newPromise());
http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0,
(short) 16, false, 0, true, newPromise());
break;
default:
throw new Error();
}
http2Client.flush(ctx());
}
});
try {
emptyDataPromise.get();
fail();
} catch (ExecutionException e) {
assertThat(e.getCause(), is(instanceOf(IllegalReferenceCountException.class)));
}
}
@Test
public void nonHttp2ExceptionInPipelineShouldNotCloseConnection() throws Exception {
bootstrapEnv(1, 1, 2, 1);

View File

@ -1852,13 +1852,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
return composite;
}
if (attemptCopyToCumulation(cumulation, next, wrapDataSize)) {
return cumulation;
}
CompositeByteBuf composite = alloc.compositeDirectBuffer(size() + 2);
composite.addComponent(true, cumulation);
composite.addComponent(true, next);
return composite;
return attemptCopyToCumulation(cumulation, next, wrapDataSize) ? cumulation :
composeIntoComposite(alloc, cumulation, next);
}
@Override
@ -1866,7 +1861,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (first instanceof CompositeByteBuf) {
CompositeByteBuf composite = (CompositeByteBuf) first;
first = allocator.directBuffer(composite.readableBytes());
first.writeBytes(composite);
try {
first.writeBytes(composite);
} catch (Throwable cause) {
first.release();
PlatformDependent.throwException(cause);
}
composite.release();
}
return first;

View File

@ -16,7 +16,9 @@ package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.UnstableApi;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -137,31 +139,42 @@ public abstract class AbstractCoalescingBufferQueue {
bytes = Math.min(bytes, readableBytes);
ByteBuf toReturn = null;
ByteBuf entryBuffer = null;
int originalBytes = bytes;
for (;;) {
Object entry = bufAndListenerPairs.poll();
if (entry == null) {
break;
}
if (entry instanceof ChannelFutureListener) {
aggregatePromise.addListener((ChannelFutureListener) entry);
continue;
}
ByteBuf entryBuffer = (ByteBuf) entry;
if (entryBuffer.readableBytes() > bytes) {
// Add the buffer back to the queue as we can't consume all of it.
bufAndListenerPairs.addFirst(entryBuffer);
if (bytes > 0) {
// Take a slice of what we can consume and retain it.
ByteBuf next = entryBuffer.readRetainedSlice(bytes);
toReturn = toReturn == null ? composeFirst(alloc, next) : compose(alloc, toReturn, next);
bytes = 0;
try {
for (;;) {
Object entry = bufAndListenerPairs.poll();
if (entry == null) {
break;
}
break;
} else {
bytes -= entryBuffer.readableBytes();
toReturn = toReturn == null ? composeFirst(alloc, entryBuffer) : compose(alloc, toReturn, entryBuffer);
if (entry instanceof ChannelFutureListener) {
aggregatePromise.addListener((ChannelFutureListener) entry);
continue;
}
entryBuffer = (ByteBuf) entry;
if (entryBuffer.readableBytes() > bytes) {
// Add the buffer back to the queue as we can't consume all of it.
bufAndListenerPairs.addFirst(entryBuffer);
if (bytes > 0) {
// Take a slice of what we can consume and retain it.
entryBuffer = entryBuffer.readRetainedSlice(bytes);
toReturn = toReturn == null ? composeFirst(alloc, entryBuffer)
: compose(alloc, toReturn, entryBuffer);
bytes = 0;
}
break;
} else {
bytes -= entryBuffer.readableBytes();
toReturn = toReturn == null ? composeFirst(alloc, entryBuffer)
: compose(alloc, toReturn, entryBuffer);
}
entryBuffer = null;
}
} catch (Throwable cause) {
ReferenceCountUtil.safeRelease(entryBuffer);
ReferenceCountUtil.safeRelease(toReturn);
aggregatePromise.setFailure(cause);
PlatformDependent.throwException(cause);
}
decrementReadableBytes(originalBytes - bytes);
return toReturn;
@ -245,6 +258,23 @@ public abstract class AbstractCoalescingBufferQueue {
*/
protected abstract ByteBuf compose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next);
/**
* Compose {@code cumulation} and {@code next} into a new {@link CompositeByteBuf}.
*/
protected final ByteBuf composeIntoComposite(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) {
// Create a composite buffer to accumulate this pair and potentially all the buffers
// in the queue. Using +2 as we have already dequeued current and next.
CompositeByteBuf composite = alloc.compositeBuffer(size() + 2);
try {
composite.addComponent(true, cumulation);
composite.addComponent(true, next);
} catch (Throwable cause) {
composite.release();
PlatformDependent.throwException(cause);
}
return composite;
}
/**
* Calculate the first {@link ByteBuf} which will be used in subsequent calls to
* {@link #compose(ByteBufAllocator, ByteBuf, ByteBuf)}.

View File

@ -19,6 +19,7 @@ import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.PlatformDependent;
/**
* A FIFO queue of bytes where producers add bytes by repeatedly adding {@link ByteBuf} and consumers take bytes in
@ -76,12 +77,7 @@ public final class CoalescingBufferQueue extends AbstractCoalescingBufferQueue {
composite.addComponent(true, next);
return composite;
}
// Create a composite buffer to accumulate this pair and potentially all the buffers
// in the queue. Using +2 as we have already dequeued current and next.
CompositeByteBuf composite = alloc.compositeBuffer(size() + 2);
composite.addComponent(true, cumulation);
composite.addComponent(true, next);
return composite;
return composeIntoComposite(alloc, cumulation, next);
}
@Override