Cleanup HTTP/2 tests for Http2FrameCodec and Http2MultiplexCodec (#8646)

Motiviation:

Http2FrameCodecTest and Http2MultiplexCodecTest were quite fragile and often not went through the whole pipeline which made testing sometimes hard and error-prone.

Modification:

- Refactor tests to have data flow through the whole pipeline and so made the test more robust (by testing the while implementation).

Result:

Easier to write tests for the codecs in the future and more robust testing in general.

Beside this it also fixes https://github.com/netty/netty/issues/6036.
This commit is contained in:
Norman Maurer 2018-12-14 10:10:04 +00:00
parent 66ccd1483a
commit fa84e2b3af
8 changed files with 1028 additions and 641 deletions

View File

@ -148,7 +148,7 @@ public class Http2FrameCodec extends Http2ConnectionHandler {
private final Integer initialFlowControlWindowSize; private final Integer initialFlowControlWindowSize;
private ChannelHandlerContext ctx; ChannelHandlerContext ctx;
/** Number of buffered streams if the {@link StreamBufferingEncoder} is used. **/ /** Number of buffered streams if the {@link StreamBufferingEncoder} is used. **/
private int numBufferedStreams; private int numBufferedStreams;

View File

@ -414,28 +414,11 @@ public class Http2MultiplexCodec extends Http2FrameCodec {
} }
} }
// Allow to override for testing final void flush0(ChannelHandlerContext ctx) {
void flush0(ChannelHandlerContext ctx) {
flush(ctx); flush(ctx);
} }
/** static final class Http2MultiplexCodecStream extends DefaultHttp2FrameStream {
* Return bytes to flow control.
* <p>
* Package private to allow to override for testing
* @param ctx The {@link ChannelHandlerContext} associated with the parent channel.
* @param stream The object representing the HTTP/2 stream.
* @param bytes The number of bytes to return to flow control.
* @return {@code true} if a frame has been written as a result of this method call.
* @throws Http2Exception If this operation violates the flow control limits.
*/
boolean onBytesConsumed(@SuppressWarnings("unused") ChannelHandlerContext ctx,
Http2FrameStream stream, int bytes) throws Http2Exception {
return consumeBytes(stream.id(), bytes);
}
// Allow to extend for testing
static class Http2MultiplexCodecStream extends DefaultHttp2FrameStream {
DefaultHttp2StreamChannel channel; DefaultHttp2StreamChannel channel;
} }
@ -1084,7 +1067,7 @@ public class Http2MultiplexCodec extends Http2FrameCodec {
allocHandle.lastBytesRead(numBytesToBeConsumed); allocHandle.lastBytesRead(numBytesToBeConsumed);
if (numBytesToBeConsumed != 0) { if (numBytesToBeConsumed != 0) {
try { try {
writeDoneAndNoFlush |= onBytesConsumed(ctx, stream, numBytesToBeConsumed); writeDoneAndNoFlush |= consumeBytes(stream.id(), numBytesToBeConsumed);
} catch (Http2Exception e) { } catch (Http2Exception e) {
pipeline().fireExceptionCaught(e); pipeline().fireExceptionCaught(e);
} }

View File

@ -27,6 +27,7 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull;
@UnstableApi @UnstableApi
public class Http2MultiplexCodecBuilder public class Http2MultiplexCodecBuilder
extends AbstractHttp2ConnectionHandlerBuilder<Http2MultiplexCodec, Http2MultiplexCodecBuilder> { extends AbstractHttp2ConnectionHandlerBuilder<Http2MultiplexCodec, Http2MultiplexCodecBuilder> {
private Http2FrameWriter frameWriter;
final ChannelHandler childHandler; final ChannelHandler childHandler;
private ChannelHandler upgradeStreamHandler; private ChannelHandler upgradeStreamHandler;
@ -44,6 +45,12 @@ public class Http2MultiplexCodecBuilder
return handler; return handler;
} }
// For testing only.
Http2MultiplexCodecBuilder frameWriter(Http2FrameWriter frameWriter) {
this.frameWriter = checkNotNull(frameWriter, "frameWriter");
return this;
}
/** /**
* Creates a builder for a HTTP/2 client. * Creates a builder for a HTTP/2 client.
* *
@ -160,6 +167,28 @@ public class Http2MultiplexCodecBuilder
@Override @Override
public Http2MultiplexCodec build() { public Http2MultiplexCodec build() {
Http2FrameWriter frameWriter = this.frameWriter;
if (frameWriter != null) {
// This is to support our tests and will never be executed by the user as frameWriter(...)
// is package-private.
DefaultHttp2Connection connection = new DefaultHttp2Connection(isServer(), maxReservedStreams());
Long maxHeaderListSize = initialSettings().maxHeaderListSize();
Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ?
new DefaultHttp2HeadersDecoder(true) :
new DefaultHttp2HeadersDecoder(true, maxHeaderListSize));
if (frameLogger() != null) {
frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger());
frameReader = new Http2InboundFrameLogger(frameReader, frameLogger());
}
Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter);
if (encoderEnforceMaxConcurrentStreams()) {
encoder = new StreamBufferingEncoder(encoder);
}
Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);
return build(decoder, encoder, initialSettings());
}
return super.build(); return super.build();
} }

View File

@ -15,9 +15,7 @@
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
@ -57,6 +55,11 @@ import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid;
import static io.netty.handler.codec.http2.Http2TestUtil.anyChannelPromise;
import static io.netty.handler.codec.http2.Http2TestUtil.anyHttp2Settings;
import static io.netty.handler.codec.http2.Http2TestUtil.assertEqualsAndRelease;
import static io.netty.handler.codec.http2.Http2TestUtil.bb;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
@ -73,7 +76,6 @@ import static org.mockito.Mockito.anyShort;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.same; import static org.mockito.Mockito.same;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -86,9 +88,10 @@ public class Http2FrameCodecTest {
private Http2FrameWriter frameWriter; private Http2FrameWriter frameWriter;
private Http2FrameCodec frameCodec; private Http2FrameCodec frameCodec;
private EmbeddedChannel channel; private EmbeddedChannel channel;
// For injecting inbound frames // For injecting inbound frames
private Http2FrameListener frameListener; private Http2FrameInboundWriter frameInboundWriter;
private ChannelHandlerContext http2HandlerCtx;
private LastInboundHandler inboundHandler; private LastInboundHandler inboundHandler;
private final Http2Headers request = new DefaultHttp2Headers() private final Http2Headers request = new DefaultHttp2Headers()
@ -123,29 +126,29 @@ public class Http2FrameCodecTest {
*/ */
tearDown(); tearDown();
frameWriter = spy(new VerifiableHttp2FrameWriter()); frameWriter = Http2TestUtil.mockedFrameWriter();
frameCodec = frameCodecBuilder.frameWriter(frameWriter).frameLogger(new Http2FrameLogger(LogLevel.TRACE)) frameCodec = frameCodecBuilder.frameWriter(frameWriter).frameLogger(new Http2FrameLogger(LogLevel.TRACE))
.initialSettings(initialRemoteSettings).build(); .initialSettings(initialRemoteSettings).build();
frameListener = ((DefaultHttp2ConnectionDecoder) frameCodec.decoder())
.internalFrameListener();
inboundHandler = new LastInboundHandler(); inboundHandler = new LastInboundHandler();
channel = new EmbeddedChannel(); channel = new EmbeddedChannel();
frameInboundWriter = new Http2FrameInboundWriter(channel);
channel.connect(new InetSocketAddress(0)); channel.connect(new InetSocketAddress(0));
channel.pipeline().addLast(frameCodec); channel.pipeline().addLast(frameCodec);
channel.pipeline().addLast(inboundHandler); channel.pipeline().addLast(inboundHandler);
channel.pipeline().fireChannelActive(); channel.pipeline().fireChannelActive();
http2HandlerCtx = channel.pipeline().context(frameCodec);
// Handshake // Handshake
verify(frameWriter).writeSettings(eq(http2HandlerCtx), verify(frameWriter).writeSettings(eqFrameCodecCtx(), anyHttp2Settings(), anyChannelPromise());
anyHttp2Settings(), anyChannelPromise());
verifyNoMoreInteractions(frameWriter); verifyNoMoreInteractions(frameWriter);
channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf());
frameListener.onSettingsRead(http2HandlerCtx, initialRemoteSettings);
verify(frameWriter).writeSettingsAck(eq(http2HandlerCtx), anyChannelPromise()); frameInboundWriter.writeInboundSettings(initialRemoteSettings);
frameListener.onSettingsAckRead(http2HandlerCtx);
verify(frameWriter).writeSettingsAck(eqFrameCodecCtx(), anyChannelPromise());
frameInboundWriter.writeInboundSettingsAck();
Http2SettingsFrame settingsFrame = inboundHandler.readInbound(); Http2SettingsFrame settingsFrame = inboundHandler.readInbound();
assertNotNull(settingsFrame); assertNotNull(settingsFrame);
@ -153,7 +156,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void stateChanges() throws Exception { public void stateChanges() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 1, request, 31, true); frameInboundWriter.writeInboundHeaders(1, request, 31, true);
Http2Stream stream = frameCodec.connection().stream(1); Http2Stream stream = frameCodec.connection().stream(1);
assertNotNull(stream); assertNotNull(stream);
@ -169,12 +172,12 @@ public class Http2FrameCodecTest {
assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2)); assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2));
assertNull(inboundHandler.readInbound()); assertNull(inboundHandler.readInbound());
inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); channel.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2));
verify(frameWriter).writeHeaders( verify(frameWriter).writeHeaders(
eq(http2HandlerCtx), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), eqFrameCodecCtx(), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(),
eq(27), eq(true), anyChannelPromise()); eq(27), eq(true), anyChannelPromise());
verify(frameWriter, never()).writeRstStream( verify(frameWriter, never()).writeRstStream(
any(ChannelHandlerContext.class), anyInt(), anyLong(), anyChannelPromise()); eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise());
assertEquals(State.CLOSED, stream.state()); assertEquals(State.CLOSED, stream.state());
event = inboundHandler.readInboundMessageOrUserEvent(); event = inboundHandler.readInboundMessageOrUserEvent();
@ -185,7 +188,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void headerRequestHeaderResponse() throws Exception { public void headerRequestHeaderResponse() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 1, request, 31, true); frameInboundWriter.writeInboundHeaders(1, request, 31, true);
Http2Stream stream = frameCodec.connection().stream(1); Http2Stream stream = frameCodec.connection().stream(1);
assertNotNull(stream); assertNotNull(stream);
@ -198,12 +201,12 @@ public class Http2FrameCodecTest {
assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2)); assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2));
assertNull(inboundHandler.readInbound()); assertNull(inboundHandler.readInbound());
inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); channel.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2));
verify(frameWriter).writeHeaders( verify(frameWriter).writeHeaders(
eq(http2HandlerCtx), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(), eqFrameCodecCtx(), eq(1), eq(response), anyInt(), anyShort(), anyBoolean(),
eq(27), eq(true), anyChannelPromise()); eq(27), eq(true), anyChannelPromise());
verify(frameWriter, never()).writeRstStream( verify(frameWriter, never()).writeRstStream(
any(ChannelHandlerContext.class), anyInt(), anyLong(), anyChannelPromise()); eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise());
assertEquals(State.CLOSED, stream.state()); assertEquals(State.CLOSED, stream.state());
assertTrue(channel.isActive()); assertTrue(channel.isActive());
@ -225,7 +228,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void entityRequestEntityResponse() throws Exception { public void entityRequestEntityResponse() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 1, request, 0, false); frameInboundWriter.writeInboundHeaders(1, request, 0, false);
Http2Stream stream = frameCodec.connection().stream(1); Http2Stream stream = frameCodec.connection().stream(1);
assertNotNull(stream); assertNotNull(stream);
@ -239,39 +242,35 @@ public class Http2FrameCodecTest {
assertNull(inboundHandler.readInbound()); assertNull(inboundHandler.readInbound());
ByteBuf hello = bb("hello"); ByteBuf hello = bb("hello");
frameListener.onDataRead(http2HandlerCtx, 1, hello, 31, true); frameInboundWriter.writeInboundData(1, hello, 31, true);
// Release hello to emulate ByteToMessageDecoder
hello.release();
Http2DataFrame inboundData = inboundHandler.readInbound(); Http2DataFrame inboundData = inboundHandler.readInbound();
Http2DataFrame expected = new DefaultHttp2DataFrame(bb("hello"), true, 31).stream(stream2); Http2DataFrame expected = new DefaultHttp2DataFrame(bb("hello"), true, 31).stream(stream2);
assertEquals(expected, inboundData); assertEqualsAndRelease(expected, inboundData);
assertEquals(1, inboundData.refCnt());
expected.release();
inboundData.release();
assertNull(inboundHandler.readInbound()); assertNull(inboundHandler.readInbound());
inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, false).stream(stream2)); channel.writeOutbound(new DefaultHttp2HeadersFrame(response, false).stream(stream2));
verify(frameWriter).writeHeaders(eq(http2HandlerCtx), eq(1), eq(response), anyInt(), verify(frameWriter).writeHeaders(eqFrameCodecCtx(), eq(1), eq(response), anyInt(),
anyShort(), anyBoolean(), eq(0), eq(false), anyChannelPromise()); anyShort(), anyBoolean(), eq(0), eq(false), anyChannelPromise());
inboundHandler.writeOutbound(new DefaultHttp2DataFrame(bb("world"), true, 27).stream(stream2)); channel.writeOutbound(new DefaultHttp2DataFrame(bb("world"), true, 27).stream(stream2));
ArgumentCaptor<ByteBuf> outboundData = ArgumentCaptor.forClass(ByteBuf.class); ArgumentCaptor<ByteBuf> outboundData = ArgumentCaptor.forClass(ByteBuf.class);
verify(frameWriter).writeData(eq(http2HandlerCtx), eq(1), outboundData.capture(), eq(27), verify(frameWriter).writeData(eqFrameCodecCtx(), eq(1), outboundData.capture(), eq(27),
eq(true), anyChannelPromise()); eq(true), anyChannelPromise());
ByteBuf bb = bb("world"); ByteBuf bb = bb("world");
assertEquals(bb, outboundData.getValue()); assertEquals(bb, outboundData.getValue());
assertEquals(1, outboundData.getValue().refCnt()); assertEquals(1, outboundData.getValue().refCnt());
bb.release(); bb.release();
verify(frameWriter, never()).writeRstStream( outboundData.getValue().release();
any(ChannelHandlerContext.class), anyInt(), anyLong(), anyChannelPromise());
verify(frameWriter, never()).writeRstStream(eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise());
assertTrue(channel.isActive()); assertTrue(channel.isActive());
} }
@Test @Test
public void sendRstStream() throws Exception { public void sendRstStream() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, true); frameInboundWriter.writeInboundHeaders(3, request, 31, true);
Http2Stream stream = frameCodec.connection().stream(3); Http2Stream stream = frameCodec.connection().stream(3);
assertNotNull(stream); assertNotNull(stream);
@ -285,16 +284,15 @@ public class Http2FrameCodecTest {
assertNotNull(stream2); assertNotNull(stream2);
assertEquals(3, stream2.id()); assertEquals(3, stream2.id());
inboundHandler.writeOutbound(new DefaultHttp2ResetFrame(314 /* non-standard error */).stream(stream2)); channel.writeOutbound(new DefaultHttp2ResetFrame(314 /* non-standard error */).stream(stream2));
verify(frameWriter).writeRstStream( verify(frameWriter).writeRstStream(eqFrameCodecCtx(), eq(3), eq(314L), anyChannelPromise());
eq(http2HandlerCtx), eq(3), eq(314L), anyChannelPromise());
assertEquals(State.CLOSED, stream.state()); assertEquals(State.CLOSED, stream.state());
assertTrue(channel.isActive()); assertTrue(channel.isActive());
} }
@Test @Test
public void receiveRstStream() throws Exception { public void receiveRstStream() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
Http2Stream stream = frameCodec.connection().stream(3); Http2Stream stream = frameCodec.connection().stream(3);
assertNotNull(stream); assertNotNull(stream);
@ -304,7 +302,7 @@ public class Http2FrameCodecTest {
Http2HeadersFrame actualHeaders = inboundHandler.readInbound(); Http2HeadersFrame actualHeaders = inboundHandler.readInbound();
assertEquals(expectedHeaders.stream(actualHeaders.stream()), actualHeaders); assertEquals(expectedHeaders.stream(actualHeaders.stream()), actualHeaders);
frameListener.onRstStreamRead(http2HandlerCtx, 3, Http2Error.NO_ERROR.code()); frameInboundWriter.writeInboundRstStream(3, Http2Error.NO_ERROR.code());
Http2ResetFrame expectedRst = new DefaultHttp2ResetFrame(Http2Error.NO_ERROR).stream(actualHeaders.stream()); Http2ResetFrame expectedRst = new DefaultHttp2ResetFrame(Http2Error.NO_ERROR).stream(actualHeaders.stream());
Http2ResetFrame actualRst = inboundHandler.readInbound(); Http2ResetFrame actualRst = inboundHandler.readInbound();
@ -315,8 +313,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void sendGoAway() throws Exception { public void sendGoAway() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
Http2Stream stream = frameCodec.connection().stream(3); Http2Stream stream = frameCodec.connection().stream(3);
assertNotNull(stream); assertNotNull(stream);
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
@ -324,32 +321,29 @@ public class Http2FrameCodecTest {
ByteBuf debugData = bb("debug"); ByteBuf debugData = bb("debug");
ByteBuf expected = debugData.copy(); ByteBuf expected = debugData.copy();
Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData.slice()); Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData);
goAwayFrame.setExtraStreamIds(2); goAwayFrame.setExtraStreamIds(2);
inboundHandler.writeOutbound(goAwayFrame); channel.writeOutbound(goAwayFrame);
verify(frameWriter).writeGoAway( verify(frameWriter).writeGoAway(eqFrameCodecCtx(), eq(7),
eq(http2HandlerCtx), eq(7), eq(Http2Error.NO_ERROR.code()), eq(expected), anyChannelPromise()); eq(Http2Error.NO_ERROR.code()), eq(expected), anyChannelPromise());
assertEquals(1, debugData.refCnt()); assertEquals(1, debugData.refCnt());
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
assertTrue(channel.isActive()); assertTrue(channel.isActive());
expected.release(); expected.release();
debugData.release();
} }
@Test @Test
public void receiveGoaway() throws Exception { public void receiveGoaway() throws Exception {
ByteBuf debugData = bb("foo"); ByteBuf debugData = bb("foo");
frameListener.onGoAwayRead(http2HandlerCtx, 2, Http2Error.NO_ERROR.code(), debugData); frameInboundWriter.writeInboundGoAway(2, Http2Error.NO_ERROR.code(), debugData);
// Release debugData to emulate ByteToMessageDecoder
debugData.release();
Http2GoAwayFrame expectedFrame = new DefaultHttp2GoAwayFrame(2, Http2Error.NO_ERROR.code(), bb("foo")); Http2GoAwayFrame expectedFrame = new DefaultHttp2GoAwayFrame(2, Http2Error.NO_ERROR.code(), bb("foo"));
Http2GoAwayFrame actualFrame = inboundHandler.readInbound(); Http2GoAwayFrame actualFrame = inboundHandler.readInbound();
assertEquals(expectedFrame, actualFrame); assertEqualsAndRelease(expectedFrame, actualFrame);
assertNull(inboundHandler.readInbound());
expectedFrame.release(); assertNull(inboundHandler.readInbound());
actualFrame.release();
} }
@Test @Test
@ -383,7 +377,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void goAwayLastStreamIdOverflowed() throws Exception { public void goAwayLastStreamIdOverflowed() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 5, request, 31, false); frameInboundWriter.writeInboundHeaders(5, request, 31, false);
Http2Stream stream = frameCodec.connection().stream(5); Http2Stream stream = frameCodec.connection().stream(5);
assertNotNull(stream); assertNotNull(stream);
@ -393,10 +387,10 @@ public class Http2FrameCodecTest {
Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData.slice()); Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(Http2Error.NO_ERROR.code(), debugData.slice());
goAwayFrame.setExtraStreamIds(Integer.MAX_VALUE); goAwayFrame.setExtraStreamIds(Integer.MAX_VALUE);
inboundHandler.writeOutbound(goAwayFrame); channel.writeOutbound(goAwayFrame);
// When the last stream id computation overflows, the last stream id should just be set to 2^31 - 1. // When the last stream id computation overflows, the last stream id should just be set to 2^31 - 1.
verify(frameWriter).writeGoAway(eq(http2HandlerCtx), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()), verify(frameWriter).writeGoAway(eqFrameCodecCtx(), eq(Integer.MAX_VALUE),
eq(debugData), anyChannelPromise()); eq(Http2Error.NO_ERROR.code()), eq(debugData), anyChannelPromise());
assertEquals(1, debugData.refCnt()); assertEquals(1, debugData.refCnt());
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
assertTrue(channel.isActive()); assertTrue(channel.isActive());
@ -404,13 +398,13 @@ public class Http2FrameCodecTest {
@Test @Test
public void streamErrorShouldFireExceptionForInbound() throws Exception { public void streamErrorShouldFireExceptionForInbound() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
Http2Stream stream = frameCodec.connection().stream(3); Http2Stream stream = frameCodec.connection().stream(3);
assertNotNull(stream); assertNotNull(stream);
StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo");
frameCodec.onError(http2HandlerCtx, false, streamEx); channel.pipeline().fireExceptionCaught(streamEx);
Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent();
assertEquals(Http2FrameStreamEvent.Type.State, event.type()); assertEquals(Http2FrameStreamEvent.Type.State, event.type());
@ -430,13 +424,13 @@ public class Http2FrameCodecTest {
@Test @Test
public void streamErrorShouldNotFireExceptionForOutbound() throws Exception { public void streamErrorShouldNotFireExceptionForOutbound() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
Http2Stream stream = frameCodec.connection().stream(3); Http2Stream stream = frameCodec.connection().stream(3);
assertNotNull(stream); assertNotNull(stream);
StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo");
frameCodec.onError(http2HandlerCtx, true, streamEx); frameCodec.onError(frameCodec.ctx, true, streamEx);
Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent();
assertEquals(Http2FrameStreamEvent.Type.State, event.type()); assertEquals(Http2FrameStreamEvent.Type.State, event.type());
@ -452,14 +446,14 @@ public class Http2FrameCodecTest {
@Test @Test
public void windowUpdateFrameDecrementsConsumedBytes() throws Exception { public void windowUpdateFrameDecrementsConsumedBytes() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
Http2Connection connection = frameCodec.connection(); Http2Connection connection = frameCodec.connection();
Http2Stream stream = connection.stream(3); Http2Stream stream = connection.stream(3);
assertNotNull(stream); assertNotNull(stream);
ByteBuf data = Unpooled.buffer(100).writeZero(100); ByteBuf data = Unpooled.buffer(100).writeZero(100);
frameListener.onDataRead(http2HandlerCtx, 3, data, 0, true); frameInboundWriter.writeInboundData(3, data, 0, false);
Http2HeadersFrame inboundHeaders = inboundHandler.readInbound(); Http2HeadersFrame inboundHeaders = inboundHandler.readInbound();
assertNotNull(inboundHeaders); assertNotNull(inboundHeaders);
@ -472,12 +466,11 @@ public class Http2FrameCodecTest {
int after = connection.local().flowController().unconsumedBytes(stream); int after = connection.local().flowController().unconsumedBytes(stream);
assertEquals(100, before - after); assertEquals(100, before - after);
assertTrue(f.isSuccess()); assertTrue(f.isSuccess());
data.release();
} }
@Test @Test
public void windowUpdateMayFail() throws Exception { public void windowUpdateMayFail() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
Http2Connection connection = frameCodec.connection(); Http2Connection connection = frameCodec.connection();
Http2Stream stream = connection.stream(3); Http2Stream stream = connection.stream(3);
assertNotNull(stream); assertNotNull(stream);
@ -496,10 +489,10 @@ public class Http2FrameCodecTest {
@Test @Test
public void inboundWindowUpdateShouldBeForwarded() throws Exception { public void inboundWindowUpdateShouldBeForwarded() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 31, false); frameInboundWriter.writeInboundHeaders(3, request, 31, false);
frameListener.onWindowUpdateRead(http2HandlerCtx, 3, 100); frameInboundWriter.writeInboundWindowUpdate(3, 100);
// Connection-level window update // Connection-level window update
frameListener.onWindowUpdateRead(http2HandlerCtx, 0, 100); frameInboundWriter.writeInboundWindowUpdate(0, 100);
Http2HeadersFrame headersFrame = inboundHandler.readInbound(); Http2HeadersFrame headersFrame = inboundHandler.readInbound();
assertNotNull(headersFrame); assertNotNull(headersFrame);
@ -558,7 +551,7 @@ public class Http2FrameCodecTest {
unknownFrame.stream(stream); unknownFrame.stream(stream);
channel.write(unknownFrame); channel.write(unknownFrame);
verify(frameWriter).writeFrame(eq(http2HandlerCtx), eq(unknownFrame.frameType()), verify(frameWriter).writeFrame(eqFrameCodecCtx(), eq(unknownFrame.frameType()),
eq(unknownFrame.stream().id()), eq(unknownFrame.flags()), eq(buffer), any(ChannelPromise.class)); eq(unknownFrame.stream().id()), eq(unknownFrame.flags()), eq(buffer), any(ChannelPromise.class));
} }
@ -567,7 +560,7 @@ public class Http2FrameCodecTest {
Http2Settings settings = new Http2Settings(); Http2Settings settings = new Http2Settings();
channel.write(new DefaultHttp2SettingsFrame(settings)); channel.write(new DefaultHttp2SettingsFrame(settings));
verify(frameWriter).writeSettings(eq(http2HandlerCtx), same(settings), any(ChannelPromise.class)); verify(frameWriter).writeSettings(eqFrameCodecCtx(), same(settings), any(ChannelPromise.class));
} }
@Test(timeout = 5000) @Test(timeout = 5000)
@ -619,7 +612,7 @@ public class Http2FrameCodecTest {
assertFalse(promise2.isDone()); assertFalse(promise2.isDone());
// Increase concurrent streams limit to 2 // Increase concurrent streams limit to 2
frameListener.onSettingsRead(http2HandlerCtx, new Http2Settings().maxConcurrentStreams(2)); frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(2));
channel.flush(); channel.flush();
@ -643,7 +636,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void receivePing() throws Http2Exception { public void receivePing() throws Http2Exception {
frameListener.onPingRead(http2HandlerCtx, 12345L); frameInboundWriter.writeInboundPing(false, 12345L);
Http2PingFrame pingFrame = inboundHandler.readInbound(); Http2PingFrame pingFrame = inboundHandler.readInbound();
assertNotNull(pingFrame); assertNotNull(pingFrame);
@ -656,13 +649,13 @@ public class Http2FrameCodecTest {
public void sendPing() { public void sendPing() {
channel.writeAndFlush(new DefaultHttp2PingFrame(12345)); channel.writeAndFlush(new DefaultHttp2PingFrame(12345));
verify(frameWriter).writePing(eq(http2HandlerCtx), eq(false), eq(12345L), anyChannelPromise()); verify(frameWriter).writePing(eqFrameCodecCtx(), eq(false), eq(12345L), anyChannelPromise());
} }
@Test @Test
public void receiveSettings() throws Http2Exception { public void receiveSettings() throws Http2Exception {
Http2Settings settings = new Http2Settings().maxConcurrentStreams(1); Http2Settings settings = new Http2Settings().maxConcurrentStreams(1);
frameListener.onSettingsRead(http2HandlerCtx, settings); frameInboundWriter.writeInboundSettings(settings);
Http2SettingsFrame settingsFrame = inboundHandler.readInbound(); Http2SettingsFrame settingsFrame = inboundHandler.readInbound();
assertNotNull(settingsFrame); assertNotNull(settingsFrame);
@ -674,7 +667,7 @@ public class Http2FrameCodecTest {
Http2Settings settings = new Http2Settings().maxConcurrentStreams(1); Http2Settings settings = new Http2Settings().maxConcurrentStreams(1);
channel.writeAndFlush(new DefaultHttp2SettingsFrame(settings)); channel.writeAndFlush(new DefaultHttp2SettingsFrame(settings));
verify(frameWriter).writeSettings(eq(http2HandlerCtx), eq(settings), anyChannelPromise()); verify(frameWriter).writeSettings(eqFrameCodecCtx(), eq(settings), anyChannelPromise());
} }
@Test @Test
@ -682,7 +675,7 @@ public class Http2FrameCodecTest {
setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true), setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true),
new Http2Settings().maxConcurrentStreams(1)); new Http2Settings().maxConcurrentStreams(1));
frameListener.onHeadersRead(http2HandlerCtx, 3, request, 0, false); frameInboundWriter.writeInboundHeaders(3, request, 0, false);
Http2HeadersFrame headersFrame = inboundHandler.readInbound(); Http2HeadersFrame headersFrame = inboundHandler.readInbound();
assertNotNull(headersFrame); assertNotNull(headersFrame);
@ -736,8 +729,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void upgradeEventNoRefCntError() throws Exception { public void upgradeEventNoRefCntError() throws Exception {
frameListener.onHeadersRead(http2HandlerCtx, Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); frameInboundWriter.writeInboundHeaders(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false);
// Using reflect as the constructor is package-private and the class is final. // Using reflect as the constructor is package-private and the class is final.
Constructor<UpgradeEvent> constructor = Constructor<UpgradeEvent> constructor =
UpgradeEvent.class.getDeclaredConstructor(CharSequence.class, FullHttpRequest.class); UpgradeEvent.class.getDeclaredConstructor(CharSequence.class, FullHttpRequest.class);
@ -753,7 +745,7 @@ public class Http2FrameCodecTest {
@Test @Test
public void upgradeWithoutFlowControlling() throws Exception { public void upgradeWithoutFlowControlling() throws Exception {
channel.pipeline().addAfter(http2HandlerCtx.name(), null, new ChannelInboundHandlerAdapter() { channel.pipeline().addAfter(frameCodec.ctx.name(), null, new ChannelInboundHandlerAdapter() {
@Override @Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof Http2DataFrame) { if (msg instanceof Http2DataFrame) {
@ -774,7 +766,7 @@ public class Http2FrameCodecTest {
} }
}); });
frameListener.onHeadersRead(http2HandlerCtx, Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); frameInboundWriter.writeInboundHeaders(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false);
// Using reflect as the constructor is package-private and the class is final. // Using reflect as the constructor is package-private and the class is final.
Constructor<UpgradeEvent> constructor = Constructor<UpgradeEvent> constructor =
@ -792,24 +784,7 @@ public class Http2FrameCodecTest {
channel.pipeline().fireUserEventTriggered(upgradeEvent); channel.pipeline().fireUserEventTriggered(upgradeEvent);
} }
private static ChannelPromise anyChannelPromise() { private ChannelHandlerContext eqFrameCodecCtx() {
return any(ChannelPromise.class); return eq(frameCodec.ctx);
}
private static Http2Settings anyHttp2Settings() {
return any(Http2Settings.class);
}
private static ByteBuf bb(String s) {
return ByteBufUtil.writeUtf8(UnpooledByteBufAllocator.DEFAULT, s);
}
private static class VerifiableHttp2FrameWriter extends DefaultHttp2FrameWriter {
@Override
public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data,
int padding, boolean endStream, ChannelPromise promise) {
// duplicate 'data' to prevent readerIndex from being changed, to ease verification
return super.writeData(ctx, streamId, data.duplicate(), padding, endStream, promise);
}
} }
} }

View File

@ -0,0 +1,340 @@
/*
* Copyright 2018 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelProgressivePromise;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.EventExecutor;
import java.net.SocketAddress;
/**
* Utility class which allows easy writing of HTTP2 frames via {@link EmbeddedChannel#writeInbound(Object...)}.
*/
final class Http2FrameInboundWriter {
private final ChannelHandlerContext ctx;
private final Http2FrameWriter writer;
Http2FrameInboundWriter(EmbeddedChannel channel) {
this(channel, new DefaultHttp2FrameWriter());
}
Http2FrameInboundWriter(EmbeddedChannel channel, Http2FrameWriter writer) {
this.ctx = new WriteInboundChannelHandlerContext(channel);
this.writer = writer;
}
void writeInboundData(int streamId, ByteBuf data, int padding, boolean endStream) {
writer.writeData(ctx, streamId, data, padding, endStream, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundHeaders(int streamId, Http2Headers headers,
int padding, boolean endStream) {
writer.writeHeaders(ctx, streamId, headers, padding, endStream, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundHeaders(int streamId, Http2Headers headers,
int streamDependency, short weight, boolean exclusive, int padding, boolean endStream) {
writer.writeHeaders(ctx, streamId, headers, streamDependency,
weight, exclusive, padding, endStream, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundPriority(int streamId, int streamDependency,
short weight, boolean exclusive) {
writer.writePriority(ctx, streamId, streamDependency, weight,
exclusive, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundRstStream(int streamId, long errorCode) {
writer.writeRstStream(ctx, streamId, errorCode, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundSettings(Http2Settings settings) {
writer.writeSettings(ctx, settings, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundSettingsAck() {
writer.writeSettingsAck(ctx, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundPing(boolean ack, long data) {
writer.writePing(ctx, ack, data, ctx.newPromise()).syncUninterruptibly();
}
void writePushPromise(int streamId, int promisedStreamId,
Http2Headers headers, int padding) {
writer.writePushPromise(ctx, streamId, promisedStreamId,
headers, padding, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundGoAway(int lastStreamId, long errorCode, ByteBuf debugData) {
writer.writeGoAway(ctx, lastStreamId, errorCode, debugData, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundWindowUpdate(int streamId, int windowSizeIncrement) {
writer.writeWindowUpdate(ctx, streamId, windowSizeIncrement, ctx.newPromise()).syncUninterruptibly();
}
void writeInboundFrame(byte frameType, int streamId,
Http2Flags flags, ByteBuf payload) {
writer.writeFrame(ctx, frameType, streamId, flags, payload, ctx.newPromise()).syncUninterruptibly();
}
private static final class WriteInboundChannelHandlerContext extends ChannelOutboundHandlerAdapter
implements ChannelHandlerContext {
private final EmbeddedChannel channel;
WriteInboundChannelHandlerContext(EmbeddedChannel channel) {
this.channel = channel;
}
@Override
public Channel channel() {
return channel;
}
@Override
public EventExecutor executor() {
return channel.eventLoop();
}
@Override
public String name() {
return "WriteInbound";
}
@Override
public ChannelHandler handler() {
return this;
}
@Override
public boolean isRemoved() {
return false;
}
@Override
public ChannelHandlerContext fireChannelRegistered() {
channel.pipeline().fireChannelRegistered();
return this;
}
@Override
public ChannelHandlerContext fireChannelUnregistered() {
channel.pipeline().fireChannelUnregistered();
return this;
}
@Override
public ChannelHandlerContext fireChannelActive() {
channel.pipeline().fireChannelActive();
return this;
}
@Override
public ChannelHandlerContext fireChannelInactive() {
channel.pipeline().fireChannelInactive();
return this;
}
@Override
public ChannelHandlerContext fireExceptionCaught(Throwable cause) {
channel.pipeline().fireExceptionCaught(cause);
return this;
}
@Override
public ChannelHandlerContext fireUserEventTriggered(Object evt) {
channel.pipeline().fireUserEventTriggered(evt);
return this;
}
@Override
public ChannelHandlerContext fireChannelRead(Object msg) {
channel.pipeline().fireChannelRead(msg);
return this;
}
@Override
public ChannelHandlerContext fireChannelReadComplete() {
channel.pipeline().fireChannelReadComplete();
return this;
}
@Override
public ChannelHandlerContext fireChannelWritabilityChanged() {
channel.pipeline().fireChannelWritabilityChanged();
return this;
}
@Override
public ChannelHandlerContext read() {
channel.read();
return this;
}
@Override
public ChannelHandlerContext flush() {
channel.pipeline().fireChannelReadComplete();
return this;
}
@Override
public ChannelPipeline pipeline() {
return channel.pipeline();
}
@Override
public ByteBufAllocator alloc() {
return channel.alloc();
}
@Override
public <T> Attribute<T> attr(AttributeKey<T> key) {
return channel.attr(key);
}
@Override
public <T> boolean hasAttr(AttributeKey<T> key) {
return channel.hasAttr(key);
}
@Override
public ChannelFuture bind(SocketAddress localAddress) {
return channel.bind(localAddress);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress) {
return channel.connect(remoteAddress);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) {
return channel.connect(remoteAddress, localAddress);
}
@Override
public ChannelFuture disconnect() {
return channel.disconnect();
}
@Override
public ChannelFuture close() {
return channel.close();
}
@Override
public ChannelFuture deregister() {
return channel.deregister();
}
@Override
public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) {
return channel.bind(localAddress, promise);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) {
return channel.connect(remoteAddress, promise);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
return channel.connect(remoteAddress, localAddress, promise);
}
@Override
public ChannelFuture disconnect(ChannelPromise promise) {
return channel.disconnect(promise);
}
@Override
public ChannelFuture close(ChannelPromise promise) {
return channel.close(promise);
}
@Override
public ChannelFuture deregister(ChannelPromise promise) {
return channel.deregister(promise);
}
@Override
public ChannelFuture write(Object msg) {
return write(msg, newPromise());
}
@Override
public ChannelFuture write(Object msg, ChannelPromise promise) {
return writeAndFlush(msg, promise);
}
@Override
public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) {
try {
channel.writeInbound(msg);
channel.runPendingTasks();
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
}
return promise;
}
@Override
public ChannelFuture writeAndFlush(Object msg) {
return writeAndFlush(msg, newPromise());
}
@Override
public ChannelPromise newPromise() {
return channel.newPromise();
}
@Override
public ChannelProgressivePromise newProgressivePromise() {
return channel.newProgressivePromise();
}
@Override
public ChannelFuture newSucceededFuture() {
return channel.newSucceededFuture();
}
@Override
public ChannelFuture newFailedFuture(Throwable cause) {
return channel.newFailedFuture(cause);
}
@Override
public ChannelPromise voidPromise() {
return channel.voidPromise();
}
}
}

View File

@ -15,7 +15,9 @@
package io.netty.handler.codec.http2; package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
@ -24,18 +26,34 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise; import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AsciiString; import io.netty.util.AsciiString;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor; import io.netty.util.concurrent.ImmediateEventExecutor;
import junit.framework.AssertionFailedError; import junit.framework.AssertionFailedError;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE;
import static io.netty.util.ReferenceCountUtil.release;
import static java.lang.Math.min; import static java.lang.Math.min;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyShort;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
/** /**
* Utilities for the integration tests. * Utilities for the integration tests.
@ -506,4 +524,179 @@ public final class Http2TestUtil {
return isWriteAllowed ? (int) min(pendingBytes, Integer.MAX_VALUE) : -1; return isWriteAllowed ? (int) min(pendingBytes, Integer.MAX_VALUE) : -1;
} }
} }
static Http2FrameWriter mockedFrameWriter() {
Http2FrameWriter.Configuration configuration = new Http2FrameWriter.Configuration() {
private final Http2HeadersEncoder.Configuration headerConfiguration =
new Http2HeadersEncoder.Configuration() {
@Override
public void maxHeaderTableSize(long max) {
// NOOP
}
@Override
public long maxHeaderTableSize() {
return 0;
}
@Override
public void maxHeaderListSize(long max) {
// NOOP
}
@Override
public long maxHeaderListSize() {
return 0;
}
};
private final Http2FrameSizePolicy policy = new Http2FrameSizePolicy() {
@Override
public void maxFrameSize(int max) {
// NOOP
}
@Override
public int maxFrameSize() {
return 0;
}
};
@Override
public Http2HeadersEncoder.Configuration headersConfiguration() {
return headerConfiguration;
}
@Override
public Http2FrameSizePolicy frameSizePolicy() {
return policy;
}
};
final ConcurrentLinkedQueue<ByteBuf> buffers = new ConcurrentLinkedQueue<ByteBuf>();
Http2FrameWriter frameWriter = Mockito.mock(Http2FrameWriter.class);
doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocationOnMock) {
for (;;) {
ByteBuf buf = buffers.poll();
if (buf == null) {
break;
}
buf.release();
}
return null;
}
}).when(frameWriter).close();
when(frameWriter.configuration()).thenReturn(configuration);
when(frameWriter.writeSettings(any(ChannelHandlerContext.class), any(Http2Settings.class),
any(ChannelPromise.class))).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(2)).setSuccess();
}
});
when(frameWriter.writeSettingsAck(any(ChannelHandlerContext.class), any(ChannelPromise.class)))
.thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(1)).setSuccess();
}
});
when(frameWriter.writeGoAway(any(ChannelHandlerContext.class), anyInt(),
anyLong(), any(ByteBuf.class), any(ChannelPromise.class))).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
buffers.offer((ByteBuf) invocationOnMock.getArgument(3));
return ((ChannelPromise) invocationOnMock.getArgument(4)).setSuccess();
}
});
when(frameWriter.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(),
anyBoolean(), any(ChannelPromise.class))).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess();
}
});
when(frameWriter.writeHeaders(any(ChannelHandlerContext.class), anyInt(),
any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(),
any(ChannelPromise.class))).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(8)).setSuccess();
}
});
when(frameWriter.writeData(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(),
anyBoolean(), any(ChannelPromise.class))).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
buffers.offer((ByteBuf) invocationOnMock.getArgument(2));
return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess();
}
});
when(frameWriter.writeRstStream(any(ChannelHandlerContext.class), anyInt(),
anyLong(), any(ChannelPromise.class))).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(3)).setSuccess();
}
});
when(frameWriter.writeWindowUpdate(any(ChannelHandlerContext.class), anyInt(), anyInt(),
any(ChannelPromise.class))).then(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(3)).setSuccess();
}
});
when(frameWriter.writePushPromise(any(ChannelHandlerContext.class), anyInt(), anyInt(), any(Http2Headers.class),
anyInt(), anyChannelPromise())).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess();
}
});
when(frameWriter.writeFrame(any(ChannelHandlerContext.class), anyByte(), anyInt(), any(Http2Flags.class),
any(ByteBuf.class), anyChannelPromise())).thenAnswer(new Answer<ChannelFuture>() {
@Override
public ChannelFuture answer(InvocationOnMock invocationOnMock) {
buffers.offer((ByteBuf) invocationOnMock.getArgument(4));
return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess();
}
});
return frameWriter;
}
static ChannelPromise anyChannelPromise() {
return any(ChannelPromise.class);
}
static Http2Settings anyHttp2Settings() {
return any(Http2Settings.class);
}
static ByteBuf bb(String s) {
return ByteBufUtil.writeUtf8(UnpooledByteBufAllocator.DEFAULT, s);
}
static void assertEqualsAndRelease(Http2Frame expected, Http2Frame actual) {
try {
assertEquals(expected, actual);
} finally {
release(expected);
release(actual);
// Will return -1 when not implements ReferenceCounted.
assertTrue(ReferenceCountUtil.refCnt(expected) <= 0);
assertTrue(ReferenceCountUtil.refCnt(actual) <= 0);
}
}
} }

View File

@ -50,9 +50,9 @@ public class TestChannelInitializer extends ChannelInitializer<Channel> {
/** /**
* Designed to read a single byte at a time to control the number of reads done at a fine granularity. * Designed to read a single byte at a time to control the number of reads done at a fine granularity.
*/ */
private static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator { static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator {
private final AtomicInteger numReads; private final AtomicInteger numReads;
TestNumReadsRecvByteBufAllocator(AtomicInteger numReads) { private TestNumReadsRecvByteBufAllocator(AtomicInteger numReads) {
this.numReads = numReads; this.numReads = numReads;
} }