Ensure we only call debugData.release() if we called debugData.retain()

Motivation:

We need to ensure we only call debugData.release() if we called debugData.retain(), otherwise we my see an IllegalReferenceCountException.

Modifications:

Only call release() if we called retain().

Result:

No more IllegalReferenceCountException possible.
This commit is contained in:
Norman Maurer 2016-09-11 10:40:10 +02:00
parent 51629245d0
commit 4beb075dd3
2 changed files with 22 additions and 15 deletions

View File

@ -668,6 +668,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
@Override
public ChannelFuture goAway(final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode,
final ByteBuf debugData, ChannelPromise promise) {
boolean release = false;
try {
final Http2Connection connection = connection();
if (connection.goAwaySent() && lastStreamId > connection.remote().lastStreamKnownByPeer()) {
@ -680,6 +681,7 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
// Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and
// result in an IllegalRefCountException.
debugData.retain();
release = true;
ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise);
if (future.isDone()) {
@ -695,7 +697,9 @@ public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http
return future;
} catch (Throwable cause) { // Make sure to catch Throwable because we are doing a retain() in this method.
debugData.release();
if (release) {
debugData.release();
}
return promise.setFailure(cause);
}
}

View File

@ -50,7 +50,6 @@ import static io.netty.handler.codec.http2.Http2Stream.State.IDLE;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
@ -419,21 +418,25 @@ public class Http2ConnectionHandlerTest {
public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception {
handler = newHandler();
ByteBuf data = dummyData();
long errorCode = Http2Error.INTERNAL_ERROR.code();
try {
long errorCode = Http2Error.INTERNAL_ERROR.code();
handler.goAway(ctx, STREAM_ID, errorCode, data.retain(), promise);
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
// The frameWriter is only mocked, so it should not have interacted with the promise.
assertFalse(promise.isDone());
handler.goAway(ctx, STREAM_ID, errorCode, data.retain(), promise);
verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data));
verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise));
// The frameWriter is only mocked, so it should not have interacted with the promise.
assertFalse(promise.isDone());
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastStreamKnownByPeer()).thenReturn(STREAM_ID);
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertEquals(0, data.refCnt());
verifyNoMoreInteractions(frameWriter);
when(connection.goAwaySent()).thenReturn(true);
when(remote.lastStreamKnownByPeer()).thenReturn(STREAM_ID);
handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise);
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertEquals(1, data.refCnt());
verifyNoMoreInteractions(frameWriter);
} finally {
data.release();
}
}
@Test