Implement user-defined writability flags

Related: #2945

Motivation:

Some special handlers such as TrafficShapingHandler need to override the
writability of a Channel to throttle the outbound traffic.

Modifications:

Add a new indexed property called 'user-defined writability flag' to
ChannelOutboundBuffer so that a handler can override the writability of
a Channel easily.

Result:

A handler can override the writability of a Channel using an unsafe API.
For example:

  Channel ch = ...;
  ch.unsafe().outboundBuffer().setUserDefinedWritability(1, false);
This commit is contained in:
Trustin Lee 2014-10-22 17:45:28 +09:00
parent 0666924e8c
commit d59629377c
2 changed files with 259 additions and 21 deletions

View File

@ -36,8 +36,14 @@ import java.util.concurrent.atomic.AtomicLongFieldUpdater;
/** /**
* (Transport implementors only) an internal data structure used by {@link AbstractChannel} to store its pending * (Transport implementors only) an internal data structure used by {@link AbstractChannel} to store its pending
* outbound write requests. * outbound write requests.
* * <p>
* All the methods should only be called by the {@link EventLoop} of the {@link Channel}. * All methods must be called by a transport implementation from an I/O thread, except the following ones:
* <ul>
* <li>{@link #size()} and {@link #isEmpty()}</li>
* <li>{@link #isWritable()}</li>
* <li>{@link #getUserDefinedWritability(int)} and {@link #setUserDefinedWritability(int, boolean)}</li>
* </ul>
* </p>
*/ */
public final class ChannelOutboundBuffer { public final class ChannelOutboundBuffer {
@ -70,21 +76,21 @@ public final class ChannelOutboundBuffer {
private static final AtomicLongFieldUpdater<ChannelOutboundBuffer> TOTAL_PENDING_SIZE_UPDATER; private static final AtomicLongFieldUpdater<ChannelOutboundBuffer> TOTAL_PENDING_SIZE_UPDATER;
@SuppressWarnings("unused") @SuppressWarnings("UnusedDeclaration")
private volatile long totalPendingSize; private volatile long totalPendingSize;
private static final AtomicIntegerFieldUpdater<ChannelOutboundBuffer> WRITABLE_UPDATER; private static final AtomicIntegerFieldUpdater<ChannelOutboundBuffer> UNWRITABLE_UPDATER;
@SuppressWarnings("FieldMayBeFinal") @SuppressWarnings("UnusedDeclaration")
private volatile int writable = 1; private volatile int unwritable;
static { static {
AtomicIntegerFieldUpdater<ChannelOutboundBuffer> writableUpdater = AtomicIntegerFieldUpdater<ChannelOutboundBuffer> unwritableUpdater =
PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "writable"); PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "unwritable");
if (writableUpdater == null) { if (unwritableUpdater == null) {
writableUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "writable"); unwritableUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "unwritable");
} }
WRITABLE_UPDATER = writableUpdater; UNWRITABLE_UPDATER = unwritableUpdater;
AtomicLongFieldUpdater<ChannelOutboundBuffer> pendingSizeUpdater = AtomicLongFieldUpdater<ChannelOutboundBuffer> pendingSizeUpdater =
PlatformDependent.newAtomicLongFieldUpdater(ChannelOutboundBuffer.class, "totalPendingSize"); PlatformDependent.newAtomicLongFieldUpdater(ChannelOutboundBuffer.class, "totalPendingSize");
@ -161,10 +167,8 @@ public final class ChannelOutboundBuffer {
} }
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size);
if (newWriteBufferSize > channel.config().getWriteBufferHighWaterMark()) { if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) {
if (WRITABLE_UPDATER.compareAndSet(this, 1, 0)) { setUnwritable();
channel.pipeline().fireChannelWritabilityChanged();
}
} }
} }
@ -178,10 +182,8 @@ public final class ChannelOutboundBuffer {
} }
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size);
if (newWriteBufferSize == 0 || newWriteBufferSize < channel.config().getWriteBufferLowWaterMark()) { if (newWriteBufferSize == 0 || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) {
if (WRITABLE_UPDATER.compareAndSet(this, 0, 1)) { setWritable();
channel.pipeline().fireChannelWritabilityChanged();
}
} }
} }
@ -438,8 +440,94 @@ public final class ChannelOutboundBuffer {
return nioBufferSize; return nioBufferSize;
} }
boolean isWritable() { /**
return writable != 0; * Returns {@code true} if and only if {@linkplain #totalPendingWriteBytes() the total number of pending bytes} did
* not exceed the write watermark of the {@link Channel} and
* no {@linkplain #setUserDefinedWritability(int, boolean) user-defined writability flag} has been set to
* {@code false}.
*/
public boolean isWritable() {
return unwritable == 0;
}
/**
* Returns {@code true} if and only if the user-defined writability flag at the specified index is set to
* {@code true}.
*/
public boolean getUserDefinedWritability(int index) {
return (unwritable & writabilityMask(index)) == 0;
}
/**
* Sets a user-defined writability flag at the specified index.
*/
public void setUserDefinedWritability(int index, boolean writable) {
if (writable) {
setUserDefinedWritability(index);
} else {
clearUserDefinedWritability(index);
}
}
private void setUserDefinedWritability(int index) {
final int mask = ~writabilityMask(index);
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue & mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) {
channel.pipeline().fireChannelWritabilityChanged();
}
break;
}
}
}
private void clearUserDefinedWritability(int index) {
final int mask = writabilityMask(index);
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue | mask;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) {
channel.pipeline().fireChannelWritabilityChanged();
}
break;
}
}
}
private static int writabilityMask(int index) {
if (index < 1 || index > 31) {
throw new IllegalArgumentException("index: " + index + " (expected: 1~31)");
}
return 1 << index;
}
private void setWritable() {
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue & ~1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue != 0 && newValue == 0) {
channel.pipeline().fireChannelWritabilityChanged();
}
break;
}
}
}
private void setUnwritable() {
for (;;) {
final int oldValue = unwritable;
final int newValue = oldValue | 1;
if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
if (oldValue == 0 && newValue != 0) {
channel.pipeline().fireChannelWritabilityChanged();
}
break;
}
}
} }
/** /**

View File

@ -17,6 +17,7 @@ package io.netty.channel;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import org.junit.Test; import org.junit.Test;
@ -24,6 +25,7 @@ import java.net.SocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import static io.netty.buffer.Unpooled.*; import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class ChannelOutboundBufferTest { public class ChannelOutboundBufferTest {
@ -203,4 +205,152 @@ public class ChannelOutboundBufferTest {
} }
} }
} }
@Test
public void testWritability() {
final StringBuilder buf = new StringBuilder();
EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
buf.append(ctx.channel().isWritable());
buf.append(' ');
}
});
ch.config().setWriteBufferLowWaterMark(128);
ch.config().setWriteBufferHighWaterMark(256);
// Ensure exceeding the low watermark does not make channel unwritable.
ch.write(buffer().writeZero(128));
assertThat(buf.toString(), is(""));
ch.unsafe().outboundBuffer().addFlush();
// Ensure exceeding the high watermark makes channel unwritable.
ch.write(buffer().writeZero(128));
assertThat(buf.toString(), is("false "));
// Ensure going down to the low watermark makes channel writable again by flushing the first write.
assertThat(ch.unsafe().outboundBuffer().remove(), is(true));
assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), is(128L));
assertThat(buf.toString(), is("false true "));
safeClose(ch);
}
@Test
public void testUserDefinedWritability() {
final StringBuilder buf = new StringBuilder();
EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
buf.append(ctx.channel().isWritable());
buf.append(' ');
}
});
ch.config().setWriteBufferLowWaterMark(128);
ch.config().setWriteBufferHighWaterMark(256);
ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer();
// Ensure that the default value of a user-defined writability flag is true.
for (int i = 1; i <= 30; i ++) {
assertThat(cob.getUserDefinedWritability(i), is(true));
}
// Ensure that setting a user-defined writability flag to false affects channel.isWritable();
cob.setUserDefinedWritability(1, false);
assertThat(buf.toString(), is("false "));
// Ensure that setting a user-defined writability flag to true affects channel.isWritable();
cob.setUserDefinedWritability(1, true);
assertThat(buf.toString(), is("false true "));
safeClose(ch);
}
@Test
public void testUserDefinedWritability2() {
final StringBuilder buf = new StringBuilder();
EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
buf.append(ctx.channel().isWritable());
buf.append(' ');
}
});
ch.config().setWriteBufferLowWaterMark(128);
ch.config().setWriteBufferHighWaterMark(256);
ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer();
// Ensure that setting a user-defined writability flag to false affects channel.isWritable()
cob.setUserDefinedWritability(1, false);
assertThat(buf.toString(), is("false "));
// Ensure that setting another user-defined writability flag to false does not trigger
// channelWritabilityChanged.
cob.setUserDefinedWritability(2, false);
assertThat(buf.toString(), is("false "));
// Ensure that setting only one user-defined writability flag to true does not affect channel.isWritable()
cob.setUserDefinedWritability(1, true);
assertThat(buf.toString(), is("false "));
// Ensure that setting all user-defined writability flags to true affects channel.isWritable()
cob.setUserDefinedWritability(2, true);
assertThat(buf.toString(), is("false true "));
safeClose(ch);
}
@Test
public void testMixedWritability() {
final StringBuilder buf = new StringBuilder();
EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
buf.append(ctx.channel().isWritable());
buf.append(' ');
}
});
ch.config().setWriteBufferLowWaterMark(128);
ch.config().setWriteBufferHighWaterMark(256);
ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer();
// Trigger channelWritabilityChanged() by writing a lot.
ch.write(buffer().writeZero(256));
assertThat(buf.toString(), is("false "));
// Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged()
cob.setUserDefinedWritability(1, false);
assertThat(buf.toString(), is("false "));
// Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged()
// because of the user-defined writability flag.
ch.flush();
assertThat(cob.totalPendingWriteBytes(), is(0L));
assertThat(buf.toString(), is("false "));
// Ensure that setting the user-defined writability flag to true triggers channelWritabilityChanged()
cob.setUserDefinedWritability(1, true);
assertThat(buf.toString(), is("false true "));
safeClose(ch);
}
private static void safeClose(EmbeddedChannel ch) {
ch.finish();
for (;;) {
ByteBuf m = ch.readOutbound();
if (m == null) {
break;
}
m.release();
}
}
} }