Http2DefaultFrameWriter direct write instead of copy

Motivation:
The Http2DefaultFrameWriter copies all contents into a buffer (or uses a CompositeBuffer in 1 case) and then writes that buffer to the socket. There is an opportunity to avoid the copy operations and write directly to the socket.

Modifications:
- Http2DefaultFrameWriter should avoid copy operations where possible.
- The Http2FrameWriter interface should be clarified to indicate that ByteBuf objects will be released.

Result:
Hopefully less allocation/copy leads to memory and throughput performance benefit.
This commit is contained in:
Scott Mitchell 2015-01-30 22:54:35 -05:00 committed by nmittler
parent abf7afca76
commit 8b5f2d7716
8 changed files with 416 additions and 195 deletions

View File

@ -15,21 +15,29 @@
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2CodecUtil.CONTINUATION_FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.DATA_FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.GO_AWAY_FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.HEADERS_FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_BYTE; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_BYTE;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT;
import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT;
import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_FRAME_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.PUSH_PROMISE_FRAME_HEADER_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.RST_STREAM_FRAME_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.WINDOW_UPDATE_FRAME_LENGTH;
import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid; import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid;
import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeader; import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeaderInternal;
import static io.netty.handler.codec.http2.Http2CodecUtil.writeUnsignedInt; import static io.netty.handler.codec.http2.Http2CodecUtil.writeUnsignedInt;
import static io.netty.handler.codec.http2.Http2CodecUtil.writeUnsignedShort; import static io.netty.handler.codec.http2.Http2CodecUtil.writeUnsignedShort;
import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR; import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION; import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION;
import static io.netty.handler.codec.http2.Http2FrameTypes.DATA; import static io.netty.handler.codec.http2.Http2FrameTypes.DATA;
import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY; import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY;
@ -42,10 +50,11 @@ import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS;
import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE; import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
import io.netty.handler.codec.http2.Http2FrameWriter.Configuration; import io.netty.handler.codec.http2.Http2FrameWriter.Configuration;
import io.netty.util.collection.IntObjectMap; import io.netty.util.collection.IntObjectMap;
@ -55,6 +64,12 @@ import io.netty.util.collection.IntObjectMap;
public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSizePolicy, Configuration { public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSizePolicy, Configuration {
private static final String STREAM_ID = "Stream ID"; private static final String STREAM_ID = "Stream ID";
private static final String STREAM_DEPENDENCY = "Stream Dependency"; private static final String STREAM_DEPENDENCY = "Stream Dependency";
/**
* This buffer is allocated to the maximum padding size needed, and filled with padding.
* When padding is needed it can be taken as a slice of this buffer. Users should call {@link ByteBuf#retain()}
* before using their slice.
*/
private static final ByteBuf ZERO_BUFFER = Unpooled.buffer(MAX_UNSIGNED_BYTE).writeZero(MAX_UNSIGNED_BYTE);
private final Http2HeadersEncoder headersEncoder; private final Http2HeadersEncoder headersEncoder;
private int maxFrameSize; private int maxFrameSize;
@ -102,6 +117,10 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
@Override @Override
public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
int padding, boolean endStream, ChannelPromise promise) { int padding, boolean endStream, ChannelPromise promise) {
boolean releaseData = true;
ByteBuf buf = null;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try { try {
verifyStreamId(streamId, STREAM_ID); verifyStreamId(streamId, STREAM_ID);
verifyPadding(padding); verifyPadding(padding);
@ -111,38 +130,40 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
int payloadLength = data.readableBytes() + padding + flags.getPaddingPresenceFieldLength(); int payloadLength = data.readableBytes() + padding + flags.getPaddingPresenceFieldLength();
verifyPayloadLength(payloadLength); verifyPayloadLength(payloadLength);
ByteBuf out = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); buf = ctx.alloc().buffer(DATA_FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, payloadLength, DATA, flags, streamId);
writeFrameHeader(out, payloadLength, DATA, flags, streamId); writePaddingLength(buf, padding);
ctx.write(buf, promiseAggregator.newPromise());
writePaddingLength(padding, out);
// Write the data. // Write the data.
out.writeBytes(data, data.readerIndex(), data.readableBytes()); releaseData = false;
ctx.write(data, promiseAggregator.newPromise());
// Write the required padding. if (padding > 0) { // Write the required padding.
out.writeZero(padding); ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise());
return ctx.write(out, promise); }
} catch (RuntimeException e) { return promiseAggregator.doneAllocatingPromises();
return promise.setFailure(e); } catch (Throwable t) {
} finally { if (releaseData) {
data.release(); data.release();
}
return promiseAggregator.setFailure(t);
} }
} }
@Override @Override
public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId,
Http2Headers headers, int padding, boolean endStream, ChannelPromise promise) { Http2Headers headers, int padding, boolean endStream, ChannelPromise promise) {
return writeHeadersInternal(ctx, promise, streamId, headers, padding, endStream, return writeHeadersInternal(ctx, streamId, headers, padding, endStream,
false, 0, (short) 0, false); false, 0, (short) 0, false, promise);
} }
@Override @Override
public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId,
Http2Headers headers, int streamDependency, short weight, boolean exclusive, Http2Headers headers, int streamDependency, short weight, boolean exclusive,
int padding, boolean endStream, ChannelPromise promise) { int padding, boolean endStream, ChannelPromise promise) {
return writeHeadersInternal(ctx, promise, streamId, headers, padding, endStream, return writeHeadersInternal(ctx, streamId, headers, padding, endStream,
true, streamDependency, weight, exclusive); true, streamDependency, weight, exclusive, promise);
} }
@Override @Override
@ -153,17 +174,15 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
verifyStreamId(streamDependency, STREAM_DEPENDENCY); verifyStreamId(streamDependency, STREAM_DEPENDENCY);
verifyWeight(weight); verifyWeight(weight);
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + PRIORITY_ENTRY_LENGTH); ByteBuf buf = ctx.alloc().buffer(PRIORITY_FRAME_LENGTH);
writeFrameHeader(frame, PRIORITY_ENTRY_LENGTH, PRIORITY, writeFrameHeaderInternal(buf, PRIORITY_ENTRY_LENGTH, PRIORITY, new Http2Flags(), streamId);
new Http2Flags(), streamId);
long word1 = exclusive ? 0x80000000L | streamDependency : streamDependency; long word1 = exclusive ? 0x80000000L | streamDependency : streamDependency;
writeUnsignedInt(word1, frame); writeUnsignedInt(word1, buf);
// Adjust the weight so that it fits into a single byte on the wire. // Adjust the weight so that it fits into a single byte on the wire.
frame.writeByte(weight - 1); buf.writeByte(weight - 1);
return ctx.write(frame, promise); return ctx.write(buf, promise);
} catch (RuntimeException e) { } catch (Throwable t) {
return promise.setFailure(e); return promise.setFailure(t);
} }
} }
@ -174,13 +193,12 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
verifyStreamId(streamId, STREAM_ID); verifyStreamId(streamId, STREAM_ID);
verifyErrorCode(errorCode); verifyErrorCode(errorCode);
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + INT_FIELD_LENGTH); ByteBuf buf = ctx.alloc().buffer(RST_STREAM_FRAME_LENGTH);
writeFrameHeader(frame, INT_FIELD_LENGTH, RST_STREAM, new Http2Flags(), writeFrameHeaderInternal(buf, INT_FIELD_LENGTH, RST_STREAM, new Http2Flags(), streamId);
streamId); writeUnsignedInt(errorCode, buf);
writeUnsignedInt(errorCode, frame); return ctx.write(buf, promise);
return ctx.write(frame, promise); } catch (Throwable t) {
} catch (RuntimeException e) { return promise.setFailure(t);
return promise.setFailure(e);
} }
} }
@ -190,44 +208,51 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
try { try {
checkNotNull(settings, "settings"); checkNotNull(settings, "settings");
int payloadLength = SETTING_ENTRY_LENGTH * settings.size(); int payloadLength = SETTING_ENTRY_LENGTH * settings.size();
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH + settings.size() * SETTING_ENTRY_LENGTH);
writeFrameHeader(frame, payloadLength, SETTINGS, new Http2Flags(), 0); writeFrameHeaderInternal(buf, payloadLength, SETTINGS, new Http2Flags(), 0);
for (IntObjectMap.Entry<Long> entry : settings.entries()) { for (IntObjectMap.Entry<Long> entry : settings.entries()) {
writeUnsignedShort(entry.key(), frame); writeUnsignedShort(entry.key(), buf);
writeUnsignedInt(entry.value(), frame); writeUnsignedInt(entry.value(), buf);
} }
return ctx.write(frame, promise); return ctx.write(buf, promise);
} catch (RuntimeException e) { } catch (Throwable t) {
return promise.setFailure(e); return promise.setFailure(t);
} }
} }
@Override @Override
public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) {
try { try {
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH); ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH);
writeFrameHeader(frame, 0, SETTINGS, new Http2Flags().ack(true), 0); writeFrameHeaderInternal(buf, 0, SETTINGS, new Http2Flags().ack(true), 0);
return ctx.write(frame, promise); return ctx.write(buf, promise);
} catch (RuntimeException e) { } catch (Throwable t) {
return promise.setFailure(e); return promise.setFailure(t);
} }
} }
@Override @Override
public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data, public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data,
ChannelPromise promise) { ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try { try {
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + data.readableBytes());
Http2Flags flags = ack ? new Http2Flags().ack(true) : new Http2Flags(); Http2Flags flags = ack ? new Http2Flags().ack(true) : new Http2Flags();
writeFrameHeader(frame, data.readableBytes(), PING, flags, 0); ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, data.readableBytes(), PING, flags, 0);
ctx.write(buf, promiseAggregator.newPromise());
// Write the debug data. // Write the debug data.
frame.writeBytes(data, data.readerIndex(), data.readableBytes()); releaseData = false;
return ctx.write(frame, promise); ctx.write(data, promiseAggregator.newPromise());
} catch (RuntimeException e) {
return promise.setFailure(e); return promiseAggregator.doneAllocatingPromises();
} finally { } catch (Throwable t) {
data.release(); if (releaseData) {
data.release();
}
return promiseAggregator.setFailure(t);
} }
} }
@ -235,6 +260,8 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId,
int promisedStreamId, Http2Headers headers, int padding, ChannelPromise promise) { int promisedStreamId, Http2Headers headers, int padding, ChannelPromise promise) {
ByteBuf headerBlock = null; ByteBuf headerBlock = null;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try { try {
verifyStreamId(streamId, STREAM_ID); verifyStreamId(streamId, STREAM_ID);
verifyStreamId(promisedStreamId, "Promised Stream ID"); verifyStreamId(promisedStreamId, "Promised Stream ID");
@ -246,38 +273,37 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
// Read the first fragment (possibly everything). // Read the first fragment (possibly everything).
Http2Flags flags = new Http2Flags().paddingPresent(padding > 0); Http2Flags flags = new Http2Flags().paddingPresent(padding > 0);
int promisedStreamIdLength = INT_FIELD_LENGTH; // INT_FIELD_LENGTH is for the length of the promisedStreamId
int nonFragmentLength = promisedStreamIdLength + padding + flags.getPaddingPresenceFieldLength(); int nonFragmentLength = INT_FIELD_LENGTH + padding + flags.getPaddingPresenceFieldLength();
int maxFragmentLength = maxFrameSize - nonFragmentLength; int maxFragmentLength = maxFrameSize - nonFragmentLength;
ByteBuf fragment = ByteBuf fragment =
headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)); headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)).retain();
flags.endOfHeaders(headerBlock.readableBytes() == 0); flags.endOfHeaders(!headerBlock.isReadable());
int payloadLength = fragment.readableBytes() + nonFragmentLength; int payloadLength = fragment.readableBytes() + nonFragmentLength;
ByteBuf firstFrame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); ByteBuf buf = ctx.alloc().buffer(PUSH_PROMISE_FRAME_HEADER_LENGTH);
writeFrameHeader(firstFrame, payloadLength, PUSH_PROMISE, flags, writeFrameHeaderInternal(buf, payloadLength, PUSH_PROMISE, flags, streamId);
streamId); writePaddingLength(buf, padding);
writePaddingLength(padding, firstFrame);
// Write out the promised stream ID. // Write out the promised stream ID.
firstFrame.writeInt(promisedStreamId); buf.writeInt(promisedStreamId);
ctx.write(buf, promiseAggregator.newPromise());
// Write the first fragment. // Write the first fragment.
firstFrame.writeBytes(fragment); ctx.write(fragment, promiseAggregator.newPromise());
// Write out the padding, if any. if (padding > 0) { // Write out the padding, if any.
firstFrame.writeZero(padding); ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise());
if (headerBlock.readableBytes() == 0) {
return ctx.write(firstFrame, promise);
} }
// Create a composite buffer wrapping the first frame and any continuation frames. if (!flags.endOfHeaders()) {
return continueHeaders(ctx, promise, streamId, padding, headerBlock, firstFrame); writeContinuationFrames(ctx, streamId, headerBlock, padding, promiseAggregator);
} catch (Exception e) { }
return promise.setFailure(e);
return promiseAggregator.doneAllocatingPromises();
} catch (Throwable t) {
return promiseAggregator.setFailure(t);
} finally { } finally {
if (headerBlock != null) { if (headerBlock != null) {
headerBlock.release(); headerBlock.release();
@ -288,21 +314,28 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
@Override @Override
public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode,
ByteBuf debugData, ChannelPromise promise) { ByteBuf debugData, ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try { try {
verifyStreamOrConnectionId(lastStreamId, "Last Stream ID"); verifyStreamOrConnectionId(lastStreamId, "Last Stream ID");
verifyErrorCode(errorCode); verifyErrorCode(errorCode);
int payloadLength = 8 + debugData.readableBytes(); int payloadLength = 8 + debugData.readableBytes();
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); ByteBuf buf = ctx.alloc().buffer(GO_AWAY_FRAME_HEADER_LENGTH);
writeFrameHeader(frame, payloadLength, GO_AWAY, new Http2Flags(), 0); writeFrameHeaderInternal(buf, payloadLength, GO_AWAY, new Http2Flags(), 0);
frame.writeInt(lastStreamId); buf.writeInt(lastStreamId);
writeUnsignedInt(errorCode, frame); writeUnsignedInt(errorCode, buf);
frame.writeBytes(debugData, debugData.readerIndex(), debugData.readableBytes()); ctx.write(buf, promiseAggregator.newPromise());
return ctx.write(frame, promise);
} catch (RuntimeException e) { releaseData = false;
return promise.setFailure(e); ctx.write(debugData, promiseAggregator.newPromise());
} finally { return promiseAggregator.doneAllocatingPromises();
debugData.release(); } catch (Throwable t) {
if (releaseData) {
debugData.release();
}
return promiseAggregator.setFailure(t);
} }
} }
@ -313,34 +346,44 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
verifyStreamOrConnectionId(streamId, STREAM_ID); verifyStreamOrConnectionId(streamId, STREAM_ID);
verifyWindowSizeIncrement(windowSizeIncrement); verifyWindowSizeIncrement(windowSizeIncrement);
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + INT_FIELD_LENGTH); ByteBuf buf = ctx.alloc().buffer(WINDOW_UPDATE_FRAME_LENGTH);
writeFrameHeader(frame, INT_FIELD_LENGTH, WINDOW_UPDATE, writeFrameHeaderInternal(buf, INT_FIELD_LENGTH, WINDOW_UPDATE, new Http2Flags(), streamId);
new Http2Flags(), streamId); buf.writeInt(windowSizeIncrement);
frame.writeInt(windowSizeIncrement); return ctx.write(buf, promise);
return ctx.write(frame, promise); } catch (Throwable t) {
} catch (RuntimeException e) { return promise.setFailure(t);
return promise.setFailure(e);
} }
} }
@Override @Override
public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId,
Http2Flags flags, ByteBuf payload, ChannelPromise promise) { Http2Flags flags, ByteBuf payload, ChannelPromise promise) {
boolean releaseData = true;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try { try {
verifyStreamOrConnectionId(streamId, STREAM_ID); verifyStreamOrConnectionId(streamId, STREAM_ID);
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payload.readableBytes()); ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH);
writeFrameHeader(frame, payload.readableBytes(), frameType, flags, streamId); writeFrameHeaderInternal(buf, payload.readableBytes(), frameType, flags, streamId);
frame.writeBytes(payload); ctx.write(buf, promiseAggregator.newPromise());
return ctx.write(frame, promise);
} catch (RuntimeException e) { releaseData = false;
return promise.setFailure(e); ctx.write(payload, promiseAggregator.newPromise());
return promiseAggregator.doneAllocatingPromises();
} catch (Throwable t) {
if (releaseData) {
payload.release();
}
return promiseAggregator.setFailure(t);
} }
} }
private ChannelFuture writeHeadersInternal(ChannelHandlerContext ctx, ChannelPromise promise, private ChannelFuture writeHeadersInternal(ChannelHandlerContext ctx,
int streamId, Http2Headers headers, int padding, boolean endStream, int streamId, Http2Headers headers, int padding, boolean endStream,
boolean hasPriority, int streamDependency, short weight, boolean exclusive) { boolean hasPriority, int streamDependency, short weight, boolean exclusive, ChannelPromise promise) {
ByteBuf headerBlock = null; ByteBuf headerBlock = null;
SimpleChannelPromiseAggregator promiseAggregator =
new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
try { try {
verifyStreamId(streamId, STREAM_ID); verifyStreamId(streamId, STREAM_ID);
if (hasPriority) { if (hasPriority) {
@ -357,46 +400,42 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
new Http2Flags().endOfStream(endStream).priorityPresent(hasPriority).paddingPresent(padding > 0); new Http2Flags().endOfStream(endStream).priorityPresent(hasPriority).paddingPresent(padding > 0);
// Read the first fragment (possibly everything). // Read the first fragment (possibly everything).
int nonFragmentBytes = int nonFragmentBytes = padding + flags.getNumPriorityBytes() + flags.getPaddingPresenceFieldLength();
padding + flags.getNumPriorityBytes() + flags.getPaddingPresenceFieldLength();
int maxFragmentLength = maxFrameSize - nonFragmentBytes; int maxFragmentLength = maxFrameSize - nonFragmentBytes;
ByteBuf fragment = ByteBuf fragment =
headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)); headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)).retain();
// Set the end of headers flag for the first frame. // Set the end of headers flag for the first frame.
flags.endOfHeaders(headerBlock.readableBytes() == 0); flags.endOfHeaders(!headerBlock.isReadable());
int payloadLength = fragment.readableBytes() + nonFragmentBytes; int payloadLength = fragment.readableBytes() + nonFragmentBytes;
ByteBuf firstFrame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); ByteBuf buf = ctx.alloc().buffer(HEADERS_FRAME_HEADER_LENGTH);
writeFrameHeader(firstFrame, payloadLength, HEADERS, flags, writeFrameHeaderInternal(buf, payloadLength, HEADERS, flags, streamId);
streamId); writePaddingLength(buf, padding);
// Write the padding length.
writePaddingLength(padding, firstFrame);
// Write the priority.
if (hasPriority) { if (hasPriority) {
long word1 = exclusive ? 0x80000000L | streamDependency : streamDependency; long word1 = exclusive ? 0x80000000L | streamDependency : streamDependency;
writeUnsignedInt(word1, firstFrame); writeUnsignedInt(word1, buf);
// Adjust the weight so that it fits into a single byte on the wire. // Adjust the weight so that it fits into a single byte on the wire.
firstFrame.writeByte(weight - 1); buf.writeByte(weight - 1);
} }
ctx.write(buf, promiseAggregator.newPromise());
// Write the first fragment. // Write the first fragment.
firstFrame.writeBytes(fragment); ctx.write(fragment, promiseAggregator.newPromise());
// Write out the padding, if any. if (padding > 0) { // Write out the padding, if any.
firstFrame.writeZero(padding); ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise());
if (flags.endOfHeaders()) {
return ctx.write(firstFrame, promise);
} }
// Create a composite buffer wrapping the first frame and any continuation frames. if (!flags.endOfHeaders()) {
return continueHeaders(ctx, promise, streamId, padding, headerBlock, firstFrame); writeContinuationFrames(ctx, streamId, headerBlock, padding, promiseAggregator);
} catch (Exception e) { }
return promise.setFailure(e);
return promiseAggregator.doneAllocatingPromises();
} catch (Throwable t) {
return promiseAggregator.setFailure(t);
} finally { } finally {
if (headerBlock != null) { if (headerBlock != null) {
headerBlock.release(); headerBlock.release();
@ -405,66 +444,60 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize
} }
/** /**
* Drains the header block and creates a composite buffer containing the first frame and a * Writes as many continuation frames as needed until {@code padding} and {@code headerBlock} are consumed.
* number of CONTINUATION frames.
*/ */
private ChannelFuture continueHeaders(ChannelHandlerContext ctx, ChannelPromise promise, private ChannelFuture writeContinuationFrames(ChannelHandlerContext ctx, int streamId,
int streamId, int padding, ByteBuf headerBlock, ByteBuf firstFrame) { ByteBuf headerBlock, int padding, SimpleChannelPromiseAggregator promiseAggregator) {
// Create a composite buffer wrapping the first frame and any continuation frames.
CompositeByteBuf out = ctx.alloc().compositeBuffer();
out.addComponent(firstFrame);
int numBytes = firstFrame.readableBytes();
// Process any continuation frames there might be.
while (headerBlock.isReadable()) {
ByteBuf frame = createContinuationFrame(ctx, streamId, headerBlock, padding);
out.addComponent(frame);
numBytes += frame.readableBytes();
}
out.writerIndex(numBytes);
return ctx.write(out, promise);
}
/**
* Allocates a new buffer and writes a single continuation frame with a fragment of the header
* block to the output buffer.
*/
private ByteBuf createContinuationFrame(ChannelHandlerContext ctx, int streamId,
ByteBuf headerBlock, int padding) {
Http2Flags flags = new Http2Flags().paddingPresent(padding > 0); Http2Flags flags = new Http2Flags().paddingPresent(padding > 0);
int nonFragmentLength = padding + flags.getPaddingPresenceFieldLength(); int nonFragmentLength = padding + flags.getPaddingPresenceFieldLength();
int maxFragmentLength = maxFrameSize - nonFragmentLength; int maxFragmentLength = maxFrameSize - nonFragmentLength;
ByteBuf fragment = // TODO: same padding is applied to all frames, is this desired?
headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)); if (maxFragmentLength <= 0) {
return promiseAggregator.setFailure(new IllegalArgumentException(
"Padding [" + padding + "] is too large for max frame size [" + maxFrameSize + "]"));
}
int payloadLength = fragment.readableBytes() + nonFragmentLength; if (headerBlock.isReadable()) {
ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); // The frame header (and padding) only changes on the last frame, so allocate it once and re-use
flags = flags.endOfHeaders(headerBlock.readableBytes() == 0); final ByteBuf paddingBuf = padding > 0 ? ZERO_BUFFER.slice(0, padding) : null;
int fragmentReadableBytes = Math.min(headerBlock.readableBytes(), maxFragmentLength);
int payloadLength = fragmentReadableBytes + nonFragmentLength;
ByteBuf buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, payloadLength, CONTINUATION, flags, streamId);
writePaddingLength(buf, padding);
writeFrameHeader(frame, payloadLength, CONTINUATION, flags, streamId); do {
fragmentReadableBytes = Math.min(headerBlock.readableBytes(), maxFragmentLength);
ByteBuf fragment = headerBlock.readSlice(fragmentReadableBytes).retain();
writePaddingLength(padding, frame); payloadLength = fragmentReadableBytes + nonFragmentLength;
if (headerBlock.isReadable()) {
ctx.write(buf.retain(), promiseAggregator.newPromise());
} else {
// The frame header is different for the last frame, so re-allocate and release the old buffer
flags = flags.endOfHeaders(true);
buf.release();
buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH);
writeFrameHeaderInternal(buf, payloadLength, CONTINUATION, flags, streamId);
writePaddingLength(buf, padding);
ctx.write(buf, promiseAggregator.newPromise());
}
frame.writeBytes(fragment); ctx.write(fragment, promiseAggregator.newPromise());
// Write out the padding, if any. // Write out the padding, if any.
frame.writeZero(padding); if (paddingBuf != null) {
return frame; ctx.write(paddingBuf.retain(), promiseAggregator.newPromise());
}
} while(headerBlock.isReadable());
}
return promiseAggregator;
} }
/** private static void writePaddingLength(ByteBuf buf, int paddingLength) {
* Writes the padding length field to the output buffer.
*/
private static void writePaddingLength(int paddingLength, ByteBuf out) {
if (paddingLength > MAX_UNSIGNED_BYTE) {
int padHigh = paddingLength / 256;
out.writeByte(padHigh);
}
// Always include PadLow if there is any padding at all.
if (paddingLength > 0) { if (paddingLength > 0) {
int padLow = paddingLength % 256; // It is assumed that the padding length has been bounds checked before this
out.writeByte(padLow); buf.writeByte(paddingLength);
} }
} }

View File

@ -19,9 +19,12 @@ import static io.netty.util.CharsetUtil.UTF_8;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action; import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action;
import io.netty.util.concurrent.EventExecutor;
/** /**
* Constants and utility method used for encoding/decoding HTTP2 frames. * Constants and utility method used for encoding/decoding HTTP2 frames.
@ -48,6 +51,18 @@ public final class Http2CodecUtil {
public static final short MAX_WEIGHT = 256; public static final short MAX_WEIGHT = 256;
public static final short MIN_WEIGHT = 1; public static final short MIN_WEIGHT = 1;
private static final int MAX_PADDING_LENGTH_LENGTH = 1;
public static final int DATA_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH;
public static final int HEADERS_FRAME_HEADER_LENGTH =
FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH + 1;
public static final int PRIORITY_FRAME_LENGTH = FRAME_HEADER_LENGTH + PRIORITY_ENTRY_LENGTH;
public static final int RST_STREAM_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH;
public static final int PUSH_PROMISE_FRAME_HEADER_LENGTH =
FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH;
public static final int GO_AWAY_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + 2 * INT_FIELD_LENGTH;
public static final int WINDOW_UPDATE_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH;
public static final int CONTINUATION_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH;
public static final int SETTINGS_HEADER_TABLE_SIZE = 1; public static final int SETTINGS_HEADER_TABLE_SIZE = 1;
public static final int SETTINGS_ENABLE_PUSH = 2; public static final int SETTINGS_ENABLE_PUSH = 2;
public static final int SETTINGS_MAX_CONCURRENT_STREAMS = 3; public static final int SETTINGS_MAX_CONCURRENT_STREAMS = 3;
@ -183,12 +198,128 @@ public final class Http2CodecUtil {
public static void writeFrameHeader(ByteBuf out, int payloadLength, byte type, public static void writeFrameHeader(ByteBuf out, int payloadLength, byte type,
Http2Flags flags, int streamId) { Http2Flags flags, int streamId) {
out.ensureWritable(FRAME_HEADER_LENGTH + payloadLength); out.ensureWritable(FRAME_HEADER_LENGTH + payloadLength);
writeFrameHeaderInternal(out, payloadLength, type, flags, streamId);
}
static void writeFrameHeaderInternal(ByteBuf out, int payloadLength, byte type,
Http2Flags flags, int streamId) {
out.writeMedium(payloadLength); out.writeMedium(payloadLength);
out.writeByte(type); out.writeByte(type);
out.writeByte(flags.value()); out.writeByte(flags.value());
out.writeInt(streamId); out.writeInt(streamId);
} }
/**
* Provides the ability to associate the outcome of multiple {@link ChannelPromise}
* objects into a single {@link ChannelPromise} object.
*/
static class SimpleChannelPromiseAggregator extends DefaultChannelPromise {
private final ChannelPromise promise;
private int expectedCount;
private int successfulCount;
private int failureCount;
private boolean doneAllocating;
SimpleChannelPromiseAggregator(ChannelPromise promise, Channel c, EventExecutor e) {
super(c, e);
assert promise != null;
this.promise = promise;
}
/**
* Allocate a new promise which will be used to aggregate the overall success of this promise aggregator.
* @return A new promise which will be aggregated.
* {@code null} if {@link #doneAllocatingPromises()} was previously called.
*/
public ChannelPromise newPromise() {
if (doneAllocating) {
throw new IllegalStateException("Done allocating. No more promises can be allocated.");
}
++expectedCount;
return this;
}
/**
* Signify that no more {@link #newPromise()} allocations will be made.
* The aggregation can not be successful until this method is called.
* @return The promise that is the aggregation of all promises allocated with {@link #newPromise()}.
*/
public ChannelPromise doneAllocatingPromises() {
if (!doneAllocating) {
doneAllocating = true;
if (successfulCount == expectedCount) {
promise.setSuccess();
return super.setSuccess();
}
}
return this;
}
@Override
public boolean tryFailure(Throwable cause) {
if (allowNotificationEvent()) {
++failureCount;
if (failureCount == 1) {
promise.tryFailure(cause);
return super.tryFailure(cause);
}
// TODO: We break the interface a bit here.
// Multiple failure events can be processed without issue because this is an aggregation.
return true;
}
return false;
}
/**
* Fail this object if it has not already been failed.
* <p>
* This method will NOT throw an {@link IllegalStateException} if called multiple times
* because that may be expected.
*/
@Override
public ChannelPromise setFailure(Throwable cause) {
if (allowNotificationEvent()) {
++failureCount;
if (failureCount == 1) {
promise.setFailure(cause);
return super.setFailure(cause);
}
}
return this;
}
private boolean allowNotificationEvent() {
return successfulCount + failureCount < expectedCount;
}
@Override
public ChannelPromise setSuccess(Void result) {
if (allowNotificationEvent()) {
++successfulCount;
if (successfulCount == expectedCount && doneAllocating) {
promise.setSuccess(result);
return super.setSuccess(result);
}
}
return this;
}
@Override
public boolean trySuccess(Void result) {
if (allowNotificationEvent()) {
++successfulCount;
if (successfulCount == expectedCount && doneAllocating) {
promise.trySuccess(result);
return super.trySuccess(result);
}
// TODO: We break the interface a bit here.
// Multiple success events can be processed without issue because this is an aggregation.
return true;
}
return false;
}
}
/** /**
* Fails the given promise with the cause and then re-throws the cause. * Fails the given promise with the cause and then re-throws the cause.
*/ */

View File

@ -28,7 +28,7 @@ public interface Http2DataWriter {
* *
* @param ctx the context to use for writing. * @param ctx the context to use for writing.
* @param streamId the stream for which to send the frame. * @param streamId the stream for which to send the frame.
* @param data the payload of the frame. * @param data the payload of the frame. This will be released by this method.
* @param padding the amount of padding to be added to the end of the frame * @param padding the amount of padding to be added to the end of the frame
* @param endStream indicates if this is the last frame to be sent for the stream. * @param endStream indicates if this is the last frame to be sent for the stream.
* @param promise the promise for the write. * @param promise the promise for the write.

View File

@ -129,7 +129,7 @@ public interface Http2FrameWriter extends Http2DataWriter, Closeable {
* @param ctx the context to use for writing. * @param ctx the context to use for writing.
* @param ack indicates whether this is an ack of a PING frame previously received from the * @param ack indicates whether this is an ack of a PING frame previously received from the
* remote endpoint. * remote endpoint.
* @param data the payload of the frame. * @param data the payload of the frame. This will be released by this method.
* @param promise the promise for the write. * @param promise the promise for the write.
* @return the future for the write. * @return the future for the write.
*/ */
@ -156,7 +156,7 @@ public interface Http2FrameWriter extends Http2DataWriter, Closeable {
* @param ctx the context to use for writing. * @param ctx the context to use for writing.
* @param lastStreamId the last known stream of this endpoint. * @param lastStreamId the last known stream of this endpoint.
* @param errorCode the error code, if the connection was abnormally terminated. * @param errorCode the error code, if the connection was abnormally terminated.
* @param debugData application-defined debug data. * @param debugData application-defined debug data. This will be released by this method.
* @param promise the promise for the write. * @param promise the promise for the write.
* @return the future for the write. * @return the future for the write.
*/ */
@ -183,7 +183,7 @@ public interface Http2FrameWriter extends Http2DataWriter, Closeable {
* @param frameType the frame type identifier. * @param frameType the frame type identifier.
* @param streamId the stream for which to send the frame. * @param streamId the stream for which to send the frame.
* @param flags the flags to write for this frame. * @param flags the flags to write for this frame.
* @param payload the payload to write for this frame. * @param payload the payload to write for this frame. This will be released by this method.
* @param promise the promise for the write. * @param promise the promise for the write.
* @return the future for the write. * @return the future for the write.
*/ */

View File

@ -17,9 +17,9 @@ package io.netty.handler.codec.http2;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.ChannelPromiseAggregator;
import io.netty.handler.codec.http.FullHttpMessage; import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator;
/** /**
* Translates HTTP/1.x object writes into HTTP/2 frames. * Translates HTTP/1.x object writes into HTTP/2 frames.
@ -65,6 +65,7 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler {
FullHttpMessage httpMsg = (FullHttpMessage) msg; FullHttpMessage httpMsg = (FullHttpMessage) msg;
boolean hasData = httpMsg.content().isReadable(); boolean hasData = httpMsg.content().isReadable();
boolean httpMsgNeedRelease = true; boolean httpMsgNeedRelease = true;
SimpleChannelPromiseAggregator promiseAggregator = null;
try { try {
// Provide the user the opportunity to specify the streamId // Provide the user the opportunity to specify the streamId
int streamId = getStreamId(httpMsg.headers()); int streamId = getStreamId(httpMsg.headers());
@ -74,18 +75,20 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler {
Http2ConnectionEncoder encoder = encoder(); Http2ConnectionEncoder encoder = encoder();
if (hasData) { if (hasData) {
ChannelPromiseAggregator promiseAggregator = new ChannelPromiseAggregator(promise); promiseAggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor());
ChannelPromise headerPromise = ctx.newPromise(); encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, promiseAggregator.newPromise());
ChannelPromise dataPromise = ctx.newPromise();
promiseAggregator.add(headerPromise, dataPromise);
encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, headerPromise);
httpMsgNeedRelease = false; httpMsgNeedRelease = false;
encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, dataPromise); encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, promiseAggregator.newPromise());
promiseAggregator.doneAllocatingPromises();
} else { } else {
encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise); encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise);
} }
} catch (Throwable t) { } catch (Throwable t) {
promise.tryFailure(t); if (promiseAggregator == null) {
promise.tryFailure(t);
} else {
promiseAggregator.setFailure(t);
}
} finally { } finally {
if (httpMsgNeedRelease) { if (httpMsgNeedRelease) {
httpMsg.release(); httpMsg.release();

View File

@ -14,9 +14,10 @@
*/ */
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2Exception.connectionError;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.TooLongFrameException; import io.netty.handler.codec.TooLongFrameException;
@ -28,7 +29,6 @@ import io.netty.handler.codec.http.HttpHeaderUtil;
import io.netty.handler.codec.http.HttpStatusClass; import io.netty.handler.codec.http.HttpStatusClass;
import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap; import io.netty.util.collection.IntObjectMap;
import static io.netty.util.internal.ObjectUtil.*;
/** /**
* This adapter provides just header/data events from the HTTP message flow defined * This adapter provides just header/data events from the HTTP message flow defined

View File

@ -18,22 +18,33 @@ package io.netty.handler.codec.http2;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT;
import static io.netty.handler.codec.http2.Http2TestUtil.as; import static io.netty.handler.codec.http2.Http2TestUtil.as;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString; import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.concurrent.EventExecutor;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** /**
* Integration tests for {@link DefaultHttp2FrameReader} and {@link DefaultHttp2FrameWriter}. * Integration tests for {@link DefaultHttp2FrameReader} and {@link DefaultHttp2FrameWriter}.
@ -43,6 +54,8 @@ public class DefaultHttp2FrameIOTest {
private DefaultHttp2FrameReader reader; private DefaultHttp2FrameReader reader;
private DefaultHttp2FrameWriter writer; private DefaultHttp2FrameWriter writer;
private ByteBufAllocator alloc; private ByteBufAllocator alloc;
private CountDownLatch latch;
private ByteBuf buffer;
@Mock @Mock
private ChannelHandlerContext ctx; private ChannelHandlerContext ctx;
@ -53,13 +66,56 @@ public class DefaultHttp2FrameIOTest {
@Mock @Mock
private ChannelPromise promise; private ChannelPromise promise;
@Mock
private Channel channel;
@Mock
private EventExecutor executor;
@Before @Before
public void setup() { public void setup() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
alloc = UnpooledByteBufAllocator.DEFAULT; alloc = UnpooledByteBufAllocator.DEFAULT;
buffer = alloc.buffer();
latch = new CountDownLatch(1);
when(executor.inEventLoop()).thenReturn(true);
when(ctx.alloc()).thenReturn(alloc); when(ctx.alloc()).thenReturn(alloc);
when(ctx.channel()).thenReturn(channel);
when(ctx.executor()).thenReturn(executor);
doAnswer(new Answer<ChannelPromise>() {
@Override
public ChannelPromise answer(InvocationOnMock invocation) throws Throwable {
return new DefaultChannelPromise(channel, executor);
}
}).when(ctx).newPromise();
doAnswer(new Answer<ChannelPromise>() {
@Override
public ChannelPromise answer(InvocationOnMock in) throws Throwable {
latch.countDown();
return promise;
}
}).when(promise).setSuccess();
doAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock in) throws Throwable {
if (in.getArguments()[0] instanceof ByteBuf) {
ByteBuf tmp = (ByteBuf) in.getArguments()[0];
try {
buffer.writeBytes(tmp);
} finally {
tmp.release();
}
}
if (in.getArguments()[1] instanceof ChannelPromise) {
return ((ChannelPromise) in.getArguments()[1]).setSuccess();
}
return null;
}
}).when(ctx).write(any(), any(ChannelPromise.class));
reader = new DefaultHttp2FrameReader(); reader = new DefaultHttp2FrameReader();
writer = new DefaultHttp2FrameWriter(); writer = new DefaultHttp2FrameWriter();
@ -452,10 +508,9 @@ public class DefaultHttp2FrameIOTest {
} }
} }
private ByteBuf captureWrite() { private ByteBuf captureWrite() throws InterruptedException {
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); assertTrue(latch.await(2, TimeUnit.SECONDS));
verify(ctx).write(captor.capture(), eq(promise)); return buffer;
return captor.getValue();
} }
private ByteBuf dummyData() { private ByteBuf dummyData() {
@ -471,9 +526,8 @@ public class DefaultHttp2FrameIOTest {
} }
private static Http2Headers dummyHeaders() { private static Http2Headers dummyHeaders() {
return new DefaultHttp2Headers().method(as("GET")).scheme(as("https")) return new DefaultHttp2Headers().method(as("GET")).scheme(as("https")).authority(as("example.org"))
.authority(as("example.org")).path(as("/some/path")) .path(as("/some/path")).add(as("accept"), as("*/*"));
.add(as("accept"), as("*/*"));
} }
private static Http2Headers largeHeaders() { private static Http2Headers largeHeaders() {

View File

@ -70,7 +70,7 @@ import org.mockito.stubbing.Answer;
* Testing the {@link HttpToHttp2ConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames * Testing the {@link HttpToHttp2ConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames
*/ */
public class HttpToHttp2ConnectionHandlerTest { public class HttpToHttp2ConnectionHandlerTest {
private static final int WAIT_TIME_SECONDS = 5; private static final int WAIT_TIME_SECONDS = 500;
@Mock @Mock
private Http2FrameListener clientListener; private Http2FrameListener clientListener;