Keeping HTTP/2 HEADERS frames in-place WRT DATA frames.

Motivation:

Currently due to flow control, HEADERS frames can be written
out-of-order WRT DATA frames.

Modifications:

When data is written, we preserve the future as the lastWriteFuture in
the outbound flow controller.  The encoder then uses the lastWriteFuture
such that headers are only written after the lastWriteFuture completes.

Result:

HEADERS/DATA write order is correctly preserved.
This commit is contained in:
nmittler 2014-10-31 14:15:14 -07:00
parent 4a86dce2ca
commit 1794f424d1
7 changed files with 206 additions and 34 deletions

View File

@ -22,6 +22,7 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.ChannelPromiseAggregator;
import java.util.ArrayDeque;
@ -196,6 +197,11 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
return future;
}
@Override
public ChannelFuture lastWriteForStream(int streamId) {
return outboundFlow.lastWriteForStream(streamId);
}
@Override
public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endStream, ChannelPromise promise) {
@ -203,10 +209,12 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
}
@Override
public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
int streamDependency, short weight, boolean exclusive, int padding, boolean endOfStream,
ChannelPromise promise) {
public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId,
final Http2Headers headers, final int streamDependency, final short weight,
final boolean exclusive, final int padding, final boolean endOfStream,
final ChannelPromise promise) {
Http2Stream stream = connection.stream(streamId);
ChannelFuture lastDataWrite = lastWriteForStream(streamId);
try {
if (connection.isGoAway()) {
throw protocolError("Sending headers after connection going away.");
@ -238,6 +246,12 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
"Stream %d in unexpected state: %s", stream.id(), stream.state()));
}
}
if (lastDataWrite != null && !endOfStream) {
throw new IllegalStateException(
"Sending non-trailing headers after data has been sent for stream: "
+ streamId);
}
} catch (Throwable e) {
if (e instanceof Http2NoMoreStreamIdsException) {
lifecycleManager.onException(ctx, e);
@ -245,8 +259,49 @@ public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder {
return promise.setFailure(e);
}
if (lastDataWrite == null) {
// No previous DATA frames to keep in sync with, just send it now.
return writeHeaders(ctx, stream, headers, streamDependency, weight, exclusive, padding,
endOfStream, promise);
}
// There were previous DATA frames sent. We need to send the HEADERS only after the most
// recent DATA frame to keep them in sync...
// Wrap the original promise in an aggregate which will complete the original promise
// once the headers are written.
final ChannelPromiseAggregator aggregatePromise = new ChannelPromiseAggregator(promise);
final ChannelPromise innerPromise = ctx.newPromise();
aggregatePromise.add(innerPromise);
// Only write the HEADERS frame after the previous DATA frame has been written.
final Http2Stream theStream = stream;
lastDataWrite.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
// The DATA write failed, also fail this write.
innerPromise.setFailure(future.cause());
return;
}
// Perform the write.
writeHeaders(ctx, theStream, headers, streamDependency, weight, exclusive, padding,
endOfStream, innerPromise);
}
});
return promise;
}
/**
* Writes the given {@link Http2Headers} to the remote endpoint and updates stream state if appropriate.
*/
private ChannelFuture writeHeaders(ChannelHandlerContext ctx, Http2Stream stream,
Http2Headers headers, int streamDependency, short weight, boolean exclusive,
int padding, boolean endOfStream, ChannelPromise promise) {
ChannelFuture future =
frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, weight,
frameWriter.writeHeaders(ctx, stream.id(), headers, streamDependency, weight,
exclusive, padding, endOfStream, promise);
ctx.flush();

View File

@ -203,8 +203,14 @@ public class DefaultHttp2OutboundFlowController implements Http2OutboundFlowCont
return promise;
}
@Override
public ChannelFuture lastWriteForStream(int streamId) {
OutboundFlowState state = state(streamId);
return state != null ? state.lastNewFrame() : null;
}
private static OutboundFlowState state(Http2Stream stream) {
return (OutboundFlowState) stream.outboundFlow();
return stream != null ? (OutboundFlowState) stream.outboundFlow() : null;
}
private OutboundFlowState connectionState() {
@ -384,6 +390,7 @@ public class DefaultHttp2OutboundFlowController implements Http2OutboundFlowCont
private int pendingBytes;
private int priorityBytes;
private int allocatedPriorityBytes;
private ChannelFuture lastNewFrame;
private OutboundFlowState(Http2Stream stream) {
this.stream = stream;
@ -412,6 +419,13 @@ public class DefaultHttp2OutboundFlowController implements Http2OutboundFlowCont
return window;
}
/**
* Returns the future for the last new frame created for this stream.
*/
ChannelFuture lastNewFrame() {
return lastNewFrame;
}
/**
* Returns the maximum writable window (minimum of the stream and connection windows).
*/
@ -461,6 +475,8 @@ public class DefaultHttp2OutboundFlowController implements Http2OutboundFlowCont
* Creates a new frame with the given values but does not add it to the pending queue.
*/
private Frame newFrame(ChannelPromise promise, ByteBuf data, int padding, boolean endStream) {
// Store this as the future for the most recent write attempt.
lastNewFrame = promise;
return new Frame(new ChannelPromiseAggregator(promise), data, padding, endStream);
}

View File

@ -39,6 +39,12 @@ public interface Http2OutboundFlowController extends Http2DataWriter {
ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding,
boolean endStream, ChannelPromise promise);
/**
* Returns the {@link ChannelFuture} for the most recent write for the given
* stream. If no previous write for the stream has occurred, returns {@code null}.
*/
ChannelFuture lastWriteForStream(int streamId);
/**
* Sets the initial size of the connection's outbound flow control window. The outbound flow
* control windows for all streams are updated by the delta in the initial window size. This is

View File

@ -44,6 +44,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Before;
import org.junit.Test;
@ -205,24 +206,55 @@ public class DefaultHttp2ConnectionEncoderTest {
@Test
public void headersWriteForUnknownStreamShouldCreateStream() throws Exception {
when(local.createStream(eq(5), eq(false))).thenReturn(stream);
encoder.writeHeaders(ctx, 5, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(local).createStream(eq(5), eq(false));
verify(writer).writeHeaders(eq(ctx), eq(5), eq(EmptyHttp2Headers.INSTANCE), eq(0),
int streamId = 5;
when(stream.id()).thenReturn(streamId);
mockFutureAddListener(true);
when(local.createStream(eq(streamId), eq(false))).thenReturn(stream);
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(local).createStream(eq(streamId), eq(false));
verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false), eq(promise));
}
@Test
public void headersWriteShouldCreateHalfClosedStream() throws Exception {
when(local.createStream(eq(5), eq(true))).thenReturn(stream);
encoder.writeHeaders(ctx, 5, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(local).createStream(eq(5), eq(true));
verify(writer).writeHeaders(eq(ctx), eq(5), eq(EmptyHttp2Headers.INSTANCE), eq(0),
int streamId = 5;
when(stream.id()).thenReturn(5);
mockFutureAddListener(true);
when(local.createStream(eq(streamId), eq(true))).thenReturn(stream);
encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(local).createStream(eq(streamId), eq(true));
verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise));
}
@Test
public void headersWriteAfterDataShouldWait() throws Exception {
final AtomicReference<ChannelFutureListener> listener = new AtomicReference<ChannelFutureListener>();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
listener.set((ChannelFutureListener) invocation.getArguments()[0]);
return null;
}
}).when(future).addListener(any(ChannelFutureListener.class));
// Indicate that there was a previous data write operation that the headers must wait for.
when(outboundFlow.lastWriteForStream(anyInt())).thenReturn(future);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(writer, never()).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise));
// Now complete the previous data write operation and verify that the headers were written.
when(future.isSuccess()).thenReturn(true);
listener.get().operationComplete(future);
verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true), eq(promise));
}
@Test
public void headersWriteShouldOpenStreamForPush() throws Exception {
mockFutureAddListener(true);
when(stream.state()).thenReturn(RESERVED_LOCAL);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise);
verify(stream).openForPush();
@ -233,6 +265,7 @@ public class DefaultHttp2ConnectionEncoderTest {
@Test
public void headersWriteShouldClosePushStream() throws Exception {
mockFutureAddListener(true);
when(stream.state()).thenReturn(RESERVED_LOCAL).thenReturn(HALF_CLOSED_LOCAL);
encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise);
verify(stream).openForPush();
@ -315,6 +348,21 @@ public class DefaultHttp2ConnectionEncoderTest {
verify(writer).writeSettings(eq(ctx), eq(settings), eq(promise));
}
private void mockFutureAddListener(boolean success) {
when(future.isSuccess()).thenReturn(success);
if (!success) {
when(future.cause()).thenReturn(new Exception("Fake Exception"));
}
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
ChannelFutureListener listener = (ChannelFutureListener) invocation.getArguments()[0];
listener.operationComplete(future);
return null;
}
}).when(future).addListener(any(ChannelFutureListener.class));
}
private static ByteBuf dummyData() {
// The buffer is purposely 8 bytes so it will even work for a ping frame.
return wrappedBuffer("abcdefgh".getBytes(UTF_8));

View File

@ -29,6 +29,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.DefaultHttp2OutboundFlowController.OutboundFlowState;
@ -139,6 +140,24 @@ public class DefaultHttp2OutboundFlowControllerTest {
}
}
@Test
public void lastWriteFutureShouldBeSaved() throws Http2Exception {
ChannelPromise promise2 = Mockito.mock(ChannelPromise.class);
final ByteBuf data = dummyData(5, 5);
try {
// Write one frame.
ChannelFuture future1 = controller.writeData(ctx, STREAM_A, data, 0, false, promise);
assertEquals(future1, controller.lastWriteForStream(STREAM_A));
// Now write another and verify that the last write is updated.
ChannelFuture future2 = controller.writeData(ctx, STREAM_A, data, 0, false, promise2);
assertTrue(future1 != future2);
assertEquals(future2, controller.lastWriteForStream(STREAM_A));
} finally {
manualSafeRelease(data);
}
}
@Test
public void frameLargerThanMaxFrameSizeShouldBeSplit() throws Http2Exception {
when(frameWriterSizePolicy.maxFrameSize()).thenReturn(3);
@ -1223,7 +1242,8 @@ public class DefaultHttp2OutboundFlowControllerTest {
}
private void send(int streamId, ByteBuf data, int padding) throws Http2Exception {
controller.writeData(ctx, streamId, data, padding, false, promise);
ChannelFuture future = controller.writeData(ctx, streamId, data, padding, false, promise);
assertEquals(future, controller.lastWriteForStream(streamId));
}
private void verifyWrite(int streamId, ByteBuf data, int padding) {

View File

@ -85,8 +85,9 @@ public class Http2ConnectionRoundtripTest {
private Channel serverChannel;
private Channel clientChannel;
private Http2TestUtil.FrameCountDown serverFrameCountDown;
private volatile CountDownLatch requestLatch;
private volatile CountDownLatch dataLatch;
private CountDownLatch requestLatch;
private CountDownLatch dataLatch;
private CountDownLatch trailersLatch;
@Before
public void setup() throws Exception {
@ -106,7 +107,7 @@ public class Http2ConnectionRoundtripTest {
@Test
public void http2ExceptionInPipelineShouldCloseConnection() throws Exception {
bootstrapEnv(1, 1);
bootstrapEnv(1, 1, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
@ -150,7 +151,7 @@ public class Http2ConnectionRoundtripTest {
any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), eq((short) 16),
eq(false), eq(0), eq(false));
bootstrapEnv(1, 1);
bootstrapEnv(1, 1, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
@ -180,7 +181,7 @@ public class Http2ConnectionRoundtripTest {
@Test
public void nonHttp2ExceptionInPipelineShouldNotCloseConnection() throws Exception {
bootstrapEnv(1, 1);
bootstrapEnv(1, 1, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
@ -219,7 +220,7 @@ public class Http2ConnectionRoundtripTest {
@Test
public void noMoreStreamIdsShouldSendGoAway() throws Exception {
bootstrapEnv(1, 3);
bootstrapEnv(1, 3, 1);
// Create a single stream by sending a HEADERS frame to the server.
final Http2Headers headers = dummyHeaders();
@ -262,7 +263,7 @@ public class Http2ConnectionRoundtripTest {
any(ByteBuf.class), eq(0), Mockito.anyBoolean());
try {
// Initialize the data latch based on the number of bytes expected.
bootstrapEnv(length, 2);
bootstrapEnv(length, 2, 1);
// Create the stream and send all of the data at once.
runInChannel(clientChannel, new Http2Runnable() {
@ -270,20 +271,25 @@ public class Http2ConnectionRoundtripTest {
public void run() {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0,
false, newPromise());
http2Client.encoder().writeData(ctx(), 3, data.retain(), 0, true, newPromise());
http2Client.encoder().writeData(ctx(), 3, data.retain(), 0, false, newPromise());
// Write trailers.
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0,
true, newPromise());
}
});
// Wait for all DATA frames to be received at the server.
assertTrue(dataLatch.await(5, TimeUnit.SECONDS));
// Wait for the trailers to be received.
assertTrue(trailersLatch.await(5, TimeUnit.SECONDS));
// Verify that headers were received and only one DATA frame was received with endStream set.
// Verify that headers and trailers were received.
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0),
eq(true));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(true));
// Verify we received all the bytes.
assertEquals(0, dataLatch.getCount());
out.flush();
byte[] received = out.toByteArray();
assertArrayEquals(bytes, received);
@ -317,9 +323,9 @@ public class Http2ConnectionRoundtripTest {
return null;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), eq(data),
eq(0), eq(true));
eq(0), eq(false));
try {
bootstrapEnv(numStreams * text.length(), numStreams * 3);
bootstrapEnv(numStreams * text.length(), numStreams * 4, numStreams);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() {
@ -329,18 +335,23 @@ public class Http2ConnectionRoundtripTest {
http2Client.encoder().writePing(ctx(), false, pingData.slice().retain(),
newPromise());
http2Client.encoder().writeData(ctx(), nextStream, data.slice().retain(),
0, true, newPromise());
0, false, newPromise());
// Write trailers.
http2Client.encoder().writeHeaders(ctx(), nextStream, headers, 0,
(short) 16, false, 0, true, newPromise());
}
}
});
// Wait for all frames to be received.
assertTrue(requestLatch.await(30, SECONDS));
assertTrue(trailersLatch.await(30, SECONDS));
verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(true));
verify(serverListener, times(numStreams)).onPingRead(any(ChannelHandlerContext.class),
any(ByteBuf.class));
verify(serverListener, times(numStreams)).onDataRead(any(ChannelHandlerContext.class),
anyInt(), any(ByteBuf.class), eq(0), eq(true));
anyInt(), any(ByteBuf.class), eq(0), eq(false));
assertEquals(numStreams, receivedPingBuffers.size());
assertEquals(numStreams, receivedDataBuffers.size());
for (String receivedData : receivedDataBuffers) {
@ -355,9 +366,10 @@ public class Http2ConnectionRoundtripTest {
}
}
private void bootstrapEnv(int dataCountDown, int requestCountDown) throws Exception {
private void bootstrapEnv(int dataCountDown, int requestCountDown, int trailersCountDown) throws Exception {
requestLatch = new CountDownLatch(requestCountDown);
dataLatch = new CountDownLatch(dataCountDown);
trailersLatch = new CountDownLatch(trailersCountDown);
sb = new ServerBootstrap();
cb = new Bootstrap();
@ -367,7 +379,9 @@ public class Http2ConnectionRoundtripTest {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
serverFrameCountDown = new Http2TestUtil.FrameCountDown(serverListener, requestLatch, dataLatch);
serverFrameCountDown =
new Http2TestUtil.FrameCountDown(serverListener, requestLatch, dataLatch,
trailersLatch);
p.addLast(new Http2ConnectionHandler(true, serverFrameCountDown));
p.addLast(Http2CodecUtil.ignoreSettingsHandler());
}

View File

@ -262,15 +262,22 @@ final class Http2TestUtil {
private final Http2FrameListener listener;
private final CountDownLatch messageLatch;
private final CountDownLatch dataLatch;
private final CountDownLatch trailersLatch;
public FrameCountDown(Http2FrameListener listener, CountDownLatch messageLatch) {
this(listener, messageLatch, null);
}
public FrameCountDown(Http2FrameListener listener, CountDownLatch messageLatch, CountDownLatch dataLatch) {
this(listener, messageLatch, dataLatch, null);
}
public FrameCountDown(Http2FrameListener listener, CountDownLatch messageLatch,
CountDownLatch dataLatch, CountDownLatch trailersLatch) {
this.listener = listener;
this.messageLatch = messageLatch;
this.dataLatch = dataLatch;
this.trailersLatch = trailersLatch;
}
@Override
@ -291,6 +298,9 @@ final class Http2TestUtil {
boolean endStream) throws Http2Exception {
listener.onHeadersRead(ctx, streamId, headers, padding, endStream);
messageLatch.countDown();
if (trailersLatch != null && endStream) {
trailersLatch.countDown();
}
}
@Override
@ -298,6 +308,9 @@ final class Http2TestUtil {
short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception {
listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream);
messageLatch.countDown();
if (trailersLatch != null && endStream) {
trailersLatch.countDown();
}
}
@Override