Make PendingWriteQueue.recycle() update its state before triggering an event

Related: #3212

Motivation:

PendingWriteQueue.recycle() updates its data structure after triggering
a channelWritabilityChanged() event. It causes a rare corruption such as
double free when channelWritabilityChanged() method accesses the
PendingWriteQueue.

Modifications:

Update the state of PendingWriteQueue before triggering an event.

Result:

Fix a rare double-free problem
This commit is contained in:
Trustin Lee 2014-12-07 23:24:19 +09:00
parent 1ef6f14734
commit 3957a88a94
2 changed files with 81 additions and 25 deletions

View File

@ -217,11 +217,11 @@ public final class PendingWriteQueue {
} }
private void recycle(PendingWrite write) { private void recycle(PendingWrite write) {
PendingWrite next = write.next; final PendingWrite next = write.next;
final long writeSize = write.size;
buffer.decrementPendingOutboundBytes(write.size);
write.recycle();
size --; size --;
if (next == null) { if (next == null) {
// Handled last PendingWrite so rest head and tail // Handled last PendingWrite so rest head and tail
head = tail = null; head = tail = null;
@ -230,6 +230,9 @@ public final class PendingWriteQueue {
head = next; head = next;
assert size > 0; assert size > 0;
} }
write.recycle();
buffer.decrementPendingOutboundBytes(writeSize);
} }
private static void safeFail(ChannelPromise promise, Throwable cause) { private static void safeFail(ChannelPromise promise, Throwable cause) {

View File

@ -20,9 +20,13 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
public class PendingWriteQueueTest { public class PendingWriteQueueTest {
@Test @Test
@ -30,7 +34,7 @@ public class PendingWriteQueueTest {
assertWrite(new TestHandler() { assertWrite(new TestHandler() {
@Override @Override
public void flush(ChannelHandlerContext ctx) throws Exception { public void flush(ChannelHandlerContext ctx) throws Exception {
Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable()); assertFalse("Should not be writable anymore", ctx.channel().isWritable());
ChannelFuture future = queue.removeAndWrite(); ChannelFuture future = queue.removeAndWrite();
future.addListener(new ChannelFutureListener() { future.addListener(new ChannelFutureListener() {
@ -49,7 +53,7 @@ public class PendingWriteQueueTest {
assertWrite(new TestHandler() { assertWrite(new TestHandler() {
@Override @Override
public void flush(ChannelHandlerContext ctx) throws Exception { public void flush(ChannelHandlerContext ctx) throws Exception {
Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable()); assertFalse("Should not be writable anymore", ctx.channel().isWritable());
ChannelFuture future = queue.removeAndWriteAll(); ChannelFuture future = queue.removeAndWriteAll();
future.addListener(new ChannelFutureListener() { future.addListener(new ChannelFutureListener() {
@ -86,6 +90,55 @@ public class PendingWriteQueueTest {
}, 3); }, 3);
} }
@Test
public void shouldFireChannelWritabilityChangedAfterRemoval() {
final AtomicReference<ChannelHandlerContext> ctxRef = new AtomicReference<ChannelHandlerContext>();
final AtomicReference<PendingWriteQueue> queueRef = new AtomicReference<PendingWriteQueue>();
final ByteBuf msg = Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
ctxRef.set(ctx);
queueRef.set(new PendingWriteQueue(ctx));
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
final PendingWriteQueue queue = queueRef.get();
final ByteBuf msg = (ByteBuf) queue.current();
if (msg == null) {
return;
}
assertThat(msg.refCnt(), is(1));
// This call will trigger another channelWritabilityChanged() event because the number of
// pending bytes will go below the low watermark.
//
// If PendingWriteQueue.remove() did not remove the current entry before triggering
// channelWritabilityChanged() event, we will end up with attempting to remove the same
// element twice, resulting in the double release.
queue.remove();
assertThat(msg.refCnt(), is(0));
}
});
channel.config().setWriteBufferLowWaterMark(1);
channel.config().setWriteBufferHighWaterMark(3);
final PendingWriteQueue queue = queueRef.get();
// Trigger channelWritabilityChanged() by adding a message that's larger than the high watermark.
queue.add(msg, channel.newPromise());
channel.finish();
assertThat(msg.refCnt(), is(0));
}
private static void assertWrite(ChannelHandler handler, int count) { private static void assertWrite(ChannelHandler handler, int count) {
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(handler); final EmbeddedChannel channel = new EmbeddedChannel(handler);
@ -96,29 +149,29 @@ public class PendingWriteQueueTest {
for (int i = 0; i < buffers.length; i++) { for (int i = 0; i < buffers.length; i++) {
buffers[i] = buffer.duplicate().retain(); buffers[i] = buffer.duplicate().retain();
} }
Assert.assertTrue(channel.writeOutbound(buffers)); assertTrue(channel.writeOutbound(buffers));
Assert.assertTrue(channel.finish()); assertTrue(channel.finish());
channel.closeFuture().syncUninterruptibly(); channel.closeFuture().syncUninterruptibly();
for (int i = 0; i < buffers.length; i++) { for (int i = 0; i < buffers.length; i++) {
assertBuffer(channel, buffer); assertBuffer(channel, buffer);
} }
buffer.release(); buffer.release();
Assert.assertNull(channel.readOutbound()); assertNull(channel.readOutbound());
} }
private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) { private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) {
ByteBuf written = (ByteBuf) channel.readOutbound(); ByteBuf written = channel.readOutbound();
Assert.assertEquals(buffer, written); assertEquals(buffer, written);
written.release(); written.release();
} }
private static void assertQueueEmpty(PendingWriteQueue queue) { private static void assertQueueEmpty(PendingWriteQueue queue) {
Assert.assertTrue(queue.isEmpty()); assertTrue(queue.isEmpty());
Assert.assertEquals(0, queue.size()); assertEquals(0, queue.size());
Assert.assertNull(queue.current()); assertNull(queue.current());
Assert.assertNull(queue.removeAndWrite()); assertNull(queue.removeAndWrite());
Assert.assertNull(queue.removeAndWriteAll()); assertNull(queue.removeAndWriteAll());
} }
private static void assertWriteFails(ChannelHandler handler, int count) { private static void assertWriteFails(ChannelHandler handler, int count) {
@ -129,16 +182,16 @@ public class PendingWriteQueueTest {
buffers[i] = buffer.duplicate().retain(); buffers[i] = buffer.duplicate().retain();
} }
try { try {
Assert.assertFalse(channel.writeOutbound(buffers)); assertFalse(channel.writeOutbound(buffers));
Assert.fail(); fail();
} catch (Exception e) { } catch (Exception e) {
Assert.assertTrue(e instanceof TestException); assertTrue(e instanceof TestException);
} }
Assert.assertFalse(channel.finish()); assertFalse(channel.finish());
channel.closeFuture().syncUninterruptibly(); channel.closeFuture().syncUninterruptibly();
buffer.release(); buffer.release();
Assert.assertNull(channel.readOutbound()); assertNull(channel.readOutbound());
} }
private static class TestHandler extends ChannelDuplexHandler { private static class TestHandler extends ChannelDuplexHandler {
@ -149,15 +202,15 @@ public class PendingWriteQueueTest {
public void channelActive(ChannelHandlerContext ctx) throws Exception { public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx); super.channelActive(ctx);
assertQueueEmpty(queue); assertQueueEmpty(queue);
Assert.assertTrue("Should be writable", ctx.channel().isWritable()); assertTrue("Should be writable", ctx.channel().isWritable());
} }
@Override @Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
queue.add(msg, promise); queue.add(msg, promise);
Assert.assertFalse(queue.isEmpty()); assertFalse(queue.isEmpty());
Assert.assertEquals(++ expectedSize, queue.size()); assertEquals(++expectedSize, queue.size());
Assert.assertNotNull(queue.current()); assertNotNull(queue.current());
} }
@Override @Override