HTTP/2 GOAWAY Reference Count Issue
Motiviation: The Http2ConnectionHandler is incrementing the reference count in the goAway method for the debugData buffer after it has already been sent and maybe consumed. This may result in an IllegalRefCountException to be thrown. The unit tests also encounter buffer leaks because they have not been updated to invoke the listener which releases the buffer in the goAway method. Modifications: - The retain() call should be before the frameWriter().writeGoAway(...) call - The unit tests which call goAway must also invoke the operationComplete(..) method for the listener. Result: No IllegalRefCountException. Less buffer leaks in tests.
This commit is contained in:
parent
f6299d942c
commit
fd22af8377
@ -18,6 +18,7 @@ package io.netty.handler.codec.http2;
|
|||||||
import static io.netty.util.CharsetUtil.UTF_8;
|
import static io.netty.util.CharsetUtil.UTF_8;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
@ -135,10 +136,9 @@ public final class Http2CodecUtil {
|
|||||||
return Unpooled.EMPTY_BUFFER;
|
return Unpooled.EMPTY_BUFFER;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the debug message.
|
// Create the debug message. `* 3` because UTF-8 max character consumes 3 bytes.
|
||||||
byte[] msg = cause.getMessage().getBytes(UTF_8);
|
ByteBuf debugData = ctx.alloc().buffer(cause.getMessage().length() * 3);
|
||||||
ByteBuf debugData = ctx.alloc().buffer(msg.length);
|
ByteBufUtil.writeUtf8(debugData, cause.getMessage());
|
||||||
debugData.writeBytes(msg);
|
|
||||||
return debugData;
|
return debugData;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ import static io.netty.handler.codec.http2.Http2Exception.isStreamError;
|
|||||||
import static io.netty.util.CharsetUtil.UTF_8;
|
import static io.netty.util.CharsetUtil.UTF_8;
|
||||||
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
||||||
import static java.lang.String.format;
|
import static java.lang.String.format;
|
||||||
|
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
@ -36,8 +35,6 @@ import io.netty.channel.ChannelPromise;
|
|||||||
import io.netty.handler.codec.ByteToMessageDecoder;
|
import io.netty.handler.codec.ByteToMessageDecoder;
|
||||||
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
|
import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException;
|
||||||
import io.netty.handler.codec.http2.Http2Exception.StreamException;
|
import io.netty.handler.codec.http2.Http2Exception.StreamException;
|
||||||
import io.netty.util.CharsetUtil;
|
|
||||||
import io.netty.util.concurrent.GenericFutureListener;
|
|
||||||
import io.netty.util.internal.logging.InternalLogger;
|
import io.netty.util.internal.logging.InternalLogger;
|
||||||
import io.netty.util.internal.logging.InternalLoggerFactory;
|
import io.netty.util.internal.logging.InternalLoggerFactory;
|
||||||
|
|
||||||
@ -575,45 +572,26 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
|
|||||||
}
|
}
|
||||||
connection.goAwaySent(lastStreamId, errorCode, debugData);
|
connection.goAwaySent(lastStreamId, errorCode, debugData);
|
||||||
|
|
||||||
|
// Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and
|
||||||
|
// result in an IllegalRefCountException.
|
||||||
|
debugData.retain();
|
||||||
ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);
|
ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);
|
||||||
|
|
||||||
// Need to retain the buffer so that it's available when the future completes.
|
if (future.isDone()) {
|
||||||
debugData.retain();
|
processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future);
|
||||||
future.addListener(new GenericFutureListener<ChannelFuture>() {
|
} else {
|
||||||
@Override
|
future.addListener(new ChannelFutureListener() {
|
||||||
public void operationComplete(ChannelFuture future) throws Exception {
|
@Override
|
||||||
try {
|
public void operationComplete(ChannelFuture future) throws Exception {
|
||||||
if (future.isSuccess()) {
|
processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future);
|
||||||
if (errorCode != NO_ERROR.code()) {
|
|
||||||
if (logger.isDebugEnabled()) {
|
|
||||||
logger.debug(
|
|
||||||
format("Sent GOAWAY: lastStreamId '%d', errorCode '%d', " +
|
|
||||||
"debugData '%s'. Forcing shutdown of the connection.",
|
|
||||||
lastStreamId, errorCode, debugData.toString(UTF_8)),
|
|
||||||
future.cause());
|
|
||||||
}
|
|
||||||
ctx.close();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (logger.isErrorEnabled()) {
|
|
||||||
logger.error(
|
|
||||||
format("Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', " +
|
|
||||||
"debugData '%s'. Forcing shutdown of the connection.",
|
|
||||||
lastStreamId, errorCode, debugData.toString(UTF_8)), future.cause());
|
|
||||||
}
|
|
||||||
ctx.close();
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
// We're done with the debug data now.
|
|
||||||
debugData.release();
|
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
});
|
}
|
||||||
|
|
||||||
return future;
|
return future;
|
||||||
} catch (Http2Exception e) {
|
} catch (Throwable cause) { // Make sure to catch Throwable because we are doing a retain() in this method.
|
||||||
debugData.release();
|
debugData.release();
|
||||||
return promise.setFailure(e);
|
return promise.setFailure(cause);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -635,6 +613,35 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
|
|||||||
return connection.isServer() ? connectionPrefaceBuf() : null;
|
return connection.isServer() ? connectionPrefaceBuf() : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void processGoAwayWriteResult(final ChannelHandlerContext ctx, final int lastStreamId,
|
||||||
|
final long errorCode, final ByteBuf debugData, ChannelFuture future) {
|
||||||
|
try {
|
||||||
|
if (future.isSuccess()) {
|
||||||
|
if (errorCode != NO_ERROR.code()) {
|
||||||
|
if (logger.isDebugEnabled()) {
|
||||||
|
logger.debug(
|
||||||
|
format("Sent GOAWAY: lastStreamId '%d', errorCode '%d', " +
|
||||||
|
"debugData '%s'. Forcing shutdown of the connection.",
|
||||||
|
lastStreamId, errorCode, debugData.toString(UTF_8)),
|
||||||
|
future.cause());
|
||||||
|
}
|
||||||
|
ctx.close();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (logger.isErrorEnabled()) {
|
||||||
|
logger.error(
|
||||||
|
format("Sending GOAWAY failed: lastStreamId '%d', errorCode '%d', " +
|
||||||
|
"debugData '%s'. Forcing shutdown of the connection.",
|
||||||
|
lastStreamId, errorCode, debugData.toString(UTF_8)), future.cause());
|
||||||
|
}
|
||||||
|
ctx.close();
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
// We're done with the debug data now.
|
||||||
|
debugData.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Closes the channel when the future completes.
|
* Closes the channel when the future completes.
|
||||||
*/
|
*/
|
||||||
|
@ -21,8 +21,9 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
|
|||||||
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
|
import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED;
|
||||||
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
|
import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED;
|
||||||
import static io.netty.util.CharsetUtil.UTF_8;
|
import static io.netty.util.CharsetUtil.UTF_8;
|
||||||
import static org.junit.Assert.assertNull;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
import static org.mockito.Matchers.any;
|
import static org.mockito.Matchers.any;
|
||||||
import static org.mockito.Matchers.anyBoolean;
|
import static org.mockito.Matchers.anyBoolean;
|
||||||
@ -30,7 +31,6 @@ import static org.mockito.Matchers.anyInt;
|
|||||||
import static org.mockito.Matchers.anyLong;
|
import static org.mockito.Matchers.anyLong;
|
||||||
import static org.mockito.Matchers.eq;
|
import static org.mockito.Matchers.eq;
|
||||||
import static org.mockito.Mockito.doAnswer;
|
import static org.mockito.Mockito.doAnswer;
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.never;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
@ -41,9 +41,11 @@ import io.netty.buffer.Unpooled;
|
|||||||
import io.netty.buffer.UnpooledByteBufAllocator;
|
import io.netty.buffer.UnpooledByteBufAllocator;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.channel.DefaultChannelPromise;
|
import io.netty.channel.DefaultChannelPromise;
|
||||||
|
import io.netty.util.CharsetUtil;
|
||||||
import io.netty.util.concurrent.GenericFutureListener;
|
import io.netty.util.concurrent.GenericFutureListener;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -98,16 +100,38 @@ public class Http2ConnectionHandlerTest {
|
|||||||
@Mock
|
@Mock
|
||||||
private Http2FrameWriter frameWriter;
|
private Http2FrameWriter frameWriter;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
@Before
|
@Before
|
||||||
public void setup() throws Exception {
|
public void setup() throws Exception {
|
||||||
MockitoAnnotations.initMocks(this);
|
MockitoAnnotations.initMocks(this);
|
||||||
|
|
||||||
promise = new DefaultChannelPromise(channel);
|
promise = new DefaultChannelPromise(channel);
|
||||||
|
|
||||||
|
Throwable fakeException = new RuntimeException("Fake exception");
|
||||||
when(encoder.connection()).thenReturn(connection);
|
when(encoder.connection()).thenReturn(connection);
|
||||||
when(decoder.connection()).thenReturn(connection);
|
when(decoder.connection()).thenReturn(connection);
|
||||||
when(encoder.frameWriter()).thenReturn(frameWriter);
|
when(encoder.frameWriter()).thenReturn(frameWriter);
|
||||||
when(frameWriter.writeGoAway(eq(ctx), anyInt(), anyInt(), any(ByteBuf.class), eq(promise))).thenReturn(future);
|
doAnswer(new Answer<ChannelFuture>() {
|
||||||
|
@Override
|
||||||
|
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
|
||||||
|
ByteBuf buf = invocation.getArgumentAt(3, ByteBuf.class);
|
||||||
|
buf.release();
|
||||||
|
return future;
|
||||||
|
}
|
||||||
|
}).when(frameWriter).writeGoAway(
|
||||||
|
any(ChannelHandlerContext.class), anyInt(), anyInt(), any(ByteBuf.class), any(ChannelPromise.class));
|
||||||
|
doAnswer(new Answer<ChannelFuture>() {
|
||||||
|
@Override
|
||||||
|
public ChannelFuture answer(InvocationOnMock invocation) throws Throwable {
|
||||||
|
Object o = invocation.getArguments()[0];
|
||||||
|
if (o instanceof ChannelFutureListener) {
|
||||||
|
((ChannelFutureListener) o).operationComplete(future);
|
||||||
|
}
|
||||||
|
return future;
|
||||||
|
}
|
||||||
|
}).when(future).addListener(any(GenericFutureListener.class));
|
||||||
|
when(future.cause()).thenReturn(fakeException);
|
||||||
|
when(future.channel()).thenReturn(channel);
|
||||||
when(channel.isActive()).thenReturn(true);
|
when(channel.isActive()).thenReturn(true);
|
||||||
when(connection.remote()).thenReturn(remote);
|
when(connection.remote()).thenReturn(remote);
|
||||||
when(connection.local()).thenReturn(local);
|
when(connection.local()).thenReturn(local);
|
||||||
@ -174,7 +198,7 @@ public class Http2ConnectionHandlerTest {
|
|||||||
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
|
ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
|
||||||
verify(frameWriter).writeGoAway(eq(ctx), eq(0), eq(PROTOCOL_ERROR.code()),
|
verify(frameWriter).writeGoAway(eq(ctx), eq(0), eq(PROTOCOL_ERROR.code()),
|
||||||
captor.capture(), eq(promise));
|
captor.capture(), eq(promise));
|
||||||
captor.getValue().release();
|
assertEquals(0, captor.getValue().refCnt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -304,11 +328,10 @@ public class Http2ConnectionHandlerTest {
|
|||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
@Test
|
@Test
|
||||||
public void canSendGoAwayFrame() throws Exception {
|
public void canSendGoAwayFrame() throws Exception {
|
||||||
ByteBuf data = mock(ByteBuf.class);
|
ByteBuf data = dummyData();
|
||||||
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
||||||
when(future.isDone()).thenReturn(true);
|
when(future.isDone()).thenReturn(true);
|
||||||
when(future.isSuccess()).thenReturn(true);
|
when(future.isSuccess()).thenReturn(true);
|
||||||
when(frameWriter.writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise))).thenReturn(future);
|
|
||||||
doAnswer(new Answer<Void>() {
|
doAnswer(new Answer<Void>() {
|
||||||
@Override
|
@Override
|
||||||
public Void answer(InvocationOnMock invocation) throws Throwable {
|
public Void answer(InvocationOnMock invocation) throws Throwable {
|
||||||
@ -322,29 +345,32 @@ public class Http2ConnectionHandlerTest {
|
|||||||
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
||||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
||||||
verify(ctx).close();
|
verify(ctx).close();
|
||||||
|
assertEquals(0, data.refCnt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void canSendGoAwayFramesWithDecreasingLastStreamIds() throws Exception {
|
public void canSendGoAwayFramesWithDecreasingLastStreamIds() throws Exception {
|
||||||
handler = newHandler();
|
handler = newHandler();
|
||||||
ByteBuf data = mock(ByteBuf.class);
|
ByteBuf data = dummyData();
|
||||||
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
||||||
|
|
||||||
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
|
handler.goAway(ctx, STREAM_ID + 2, errorCode, data.retain(), promise);
|
||||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID + 2), eq(errorCode), eq(data), eq(promise));
|
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID + 2), eq(errorCode), eq(data), eq(promise));
|
||||||
verify(connection).goAwaySent(eq(STREAM_ID + 2), eq(errorCode), eq(data));
|
verify(connection).goAwaySent(eq(STREAM_ID + 2), eq(errorCode), eq(data));
|
||||||
|
promise = new DefaultChannelPromise(channel);
|
||||||
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
|
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
|
||||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
||||||
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
||||||
|
assertEquals(0, data.refCnt());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception {
|
public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception {
|
||||||
handler = newHandler();
|
handler = newHandler();
|
||||||
ByteBuf data = mock(ByteBuf.class);
|
ByteBuf data = dummyData();
|
||||||
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
long errorCode = Http2Error.INTERNAL_ERROR.code();
|
||||||
|
|
||||||
handler.goAway(ctx, STREAM_ID, errorCode, data, promise);
|
handler.goAway(ctx, STREAM_ID, errorCode, data.retain(), promise);
|
||||||
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
|
||||||
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
|
||||||
// The frameWriter is only mocked, so it should not have interacted with the promise.
|
// The frameWriter is only mocked, so it should not have interacted with the promise.
|
||||||
@ -355,7 +381,7 @@ public class Http2ConnectionHandlerTest {
|
|||||||
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
|
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
|
||||||
assertTrue(promise.isDone());
|
assertTrue(promise.isDone());
|
||||||
assertFalse(promise.isSuccess());
|
assertFalse(promise.isSuccess());
|
||||||
verify(data).release();
|
assertEquals(0, data.refCnt());
|
||||||
verifyNoMoreInteractions(frameWriter);
|
verifyNoMoreInteractions(frameWriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -365,4 +391,8 @@ public class Http2ConnectionHandlerTest {
|
|||||||
handler.channelReadComplete(ctx);
|
handler.channelReadComplete(ctx);
|
||||||
verify(ctx, times(1)).flush();
|
verify(ctx, times(1)).flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ByteBuf dummyData() {
|
||||||
|
return Unpooled.buffer().writeBytes("abcdefgh".getBytes(CharsetUtil.UTF_8));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user