Trigger channelWritabilityChanged() later to avoid reentrance

Related: #3212

Motivation:

When SslHandler and ChunkedWriteHandler exists in a pipeline together,
it is possible that ChunkedWriteHandler.channelWritabilityChanged()
invokes SslHandler.flush() and vice versa. Because they can feed each
other (i.e. ChunkedWriteHandler.channelWritabilityChanged() ->
SslHandler.flush() -> ChunkedWriteHandler.channelWritabilityChanged() ->
..), they can fall into an inconsistent state due to reentrance (e.g.
bad MAC record at the remote peer due to incorrect ordering.)

Modifications:

- Trigger channelWritabilityChanged() using EventLoop.execute() when
  there's a chance where channelWritabilityChanged() can cause a
  reentrance issue
- Fix test failures caused by the modification

Result:

Fix the handler reentrance issues related with a
channelWritabilityChanged() event
This commit is contained in:
Trustin Lee 2014-12-10 18:36:53 +09:00
parent 7f9fb95702
commit 85ec4d9cc4
4 changed files with 124 additions and 27 deletions

View File

@ -84,6 +84,8 @@ public final class ChannelOutboundBuffer {
@SuppressWarnings("UnusedDeclaration") @SuppressWarnings("UnusedDeclaration")
private volatile int unwritable; private volatile int unwritable;
private volatile Runnable fireChannelWritabilityChangedTask;
static { static {
AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater = AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater =
PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "unwritable"); PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "unwritable");
@ -124,7 +126,7 @@ public final class ChannelOutboundBuffer {
// increment pending bytes after adding message to the unflushed arrays. // increment pending bytes after adding message to the unflushed arrays.
// See https://github.com/netty/netty/issues/1619 // See https://github.com/netty/netty/issues/1619
incrementPendingOutboundBytes(size); incrementPendingOutboundBytes(size, false);
} }
/** /**
@ -147,7 +149,7 @@ public final class ChannelOutboundBuffer {
if (!entry.promise.setUncancellable()) { if (!entry.promise.setUncancellable()) {
// Was cancelled so make sure we free up memory and notify about the freed bytes // Was cancelled so make sure we free up memory and notify about the freed bytes
int pending = entry.cancel(); int pending = entry.cancel();
decrementPendingOutboundBytes(pending); decrementPendingOutboundBytes(pending, false);
} }
entry = entry.next; entry = entry.next;
} while (entry != null); } while (entry != null);
@ -162,13 +164,17 @@ public final class ChannelOutboundBuffer {
* This method is thread-safe! * This method is thread-safe!
*/ */
void incrementPendingOutboundBytes(long size) { void incrementPendingOutboundBytes(long size) {
incrementPendingOutboundBytes(size, true);
}
private void incrementPendingOutboundBytes(long size, boolean invokeLater) {
if (size == 0) { if (size == 0) {
return; return;
} }
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size);
if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) { if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) {
setUnwritable(); setUnwritable(invokeLater);
} }
} }
@ -177,13 +183,17 @@ public final class ChannelOutboundBuffer {
* This method is thread-safe! * This method is thread-safe!
*/ */
void decrementPendingOutboundBytes(long size) { void decrementPendingOutboundBytes(long size) {
decrementPendingOutboundBytes(size, true);
}
private void decrementPendingOutboundBytes(long size, boolean invokeLater) {
if (size == 0) { if (size == 0) {
return; return;
} }
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size);
if (newWriteBufferSize == 0 || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) { if (newWriteBufferSize == 0 || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) {
setWritable(); setWritable(invokeLater);
} }
} }
@ -247,7 +257,7 @@ public final class ChannelOutboundBuffer {
// only release message, notify and decrement if it was not canceled before. // only release message, notify and decrement if it was not canceled before.
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
safeSuccess(promise); safeSuccess(promise);
decrementPendingOutboundBytes(size); decrementPendingOutboundBytes(size, false);
} }
// recycle the entry // recycle the entry
@ -278,7 +288,7 @@ public final class ChannelOutboundBuffer {
ReferenceCountUtil.safeRelease(msg); ReferenceCountUtil.safeRelease(msg);
safeFail(promise, cause); safeFail(promise, cause);
decrementPendingOutboundBytes(size); decrementPendingOutboundBytes(size, false);
} }
// recycle the entry // recycle the entry
@ -476,7 +486,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue & mask; final int newValue = oldValue & mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) { if (oldValue != 0 && newValue == 0) {
channel.pipeline().fireChannelWritabilityChanged(); fireChannelWritabilityChanged(true);
} }
break; break;
} }
@ -490,7 +500,7 @@ public final class ChannelOutboundBuffer {
final int newValue = oldValue | mask; final int newValue = oldValue | mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) { if (oldValue == 0 && newValue != 0) {
channel.pipeline().fireChannelWritabilityChanged(); fireChannelWritabilityChanged(true);
} }
break; break;
} }
@ -504,32 +514,50 @@ public final class ChannelOutboundBuffer {
return 1 << index; return 1 << index;
} }
private void setWritable() { private void setWritable(boolean invokeLater) {
for (;;) { for (;;) {
final int oldValue = unwritable; final int oldValue = unwritable;
final int newValue = oldValue & ~1; final int newValue = oldValue & ~1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) { if (oldValue != 0 && newValue == 0) {
channel.pipeline().fireChannelWritabilityChanged(); fireChannelWritabilityChanged(invokeLater);
} }
break; break;
} }
} }
} }
private void setUnwritable() { private void setUnwritable(boolean invokeLater) {
for (;;) { for (;;) {
final int oldValue = unwritable; final int oldValue = unwritable;
final int newValue = oldValue | 1; final int newValue = oldValue | 1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) { if (oldValue == 0 && newValue != 0) {
channel.pipeline().fireChannelWritabilityChanged(); fireChannelWritabilityChanged(invokeLater);
} }
break; break;
} }
} }
} }
private void fireChannelWritabilityChanged(boolean invokeLater) {
final ChannelPipeline pipeline = channel.pipeline();
if (invokeLater) {
Runnable task = fireChannelWritabilityChangedTask;
if (task == null) {
fireChannelWritabilityChangedTask = task = new Runnable() {
@Override
public void run() {
pipeline.fireChannelWritabilityChanged();
}
};
}
channel.eventLoop().execute(task);
} else {
pipeline.fireChannelWritabilityChanged();
}
}
/** /**
* Returns the number of flushed messages in this {@link ChannelOutboundBuffer}. * Returns the number of flushed messages in this {@link ChannelOutboundBuffer}.
*/ */

View File

@ -64,9 +64,19 @@ class BaseChannelTest {
return buf; return buf;
} }
void assertLog(String expected) { void assertLog(String firstExpected, String... otherExpected) {
String actual = loggingHandler.getLog(); String actual = loggingHandler.getLog();
assertEquals(expected, actual); if (firstExpected.equals(actual)) {
return;
}
for (String e: otherExpected) {
if (e.equals(actual)) {
return;
}
}
// Let the comparison fail with the first expectation.
assertEquals(firstExpected, actual);
} }
void clearLog() { void clearLog() {

View File

@ -261,10 +261,12 @@ public class ChannelOutboundBufferTest {
// Ensure that setting a user-defined writability flag to false affects channel.isWritable(); // Ensure that setting a user-defined writability flag to false affects channel.isWritable();
cob.setUserDefinedWritability(1, false); cob.setUserDefinedWritability(1, false);
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure that setting a user-defined writability flag to true affects channel.isWritable(); // Ensure that setting a user-defined writability flag to true affects channel.isWritable();
cob.setUserDefinedWritability(1, true); cob.setUserDefinedWritability(1, true);
ch.runPendingTasks();
assertThat(buf.toString(), is("false true ")); assertThat(buf.toString(), is("false true "));
safeClose(ch); safeClose(ch);
@ -288,19 +290,23 @@ public class ChannelOutboundBufferTest {
// Ensure that setting a user-defined writability flag to false affects channel.isWritable() // Ensure that setting a user-defined writability flag to false affects channel.isWritable()
cob.setUserDefinedWritability(1, false); cob.setUserDefinedWritability(1, false);
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure that setting another user-defined writability flag to false does not trigger // Ensure that setting another user-defined writability flag to false does not trigger
// channelWritabilityChanged. // channelWritabilityChanged.
cob.setUserDefinedWritability(2, false); cob.setUserDefinedWritability(2, false);
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure that setting only one user-defined writability flag to true does not affect channel.isWritable() // Ensure that setting only one user-defined writability flag to true does not affect channel.isWritable()
cob.setUserDefinedWritability(1, true); cob.setUserDefinedWritability(1, true);
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure that setting all user-defined writability flags to true affects channel.isWritable() // Ensure that setting all user-defined writability flags to true affects channel.isWritable()
cob.setUserDefinedWritability(2, true); cob.setUserDefinedWritability(2, true);
ch.runPendingTasks();
assertThat(buf.toString(), is("false true ")); assertThat(buf.toString(), is("false true "));
safeClose(ch); safeClose(ch);
@ -328,6 +334,7 @@ public class ChannelOutboundBufferTest {
// Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged() // Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged()
cob.setUserDefinedWritability(1, false); cob.setUserDefinedWritability(1, false);
ch.runPendingTasks();
assertThat(buf.toString(), is("false ")); assertThat(buf.toString(), is("false "));
// Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged() // Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged()
@ -338,6 +345,7 @@ public class ChannelOutboundBufferTest {
// Ensure that setting the user-defined writability flag to true triggers channelWritabilityChanged() // Ensure that setting the user-defined writability flag to true triggers channelWritabilityChanged()
cob.setUserDefinedWritability(1, true); cob.setUserDefinedWritability(1, true);
ch.runPendingTasks();
assertThat(buf.toString(), is("false true ")); assertThat(buf.toString(), is("false true "));
safeClose(ch); safeClose(ch);

View File

@ -45,21 +45,62 @@ public class ReentrantChannelTest extends BaseChannelTest {
clientChannel.config().setWriteBufferLowWaterMark(512); clientChannel.config().setWriteBufferLowWaterMark(512);
clientChannel.config().setWriteBufferHighWaterMark(1024); clientChannel.config().setWriteBufferHighWaterMark(1024);
// What is supposed to happen from this point:
//
// 1. Because this write attempt has been made from a non-I/O thread,
// ChannelOutboundBuffer.pendingWriteBytes will be increased before
// write() event is really evaluated.
// -> channelWritabilityChanged() will be triggered,
// because the Channel became unwritable.
//
// 2. The write() event is handled by the pipeline in an I/O thread.
// -> write() will be triggered.
//
// 3. Once the write() event is handled, ChannelOutboundBuffer.pendingWriteBytes
// will be decreased.
// -> channelWritabilityChanged() will be triggered,
// because the Channel became writable again.
//
// 4. The message is added to the ChannelOutboundBuffer and thus
// pendingWriteBytes will be increased again.
// -> channelWritabilityChanged() will be triggered.
//
// 5. The flush() event causes the write request in theChannelOutboundBuffer
// to be removed.
// -> flush() and channelWritabilityChanged() will be triggered.
//
// Note that the channelWritabilityChanged() in the step 4 can occur between
// the flush() and the channelWritabilityChanged() in the stap 5, because
// the flush() is invoked from a non-I/O thread while the other are from
// an I/O thread.
ChannelFuture future = clientChannel.write(createTestBuf(2000)); ChannelFuture future = clientChannel.write(createTestBuf(2000));
clientChannel.flush(); clientChannel.flush();
future.sync(); future.sync();
clientChannel.close().sync(); clientChannel.close().sync();
assertLog( assertLog(
// Case 1:
"WRITABILITY: writable=false\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"WRITABILITY: writable=true\n" +
"WRITE\n" + "WRITE\n" +
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n"); "WRITABILITY: writable=true\n");
} }
/**
* Similar to {@link #testWritabilityChanged()} with slight variation.
*/
@Test @Test
public void testFlushInWritabilityChanged() throws Exception { public void testFlushInWritabilityChanged() throws Exception {
@ -87,16 +128,26 @@ public class ReentrantChannelTest extends BaseChannelTest {
}); });
assertTrue(clientChannel.isWritable()); assertTrue(clientChannel.isWritable());
clientChannel.write(createTestBuf(2000)).sync(); clientChannel.write(createTestBuf(2000)).sync();
clientChannel.close().sync(); clientChannel.close().sync();
assertLog( assertLog(
// Case 1:
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITE\n" +
"WRITABILITY: writable=false\n" +
"WRITABILITY: writable=false\n" +
"FLUSH\n" +
"WRITABILITY: writable=true\n",
// Case 2:
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITE\n" + "WRITE\n" +
"WRITABILITY: writable=false\n" + "WRITABILITY: writable=false\n" +
"FLUSH\n" + "FLUSH\n" +
"WRITABILITY: writable=true\n" +
"WRITABILITY: writable=true\n"); "WRITABILITY: writable=true\n");
} }