HTTP/2 GOAWAY Reference Count Issue
Motiviation: The Http2ConnectionHandler is incrementing the reference count in the goAway method for the debugData buffer after it has already been sent and maybe consumed. This may result in an IllegalRefCountException to be thrown. The unit tests also encounter buffer leaks because they have not been updated to invoke the listener which releases the buffer in the goAway method. Modifications: - The retain() call should be before the frameWriter().writeGoAway(...) call - The unit tests which call goAway must also invoke the operationComplete(..) method for the listener. Result: No IllegalRefCountException. Less buffer leaks in tests.
This commit is contained in:
parent
25b90927f4
commit
2419f0c417
@ -18,6 +18,7 @@ package io.netty.handler.codec.http2;
|
||||
import static io.netty.util.CharsetUtil.UTF_8;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
@ -135,10 +136,9 @@ public final class Http2CodecUtil {
|
||||
return Unpooled.EMPTY_BUFFER;
|
||||
}
|
||||
|
||||
// Create the debug message.
|
||||
byte[] msg = cause.getMessage().getBytes(UTF_8);
|
||||
ByteBuf debugData = ctx.alloc().buffer(msg.length);
|
||||
debugData.writeBytes(msg);
|
||||
// Create the debug message. `* 3` because UTF-8 max character consumes 3 bytes.
|
||||
ByteBuf debugData = ctx.alloc().buffer(cause.getMessage().length() * 3);
|
||||
ByteBufUtil.writeUtf8(debugData, cause.getMessage());
|
||||
return debugData;
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,6 @@ import static io.netty.handler.codec.http2.Http2Exception.isStreamError;
|
||||
import static io.netty.util.CharsetUtil.UTF_8;
|
||||
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
||||
import static java.lang.String.format;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufUtil;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
@ -35,8 +34,6 @@ import io.netty.channel.ChannelPromise;
|
||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
|
||||
import io.netty.handler.codec.http2.Http2Exception.StreamException;
|
||||
import io.netty.util.CharsetUtil;
|
||||
import io.netty.util.concurrent.GenericFutureListener;
|
||||
import io.netty.util.internal.logging.InternalLogger;
|
||||
import io.netty.util.internal.logging.InternalLoggerFactory;
|
||||
|
||||
@ -536,45 +533,26 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
|
||||
}
|
||||
connection.goAwaySent(lastStreamId, errorCode, debugData);
|
||||
|
||||
// Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and
|
||||
// result in an IllegalRefCountException.
|
||||
debugData.retain();
|
||||
ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);
|
||||
|
||||
// Need to retain the buffer so that it's available when the future completes.
|
||||
debugData.retain();
|
||||
future.addListener(new GenericFutureListener<ChannelFuture>() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
try {
|
||||
if (future.isSuccess()) {
|
||||
if (errorCode != NO_ERROR.code()) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug(
|
||||
format("Sent GOAWAY: lastStreamId '%d', errorCode '%d', " +
|
||||
"debugData '%s'. Forcing shutdown of the connection.",
|
||||
lastStreamId, errorCode, debugData.toString(UTF_8)),
|
||||
future.cause());
|
||||
}
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error(
|
||||
format("Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', " +
|
||||
"debugData '%s'. Forcing shutdown of the connection.",
|
||||
lastStreamId, errorCode, debugData.toString(UTF_8)), future.cause());
|
||||
}
|
||||
ctx.close();
|
||||
}
|
||||
} finally {
|
||||
// We're done with the debug data now.
|
||||
debugData.release();
|
||||
if (future.isDone()) {
|
||||
processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future);
|
||||
} else {
|
||||
future.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return future;
|
||||
} catch (Http2Exception e) {
|
||||
} catch (Throwable cause) { // Make sure to catch Throwable because we are doing a retain() in this method.
|
||||
debugData.release();
|
||||
return promise.setFailure(e);
|
||||
return promise.setFailure(cause);
|
||||
}
|
||||
}
|
||||
|
||||
@ -596,6 +574,35 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
|
||||
return connection.isServer() ? connectionPrefaceBuf() : null;
|
||||
}
|
||||
|
||||
private static void processGoAwayWriteResult(final ChannelHandlerContext ctx, final int lastStreamId,
|
||||
final long errorCode, final ByteBuf debugData, ChannelFuture future) {
|
||||
try {
|
||||
if (future.isSuccess()) {
|
||||
if (errorCode != NO_ERROR.code()) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug(
|
||||
format("Sent GOAWAY: lastStreamId '%d', errorCode '%d', " +
|
||||
"debugData '%s'. Forcing shutdown of the connection.",
|
||||
lastStreamId, errorCode, debugData.toString(UTF_8)),
|
||||
future.cause());
|
||||
}
|
||||
ctx.close();
|
||||
}
|
||||
} else {
|
||||
if (logger.isErrorEnabled()) {
|
||||
logger.error(
|
||||
format("Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', " +
|
||||
"debugData '%s'. Forcing shutdown of the connection.",
|
||||
lastStreamId, errorCode, debugData.toString(UTF_8)), future.cause());
|
||||
}
|
||||
ctx.close();
|
||||
}
|
||||
} finally {
|
||||
// We're done with the debug data now.
|
||||
debugData.release();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the channel when the future completes.
|
||||
*/
|
||||
|
@ -21,8 +21,9 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
|
||||
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
|
||||
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
|
||||
import static io.netty.util.CharsetUtil.UTF_8;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyBoolean;
|
||||
@ -30,7 +31,6 @@ import static org.mockito.Matchers.anyInt;
|
||||
import static org.mockito.Matchers.anyLong;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
@ -41,9 +41,11 @@ import io.netty.buffer.Unpooled;
|
||||
import io.netty.buffer.UnpooledByteBufAllocator;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.DefaultChannelPromise;
|
||||
import io.netty.util.CharsetUtil;
|
||||
import io.netty.util.concurrent.GenericFutureListener;
|
||||
|
||||
import java.util.List;
|
||||
@ -98,16 +100,38 @@ public class Http2ConnectionHandlerTest {
|
||||
@Mock
|
||||
private Http2FrameWriter frameWriter;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Before
|
||||
public void setup() throws Exception {
|
||||
MockitoAnnotations.initMocks(this);
|
||||
|
||||
promise = new DefaultChannelPromise(channel);
|
||||
|
||||
Throwable fakeException = new RuntimeException("Fake exception");
|
||||
when(encoder.connection()).thenReturn(connection);
|
||||
when(decoder.connection()).thenReturn(connection);
|
||||
when(encoder.frameWriter()).thenReturn(frameWriter);
|
||||
when(frameWriter.writeGoAway(eq(ctx), anyInt(), anyInt(), any(ByteBuf.class), eq(promise))).thenReturn(future);
|
||||
doAnswer(new Answer<ChannelFuture>() {
|
||||
@Override
|
||||
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
|
||||
ByteBuf buf = invocation.getArgumentAt(3, ByteBuf.class);
|
||||
buf.release();
|
||||
return future;
|
||||
}
|
||||
}).when(frameWriter).writeGoAway(
|
||||
any(ChannelHandlerContext.class), anyInt(), anyInt(), any(ByteBuf.class), any(ChannelPromise.class));
|
||||
doAnswer(new Answer<ChannelFuture>() {
|
||||
@Override
|
||||
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
|
||||
Object o = invocation.getArguments()[0];
|
||||
if (o instanceof ChannelFutureListener) {
|
||||
((ChannelFutureListener) o).operationComplete(future);
|
||||
}
|
||||
return future;
|
||||
}
|
||||
}).when(future).addListener(any(GenericFutureListener.class));
|
||||
when(future.cause()).thenReturn(fakeException);
|
||||
when(future.channel()).thenReturn(channel);
|
||||
when(channel.isActive()).thenReturn(true);
|
||||
when(connection.remote()).thenReturn(remote);
|
||||
when(connection.local()).thenReturn(local);
|
||||
@ -174,7 +198,7 @@ public class Http2ConnectionHandlerTest {
|
||||
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
|
||||
verify(frameWriter).writeGoAway(eq(ctx), eq(0), eq(PROTOCOL_ERROR.code()),
|
||||
captor.capture(), eq(promise));
|
||||
captor.getValue().release();
|
||||
assertEquals(0, captor.getValue().refCnt());
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -304,11 +328,10 @@ public class Http2ConnectionHandlerTest {
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void canSendGoAwayFrame() throws Exception {
|
||||
ByteBuf data = mock(ByteBuf.class);
|
||||
ByteBuf data = dummyData();
|
||||
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
||||
when(future.isDone()).thenReturn(true);
|
||||
when(future.isSuccess()).thenReturn(true);
|
||||
when(frameWriter.writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise))).thenReturn(future);
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock invocation) throws Throwable {
|
||||
@ -322,29 +345,32 @@ public class Http2ConnectionHandlerTest {
|
||||
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
||||
verify(ctx).close();
|
||||
assertEquals(0, data.refCnt());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void canSendGoAwayFramesWithDecreasingLastStreamIds() throws Exception {
|
||||
handler = newHandler();
|
||||
ByteBuf data = mock(ByteBuf.class);
|
||||
ByteBuf data = dummyData();
|
||||
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
||||
|
||||
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
|
||||
handler.goAway(ctx, STREAM_ID + 2, errorCode, data.retain(), promise);
|
||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID + 2), eq(errorCode), eq(data), eq(promise));
|
||||
verify(connection).goAwaySent(eq(STREAM_ID + 2), eq(errorCode), eq(data));
|
||||
promise = new DefaultChannelPromise(channel);
|
||||
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
|
||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
||||
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
||||
assertEquals(0, data.refCnt());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception {
|
||||
handler = newHandler();
|
||||
ByteBuf data = mock(ByteBuf.class);
|
||||
ByteBuf data = dummyData();
|
||||
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
||||
|
||||
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
|
||||
handler.goAway(ctx, STREAM_ID, errorCode, data.retain(), promise);
|
||||
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
||||
// The frameWriter is only mocked, so it should not have interacted with the promise.
|
||||
@ -355,7 +381,7 @@ public class Http2ConnectionHandlerTest {
|
||||
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
|
||||
assertTrue(promise.isDone());
|
||||
assertFalse(promise.isSuccess());
|
||||
verify(data).release();
|
||||
assertEquals(0, data.refCnt());
|
||||
verifyNoMoreInteractions(frameWriter);
|
||||
}
|
||||
|
||||
@ -365,4 +391,8 @@ public class Http2ConnectionHandlerTest {
|
||||
handler.channelReadComplete(ctx);
|
||||
verify(ctx, times(1)).flush();
|
||||
}
|
||||
|
||||
private ByteBuf dummyData() {
|
||||
return Unpooled.buffer().writeBytes("abcdefgh".getBytes(CharsetUtil.UTF_8));
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user