diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java index 998bc2816a..1f6c44c3bd 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java @@ -15,21 +15,29 @@ package io.netty.handler.codec.http2; -import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2CodecUtil.CONTINUATION_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.DATA_FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.GO_AWAY_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.HEADERS_FRAME_HEADER_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_BYTE; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT; import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_FRAME_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PUSH_PROMISE_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.RST_STREAM_FRAME_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.WINDOW_UPDATE_FRAME_LENGTH; import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid; -import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeader; +import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeaderInternal; import static io.netty.handler.codec.http2.Http2CodecUtil.writeUnsignedInt; import static io.netty.handler.codec.http2.Http2CodecUtil.writeUnsignedShort; import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION; import static io.netty.handler.codec.http2.Http2FrameTypes.DATA; import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY; @@ -42,10 +50,11 @@ import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE; import static io.netty.util.internal.ObjectUtil.checkNotNull; import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; import io.netty.handler.codec.http2.Http2FrameWriter.Configuration; import io.netty.util.collection.IntObjectMap; @@ -55,6 +64,12 @@ import io.netty.util.collection.IntObjectMap; public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSizePolicy, Configuration { private static final String STREAM_ID = "Stream ID"; private static final String STREAM_DEPENDENCY = "Stream Dependency"; + /** + * This buffer is allocated to the maximum padding size needed, and filled with padding. + * When padding is needed it can be taken as a slice of this buffer. Users should call {@link ByteBuf#retain()} + * before using their slice. + */ + private static final ByteBuf ZERO_BUFFER = Unpooled.buffer(MAX_UNSIGNED_BYTE).writeZero(MAX_UNSIGNED_BYTE); private final Http2HeadersEncoder headersEncoder; private int maxFrameSize; @@ -102,6 +117,10 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize @Override public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endStream, ChannelPromise promise) { + boolean releaseData = true; + ByteBuf buf = null; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); try { verifyStreamId(streamId, STREAM_ID); verifyPadding(padding); @@ -111,38 +130,40 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize int payloadLength = data.readableBytes() + padding + flags.getPaddingPresenceFieldLength(); verifyPayloadLength(payloadLength); - ByteBuf out = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); - - writeFrameHeader(out, payloadLength, DATA, flags, streamId); - - writePaddingLength(padding, out); + buf = ctx.alloc().buffer(DATA_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, DATA, flags, streamId); + writePaddingLength(buf, padding); + ctx.write(buf, promiseAggregator.newPromise()); // Write the data. - out.writeBytes(data, data.readerIndex(), data.readableBytes()); + releaseData = false; + ctx.write(data, promiseAggregator.newPromise()); - // Write the required padding. - out.writeZero(padding); - return ctx.write(out, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); - } finally { - data.release(); + if (padding > 0) { // Write the required padding. + ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise()); + } + return promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + if (releaseData) { + data.release(); + } + return promiseAggregator.setFailure(t); } } @Override public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, boolean endStream, ChannelPromise promise) { - return writeHeadersInternal(ctx, promise, streamId, headers, padding, endStream, - false, 0, (short) 0, false); + return writeHeadersInternal(ctx, streamId, headers, padding, endStream, + false, 0, (short) 0, false, promise); } @Override public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, short weight, boolean exclusive, int padding, boolean endStream, ChannelPromise promise) { - return writeHeadersInternal(ctx, promise, streamId, headers, padding, endStream, - true, streamDependency, weight, exclusive); + return writeHeadersInternal(ctx, streamId, headers, padding, endStream, + true, streamDependency, weight, exclusive, promise); } @Override @@ -153,17 +174,15 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize verifyStreamId(streamDependency, STREAM_DEPENDENCY); verifyWeight(weight); - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + PRIORITY_ENTRY_LENGTH); - writeFrameHeader(frame, PRIORITY_ENTRY_LENGTH, PRIORITY, - new Http2Flags(), streamId); + ByteBuf buf = ctx.alloc().buffer(PRIORITY_FRAME_LENGTH); + writeFrameHeaderInternal(buf, PRIORITY_ENTRY_LENGTH, PRIORITY, new Http2Flags(), streamId); long word1 = exclusive ? 0x80000000L | streamDependency : streamDependency; - writeUnsignedInt(word1, frame); - + writeUnsignedInt(word1, buf); // Adjust the weight so that it fits into a single byte on the wire. - frame.writeByte(weight - 1); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); + buf.writeByte(weight - 1); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); } } @@ -174,13 +193,12 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize verifyStreamId(streamId, STREAM_ID); verifyErrorCode(errorCode); - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + INT_FIELD_LENGTH); - writeFrameHeader(frame, INT_FIELD_LENGTH, RST_STREAM, new Http2Flags(), - streamId); - writeUnsignedInt(errorCode, frame); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); + ByteBuf buf = ctx.alloc().buffer(RST_STREAM_FRAME_LENGTH); + writeFrameHeaderInternal(buf, INT_FIELD_LENGTH, RST_STREAM, new Http2Flags(), streamId); + writeUnsignedInt(errorCode, buf); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); } } @@ -190,44 +208,51 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize try { checkNotNull(settings, "settings"); int payloadLength = SETTING_ENTRY_LENGTH * settings.size(); - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); - writeFrameHeader(frame, payloadLength, SETTINGS, new Http2Flags(), 0); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH + settings.size() * SETTING_ENTRY_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, SETTINGS, new Http2Flags(), 0); for (IntObjectMap.Entry entry : settings.entries()) { - writeUnsignedShort(entry.key(), frame); - writeUnsignedInt(entry.value(), frame); + writeUnsignedShort(entry.key(), buf); + writeUnsignedInt(entry.value(), buf); } - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); } } @Override public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { try { - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH); - writeFrameHeader(frame, 0, SETTINGS, new Http2Flags().ack(true), 0); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, 0, SETTINGS, new Http2Flags().ack(true), 0); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); } } @Override public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, ByteBuf data, ChannelPromise promise) { + boolean releaseData = true; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); try { - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + data.readableBytes()); Http2Flags flags = ack ? new Http2Flags().ack(true) : new Http2Flags(); - writeFrameHeader(frame, data.readableBytes(), PING, flags, 0); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, data.readableBytes(), PING, flags, 0); + ctx.write(buf, promiseAggregator.newPromise()); // Write the debug data. - frame.writeBytes(data, data.readerIndex(), data.readableBytes()); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); - } finally { - data.release(); + releaseData = false; + ctx.write(data, promiseAggregator.newPromise()); + + return promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + if (releaseData) { + data.release(); + } + return promiseAggregator.setFailure(t); } } @@ -235,6 +260,8 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int promisedStreamId, Http2Headers headers, int padding, ChannelPromise promise) { ByteBuf headerBlock = null; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); try { verifyStreamId(streamId, STREAM_ID); verifyStreamId(promisedStreamId, "Promised Stream ID"); @@ -246,38 +273,37 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize // Read the first fragment (possibly everything). Http2Flags flags = new Http2Flags().paddingPresent(padding > 0); - int promisedStreamIdLength = INT_FIELD_LENGTH; - int nonFragmentLength = promisedStreamIdLength + padding + flags.getPaddingPresenceFieldLength(); + // INT_FIELD_LENGTH is for the length of the promisedStreamId + int nonFragmentLength = INT_FIELD_LENGTH + padding + flags.getPaddingPresenceFieldLength(); int maxFragmentLength = maxFrameSize - nonFragmentLength; ByteBuf fragment = - headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)); + headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)).retain(); - flags.endOfHeaders(headerBlock.readableBytes() == 0); + flags.endOfHeaders(!headerBlock.isReadable()); int payloadLength = fragment.readableBytes() + nonFragmentLength; - ByteBuf firstFrame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); - writeFrameHeader(firstFrame, payloadLength, PUSH_PROMISE, flags, - streamId); - - writePaddingLength(padding, firstFrame); + ByteBuf buf = ctx.alloc().buffer(PUSH_PROMISE_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, PUSH_PROMISE, flags, streamId); + writePaddingLength(buf, padding); // Write out the promised stream ID. - firstFrame.writeInt(promisedStreamId); + buf.writeInt(promisedStreamId); + ctx.write(buf, promiseAggregator.newPromise()); // Write the first fragment. - firstFrame.writeBytes(fragment); + ctx.write(fragment, promiseAggregator.newPromise()); - // Write out the padding, if any. - firstFrame.writeZero(padding); - - if (headerBlock.readableBytes() == 0) { - return ctx.write(firstFrame, promise); + if (padding > 0) { // Write out the padding, if any. + ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise()); } - // Create a composite buffer wrapping the first frame and any continuation frames. - return continueHeaders(ctx, promise, streamId, padding, headerBlock, firstFrame); - } catch (Exception e) { - return promise.setFailure(e); + if (!flags.endOfHeaders()) { + writeContinuationFrames(ctx, streamId, headerBlock, padding, promiseAggregator); + } + + return promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + return promiseAggregator.setFailure(t); } finally { if (headerBlock != null) { headerBlock.release(); @@ -288,21 +314,28 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize @Override public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData, ChannelPromise promise) { + boolean releaseData = true; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); try { verifyStreamOrConnectionId(lastStreamId, "Last Stream ID"); verifyErrorCode(errorCode); int payloadLength = 8 + debugData.readableBytes(); - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); - writeFrameHeader(frame, payloadLength, GO_AWAY, new Http2Flags(), 0); - frame.writeInt(lastStreamId); - writeUnsignedInt(errorCode, frame); - frame.writeBytes(debugData, debugData.readerIndex(), debugData.readableBytes()); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); - } finally { - debugData.release(); + ByteBuf buf = ctx.alloc().buffer(GO_AWAY_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, GO_AWAY, new Http2Flags(), 0); + buf.writeInt(lastStreamId); + writeUnsignedInt(errorCode, buf); + ctx.write(buf, promiseAggregator.newPromise()); + + releaseData = false; + ctx.write(debugData, promiseAggregator.newPromise()); + return promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + if (releaseData) { + debugData.release(); + } + return promiseAggregator.setFailure(t); } } @@ -313,34 +346,44 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize verifyStreamOrConnectionId(streamId, STREAM_ID); verifyWindowSizeIncrement(windowSizeIncrement); - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + INT_FIELD_LENGTH); - writeFrameHeader(frame, INT_FIELD_LENGTH, WINDOW_UPDATE, - new Http2Flags(), streamId); - frame.writeInt(windowSizeIncrement); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); + ByteBuf buf = ctx.alloc().buffer(WINDOW_UPDATE_FRAME_LENGTH); + writeFrameHeaderInternal(buf, INT_FIELD_LENGTH, WINDOW_UPDATE, new Http2Flags(), streamId); + buf.writeInt(windowSizeIncrement); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); } } @Override public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload, ChannelPromise promise) { + boolean releaseData = true; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); try { verifyStreamOrConnectionId(streamId, STREAM_ID); - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payload.readableBytes()); - writeFrameHeader(frame, payload.readableBytes(), frameType, flags, streamId); - frame.writeBytes(payload); - return ctx.write(frame, promise); - } catch (RuntimeException e) { - return promise.setFailure(e); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payload.readableBytes(), frameType, flags, streamId); + ctx.write(buf, promiseAggregator.newPromise()); + + releaseData = false; + ctx.write(payload, promiseAggregator.newPromise()); + return promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + if (releaseData) { + payload.release(); + } + return promiseAggregator.setFailure(t); } } - private ChannelFuture writeHeadersInternal(ChannelHandlerContext ctx, ChannelPromise promise, + private ChannelFuture writeHeadersInternal(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, boolean endStream, - boolean hasPriority, int streamDependency, short weight, boolean exclusive) { + boolean hasPriority, int streamDependency, short weight, boolean exclusive, ChannelPromise promise) { ByteBuf headerBlock = null; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); try { verifyStreamId(streamId, STREAM_ID); if (hasPriority) { @@ -357,46 +400,42 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize new Http2Flags().endOfStream(endStream).priorityPresent(hasPriority).paddingPresent(padding > 0); // Read the first fragment (possibly everything). - int nonFragmentBytes = - padding + flags.getNumPriorityBytes() + flags.getPaddingPresenceFieldLength(); + int nonFragmentBytes = padding + flags.getNumPriorityBytes() + flags.getPaddingPresenceFieldLength(); int maxFragmentLength = maxFrameSize - nonFragmentBytes; ByteBuf fragment = - headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)); + headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)).retain(); // Set the end of headers flag for the first frame. - flags.endOfHeaders(headerBlock.readableBytes() == 0); + flags.endOfHeaders(!headerBlock.isReadable()); int payloadLength = fragment.readableBytes() + nonFragmentBytes; - ByteBuf firstFrame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); - writeFrameHeader(firstFrame, payloadLength, HEADERS, flags, - streamId); + ByteBuf buf = ctx.alloc().buffer(HEADERS_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, HEADERS, flags, streamId); + writePaddingLength(buf, padding); - // Write the padding length. - writePaddingLength(padding, firstFrame); - - // Write the priority. if (hasPriority) { long word1 = exclusive ? 0x80000000L | streamDependency : streamDependency; - writeUnsignedInt(word1, firstFrame); + writeUnsignedInt(word1, buf); // Adjust the weight so that it fits into a single byte on the wire. - firstFrame.writeByte(weight - 1); + buf.writeByte(weight - 1); } + ctx.write(buf, promiseAggregator.newPromise()); // Write the first fragment. - firstFrame.writeBytes(fragment); + ctx.write(fragment, promiseAggregator.newPromise()); - // Write out the padding, if any. - firstFrame.writeZero(padding); - - if (flags.endOfHeaders()) { - return ctx.write(firstFrame, promise); + if (padding > 0) { // Write out the padding, if any. + ctx.write(ZERO_BUFFER.slice(0, padding).retain(), promiseAggregator.newPromise()); } - // Create a composite buffer wrapping the first frame and any continuation frames. - return continueHeaders(ctx, promise, streamId, padding, headerBlock, firstFrame); - } catch (Exception e) { - return promise.setFailure(e); + if (!flags.endOfHeaders()) { + writeContinuationFrames(ctx, streamId, headerBlock, padding, promiseAggregator); + } + + return promiseAggregator.doneAllocatingPromises(); + } catch (Throwable t) { + return promiseAggregator.setFailure(t); } finally { if (headerBlock != null) { headerBlock.release(); @@ -405,66 +444,60 @@ public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSize } /** - * Drains the header block and creates a composite buffer containing the first frame and a - * number of CONTINUATION frames. + * Writes as many continuation frames as needed until {@code padding} and {@code headerBlock} are consumed. */ - private ChannelFuture continueHeaders(ChannelHandlerContext ctx, ChannelPromise promise, - int streamId, int padding, ByteBuf headerBlock, ByteBuf firstFrame) { - // Create a composite buffer wrapping the first frame and any continuation frames. - CompositeByteBuf out = ctx.alloc().compositeBuffer(); - out.addComponent(firstFrame); - int numBytes = firstFrame.readableBytes(); - - // Process any continuation frames there might be. - while (headerBlock.isReadable()) { - ByteBuf frame = createContinuationFrame(ctx, streamId, headerBlock, padding); - out.addComponent(frame); - numBytes += frame.readableBytes(); - } - - out.writerIndex(numBytes); - return ctx.write(out, promise); - } - - /** - * Allocates a new buffer and writes a single continuation frame with a fragment of the header - * block to the output buffer. - */ - private ByteBuf createContinuationFrame(ChannelHandlerContext ctx, int streamId, - ByteBuf headerBlock, int padding) { + private ChannelFuture writeContinuationFrames(ChannelHandlerContext ctx, int streamId, + ByteBuf headerBlock, int padding, SimpleChannelPromiseAggregator promiseAggregator) { Http2Flags flags = new Http2Flags().paddingPresent(padding > 0); int nonFragmentLength = padding + flags.getPaddingPresenceFieldLength(); int maxFragmentLength = maxFrameSize - nonFragmentLength; - ByteBuf fragment = - headerBlock.readSlice(Math.min(headerBlock.readableBytes(), maxFragmentLength)); + // TODO: same padding is applied to all frames, is this desired? + if (maxFragmentLength <= 0) { + return promiseAggregator.setFailure(new IllegalArgumentException( + "Padding [" + padding + "] is too large for max frame size [" + maxFrameSize + "]")); + } - int payloadLength = fragment.readableBytes() + nonFragmentLength; - ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); - flags = flags.endOfHeaders(headerBlock.readableBytes() == 0); + if (headerBlock.isReadable()) { + // The frame header (and padding) only changes on the last frame, so allocate it once and re-use + final ByteBuf paddingBuf = padding > 0 ? ZERO_BUFFER.slice(0, padding) : null; + int fragmentReadableBytes = Math.min(headerBlock.readableBytes(), maxFragmentLength); + int payloadLength = fragmentReadableBytes + nonFragmentLength; + ByteBuf buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, CONTINUATION, flags, streamId); + writePaddingLength(buf, padding); - writeFrameHeader(frame, payloadLength, CONTINUATION, flags, streamId); + do { + fragmentReadableBytes = Math.min(headerBlock.readableBytes(), maxFragmentLength); + ByteBuf fragment = headerBlock.readSlice(fragmentReadableBytes).retain(); - writePaddingLength(padding, frame); + payloadLength = fragmentReadableBytes + nonFragmentLength; + if (headerBlock.isReadable()) { + ctx.write(buf.retain(), promiseAggregator.newPromise()); + } else { + // The frame header is different for the last frame, so re-allocate and release the old buffer + flags = flags.endOfHeaders(true); + buf.release(); + buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, CONTINUATION, flags, streamId); + writePaddingLength(buf, padding); + ctx.write(buf, promiseAggregator.newPromise()); + } - frame.writeBytes(fragment); + ctx.write(fragment, promiseAggregator.newPromise()); - // Write out the padding, if any. - frame.writeZero(padding); - return frame; + // Write out the padding, if any. + if (paddingBuf != null) { + ctx.write(paddingBuf.retain(), promiseAggregator.newPromise()); + } + } while(headerBlock.isReadable()); + } + return promiseAggregator; } - /** - * Writes the padding length field to the output buffer. - */ - private static void writePaddingLength(int paddingLength, ByteBuf out) { - if (paddingLength > MAX_UNSIGNED_BYTE) { - int padHigh = paddingLength / 256; - out.writeByte(padHigh); - } - // Always include PadLow if there is any padding at all. + private static void writePaddingLength(ByteBuf buf, int paddingLength) { if (paddingLength > 0) { - int padLow = paddingLength % 256; - out.writeByte(padLow); + // It is assumed that the padding length has been bounds checked before this + buf.writeByte(paddingLength); } } diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java index 3b0be13d59..fc330620d9 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java @@ -19,9 +19,12 @@ import static io.netty.util.CharsetUtil.UTF_8; import static io.netty.util.internal.ObjectUtil.checkNotNull; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action; +import io.netty.util.concurrent.EventExecutor; /** * Constants and utility method used for encoding/decoding HTTP2 frames. @@ -48,6 +51,18 @@ public final class Http2CodecUtil { public static final short MAX_WEIGHT = 256; public static final short MIN_WEIGHT = 1; + private static final int MAX_PADDING_LENGTH_LENGTH = 1; + public static final int DATA_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH; + public static final int HEADERS_FRAME_HEADER_LENGTH = + FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH + 1; + public static final int PRIORITY_FRAME_LENGTH = FRAME_HEADER_LENGTH + PRIORITY_ENTRY_LENGTH; + public static final int RST_STREAM_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH; + public static final int PUSH_PROMISE_FRAME_HEADER_LENGTH = + FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH; + public static final int GO_AWAY_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + 2 * INT_FIELD_LENGTH; + public static final int WINDOW_UPDATE_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH; + public static final int CONTINUATION_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH; + public static final int SETTINGS_HEADER_TABLE_SIZE = 1; public static final int SETTINGS_ENABLE_PUSH = 2; public static final int SETTINGS_MAX_CONCURRENT_STREAMS = 3; @@ -183,12 +198,128 @@ public final class Http2CodecUtil { public static void writeFrameHeader(ByteBuf out, int payloadLength, byte type, Http2Flags flags, int streamId) { out.ensureWritable(FRAME_HEADER_LENGTH + payloadLength); + writeFrameHeaderInternal(out, payloadLength, type, flags, streamId); + } + + static void writeFrameHeaderInternal(ByteBuf out, int payloadLength, byte type, + Http2Flags flags, int streamId) { out.writeMedium(payloadLength); out.writeByte(type); out.writeByte(flags.value()); out.writeInt(streamId); } + /** + * Provides the ability to associate the outcome of multiple {@link ChannelPromise} + * objects into a single {@link ChannelPromise} object. + */ + static class SimpleChannelPromiseAggregator extends DefaultChannelPromise { + private final ChannelPromise promise; + private int expectedCount; + private int successfulCount; + private int failureCount; + private boolean doneAllocating; + + SimpleChannelPromiseAggregator(ChannelPromise promise, Channel c, EventExecutor e) { + super(c, e); + assert promise != null; + this.promise = promise; + } + + /** + * Allocate a new promise which will be used to aggregate the overall success of this promise aggregator. + * @return A new promise which will be aggregated. + * {@code null} if {@link #doneAllocatingPromises()} was previously called. + */ + public ChannelPromise newPromise() { + if (doneAllocating) { + throw new IllegalStateException("Done allocating. No more promises can be allocated."); + } + ++expectedCount; + return this; + } + + /** + * Signify that no more {@link #newPromise()} allocations will be made. + * The aggregation can not be successful until this method is called. + * @return The promise that is the aggregation of all promises allocated with {@link #newPromise()}. + */ + public ChannelPromise doneAllocatingPromises() { + if (!doneAllocating) { + doneAllocating = true; + if (successfulCount == expectedCount) { + promise.setSuccess(); + return super.setSuccess(); + } + } + return this; + } + + @Override + public boolean tryFailure(Throwable cause) { + if (allowNotificationEvent()) { + ++failureCount; + if (failureCount == 1) { + promise.tryFailure(cause); + return super.tryFailure(cause); + } + // TODO: We break the interface a bit here. + // Multiple failure events can be processed without issue because this is an aggregation. + return true; + } + return false; + } + + /** + * Fail this object if it has not already been failed. + *

+ * This method will NOT throw an {@link IllegalStateException} if called multiple times + * because that may be expected. + */ + @Override + public ChannelPromise setFailure(Throwable cause) { + if (allowNotificationEvent()) { + ++failureCount; + if (failureCount == 1) { + promise.setFailure(cause); + return super.setFailure(cause); + } + } + return this; + } + + private boolean allowNotificationEvent() { + return successfulCount + failureCount < expectedCount; + } + + @Override + public ChannelPromise setSuccess(Void result) { + if (allowNotificationEvent()) { + ++successfulCount; + if (successfulCount == expectedCount && doneAllocating) { + promise.setSuccess(result); + return super.setSuccess(result); + } + } + return this; + } + + @Override + public boolean trySuccess(Void result) { + if (allowNotificationEvent()) { + ++successfulCount; + if (successfulCount == expectedCount && doneAllocating) { + promise.trySuccess(result); + return super.trySuccess(result); + } + // TODO: We break the interface a bit here. + // Multiple success events can be processed without issue because this is an aggregation. + return true; + } + return false; + } + } + /** * Fails the given promise with the cause and then re-throws the cause. */ diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java index 36a3595105..03263e4734 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java @@ -28,7 +28,7 @@ public interface Http2DataWriter { * * @param ctx the context to use for writing. * @param streamId the stream for which to send the frame. - * @param data the payload of the frame. + * @param data the payload of the frame. This will be released by this method. * @param padding the amount of padding to be added to the end of the frame * @param endStream indicates if this is the last frame to be sent for the stream. * @param promise the promise for the write. diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java index c97824f773..43c8dec9f9 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java @@ -129,7 +129,7 @@ public interface Http2FrameWriter extends Http2DataWriter, Closeable { * @param ctx the context to use for writing. * @param ack indicates whether this is an ack of a PING frame previously received from the * remote endpoint. - * @param data the payload of the frame. + * @param data the payload of the frame. This will be released by this method. * @param promise the promise for the write. * @return the future for the write. */ @@ -156,7 +156,7 @@ public interface Http2FrameWriter extends Http2DataWriter, Closeable { * @param ctx the context to use for writing. * @param lastStreamId the last known stream of this endpoint. * @param errorCode the error code, if the connection was abnormally terminated. - * @param debugData application-defined debug data. + * @param debugData application-defined debug data. This will be released by this method. * @param promise the promise for the write. * @return the future for the write. */ @@ -183,7 +183,7 @@ public interface Http2FrameWriter extends Http2DataWriter, Closeable { * @param frameType the frame type identifier. * @param streamId the stream for which to send the frame. * @param flags the flags to write for this frame. - * @param payload the payload to write for this frame. + * @param payload the payload to write for this frame. This will be released by this method. * @param promise the promise for the write. * @return the future for the write. */ diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java index 035f4250bf..1633f170b4 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java @@ -17,9 +17,9 @@ package io.netty.handler.codec.http2; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.channel.ChannelPromiseAggregator; import io.netty.handler.codec.http.FullHttpMessage; import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; /** * Translates HTTP/1.x object writes into HTTP/2 frames. @@ -65,6 +65,7 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler { FullHttpMessage httpMsg = (FullHttpMessage) msg; boolean hasData = httpMsg.content().isReadable(); boolean httpMsgNeedRelease = true; + SimpleChannelPromiseAggregator promiseAggregator = null; try { // Provide the user the opportunity to specify the streamId int streamId = getStreamId(httpMsg.headers()); @@ -74,18 +75,20 @@ public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler { Http2ConnectionEncoder encoder = encoder(); if (hasData) { - ChannelPromiseAggregator promiseAggregator = new ChannelPromiseAggregator(promise); - ChannelPromise headerPromise = ctx.newPromise(); - ChannelPromise dataPromise = ctx.newPromise(); - promiseAggregator.add(headerPromise, dataPromise); - encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, headerPromise); + promiseAggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + encoder.writeHeaders(ctx, streamId, http2Headers, 0, false, promiseAggregator.newPromise()); httpMsgNeedRelease = false; - encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, dataPromise); + encoder.writeData(ctx, streamId, httpMsg.content(), 0, true, promiseAggregator.newPromise()); + promiseAggregator.doneAllocatingPromises(); } else { encoder.writeHeaders(ctx, streamId, http2Headers, 0, true, promise); } } catch (Throwable t) { - promise.tryFailure(t); + if (promiseAggregator == null) { + promise.tryFailure(t); + } else { + promiseAggregator.setFailure(t); + } } finally { if (httpMsgNeedRelease) { httpMsg.release(); diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java index 890a02f1a1..47f08a7c0f 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java @@ -14,9 +14,10 @@ */ package io.netty.handler.codec.http2; -import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.ObjectUtil.checkNotNull; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.TooLongFrameException; @@ -28,7 +29,6 @@ import io.netty.handler.codec.http.HttpHeaderUtil; import io.netty.handler.codec.http.HttpStatusClass; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; -import static io.netty.util.internal.ObjectUtil.*; /** * This adapter provides just header/data events from the HTTP message flow defined diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameIOTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameIOTest.java index 7fafd46a4b..ef98c58372 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameIOTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameIOTest.java @@ -18,22 +18,33 @@ package io.netty.handler.codec.http2; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; import static io.netty.handler.codec.http2.Http2TestUtil.as; import static io.netty.handler.codec.http2.Http2TestUtil.randomString; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.EventExecutor; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** * Integration tests for {@link DefaultHttp2FrameReader} and {@link DefaultHttp2FrameWriter}. @@ -43,6 +54,8 @@ public class DefaultHttp2FrameIOTest { private DefaultHttp2FrameReader reader; private DefaultHttp2FrameWriter writer; private ByteBufAllocator alloc; + private CountDownLatch latch; + private ByteBuf buffer; @Mock private ChannelHandlerContext ctx; @@ -53,13 +66,56 @@ public class DefaultHttp2FrameIOTest { @Mock private ChannelPromise promise; + @Mock + private Channel channel; + + @Mock + private EventExecutor executor; + @Before public void setup() { MockitoAnnotations.initMocks(this); alloc = UnpooledByteBufAllocator.DEFAULT; + buffer = alloc.buffer(); + latch = new CountDownLatch(1); + when(executor.inEventLoop()).thenReturn(true); when(ctx.alloc()).thenReturn(alloc); + when(ctx.channel()).thenReturn(channel); + when(ctx.executor()).thenReturn(executor); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { + return new DefaultChannelPromise(channel, executor); + } + }).when(ctx).newPromise(); + + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock in) throws Throwable { + latch.countDown(); + return promise; + } + }).when(promise).setSuccess(); + + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) throws Throwable { + if (in.getArguments()[0] instanceof ByteBuf) { + ByteBuf tmp = (ByteBuf) in.getArguments()[0]; + try { + buffer.writeBytes(tmp); + } finally { + tmp.release(); + } + } + if (in.getArguments()[1] instanceof ChannelPromise) { + return ((ChannelPromise) in.getArguments()[1]).setSuccess(); + } + return null; + } + }).when(ctx).write(any(), any(ChannelPromise.class)); reader = new DefaultHttp2FrameReader(); writer = new DefaultHttp2FrameWriter(); @@ -452,10 +508,9 @@ public class DefaultHttp2FrameIOTest { } } - private ByteBuf captureWrite() { - ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); - verify(ctx).write(captor.capture(), eq(promise)); - return captor.getValue(); + private ByteBuf captureWrite() throws InterruptedException { + assertTrue(latch.await(2, TimeUnit.SECONDS)); + return buffer; } private ByteBuf dummyData() { @@ -471,9 +526,8 @@ public class DefaultHttp2FrameIOTest { } private static Http2Headers dummyHeaders() { - return new DefaultHttp2Headers().method(as("GET")).scheme(as("https")) - .authority(as("example.org")).path(as("/some/path")) - .add(as("accept"), as("*/*")); + return new DefaultHttp2Headers().method(as("GET")).scheme(as("https")).authority(as("example.org")) + .path(as("/some/path")).add(as("accept"), as("*/*")); } private static Http2Headers largeHeaders() { diff --git a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java index 04d5f7a743..8584998e86 100644 --- a/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java +++ b/codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java @@ -70,7 +70,7 @@ import org.mockito.stubbing.Answer; * Testing the {@link HttpToHttp2ConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames */ public class HttpToHttp2ConnectionHandlerTest { - private static final int WAIT_TIME_SECONDS = 5; + private static final int WAIT_TIME_SECONDS = 500; @Mock private Http2FrameListener clientListener;