[#5028] Fix re-entrance issue with channelWritabilityChanged(...) and write(...)

Motivation:

When always triggered fireChannelWritabilityChanged() directly when the update the pending bytes in the ChannelOutboundBuffer was made from within the EventLoop. This is problematic as this can cause some re-entrance issue if the user has a custom ChannelOutboundHandler that does multiple writes from within the write(...) method and also has a handler that will intercept the channelWritabilityChanged event and trigger another write when the Channel is writable. This can also easily happen if the user just use a MessageToMessageEncoder subclass and triggers a write from channelWritabilityChanged().

Beside this we also triggered fireChannelWritabilityChanged() too often when a user did a write from outside the EventLoop. In this case we increased the pending bytes of the outboundbuffer before scheduled the actual write and decreased again before the write then takes place. Both of this may trigger a fireChannelWritabilityChanged() event which then may be re-triggered once the actual write ends again in the ChannelOutboundBuffer.

The third gotcha was that a user may get multiple events even if the writability of the channel not changed.

Modification:

- Always invoke the fireChannelWritabilityChanged() later on the EventLoop.
- Only trigger the fireChannelWritabilityChanged() if the channel is still active and if the writability of the channel changed. No need to cause events that were already triggered without a real writability change.
- when write(...) is called from outside the EventLoop we only increase the pending bytes in the outbound buffer (so that Channel.isWritable() is updated directly) but not cause a fireChannelWritabilityChanged(). The fireChannelWritabilityChanged() is then triggered once the task is picked up by the EventLoop as usual.

Result:

No more re-entrance possible because of writes from within channelWritabilityChanged(...) method and no events without a real writability change.
This commit is contained in:
Norman Maurer 2016-04-01 11:45:43 +02:00
parent ea3ffb8536
commit d0943dcd30
6 changed files with 214 additions and 236 deletions

View File

@ -89,7 +89,7 @@ public final class ChannelOutboundBuffer {
@SuppressWarnings("UnusedDeclaration")
private volatile int unwritable;
private volatile Runnable fireChannelWritabilityChangedTask;
private final Runnable fireChannelWritabilityChangedTask;
static {
AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater =
@ -107,8 +107,9 @@ public final class ChannelOutboundBuffer {
TOTAL_PENDING_SIZE_UPDATER = pendingSizeUpdater;
}
ChannelOutboundBuffer(AbstractChannel channel) {
ChannelOutboundBuffer(final AbstractChannel channel) {
this.channel = channel;
fireChannelWritabilityChangedTask = new ChannelWritabilityChangedTask(channel);
}
/**
@ -131,7 +132,7 @@ public final class ChannelOutboundBuffer {
// increment pending bytes after adding message to the unflushed arrays.
// See https://github.com/netty/netty/issues/1619
incrementPendingOutboundBytes(size, false);
incrementPendingOutboundBytes(size, true);
}
/**
@ -154,7 +155,7 @@ public final class ChannelOutboundBuffer {
if (!entry.promise.setUncancellable()) {
// Was cancelled so make sure we free up memory and notify about the freed bytes
int pending = entry.cancel();
decrementPendingOutboundBytes(pending, false, true);
decrementPendingOutboundBytes(pending, true);
}
entry = entry.next;
} while (entry != null);
@ -168,18 +169,14 @@ public final class ChannelOutboundBuffer {
* Increment the pending bytes which will be written at some point.
* This method is thread-safe!
*/
void incrementPendingOutboundBytes(long size) {
incrementPendingOutboundBytes(size, true);
}
private void incrementPendingOutboundBytes(long size, boolean invokeLater) {
void incrementPendingOutboundBytes(long size, boolean notifyWritability) {
if (size == 0) {
return;
}
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size);
if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) {
setUnwritable(invokeLater);
setUnwritable(notifyWritability);
}
}
@ -187,19 +184,15 @@ public final class ChannelOutboundBuffer {
* Decrement the pending bytes which will be written at some point.
* This method is thread-safe!
*/
void decrementPendingOutboundBytes(long size) {
decrementPendingOutboundBytes(size, true, true);
}
private void decrementPendingOutboundBytes(long size, boolean invokeLater, boolean notifyWritability) {
void decrementPendingOutboundBytes(long size, boolean notifyWritability) {
if (size == 0) {
return;
}
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size);
if (notifyWritability && (newWriteBufferSize == 0
|| newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark())) {
setWritable(invokeLater);
if (newWriteBufferSize == 0
|| newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) {
setWritable(notifyWritability);
}
}
@ -264,7 +257,7 @@ public final class ChannelOutboundBuffer {
// only release message, notify and decrement if it was not canceled before.
ReferenceCountUtil.safeRelease(msg);
safeSuccess(promise);
decrementPendingOutboundBytes(size, false, true);
decrementPendingOutboundBytes(size, true);
}
// recycle the entry
@ -300,7 +293,7 @@ public final class ChannelOutboundBuffer {
ReferenceCountUtil.safeRelease(msg);
safeFail(promise, cause);
decrementPendingOutboundBytes(size, false, notifyWritability);
decrementPendingOutboundBytes(size, notifyWritability);
}
// recycle the entry
@ -522,7 +515,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue & mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged(true);
fireChannelWritabilityChanged();
}
break;
}
@ -536,7 +529,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue | mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged(true);
fireChannelWritabilityChanged();
}
break;
}
@ -550,48 +543,36 @@ public final class ChannelOutboundBuffer {
return 1 << index;
}
private void setWritable(boolean invokeLater) {
private void setWritable(boolean notify) {
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue & ~1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged(invokeLater);
if (notify && oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged();
}
break;
}
}
}
private void setUnwritable(boolean invokeLater) {
private void setUnwritable(boolean notify) {
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue | 1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged(invokeLater);
if (notify && oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged();
}
break;
}
}
}
private void fireChannelWritabilityChanged(boolean invokeLater) {
final ChannelPipeline pipeline = channel.pipeline();
if (invokeLater) {
Runnable task = fireChannelWritabilityChangedTask;
if (task == null) {
fireChannelWritabilityChangedTask = task = new Runnable() {
@Override
public void run() {
pipeline.fireChannelWritabilityChanged();
}
};
}
channel.eventLoop().execute(task);
} else {
pipeline.fireChannelWritabilityChanged();
}
private void fireChannelWritabilityChanged() {
// Always invoke it later to prevent re-entrance bug.
// See https://github.com/netty/netty/issues/5028
channel.eventLoop().execute(fireChannelWritabilityChangedTask);
}
/**
@ -862,4 +843,25 @@ public final class ChannelOutboundBuffer {
return next;
}
}
private static final class ChannelWritabilityChangedTask implements Runnable {
private final Channel channel;
private boolean writable = true;
ChannelWritabilityChangedTask(Channel channel) {
this.channel = channel;
}
@Override
public void run() {
if (channel.isActive()) {
boolean newWritable = channel.isWritable();
if (writable != newWritable) {
writable = newWritable;
channel.pipeline().fireChannelWritabilityChanged();
}
}
}
}
}

View File

@ -470,7 +470,9 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker {
// Check for null as it may be set to null if the channel is closed already
if (buffer != null) {
task.size = ((AbstractChannel) ctx.channel()).estimatorHandle().size(msg) + WRITE_TASK_OVERHEAD;
buffer.incrementPendingOutboundBytes(task.size);
// We increment the pending bytes but NOT call fireChannelWritabilityChanged() because this
// will be done automaticaly once we add the message to the ChannelOutboundBuffer.
buffer.incrementPendingOutboundBytes(task.size, false);
} else {
task.size = 0;
}
@ -491,7 +493,9 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker {
ChannelOutboundBuffer buffer = ctx.channel().unsafe().outboundBuffer();
// Check for null as it may be set to null if the channel is closed already
if (ESTIMATE_TASK_SIZE_ON_SUBMIT && buffer != null) {
buffer.decrementPendingOutboundBytes(size);
// We decrement the pending bytes but NOT call fireChannelWritabilityChanged() because this
// will be done automaticaly once we pick up the messages out of the buffer to actually write these.
buffer.decrementPendingOutboundBytes(size, false);
}
invokeWriteNow(ctx, msg, promise);
} finally {

View File

@ -94,7 +94,7 @@ public final class PendingWriteQueue {
// if the channel was already closed when constructing the PendingWriteQueue.
// See https://github.com/netty/netty/issues/3967
if (buffer != null) {
buffer.incrementPendingOutboundBytes(write.size);
buffer.incrementPendingOutboundBytes(write.size, true);
}
}
@ -264,7 +264,7 @@ public final class PendingWriteQueue {
// if the channel was already closed when constructing the PendingWriteQueue.
// See https://github.com/netty/netty/issues/3967
if (buffer != null) {
buffer.decrementPendingOutboundBytes(writeSize);
buffer.decrementPendingOutboundBytes(writeSize, true);
}
}

View File

@ -1,90 +0,0 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import static org.junit.Assert.*;
class BaseChannelTest {
private final LoggingHandler loggingHandler;
BaseChannelTest() {
loggingHandler = new LoggingHandler();
}
ServerBootstrap getLocalServerBootstrap() {
EventLoopGroup serverGroup = new DefaultEventLoopGroup();
ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup);
sb.channel(LocalServerChannel.class);
sb.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
}
});
return sb;
}
Bootstrap getLocalClientBootstrap() {
EventLoopGroup clientGroup = new DefaultEventLoopGroup();
Bootstrap cb = new Bootstrap();
cb.channel(LocalChannel.class);
cb.group(clientGroup);
cb.handler(loggingHandler);
return cb;
}
static ByteBuf createTestBuf(int len) {
ByteBuf buf = Unpooled.buffer(len, len);
buf.setIndex(0, len);
return buf;
}
void assertLog(String firstExpected, String... otherExpected) {
String actual = loggingHandler.getLog();
if (firstExpected.equals(actual)) {
return;
}
for (String e: otherExpected) {
if (e.equals(actual)) {
return;
}
}
// Let the comparison fail with the first expectation.
assertEquals(firstExpected, actual);
}
void clearLog() {
loggingHandler.clear();
}
void setInterest(LoggingHandler.Event... events) {
loggingHandler.setInterest(events);
}
}

View File

@ -229,11 +229,17 @@ public class ChannelOutboundBufferTest {
// Ensure exceeding the high watermark makes channel unwritable.
ch.write(buffer().writeZero(128));
assertThat(buf.toString(), is(""));
ch.runPendingTasks();
assertThat(buf.toString(), is("false "));
// Ensure going down to the low watermark makes channel writable again by flushing the first write.
assertThat(ch.unsafe().outboundBuffer().remove(), is(true));
assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), is(128L));
assertThat(buf.toString(), is("false "));
ch.runPendingTasks();
assertThat(buf.toString(), is("false true "));
safeClose(ch);
@ -331,6 +337,9 @@ public class ChannelOutboundBufferTest {
// Trigger channelWritabilityChanged() by writing a lot.
ch.write(buffer().writeZero(256));
assertThat(buf.toString(), is(""));
ch.runPendingTasks();
assertThat(buf.toString(), is("false "));
// Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged()

View File

@ -17,97 +17,100 @@ package io.netty.channel;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.LoggingHandler.Event;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.Assert.*;
public class ReentrantChannelTest extends BaseChannelTest {
public class ReentrantChannelTest {
@Test
public void testWritabilityChanged() throws Exception {
private EventLoopGroup clientGroup;
private EventLoopGroup serverGroup;
private LoggingHandler loggingHandler;
LocalAddress addr = new LocalAddress("testWritabilityChanged");
private ServerBootstrap getLocalServerBootstrap() {
ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup);
sb.channel(LocalServerChannel.class);
sb.childHandler(new ChannelInboundHandlerAdapter());
ServerBootstrap sb = getLocalServerBootstrap();
sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
setInterest(Event.WRITE, Event.FLUSH, Event.WRITABILITY);
Channel clientChannel = cb.connect(addr).sync().channel();
clientChannel.config().setWriteBufferLowWaterMark(512);
clientChannel.config().setWriteBufferHighWaterMark(1024);
// What is supposed to happen from this point:
//
// 1. Because this write attempt has been made from a non-I/O thread,
// ChannelOutboundBuffer.pendingWriteBytes will be increased before
// write() event is really evaluated.
// -> channelWritabilityChanged() will be triggered,
// because the Channel became unwritable.
//
// 2. The write() event is handled by the pipeline in an I/O thread.
// -> write() will be triggered.
//
// 3. Once the write() event is handled, ChannelOutboundBuffer.pendingWriteBytes
// will be decreased.
// -> channelWritabilityChanged() will be triggered,
// because the Channel became writable again.
//
// 4. The message is added to the ChannelOutboundBuffer and thus
// pendingWriteBytes will be increased again.
// -> channelWritabilityChanged() will be triggered.
//
// 5. The flush() event causes the write request in theChannelOutboundBuffer
// to be removed.
// -> flush() and channelWritabilityChanged() will be triggered.
//
// Note that the channelWritabilityChanged() in the step 4 can occur between
// the flush() and the channelWritabilityChanged() in the stap 5, because
// the flush() is invoked from a non-I/O thread while the other are from
// an I/O thread.
ChannelFuture future = clientChannel.write(createTestBuf(2000));
clientChannel.flush();
future.sync();
clientChannel.close().sync();
assertLog(
// Case 1:
"WRITABILITY: writable=false\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n");
return sb;
}
private Bootstrap getLocalClientBootstrap() {
Bootstrap cb = new Bootstrap();
cb.group(clientGroup);
cb.channel(LocalChannel.class);
cb.handler(loggingHandler);
return cb;
}
private static ByteBuf createTestBuf(int len) {
ByteBuf buf = Unpooled.buffer(len, len);
buf.setIndex(0, len);
return buf;
}
private void assertLog(String firstExpected, String... otherExpected) {
String actual = loggingHandler.getLog();
if (firstExpected.equals(actual)) {
return;
}
for (String e: otherExpected) {
if (e.equals(actual)) {
return;
}
}
// Let the comparison fail with the first expectation.
assertEquals(firstExpected, actual);
}
private void setInterest(LoggingHandler.Event... events) {
loggingHandler.setInterest(events);
}
@Before
public void setup() {
loggingHandler = new LoggingHandler();
clientGroup = new DefaultEventLoop();
serverGroup = new DefaultEventLoop();
}
@After
public void teardown() {
clientGroup.shutdownGracefully();
serverGroup.shutdownGracefully();
}
/**
* Similar to {@link #testWritabilityChanged()} with slight variation.
*/
@Test
public void testFlushInWritabilityChanged() throws Exception {
public void testWritabilityChangedWriteAndFlush() throws Exception {
testWritabilityChanged0(true);
}
@Test
public void testWritabilityChangedWriteThenFlush() throws Exception {
testWritabilityChanged0(false);
}
private void testWritabilityChanged0(boolean writeAndFlush) throws Exception {
LocalAddress addr = new LocalAddress("testFlushInWritabilityChanged");
ServerBootstrap sb = getLocalServerBootstrap();
sb.bind(addr).sync().channel();
Channel serverChannel = sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -117,38 +120,37 @@ public class ReentrantChannelTest extends BaseChannelTest {
clientChannel.config().setWriteBufferLowWaterMark(512);
clientChannel.config().setWriteBufferHighWaterMark(1024);
clientChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
if (!ctx.channel().isWritable()) {
ctx.channel().flush();
}
ctx.fireChannelWritabilityChanged();
}
});
assertTrue(clientChannel.isWritable());
clientChannel.write(createTestBuf(2000)).sync();
if (writeAndFlush) {
clientChannel.writeAndFlush(createTestBuf(2000)).sync();
} else {
ChannelFuture future = clientChannel.write(createTestBuf(2000));
clientChannel.flush();
future.sync();
}
clientChannel.close().sync();
serverChannel.close().sync();
// Because of timing of the scheduling we either should see:
// - WRITE, FLUSH
// - WRITE, WRITABILITY: writable=false, FLUSH
// - WRITE, WRITABILITY: writable=false, FLUSH, WRITABILITY: writable=true
//
// This is the case as between the write and flush the EventLoop may already pick up the pending writes and
// put these into the ChannelOutboundBuffer. Once the flush then happen from outside the EventLoop we may be
// able to flush everything and also schedule the writabilityChanged event before the actual close takes
// place which means we may see another writability changed event to inform the channel is writable again.
assertLog(
// Case 1:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"FLUSH\n",
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"FLUSH\n",
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n");
"WRITABILITY: writable=true\n"
);
}
@Test
@ -157,7 +159,7 @@ public class ReentrantChannelTest extends BaseChannelTest {
LocalAddress addr = new LocalAddress("testWriteFlushPingPong");
ServerBootstrap sb = getLocalServerBootstrap();
sb.bind(addr).sync().channel();
Channel serverChannel = sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -191,6 +193,7 @@ public class ReentrantChannelTest extends BaseChannelTest {
clientChannel.writeAndFlush(createTestBuf(2000));
clientChannel.close().sync();
serverChannel.close().sync();
assertLog(
"WRITE\n" +
@ -214,7 +217,7 @@ public class ReentrantChannelTest extends BaseChannelTest {
LocalAddress addr = new LocalAddress("testCloseInFlush");
ServerBootstrap sb = getLocalServerBootstrap();
sb.bind(addr).sync().channel();
Channel serverChannel = sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -239,6 +242,7 @@ public class ReentrantChannelTest extends BaseChannelTest {
clientChannel.write(createTestBuf(2000)).sync();
clientChannel.closeFuture().sync();
serverChannel.close().sync();
assertLog("WRITE\nFLUSH\nCLOSE\n");
}
@ -249,7 +253,7 @@ public class ReentrantChannelTest extends BaseChannelTest {
LocalAddress addr = new LocalAddress("testFlushFailure");
ServerBootstrap sb = getLocalServerBootstrap();
sb.bind(addr).sync().channel();
Channel serverChannel = sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -279,7 +283,56 @@ public class ReentrantChannelTest extends BaseChannelTest {
}
clientChannel.closeFuture().sync();
serverChannel.close().sync();
assertLog("WRITE\nCLOSE\n");
}
// Test for https://github.com/netty/netty/issues/5028
@Test
public void testWriteRentrance() {
EmbeddedChannel channel = new EmbeddedChannel(new ChannelDuplexHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
for (int i = 0; i < 3; i++) {
ctx.write(i);
}
ctx.write(3, promise);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
if (ctx.channel().isWritable()) {
ctx.channel().writeAndFlush(-1);
}
}
});
channel.config().setMessageSizeEstimator(new MessageSizeEstimator() {
@Override
public Handle newHandle() {
return new Handle() {
@Override
public int size(Object msg) {
// Each message will just increase the pending bytes by 1.
return 1;
}
};
}
});
channel.config().setWriteBufferLowWaterMark(3);
channel.config().setWriteBufferHighWaterMark(4);
channel.writeOutbound(-1);
assertTrue(channel.finish());
assertSequenceOutbound(channel);
assertSequenceOutbound(channel);
assertNull(channel.readOutbound());
}
private static void assertSequenceOutbound(EmbeddedChannel channel) {
assertEquals(0, channel.readOutbound());
assertEquals(1, channel.readOutbound());
assertEquals(2, channel.readOutbound());
assertEquals(3, channel.readOutbound());
}
}