Make PendingWriteQueue.recycle() update its state before triggering an event
Related: #3212 Motivation: PendingWriteQueue.recycle() updates its data structure after triggering a channelWritabilityChanged() event. It causes a rare corruption such as double free when channelWritabilityChanged() method accesses the PendingWriteQueue. Modifications: Update the state of PendingWriteQueue before triggering an event. Result: Fix a rare double-free problem
This commit is contained in:
parent
1ef6f14734
commit
3957a88a94
@ -217,11 +217,11 @@ public final class PendingWriteQueue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void recycle(PendingWrite write) {
|
private void recycle(PendingWrite write) {
|
||||||
PendingWrite next = write.next;
|
final PendingWrite next = write.next;
|
||||||
|
final long writeSize = write.size;
|
||||||
|
|
||||||
buffer.decrementPendingOutboundBytes(write.size);
|
|
||||||
write.recycle();
|
|
||||||
size --;
|
size --;
|
||||||
|
|
||||||
if (next == null) {
|
if (next == null) {
|
||||||
// Handled last PendingWrite so rest head and tail
|
// Handled last PendingWrite so rest head and tail
|
||||||
head = tail = null;
|
head = tail = null;
|
||||||
@ -230,6 +230,9 @@ public final class PendingWriteQueue {
|
|||||||
head = next;
|
head = next;
|
||||||
assert size > 0;
|
assert size > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
write.recycle();
|
||||||
|
buffer.decrementPendingOutboundBytes(writeSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void safeFail(ChannelPromise promise, Throwable cause) {
|
private static void safeFail(ChannelPromise promise, Throwable cause) {
|
||||||
|
@ -20,9 +20,13 @@ import io.netty.buffer.ByteBuf;
|
|||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.embedded.EmbeddedChannel;
|
import io.netty.channel.embedded.EmbeddedChannel;
|
||||||
import io.netty.util.CharsetUtil;
|
import io.netty.util.CharsetUtil;
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.*;
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class PendingWriteQueueTest {
|
public class PendingWriteQueueTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -30,7 +34,7 @@ public class PendingWriteQueueTest {
|
|||||||
assertWrite(new TestHandler() {
|
assertWrite(new TestHandler() {
|
||||||
@Override
|
@Override
|
||||||
public void flush(ChannelHandlerContext ctx) throws Exception {
|
public void flush(ChannelHandlerContext ctx) throws Exception {
|
||||||
Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable());
|
assertFalse("Should not be writable anymore", ctx.channel().isWritable());
|
||||||
|
|
||||||
ChannelFuture future = queue.removeAndWrite();
|
ChannelFuture future = queue.removeAndWrite();
|
||||||
future.addListener(new ChannelFutureListener() {
|
future.addListener(new ChannelFutureListener() {
|
||||||
@ -49,7 +53,7 @@ public class PendingWriteQueueTest {
|
|||||||
assertWrite(new TestHandler() {
|
assertWrite(new TestHandler() {
|
||||||
@Override
|
@Override
|
||||||
public void flush(ChannelHandlerContext ctx) throws Exception {
|
public void flush(ChannelHandlerContext ctx) throws Exception {
|
||||||
Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable());
|
assertFalse("Should not be writable anymore", ctx.channel().isWritable());
|
||||||
|
|
||||||
ChannelFuture future = queue.removeAndWriteAll();
|
ChannelFuture future = queue.removeAndWriteAll();
|
||||||
future.addListener(new ChannelFutureListener() {
|
future.addListener(new ChannelFutureListener() {
|
||||||
@ -86,6 +90,55 @@ public class PendingWriteQueueTest {
|
|||||||
}, 3);
|
}, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldFireChannelWritabilityChangedAfterRemoval() {
|
||||||
|
final AtomicReference<ChannelHandlerContext> ctxRef = new AtomicReference<ChannelHandlerContext>();
|
||||||
|
final AtomicReference<PendingWriteQueue> queueRef = new AtomicReference<PendingWriteQueue>();
|
||||||
|
final ByteBuf msg = Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII);
|
||||||
|
|
||||||
|
final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() {
|
||||||
|
@Override
|
||||||
|
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
|
||||||
|
ctxRef.set(ctx);
|
||||||
|
queueRef.set(new PendingWriteQueue(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
|
||||||
|
final PendingWriteQueue queue = queueRef.get();
|
||||||
|
|
||||||
|
final ByteBuf msg = (ByteBuf) queue.current();
|
||||||
|
if (msg == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThat(msg.refCnt(), is(1));
|
||||||
|
|
||||||
|
// This call will trigger another channelWritabilityChanged() event because the number of
|
||||||
|
// pending bytes will go below the low watermark.
|
||||||
|
//
|
||||||
|
// If PendingWriteQueue.remove() did not remove the current entry before triggering
|
||||||
|
// channelWritabilityChanged() event, we will end up with attempting to remove the same
|
||||||
|
// element twice, resulting in the double release.
|
||||||
|
queue.remove();
|
||||||
|
|
||||||
|
assertThat(msg.refCnt(), is(0));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
channel.config().setWriteBufferLowWaterMark(1);
|
||||||
|
channel.config().setWriteBufferHighWaterMark(3);
|
||||||
|
|
||||||
|
final PendingWriteQueue queue = queueRef.get();
|
||||||
|
|
||||||
|
// Trigger channelWritabilityChanged() by adding a message that's larger than the high watermark.
|
||||||
|
queue.add(msg, channel.newPromise());
|
||||||
|
|
||||||
|
channel.finish();
|
||||||
|
|
||||||
|
assertThat(msg.refCnt(), is(0));
|
||||||
|
}
|
||||||
|
|
||||||
private static void assertWrite(ChannelHandler handler, int count) {
|
private static void assertWrite(ChannelHandler handler, int count) {
|
||||||
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
|
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
|
||||||
final EmbeddedChannel channel = new EmbeddedChannel(handler);
|
final EmbeddedChannel channel = new EmbeddedChannel(handler);
|
||||||
@ -96,29 +149,29 @@ public class PendingWriteQueueTest {
|
|||||||
for (int i = 0; i < buffers.length; i++) {
|
for (int i = 0; i < buffers.length; i++) {
|
||||||
buffers[i] = buffer.duplicate().retain();
|
buffers[i] = buffer.duplicate().retain();
|
||||||
}
|
}
|
||||||
Assert.assertTrue(channel.writeOutbound(buffers));
|
assertTrue(channel.writeOutbound(buffers));
|
||||||
Assert.assertTrue(channel.finish());
|
assertTrue(channel.finish());
|
||||||
channel.closeFuture().syncUninterruptibly();
|
channel.closeFuture().syncUninterruptibly();
|
||||||
|
|
||||||
for (int i = 0; i < buffers.length; i++) {
|
for (int i = 0; i < buffers.length; i++) {
|
||||||
assertBuffer(channel, buffer);
|
assertBuffer(channel, buffer);
|
||||||
}
|
}
|
||||||
buffer.release();
|
buffer.release();
|
||||||
Assert.assertNull(channel.readOutbound());
|
assertNull(channel.readOutbound());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) {
|
private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) {
|
||||||
ByteBuf written = (ByteBuf) channel.readOutbound();
|
ByteBuf written = channel.readOutbound();
|
||||||
Assert.assertEquals(buffer, written);
|
assertEquals(buffer, written);
|
||||||
written.release();
|
written.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertQueueEmpty(PendingWriteQueue queue) {
|
private static void assertQueueEmpty(PendingWriteQueue queue) {
|
||||||
Assert.assertTrue(queue.isEmpty());
|
assertTrue(queue.isEmpty());
|
||||||
Assert.assertEquals(0, queue.size());
|
assertEquals(0, queue.size());
|
||||||
Assert.assertNull(queue.current());
|
assertNull(queue.current());
|
||||||
Assert.assertNull(queue.removeAndWrite());
|
assertNull(queue.removeAndWrite());
|
||||||
Assert.assertNull(queue.removeAndWriteAll());
|
assertNull(queue.removeAndWriteAll());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertWriteFails(ChannelHandler handler, int count) {
|
private static void assertWriteFails(ChannelHandler handler, int count) {
|
||||||
@ -129,16 +182,16 @@ public class PendingWriteQueueTest {
|
|||||||
buffers[i] = buffer.duplicate().retain();
|
buffers[i] = buffer.duplicate().retain();
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
Assert.assertFalse(channel.writeOutbound(buffers));
|
assertFalse(channel.writeOutbound(buffers));
|
||||||
Assert.fail();
|
fail();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
Assert.assertTrue(e instanceof TestException);
|
assertTrue(e instanceof TestException);
|
||||||
}
|
}
|
||||||
Assert.assertFalse(channel.finish());
|
assertFalse(channel.finish());
|
||||||
channel.closeFuture().syncUninterruptibly();
|
channel.closeFuture().syncUninterruptibly();
|
||||||
|
|
||||||
buffer.release();
|
buffer.release();
|
||||||
Assert.assertNull(channel.readOutbound());
|
assertNull(channel.readOutbound());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestHandler extends ChannelDuplexHandler {
|
private static class TestHandler extends ChannelDuplexHandler {
|
||||||
@ -149,15 +202,15 @@ public class PendingWriteQueueTest {
|
|||||||
public void channelActive(ChannelHandlerContext ctx) throws Exception {
|
public void channelActive(ChannelHandlerContext ctx) throws Exception {
|
||||||
super.channelActive(ctx);
|
super.channelActive(ctx);
|
||||||
assertQueueEmpty(queue);
|
assertQueueEmpty(queue);
|
||||||
Assert.assertTrue("Should be writable", ctx.channel().isWritable());
|
assertTrue("Should be writable", ctx.channel().isWritable());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
|
||||||
queue.add(msg, promise);
|
queue.add(msg, promise);
|
||||||
Assert.assertFalse(queue.isEmpty());
|
assertFalse(queue.isEmpty());
|
||||||
Assert.assertEquals(++ expectedSize, queue.size());
|
assertEquals(++expectedSize, queue.size());
|
||||||
Assert.assertNotNull(queue.current());
|
assertNotNull(queue.current());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
Loading…
Reference in New Issue
Block a user