[#3367] Fix re-entrance bug in PendingWriteQueue

Motivation:

Because of a re-entrance bug in PendingWriteQueue it was possible to get the queue corrupted and also trigger an IllegalStateException caused by multiple recycling of the internal PendingWrite objects.

Modifications:

- Correctly guard against re-entrance

Result:

No more IllegalStateException possible
This commit is contained in:
Norman Maurer 2015-01-30 20:27:27 +01:00
parent 261a30d8af
commit ca8b937c14
2 changed files with 82 additions and 19 deletions

View File

@ -99,12 +99,16 @@ public final class PendingWriteQueue {
if (cause == null) { if (cause == null) {
throw new NullPointerException("cause"); throw new NullPointerException("cause");
} }
// Guard against re-entrance by directly reset
PendingWrite write = head; PendingWrite write = head;
head = tail = null;
size = 0;
while (write != null) { while (write != null) {
PendingWrite next = write.next; PendingWrite next = write.next;
ReferenceCountUtil.safeRelease(write.msg); ReferenceCountUtil.safeRelease(write.msg);
ChannelPromise promise = write.promise; ChannelPromise promise = write.promise;
recycle(write); recycle(write, false);
safeFail(promise, cause); safeFail(promise, cause);
write = next; write = next;
} }
@ -121,13 +125,14 @@ public final class PendingWriteQueue {
throw new NullPointerException("cause"); throw new NullPointerException("cause");
} }
PendingWrite write = head; PendingWrite write = head;
if (write == null) { if (write == null) {
return; return;
} }
ReferenceCountUtil.safeRelease(write.msg); ReferenceCountUtil.safeRelease(write.msg);
ChannelPromise promise = write.promise; ChannelPromise promise = write.promise;
safeFail(promise, cause); safeFail(promise, cause);
recycle(write); recycle(write, true);
} }
/** /**
@ -139,22 +144,28 @@ public final class PendingWriteQueue {
*/ */
public ChannelFuture removeAndWriteAll() { public ChannelFuture removeAndWriteAll() {
assert ctx.executor().inEventLoop(); assert ctx.executor().inEventLoop();
if (size == 1) {
// No need to use ChannelPromiseAggregator for this case.
return removeAndWrite();
}
PendingWrite write = head; PendingWrite write = head;
if (write == null) { if (write == null) {
// empty so just return null // empty so just return null
return null; return null;
} }
if (size == 1) {
// No need to use ChannelPromiseAggregator for this case. // Guard against re-entrance by directly reset
return removeAndWrite(); head = tail = null;
} size = 0;
ChannelPromise p = ctx.newPromise(); ChannelPromise p = ctx.newPromise();
ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(p); ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(p);
while (write != null) { while (write != null) {
PendingWrite next = write.next; PendingWrite next = write.next;
Object msg = write.msg; Object msg = write.msg;
ChannelPromise promise = write.promise; ChannelPromise promise = write.promise;
recycle(write); recycle(write, false);
ctx.write(msg, promise); ctx.write(msg, promise);
aggregator.add(promise); aggregator.add(promise);
write = next; write = next;
@ -182,7 +193,7 @@ public final class PendingWriteQueue {
} }
Object msg = write.msg; Object msg = write.msg;
ChannelPromise promise = write.promise; ChannelPromise promise = write.promise;
recycle(write); recycle(write, true);
return ctx.write(msg, promise); return ctx.write(msg, promise);
} }
@ -200,7 +211,7 @@ public final class PendingWriteQueue {
} }
ChannelPromise promise = write.promise; ChannelPromise promise = write.promise;
ReferenceCountUtil.safeRelease(write.msg); ReferenceCountUtil.safeRelease(write.msg);
recycle(write); recycle(write, true);
return promise; return promise;
} }
@ -216,19 +227,21 @@ public final class PendingWriteQueue {
return write.msg; return write.msg;
} }
private void recycle(PendingWrite write) { private void recycle(PendingWrite write, boolean update) {
final PendingWrite next = write.next; final PendingWrite next = write.next;
final long writeSize = write.size; final long writeSize = write.size;
size --; if (update) {
if (next == null) {
if (next == null) { // Handled last PendingWrite so rest head and tail
// Handled last PendingWrite so rest head and tail // Guard against re-entrance by directly reset
head = tail = null; head = tail = null;
assert size == 0; size = 0;
} else { } else {
head = next; head = next;
assert size > 0; size --;
assert size > 0;
}
} }
write.recycle(); write.recycle();

View File

@ -194,6 +194,56 @@ public class PendingWriteQueueTest {
assertNull(channel.readOutbound()); assertNull(channel.readOutbound());
} }
@Test
public void testRemoveAndFailAllReentrance() {
EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter());
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
ChannelPromise promise = channel.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
queue.removeAndFailAll(new IllegalStateException());
}
});
queue.add(1L, promise);
ChannelPromise promise2 = channel.newPromise();
queue.add(2L, promise2);
queue.removeAndFailAll(new Exception());
assertFalse(promise.isSuccess());
assertFalse(promise2.isSuccess());
assertFalse(channel.finish());
}
@Test
public void testRemoveAndWriteAllReentrance() {
EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter());
final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext());
ChannelPromise promise = channel.newPromise();
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
queue.removeAndWriteAll();
}
});
queue.add(1L, promise);
ChannelPromise promise2 = channel.newPromise();
queue.add(2L, promise2);
queue.removeAndWriteAll();
channel.flush();
assertTrue(promise.isSuccess());
assertTrue(promise2.isSuccess());
assertTrue(channel.finish());
assertEquals(1L, channel.readOutbound());
assertEquals(2L, channel.readOutbound());
assertNull(channel.readOutbound());
assertNull(channel.readInbound());
}
private static class TestHandler extends ChannelDuplexHandler { private static class TestHandler extends ChannelDuplexHandler {
protected PendingWriteQueue queue; protected PendingWriteQueue queue;
private int expectedSize; private int expectedSize;