Revert "[#5028] Fix re-entrance issue with channelWritabilityChanged(...) and write(...)"

Motivation:
Revert d0943dcd30. Delaying the notification of writability change may lead to notification being missed. This is a ABA type of concurrency problem.

Modifications:
- Revert d0943dcd30.

Result:
channelWritabilityChange will be called on every change, and will not be suppressed due to ABA scenario.
This commit is contained in:
Scott Mitchell 2016-04-08 14:59:42 -07:00 committed by Norman Maurer
parent 69070c37ba
commit abce89d1bc
6 changed files with 242 additions and 220 deletions

View File

@ -89,7 +89,7 @@ public final class ChannelOutboundBuffer {
@SuppressWarnings("UnusedDeclaration")
private volatile int unwritable;
private final Runnable fireChannelWritabilityChangedTask;
private volatile Runnable fireChannelWritabilityChangedTask;
static {
AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater =
@ -107,9 +107,8 @@ public final class ChannelOutboundBuffer {
TOTAL_PENDING_SIZE_UPDATER = pendingSizeUpdater;
}
ChannelOutboundBuffer(final AbstractChannel channel) {
ChannelOutboundBuffer(AbstractChannel channel) {
this.channel = channel;
fireChannelWritabilityChangedTask = new ChannelWritabilityChangedTask(channel);
}
/**
@ -132,7 +131,7 @@ public final class ChannelOutboundBuffer {
// increment pending bytes after adding message to the unflushed arrays.
// See https://github.com/netty/netty/issues/1619
incrementPendingOutboundBytes(size, true);
incrementPendingOutboundBytes(size, false);
}
/**
@ -155,7 +154,7 @@ public final class ChannelOutboundBuffer {
if (!entry.promise.setUncancellable()) {
// Was cancelled so make sure we free up memory and notify about the freed bytes
int pending = entry.cancel();
decrementPendingOutboundBytes(pending, true);
decrementPendingOutboundBytes(pending, false, true);
}
entry = entry.next;
} while (entry != null);
@ -169,14 +168,18 @@ public final class ChannelOutboundBuffer {
* Increment the pending bytes which will be written at some point.
* This method is thread-safe!
*/
void incrementPendingOutboundBytes(long size, boolean notifyWritability) {
void incrementPendingOutboundBytes(long size) {
incrementPendingOutboundBytes(size, true);
}
private void incrementPendingOutboundBytes(long size, boolean invokeLater) {
if (size == 0) {
return;
}
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size);
if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) {
setUnwritable(notifyWritability);
setUnwritable(invokeLater);
}
}
@ -184,15 +187,19 @@ public final class ChannelOutboundBuffer {
* Decrement the pending bytes which will be written at some point.
* This method is thread-safe!
*/
void decrementPendingOutboundBytes(long size, boolean notifyWritability) {
void decrementPendingOutboundBytes(long size) {
decrementPendingOutboundBytes(size, true, true);
}
private void decrementPendingOutboundBytes(long size, boolean invokeLater, boolean notifyWritability) {
if (size == 0) {
return;
}
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size);
if (newWriteBufferSize == 0
|| newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) {
setWritable(notifyWritability);
if (notifyWritability && (newWriteBufferSize == 0
|| newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark())) {
setWritable(invokeLater);
}
}
@ -257,7 +264,7 @@ public final class ChannelOutboundBuffer {
// only release message, notify and decrement if it was not canceled before.
ReferenceCountUtil.safeRelease(msg);
safeSuccess(promise);
decrementPendingOutboundBytes(size, true);
decrementPendingOutboundBytes(size, false, true);
}
// recycle the entry
@ -293,7 +300,7 @@ public final class ChannelOutboundBuffer {
ReferenceCountUtil.safeRelease(msg);
safeFail(promise, cause);
decrementPendingOutboundBytes(size, notifyWritability);
decrementPendingOutboundBytes(size, false, notifyWritability);
}
// recycle the entry
@ -515,7 +522,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue & mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged();
fireChannelWritabilityChanged(true);
}
break;
}
@ -529,7 +536,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue | mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged();
fireChannelWritabilityChanged(true);
}
break;
}
@ -543,36 +550,48 @@ public final class ChannelOutboundBuffer {
return 1 << index;
}
private void setWritable(boolean notify) {
private void setWritable(boolean invokeLater) {
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue & ~1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (notify && oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged();
if (oldValue != 0 && newValue == 0) {
fireChannelWritabilityChanged(invokeLater);
}
break;
}
}
}
private void setUnwritable(boolean notify) {
private void setUnwritable(boolean invokeLater) {
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue | 1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (notify && oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged();
if (oldValue == 0 && newValue != 0) {
fireChannelWritabilityChanged(invokeLater);
}
break;
}
}
}
private void fireChannelWritabilityChanged() {
// Always invoke it later to prevent re-entrance bug.
// See https://github.com/netty/netty/issues/5028
channel.eventLoop().execute(fireChannelWritabilityChangedTask);
private void fireChannelWritabilityChanged(boolean invokeLater) {
final ChannelPipeline pipeline = channel.pipeline();
if (invokeLater) {
Runnable task = fireChannelWritabilityChangedTask;
if (task == null) {
fireChannelWritabilityChangedTask = task = new Runnable() {
@Override
public void run() {
pipeline.fireChannelWritabilityChanged();
}
};
}
channel.eventLoop().execute(task);
} else {
pipeline.fireChannelWritabilityChanged();
}
}
/**
@ -843,25 +862,4 @@ public final class ChannelOutboundBuffer {
return next;
}
}
private static final class ChannelWritabilityChangedTask implements Runnable {
private final Channel channel;
private boolean writable = true;
ChannelWritabilityChangedTask(Channel channel) {
this.channel = channel;
}
@Override
public void run() {
if (channel.isActive()) {
boolean newWritable = channel.isWritable();
if (writable != newWritable) {
writable = newWritable;
channel.pipeline().fireChannelWritabilityChanged();
}
}
}
}
}

View File

@ -470,9 +470,7 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker {
// Check for null as it may be set to null if the channel is closed already
if (buffer != null) {
task.size = ((AbstractChannel) ctx.channel()).estimatorHandle().size(msg) + WRITE_TASK_OVERHEAD;
// We increment the pending bytes but NOT call fireChannelWritabilityChanged() because this
// will be done automaticaly once we add the message to the ChannelOutboundBuffer.
buffer.incrementPendingOutboundBytes(task.size, false);
buffer.incrementPendingOutboundBytes(task.size);
} else {
task.size = 0;
}
@ -493,9 +491,7 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker {
ChannelOutboundBuffer buffer = ctx.channel().unsafe().outboundBuffer();
// Check for null as it may be set to null if the channel is closed already
if (ESTIMATE_TASK_SIZE_ON_SUBMIT && buffer != null) {
// We decrement the pending bytes but NOT call fireChannelWritabilityChanged() because this
// will be done automaticaly once we pick up the messages out of the buffer to actually write these.
buffer.decrementPendingOutboundBytes(size, false);
buffer.decrementPendingOutboundBytes(size);
}
invokeWriteNow(ctx, msg, promise);
} finally {

View File

@ -94,7 +94,7 @@ public final class PendingWriteQueue {
// if the channel was already closed when constructing the PendingWriteQueue.
// See https://github.com/netty/netty/issues/3967
if (buffer != null) {
buffer.incrementPendingOutboundBytes(write.size, true);
buffer.incrementPendingOutboundBytes(write.size);
}
}
@ -264,7 +264,7 @@ public final class PendingWriteQueue {
// if the channel was already closed when constructing the PendingWriteQueue.
// See https://github.com/netty/netty/issues/3967
if (buffer != null) {
buffer.decrementPendingOutboundBytes(writeSize, true);
buffer.decrementPendingOutboundBytes(writeSize);
}
}

View File

@ -0,0 +1,90 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import static org.junit.Assert.*;
class BaseChannelTest {
private final LoggingHandler loggingHandler;
BaseChannelTest() {
loggingHandler = new LoggingHandler();
}
ServerBootstrap getLocalServerBootstrap() {
EventLoopGroup serverGroup = new DefaultEventLoopGroup();
ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup);
sb.channel(LocalServerChannel.class);
sb.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
}
});
return sb;
}
Bootstrap getLocalClientBootstrap() {
EventLoopGroup clientGroup = new DefaultEventLoopGroup();
Bootstrap cb = new Bootstrap();
cb.channel(LocalChannel.class);
cb.group(clientGroup);
cb.handler(loggingHandler);
return cb;
}
static ByteBuf createTestBuf(int len) {
ByteBuf buf = Unpooled.buffer(len, len);
buf.setIndex(0, len);
return buf;
}
void assertLog(String firstExpected, String... otherExpected) {
String actual = loggingHandler.getLog();
if (firstExpected.equals(actual)) {
return;
}
for (String e: otherExpected) {
if (e.equals(actual)) {
return;
}
}
// Let the comparison fail with the first expectation.
assertEquals(firstExpected, actual);
}
void clearLog() {
loggingHandler.clear();
}
void setInterest(LoggingHandler.Event... events) {
loggingHandler.setInterest(events);
}
}

View File

@ -229,17 +229,11 @@ public class ChannelOutboundBufferTest {
// Ensure exceeding the high watermark makes channel unwritable.
ch.write(buffer().writeZero(128));
assertThat(buf.toString(), is(""));
ch.runPendingTasks();
assertThat(buf.toString(), is("false "));
// Ensure going down to the low watermark makes channel writable again by flushing the first write.
assertThat(ch.unsafe().outboundBuffer().remove(), is(true));
assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), is(128L));
assertThat(buf.toString(), is("false "));
ch.runPendingTasks();
assertThat(buf.toString(), is("false true "));
safeClose(ch);
@ -337,9 +331,6 @@ public class ChannelOutboundBufferTest {
// Trigger channelWritabilityChanged() by writing a lot.
ch.write(buffer().writeZero(256));
assertThat(buf.toString(), is(""));
ch.runPendingTasks();
assertThat(buf.toString(), is("false "));
// Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged()

View File

@ -17,100 +17,25 @@ package io.netty.channel;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.LoggingHandler.Event;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.Assert.*;
public class ReentrantChannelTest {
private EventLoopGroup clientGroup;
private EventLoopGroup serverGroup;
private LoggingHandler loggingHandler;
private ServerBootstrap getLocalServerBootstrap() {
ServerBootstrap sb = new ServerBootstrap();
sb.group(serverGroup);
sb.channel(LocalServerChannel.class);
sb.childHandler(new ChannelInboundHandlerAdapter());
return sb;
}
private Bootstrap getLocalClientBootstrap() {
Bootstrap cb = new Bootstrap();
cb.group(clientGroup);
cb.channel(LocalChannel.class);
cb.handler(loggingHandler);
return cb;
}
private static ByteBuf createTestBuf(int len) {
ByteBuf buf = Unpooled.buffer(len, len);
buf.setIndex(0, len);
return buf;
}
private void assertLog(String firstExpected, String... otherExpected) {
String actual = loggingHandler.getLog();
if (firstExpected.equals(actual)) {
return;
}
for (String e: otherExpected) {
if (e.equals(actual)) {
return;
}
}
// Let the comparison fail with the first expectation.
assertEquals(firstExpected, actual);
}
private void setInterest(LoggingHandler.Event... events) {
loggingHandler.setInterest(events);
}
@Before
public void setup() {
loggingHandler = new LoggingHandler();
clientGroup = new DefaultEventLoop();
serverGroup = new DefaultEventLoop();
}
@After
public void teardown() {
clientGroup.shutdownGracefully();
serverGroup.shutdownGracefully();
}
public class ReentrantChannelTest extends BaseChannelTest {
@Test
public void testWritabilityChangedWriteAndFlush() throws Exception {
testWritabilityChanged0(true);
}
public void testWritabilityChanged() throws Exception {
@Test
public void testWritabilityChangedWriteThenFlush() throws Exception {
testWritabilityChanged0(false);
}
LocalAddress addr = new LocalAddress("testWritabilityChanged");
private void testWritabilityChanged0(boolean writeAndFlush) throws Exception {
LocalAddress addr = new LocalAddress("testFlushInWritabilityChanged");
ServerBootstrap sb = getLocalServerBootstrap();
Channel serverChannel = sb.bind(addr).sync().channel();
sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -120,37 +45,110 @@ public class ReentrantChannelTest {
clientChannel.config().setWriteBufferLowWaterMark(512);
clientChannel.config().setWriteBufferHighWaterMark(1024);
assertTrue(clientChannel.isWritable());
if (writeAndFlush) {
clientChannel.writeAndFlush(createTestBuf(2000)).sync();
} else {
ChannelFuture future = clientChannel.write(createTestBuf(2000));
clientChannel.flush();
future.sync();
}
clientChannel.close().sync();
serverChannel.close().sync();
// Because of timing of the scheduling we either should see:
// - WRITE, FLUSH
// - WRITE, WRITABILITY: writable=false, FLUSH
// - WRITE, WRITABILITY: writable=false, FLUSH, WRITABILITY: writable=true
// What is supposed to happen from this point:
//
// This is the case as between the write and flush the EventLoop may already pick up the pending writes and
// put these into the ChannelOutboundBuffer. Once the flush then happen from outside the EventLoop we may be
// able to flush everything and also schedule the writabilityChanged event before the actual close takes
// place which means we may see another writability changed event to inform the channel is writable again.
// 1. Because this write attempt has been made from a non-I/O thread,
// ChannelOutboundBuffer.pendingWriteBytes will be increased before
// write() event is really evaluated.
// -> channelWritabilityChanged() will be triggered,
// because the Channel became unwritable.
//
// 2. The write() event is handled by the pipeline in an I/O thread.
// -> write() will be triggered.
//
// 3. Once the write() event is handled, ChannelOutboundBuffer.pendingWriteBytes
// will be decreased.
// -> channelWritabilityChanged() will be triggered,
// because the Channel became writable again.
//
// 4. The message is added to the ChannelOutboundBuffer and thus
// pendingWriteBytes will be increased again.
// -> channelWritabilityChanged() will be triggered.
//
// 5. The flush() event causes the write request in theChannelOutboundBuffer
// to be removed.
// -> flush() and channelWritabilityChanged() will be triggered.
//
// Note that the channelWritabilityChanged() in the step 4 can occur between
// the flush() and the channelWritabilityChanged() in the stap 5, because
// the flush() is invoked from a non-I/O thread while the other are from
// an I/O thread.
ChannelFuture future = clientChannel.write(createTestBuf(2000));
clientChannel.flush();
future.sync();
clientChannel.close().sync();
assertLog(
"WRITE\n" +
"FLUSH\n",
// Case 1:
"WRITABILITY: writable=false\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n",
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n"
);
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n");
}
/**
* Similar to {@link #testWritabilityChanged()} with slight variation.
*/
@Test
public void testFlushInWritabilityChanged() throws Exception {
LocalAddress addr = new LocalAddress("testFlushInWritabilityChanged");
ServerBootstrap sb = getLocalServerBootstrap();
sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
setInterest(Event.WRITE, Event.FLUSH, Event.WRITABILITY);
Channel clientChannel = cb.connect(addr).sync().channel();
clientChannel.config().setWriteBufferLowWaterMark(512);
clientChannel.config().setWriteBufferHighWaterMark(1024);
clientChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
if (!ctx.channel().isWritable()) {
ctx.channel().flush();
}
ctx.fireChannelWritabilityChanged();
}
});
assertTrue(clientChannel.isWritable());
clientChannel.write(createTestBuf(2000)).sync();
clientChannel.close().sync();
assertLog(
// Case 1:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n");
}
@Test
@ -159,7 +157,7 @@ public class ReentrantChannelTest {
LocalAddress addr = new LocalAddress("testWriteFlushPingPong");
ServerBootstrap sb = getLocalServerBootstrap();
Channel serverChannel = sb.bind(addr).sync().channel();
sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -193,7 +191,6 @@ public class ReentrantChannelTest {
clientChannel.writeAndFlush(createTestBuf(2000));
clientChannel.close().sync();
serverChannel.close().sync();
assertLog(
"WRITE\n" +
@ -217,7 +214,7 @@ public class ReentrantChannelTest {
LocalAddress addr = new LocalAddress("testCloseInFlush");
ServerBootstrap sb = getLocalServerBootstrap();
Channel serverChannel = sb.bind(addr).sync().channel();
sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -242,7 +239,6 @@ public class ReentrantChannelTest {
clientChannel.write(createTestBuf(2000)).sync();
clientChannel.closeFuture().sync();
serverChannel.close().sync();
assertLog("WRITE\nFLUSH\nCLOSE\n");
}
@ -253,7 +249,7 @@ public class ReentrantChannelTest {
LocalAddress addr = new LocalAddress("testFlushFailure");
ServerBootstrap sb = getLocalServerBootstrap();
Channel serverChannel = sb.bind(addr).sync().channel();
sb.bind(addr).sync().channel();
Bootstrap cb = getLocalClientBootstrap();
@ -283,56 +279,7 @@ public class ReentrantChannelTest {
}
clientChannel.closeFuture().sync();
serverChannel.close().sync();
assertLog("WRITE\nCLOSE\n");
}
// Test for https://github.com/netty/netty/issues/5028
@Test
public void testWriteRentrance() {
EmbeddedChannel channel = new EmbeddedChannel(new ChannelDuplexHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
for (int i = 0; i < 3; i++) {
ctx.write(i);
}
ctx.write(3, promise);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
if (ctx.channel().isWritable()) {
ctx.channel().writeAndFlush(-1);
}
}
});
channel.config().setMessageSizeEstimator(new MessageSizeEstimator() {
@Override
public Handle newHandle() {
return new Handle() {
@Override
public int size(Object msg) {
// Each message will just increase the pending bytes by 1.
return 1;
}
};
}
});
channel.config().setWriteBufferLowWaterMark(3);
channel.config().setWriteBufferHighWaterMark(4);
channel.writeOutbound(-1);
assertTrue(channel.finish());
assertSequenceOutbound(channel);
assertSequenceOutbound(channel);
assertNull(channel.readOutbound());
}
private static void assertSequenceOutbound(EmbeddedChannel channel) {
assertEquals(0, channel.readOutbound());
assertEquals(1, channel.readOutbound());
assertEquals(2, channel.readOutbound());
assertEquals(3, channel.readOutbound());
}
}