[#5028] Fix re-entrance issue with channelWritabilityChanged(...) and write(...)

Motivation:

When always triggered fireChannelWritabilityChanged() directly when the update the pending bytes in the ChannelOutboundBuffer was made from within the EventLoop. This is problematic as this can cause some re-entrance issue if the user has a custom ChannelOutboundHandler that does multiple writes from within the write(...) method and also has a handler that will intercept the channelWritabilityChanged event and trigger another write when the Channel is writable. This can also easily happen if the user just use a MessageToMessageEncoder subclass and triggers a write from channelWritabilityChanged().

Modification:

- Always invoke the fireChannelWritabilityChanged() later on the EventLoop.
- Only trigger the fireChannelWritabilityChanged() if the channel is still active.
- when write(...) is called from outside the EventLoop we only increase the pending bytes in the outbound buffer (so that Channel.isWritable() is updated directly) but not cause a fireChannelWritabilityChanged(). The fireChannelWritabilityChanged() is then triggered once the task is picked up by the EventLoop as usual.

Result:

No more re-entrance possible because of writes from within channelWritabilityChanged(...) method.
This commit is contained in:
Norman Maurer 2016-04-01 11:45:43 +02:00
parent 13a8ebade4
commit f23e901c99
2 changed files with 47 additions and 37 deletions

View File

@ -96,7 +96,7 @@ public final class ChannelOutboundBuffer {
@SuppressWarnings("UnusedDeclaration") @SuppressWarnings("UnusedDeclaration")
private volatile int unwritable; private volatile int unwritable;
private volatile Runnable fireChannelWritabilityChangedTask; private final Runnable fireChannelWritabilityChangedTask;
static { static {
AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater = AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater =
@ -116,6 +116,7 @@ public final class ChannelOutboundBuffer {
ChannelOutboundBuffer(AbstractChannel channel) { ChannelOutboundBuffer(AbstractChannel channel) {
this.channel = channel; this.channel = channel;
fireChannelWritabilityChangedTask = new ChannelWritabilityChangedTask(channel);
} }
/** /**
@ -138,7 +139,7 @@ public final class ChannelOutboundBuffer {
// increment pending bytes after adding message to the unflushed arrays. // increment pending bytes after adding message to the unflushed arrays.
// See https://github.com/netty/netty/issues/1619 // See https://github.com/netty/netty/issues/1619
incrementPendingOutboundBytes(entry.pendingSize, false); incrementPendingOutboundBytes(entry.pendingSize);
} }
/** /**
@ -161,7 +162,7 @@ public final class ChannelOutboundBuffer {
if (!entry.promise.setUncancellable()) { if (!entry.promise.setUncancellable()) {
// Was cancelled so make sure we free up memory and notify about the freed bytes // Was cancelled so make sure we free up memory and notify about the freed bytes
int pending = entry.cancel(); int pending = entry.cancel();
decrementPendingOutboundBytes(pending, false, true); decrementPendingOutboundBytes(pending);
} }
entry = entry.next; entry = entry.next;
} while (entry != null); } while (entry != null);
@ -176,17 +177,13 @@ public final class ChannelOutboundBuffer {
* This method is thread-safe! * This method is thread-safe!
*/ */
void incrementPendingOutboundBytes(long size) { void incrementPendingOutboundBytes(long size) {
incrementPendingOutboundBytes(size, true);
}
private void incrementPendingOutboundBytes(long size, boolean invokeLater) {
if (size == 0) { if (size == 0) {
return; return;
} }
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size);
if (newWriteBufferSize > channel.config().getWriteBufferHighWaterMark()) { if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) {
setUnwritable(invokeLater); setUnwritable();
} }
} }
@ -195,17 +192,17 @@ public final class ChannelOutboundBuffer {
* This method is thread-safe! * This method is thread-safe!
*/ */
void decrementPendingOutboundBytes(long size) { void decrementPendingOutboundBytes(long size) {
decrementPendingOutboundBytes(size, true, true); decrementPendingOutboundBytes(size, true);
} }
private void decrementPendingOutboundBytes(long size, boolean invokeLater, boolean notifyWritability) { private void decrementPendingOutboundBytes(long size, boolean notifyWritability) {
if (size == 0) { if (size == 0) {
return; return;
} }
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size);
if (notifyWritability && newWriteBufferSize < channel.config().getWriteBufferLowWaterMark()) { if (notifyWritability && newWriteBufferSize < channel.config().getWriteBufferLowWaterMark()) {
setWritable(invokeLater); setWritable(notifyWritability);
} }
} }
@ -270,7 +267,7 @@ public final class ChannelOutboundBuffer {
// only release message, notify and decrement if it was not canceled before. // only release message, notify and decrement if it was not canceled before.
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
safeSuccess(promise); safeSuccess(promise);
decrementPendingOutboundBytes(size, false, true); decrementPendingOutboundBytes(size);
} }
// recycle the entry // recycle the entry
@ -306,7 +303,7 @@ public final class ChannelOutboundBuffer {
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
safeFail(promise, cause); safeFail(promise, cause);
decrementPendingOutboundBytes(size, false, notifyWritability); decrementPendingOutboundBytes(size, notifyWritability);
} }
// recycle the entry // recycle the entry
@ -528,7 +525,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue & mask; final int newValue = oldValue & mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) { if (oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged(true); fireChannelWritabilityChanged();
} }
break; break;
} }
@ -542,7 +539,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue | mask; final int newValue = oldValue | mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) { if (oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged(true); fireChannelWritabilityChanged();
} }
break; break;
} }
@ -556,48 +553,36 @@ public final class ChannelOutboundBuffer {
return 1 << index; return 1 << index;
} }
private void setWritable(boolean invokeLater) { private void setWritable(boolean notify) {
for (;;) { for (;;) {
final int oldValue = unwritable; final int oldValue = unwritable;
final int newValue = oldValue & ~1; final int newValue = oldValue & ~1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) { if (notify && oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged(invokeLater); fireChannelWritabilityChanged();
} }
break; break;
} }
} }
} }
private void setUnwritable(boolean invokeLater) { private void setUnwritable() {
for (;;) { for (;;) {
final int oldValue = unwritable; final int oldValue = unwritable;
final int newValue = oldValue | 1; final int newValue = oldValue | 1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) { if (oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged(invokeLater); fireChannelWritabilityChanged();
} }
break; break;
} }
} }
} }
private void fireChannelWritabilityChanged(boolean invokeLater) { private void fireChannelWritabilityChanged() {
final ChannelPipeline pipeline = channel.pipeline(); // Always invoke it later to prevent re-entrance bug.
if (invokeLater) { // See https://github.com/netty/netty/issues/5028
Runnable task = fireChannelWritabilityChangedTask; channel.eventLoop().execute(fireChannelWritabilityChangedTask);
if (task == null) {
fireChannelWritabilityChangedTask = task = new Runnable() {
@Override
public void run() {
pipeline.fireChannelWritabilityChanged();
}
};
}
channel.eventLoop().execute(task);
} else {
pipeline.fireChannelWritabilityChanged();
}
} }
/** /**
@ -838,4 +823,20 @@ public final class ChannelOutboundBuffer {
return next; return next;
} }
} }
private static final class ChannelWritabilityChangedTask implements Runnable {
private final Channel channel;
ChannelWritabilityChangedTask(Channel channel) {
this.channel = channel;
}
@Override
public void run() {
// Only need to fire if the Channel is still active otherwise it may be confusing to receive such an event.
if (channel.isActive()) {
channel.pipeline().fireChannelWritabilityChanged();
}
}
}
} }

View File

@ -230,6 +230,9 @@ public class ChannelOutboundBufferTest {
// Ensure exceeding the high watermark makes channel unwritable. // Ensure exceeding the high watermark makes channel unwritable.
ch.write(buffer().writeZero(127)); ch.write(buffer().writeZero(127));
assertThat(buf.toString(), is(""));
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure going down to the low watermark makes channel writable again by flushing the first write. // Ensure going down to the low watermark makes channel writable again by flushing the first write.
@ -237,6 +240,9 @@ public class ChannelOutboundBufferTest {
assertThat(ch.unsafe().outboundBuffer().remove(), is(true)); assertThat(ch.unsafe().outboundBuffer().remove(), is(true));
assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(),
is(127L + ChannelOutboundBuffer.CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD)); is(127L + ChannelOutboundBuffer.CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD));
assertThat(buf.toString(), is("false "));
ch.runPendingTasks();
assertThat(buf.toString(), is("false true ")); assertThat(buf.toString(), is("false true "));
safeClose(ch); safeClose(ch);
@ -334,6 +340,9 @@ public class ChannelOutboundBufferTest {
// Trigger channelWritabilityChanged() by writing a lot. // Trigger channelWritabilityChanged() by writing a lot.
ch.write(buffer().writeZero(257)); ch.write(buffer().writeZero(257));
assertThat(buf.toString(), is(""));
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged() // Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged()