Fix race when changing channel writability

Motivation:

A rare issue was seen wherein a channel would become
persistently unwritable, even though the underlying
connection was working fine, as evidenced by the fact
that heartbeats were still being received from the
other side.

Examination of WriteRequestQueue and the threading
model around it revealed the following race:

T1: offer MessageEvent
    go over high water mark
    context switch before calling setUnwritable()

T2: poll multiple MessageEvent
    go under low water mark
    call setWritable()

T1: call setUnwritable()

At which point the channel is incorrectly marked as
unwritable and will remain so until another back-and-forth
across the low water mark, which may never happen if
the application uses writability to apply backpressure
on the writer.

The commit log shows that this issue has been present
since commit 2127c8eafe
i.e. version 3.10, when the introduction of user-defined
writability decoupled isWritable from the bufferSize.

Solution:

The use of a mutex was decided against due to the serious
performance implications.

To avoid subtle breakage in backward-compat, deferring the
call to setUnwritable() in the io thread was also decided
against.

Instead, the buffer size is re-checked against the water
mark immediately before firing the interestChanged event
and the writability change is rolled back if necessary.

In addition, a stronger invariant is enforced on poll()
to make it easier to reason about the possible water
mark/writability transitions. This requires minor changes
to NioChannelTest which manually calls poll().
This commit is contained in:
Hugues Bruant 2016-01-20 21:48:03 -05:00 committed by Norman Maurer
parent a1801ac5b2
commit 6baa592772
4 changed files with 83 additions and 90 deletions

View File

@ -28,8 +28,6 @@ import org.jboss.netty.util.internal.ThreadLocalBoolean;
import java.net.InetSocketAddress;
import java.nio.channels.SelectableChannel;
import java.nio.channels.WritableByteChannel;
import java.util.Collection;
import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
@ -63,7 +61,7 @@ abstract class AbstractNioChannel<C extends SelectableChannel & WritableByteChan
/**
* Queue of write {@link MessageEvent}s.
*/
final Queue<MessageEvent> writeBufferQueue = new WriteRequestQueue();
final WriteRequestQueue writeBufferQueue = new WriteRequestQueue();
/**
* Keeps track of the number of bytes that the {@link WriteRequestQueue} currently
@ -173,7 +171,7 @@ abstract class AbstractNioChannel<C extends SelectableChannel & WritableByteChan
abstract InetSocketAddress getRemoteSocketAddress() throws Exception;
private final class WriteRequestQueue implements Queue<MessageEvent> {
final class WriteRequestQueue {
private final ThreadLocalBoolean notifying = new ThreadLocalBoolean();
private final Queue<MessageEvent> queue;
@ -182,88 +180,64 @@ abstract class AbstractNioChannel<C extends SelectableChannel & WritableByteChan
queue = new ConcurrentLinkedQueue<MessageEvent>();
}
public MessageEvent remove() {
return queue.remove();
}
public MessageEvent element() {
return queue.element();
}
public MessageEvent peek() {
return queue.peek();
}
public int size() {
return queue.size();
}
public boolean isEmpty() {
return queue.isEmpty();
}
public Iterator<MessageEvent> iterator() {
return queue.iterator();
}
public Object[] toArray() {
return queue.toArray();
}
public <T> T[] toArray(T[] a) {
return queue.toArray(a);
}
public boolean containsAll(Collection<?> c) {
return queue.containsAll(c);
}
public boolean addAll(Collection<? extends MessageEvent> c) {
return queue.addAll(c);
}
public boolean removeAll(Collection<?> c) {
return queue.removeAll(c);
}
public boolean retainAll(Collection<?> c) {
return queue.retainAll(c);
}
public void clear() {
queue.clear();
}
public boolean add(MessageEvent e) {
return queue.add(e);
}
public boolean remove(Object o) {
return queue.remove(o);
}
public boolean contains(Object o) {
return queue.contains(o);
}
public boolean offer(MessageEvent e) {
boolean success = queue.offer(e);
assert success;
int messageSize = getMessageSize(e);
int newWriteBufferSize = writeBufferSize.addAndGet(messageSize);
int highWaterMark = getConfig().getWriteBufferHighWaterMark();
final int highWaterMark = getConfig().getWriteBufferHighWaterMark();
if (newWriteBufferSize >= highWaterMark) {
if (newWriteBufferSize - messageSize < highWaterMark) {
highWaterMarkCounter.incrementAndGet();
if (setUnwritable()) {
if (!isIoThread(AbstractNioChannel.this)) {
fireChannelInterestChangedLater(AbstractNioChannel.this);
} else if (!notifying.get()) {
notifying.set(Boolean.TRUE);
fireChannelInterestChanged(AbstractNioChannel.this);
notifying.set(Boolean.FALSE);
if (isIoThread(AbstractNioChannel.this)) {
if (!notifying.get()) {
notifying.set(Boolean.TRUE);
fireChannelInterestChanged(AbstractNioChannel.this);
notifying.set(Boolean.FALSE);
}
} else {
// Adjusting writability requires careful synchronization.
//
// Consider for instance:
//
// T1 repeated offer: go above high water mark
// context switch *before* calling setUnwritable
// T2 repeated poll: go under low water mark
// context switch *after* calling setWritable
// T1 setUnwritable
//
// At this point the channel is incorrectly marked unwritable and would
// remain so until the high water mark were exceeded again, which may
// never happen if the application did control flow based on writability.
//
// The simplest solution would be to use a mutex to protect both the
// buffer size and the writability bit, however that would impose a
// serious performance penalty.
//
// A better approach would be to always call setUnwritable in the io
// thread, which does not impact performance as the interestChanged
// event has to be fired from there anyway.
//
// However this could break code which expects isWritable to immediately
// be updated after a write() from a non-io thread.
//
// Instead, we re-check the buffer size before firing the interest
// changed event and revert the change to the writability bit if needed.
worker.executeInIoThread(new Runnable() {
@Override
public void run() {
if (writeBufferSize.get() >= highWaterMark || setWritable()) {
fireChannelInterestChanged(AbstractNioChannel.this);
}
}
});
}
}
}
@ -281,10 +255,13 @@ abstract class AbstractNioChannel<C extends SelectableChannel & WritableByteChan
if (newWriteBufferSize == 0 || newWriteBufferSize < lowWaterMark) {
if (newWriteBufferSize + messageSize >= lowWaterMark) {
highWaterMarkCounter.decrementAndGet();
if (isConnected() && setWritable()) {
if (!isIoThread(AbstractNioChannel.this)) {
fireChannelInterestChangedLater(AbstractNioChannel.this);
} else if (!notifying.get()) {
if (isConnected()) {
// write events are only ever processed from the channel io thread
// except when cleaning up the buffer, which only happens on close()
// changing that would require additional synchronization around
// writeBufferSize and writability changes.
assert isIoThread(AbstractNioChannel.this);
if (setWritable()) {
notifying.set(Boolean.TRUE);
fireChannelInterestChanged(AbstractNioChannel.this);
notifying.set(Boolean.FALSE);

View File

@ -19,6 +19,7 @@ import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.socket.Worker;
import org.jboss.netty.channel.socket.nio.AbstractNioChannel.WriteRequestQueue;
import org.jboss.netty.channel.socket.nio.SocketSendBufferPool.SendBuffer;
import org.jboss.netty.util.ThreadNameDeterminer;
import org.jboss.netty.util.ThreadRenamingRunnable;
@ -34,7 +35,6 @@ import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.Executor;
@ -170,7 +170,7 @@ abstract class AbstractNioWorker extends AbstractNioSelector implements Worker {
final SocketSendBufferPool sendBufferPool = this.sendBufferPool;
final WritableByteChannel ch = channel.channel;
final Queue<MessageEvent> writeBuffer = channel.writeBufferQueue;
final WriteRequestQueue writeBuffer = channel.writeBufferQueue;
final int writeSpinCount = channel.getConfig().getWriteSpinCount();
List<Throwable> causes = null;
@ -418,7 +418,7 @@ abstract class AbstractNioWorker extends AbstractNioSelector implements Worker {
fireExceptionCaught = true;
}
Queue<MessageEvent> writeBuffer = channel.writeBufferQueue;
WriteRequestQueue writeBuffer = channel.writeBufferQueue;
for (;;) {
evt = writeBuffer.poll();
if (evt == null) {

View File

@ -22,6 +22,7 @@ import org.jboss.netty.channel.ChannelException;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.ReceiveBufferSizePredictor;
import org.jboss.netty.channel.socket.nio.AbstractNioChannel.WriteRequestQueue;
import java.io.IOException;
import java.net.SocketAddress;
@ -240,7 +241,7 @@ public class NioDatagramWorker extends AbstractNioWorker {
final SocketSendBufferPool sendBufferPool = this.sendBufferPool;
final DatagramChannel ch = ((NioDatagramChannel) channel).getDatagramChannel();
final Queue<MessageEvent> writeBuffer = channel.writeBufferQueue;
final WriteRequestQueue writeBuffer = channel.writeBufferQueue;
final int writeSpinCount = channel.getConfig().getWriteSpinCount();
synchronized (channel.writeLock) {
// inform the channel that write is in-progress

View File

@ -33,6 +33,7 @@ import org.junit.Test;
import java.nio.channels.SocketChannel;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
@ -93,6 +94,20 @@ public class NioChannelTest {
}, true);
return f;
}
MessageEvent poll() {
final DefaultChannelFuture f = new DefaultChannelFuture(this, false);
final AtomicReference<MessageEvent> me = new AtomicReference<MessageEvent>();
worker.executeInIoThread(new Runnable() {
@Override
public void run() {
me.set(writeBufferQueue.poll());
f.setSuccess();
}
});
f.syncUninterruptibly();
return me.get();
}
}
private static ChannelBuffer writeZero(int size) {
@ -129,17 +144,17 @@ public class NioChannelTest {
channel.flushIoTasks().await();
assertEquals("false ", buf.toString());
// Ensure going down to the low watermark makes channel writable again by flushing the first write.
assertNotNull(channel.writeBufferQueue.poll());
assertNotNull(channel.poll());
assertEquals(128, channel.writeBufferSize.get());
// once more since in Netty 3.9, the check is < and not <=
assertNotNull(channel.writeBufferQueue.poll());
assertNotNull(channel.poll());
assertEquals(64, channel.writeBufferSize.get());
assertEquals(0, channel.getInterestOps() & Channel.OP_WRITE);
channel.flushIoTasks().await();
assertEquals("false true ", buf.toString());
while (! channel.writeBufferQueue.isEmpty()) {
channel.writeBufferQueue.poll();
channel.poll();
}
worker.shutdown();
executor.shutdown();
@ -178,7 +193,7 @@ public class NioChannelTest {
assertEquals("false true ", buf.toString());
while (! channel.writeBufferQueue.isEmpty()) {
channel.writeBufferQueue.poll();
channel.poll();
}
worker.shutdown();
executor.shutdown();
@ -228,7 +243,7 @@ public class NioChannelTest {
assertEquals("false true ", buf.toString());
while (! channel.writeBufferQueue.isEmpty()) {
channel.writeBufferQueue.poll();
channel.poll();
}
worker.shutdown();
executor.shutdown();
@ -268,7 +283,7 @@ public class NioChannelTest {
assertEquals("false ", buf.toString());
// Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChanged()
// because of the user-defined writability flag.
assertNotNull(channel.writeBufferQueue.poll());
assertNotNull(channel.poll());
assertEquals(0, channel.writeBufferSize.get());
assertTrue((channel.getInterestOps() & Channel.OP_WRITE) != 0);
channel.flushIoTasks().await();
@ -296,8 +311,8 @@ public class NioChannelTest {
assertTrue((channel.getInterestOps() & Channel.OP_WRITE) != 0);
channel.flushIoTasks().await();
assertEquals("false true false ", buf.toString());
// Ensure reducing the totalPendingWriteBytes down to zero does trigger channelWritabilityChannged()
assertNotNull(channel.writeBufferQueue.poll());
// Ensure reducing the totalPendingWriteBytes down to zero does trigger channelWritabilityChanged()
assertNotNull(channel.poll());
assertEquals(0, channel.writeBufferSize.get());
assertEquals(0, channel.getInterestOps() & Channel.OP_WRITE);
channel.flushIoTasks().await();
@ -321,7 +336,7 @@ public class NioChannelTest {
assertEquals("false true false true false ", buf.toString());
// Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged()
// because of the user-defined writability flag.
assertNotNull(channel.writeBufferQueue.poll());
assertNotNull(channel.poll());
assertEquals(0, channel.writeBufferSize.get());
assertEquals(0, channel.getInterestOps() & Channel.OP_WRITE);
channel.flushIoTasks().await();
@ -341,7 +356,7 @@ public class NioChannelTest {
assertEquals("false true false true false true false ", buf.toString());
// Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged()
// because of the user-defined writability flag.
assertNotNull(channel.writeBufferQueue.poll());
assertNotNull(channel.poll());
assertEquals(0, channel.writeBufferSize.get());
assertTrue((channel.getInterestOps() & Channel.OP_WRITE) != 0);
channel.flushIoTasks().await();
@ -353,7 +368,7 @@ public class NioChannelTest {
assertEquals("false true false true false true false true ", buf.toString());
while (! channel.writeBufferQueue.isEmpty()) {
channel.writeBufferQueue.poll();
channel.poll();
}
worker.shutdown();
executor.shutdown();