Guard against re-entrance in PendingWriteQueue.removeAndWriteAll()

Motivation:

PendingWriteQueue should guard against re-entrant writes once removeAndWriteAll() is run.

Modifications:

Continue writing until queue is empty.

Result:

Correctly guard against re-entrance.
This commit is contained in:
Norman Maurer 2016-08-12 08:53:39 +02:00
parent 382f8eacaf
commit 6d70f4b38a
2 changed files with 79 additions and 45 deletions

View File

@ -109,6 +109,48 @@ public final class PendingWriteQueue {
}
}
/**
* Remove all pending write operation and performs them via
* {@link ChannelHandlerContext#write(Object, ChannelPromise)}.
*
* @return {@link ChannelFuture} if something was written and {@code null}
* if the {@link PendingWriteQueue} is empty.
*/
public ChannelFuture removeAndWriteAll() {
assert ctx.executor().inEventLoop();
if (isEmpty()) {
return null;
}
ChannelPromise p = ctx.newPromise();
PromiseCombiner combiner = new PromiseCombiner();
try {
// It is possible for some of the written promises to trigger more writes. The new writes
// will "revive" the queue, so we need to write them up until the queue is empty.
for (PendingWrite write = head; write != null; write = head) {
head = tail = null;
size = 0;
bytes = 0;
while (write != null) {
PendingWrite next = write.next;
Object msg = write.msg;
ChannelPromise promise = write.promise;
recycle(write, false);
combiner.add(promise);
ctx.write(msg, promise);
write = next;
}
}
combiner.finish(p);
} catch (Throwable cause) {
p.setFailure(cause);
}
assertEmpty();
return p;
}
/**
* Remove all pending write operation and fail them with the given {@link Throwable}. The message will be released
* via {@link ReferenceCountUtil#safeRelease(Object)}.
@ -156,51 +198,6 @@ public final class PendingWriteQueue {
recycle(write, true);
}
/**
* Remove all pending write operation and performs them via
* {@link ChannelHandlerContext#write(Object, ChannelPromise)}.
*
* @return {@link ChannelFuture} if something was written and {@code null}
* if the {@link PendingWriteQueue} is empty.
*/
public ChannelFuture removeAndWriteAll() {
assert ctx.executor().inEventLoop();
if (size == 1) {
// No need to use ChannelPromiseAggregator for this case.
return removeAndWrite();
}
PendingWrite write = head;
if (write == null) {
// empty so just return null
return null;
}
// Guard against re-entrance by directly reset
head = tail = null;
size = 0;
bytes = 0;
ChannelPromise p = ctx.newPromise();
PromiseCombiner combiner = new PromiseCombiner();
try {
while (write != null) {
PendingWrite next = write.next;
Object msg = write.msg;
ChannelPromise promise = write.promise;
recycle(write, false);
combiner.add(promise);
ctx.write(msg, promise);
write = next;
}
assertEmpty();
combiner.finish(p);
} catch (Throwable cause) {
p.setFailure(cause);
}
return p;
}
private void assertEmpty() {
assert tail == null && head == null && size == 0;
}

View File

@ -227,6 +227,43 @@ public class PendingWriteQueueTest {
assertFalse(channel.finish());
}
@Test
public void testRemoveAndWriteAllReentrantWrite() {
EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
// Convert to writeAndFlush(...) so the promise will be notified by the transport.
ctx.writeAndFlush(msg, promise);
}
}, new ChannelOutboundHandlerAdapter());
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().lastContext());
ChannelPromise promise = channel.newPromise();
final ChannelPromise promise3 = channel.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
queue.add(3L, promise3);
}
});
queue.add(1L, promise);
ChannelPromise promise2 = channel.newPromise();
queue.add(2L, promise2);
queue.removeAndWriteAll();
assertTrue(promise.isDone());
assertTrue(promise.isSuccess());
assertTrue(promise2.isDone());
assertTrue(promise2.isSuccess());
assertTrue(promise3.isDone());
assertTrue(promise3.isSuccess());
assertTrue(channel.finish());
assertEquals(1L, channel.readOutbound());
assertEquals(2L, channel.readOutbound());
assertEquals(3L, channel.readOutbound());
}
@Test
public void testRemoveAndFailAllReentrantWrite() {
final List<Integer> failOrder = Collections.synchronizedList(new ArrayList<Integer>());