Guard against re-entrance in PendingWriteQueue.

Motivation:

PendingWriteQueue should guard against re-entrant writes once
removeAndFailAll() is run.

Modifications:

removeAndFailAll() should repeat until the queue is finally empty.

Result:

assertEmpty() will always hold.
This commit is contained in:
Max Ng 2016-03-07 15:58:38 -08:00 committed by Norman Maurer
parent 16ba763436
commit ec946f8a3e
2 changed files with 65 additions and 12 deletions

View File

@ -73,6 +73,8 @@ public final class PendingWriteQueue {
if (promise == null) {
throw new NullPointerException("promise");
}
// It is possible for writes to be triggered from removeAndFailAll(). To preserve ordering,
// we should add them to the queue and let removeAndFailAll() fail them later.
int messageSize = estimatorHandle.size(msg);
if (messageSize < 0) {
// Size may be unknow so just use 0
@ -104,18 +106,20 @@ public final class PendingWriteQueue {
if (cause == null) {
throw new NullPointerException("cause");
}
// Guard against re-entrance by directly reset
PendingWrite write = head;
head = tail = null;
size = 0;
// It is possible for some of the failed promises to trigger more writes. The new writes
// will "revive" the queue, so we need to clean them up until the queue is empty.
for (PendingWrite write = head; write != null; write = head) {
head = tail = null;
size = 0;
while (write != null) {
PendingWrite next = write.next;
ReferenceCountUtil.safeRelease(write.msg);
ChannelPromise promise = write.promise;
recycle(write, false);
safeFail(promise, cause);
write = next;
while (write != null) {
PendingWrite next = write.next;
ReferenceCountUtil.safeRelease(write.msg);
ChannelPromise promise = write.promise;
recycle(write, false);
safeFail(promise, cause);
write = next;
}
}
assertEmpty();
}

View File

@ -22,6 +22,9 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.Matchers.*;
@ -195,7 +198,7 @@ public class PendingWriteQueueTest {
}
@Test
public void testRemoveAndFailAllReentrance() {
public void testRemoveAndFailAllReentrantFailAll() {
EmbeddedChannel channel = new EmbeddedChannel();
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
@ -211,11 +214,57 @@ public class PendingWriteQueueTest {
ChannelPromise promise2 = channel.newPromise();
queue.add(2L, promise2);
queue.removeAndFailAll(new Exception());
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertTrue(promise2.isDone());
assertFalse(promise2.isSuccess());
assertFalse(channel.finish());
}
@Test
public void testRemoveAndFailAllReentrantWrite() {
final List<Integer> failOrder = Collections.synchronizedList(new ArrayList<Integer>());
EmbeddedChannel channel = new EmbeddedChannel();
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
ChannelPromise promise = channel.newPromise();
final ChannelPromise promise3 = channel.newPromise();
promise3.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
failOrder.add(3);
}
});
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
failOrder.add(1);
queue.add(3L, promise3);
}
});
queue.add(1L, promise);
ChannelPromise promise2 = channel.newPromise();
promise2.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
failOrder.add(2);
}
});
queue.add(2L, promise2);
queue.removeAndFailAll(new Exception());
assertTrue(promise.isDone());
assertFalse(promise.isSuccess());
assertTrue(promise2.isDone());
assertFalse(promise2.isSuccess());
assertTrue(promise3.isDone());
assertFalse(promise3.isSuccess());
assertFalse(channel.finish());
assertEquals(1, (int) failOrder.get(0));
assertEquals(2, (int) failOrder.get(1));
assertEquals(3, (int) failOrder.get(2));
}
@Test
public void testRemoveAndWriteAllReentrance() {
EmbeddedChannel channel = new EmbeddedChannel();