From ee9cbda9f00dd616076f044270ca075265c9c068 Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Wed, 22 Oct 2014 17:45:28 +0900 Subject: [PATCH] Implement user-defined writability flags Related: #2945 Motivation: Some special handlers such as TrafficShapingHandler need to override the writability of a Channel to throttle the outbound traffic. Modifications: Add a new indexed property called 'user-defined writability flag' to ChannelOutboundBuffer so that a handler can override the writability of a Channel easily. Result: A handler can override the writability of a Channel using an unsafe API. For example: Channel ch = ...; ch.unsafe().outboundBuffer().setUserDefinedWritability(1, false); --- .../netty/channel/ChannelOutboundBuffer.java | 130 ++++++++++++--- .../channel/ChannelOutboundBufferTest.java | 150 ++++++++++++++++++ 2 files changed, 259 insertions(+), 21 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index e41b3724f6..d9f5447ede 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -36,8 +36,14 @@ import java.util.concurrent.atomic.AtomicLongFieldUpdater; /** * (Transport implementors only) an internal data structure used by {@link AbstractChannel} to store its pending * outbound write requests. - * - * All the methods should only be called by the {@link EventLoop} of the {@link Channel}. + *

+ * All methods must be called by a transport implementation from an I/O thread, except the following ones: + *

+ *

*/ public final class ChannelOutboundBuffer { @@ -70,21 +76,21 @@ public final class ChannelOutboundBuffer { private static final AtomicLongFieldUpdater TOTAL_PENDING_SIZE_UPDATER; - @SuppressWarnings("unused") + @SuppressWarnings("UnusedDeclaration") private volatile long totalPendingSize; - private static final AtomicIntegerFieldUpdater WRITABLE_UPDATER; + private static final AtomicIntegerFieldUpdater UNWRITABLE_UPDATER; - @SuppressWarnings("FieldMayBeFinal") - private volatile int writable = 1; + @SuppressWarnings("UnusedDeclaration") + private volatile int unwritable; static { - AtomicIntegerFieldUpdater writableUpdater = - PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "writable"); - if (writableUpdater == null) { - writableUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "writable"); + AtomicIntegerFieldUpdater unwritableUpdater = + PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "unwritable"); + if (unwritableUpdater == null) { + unwritableUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "unwritable"); } - WRITABLE_UPDATER = writableUpdater; + UNWRITABLE_UPDATER = unwritableUpdater; AtomicLongFieldUpdater pendingSizeUpdater = PlatformDependent.newAtomicLongFieldUpdater(ChannelOutboundBuffer.class, "totalPendingSize"); @@ -161,10 +167,8 @@ public final class ChannelOutboundBuffer { } long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); - if (newWriteBufferSize > channel.config().getWriteBufferHighWaterMark()) { - if (WRITABLE_UPDATER.compareAndSet(this, 1, 0)) { - channel.pipeline().fireChannelWritabilityChanged(); - } + if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) { + setUnwritable(); } } @@ -178,10 +182,8 @@ public final class ChannelOutboundBuffer { } long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); - if (newWriteBufferSize == 0 || newWriteBufferSize < channel.config().getWriteBufferLowWaterMark()) { - if (WRITABLE_UPDATER.compareAndSet(this, 0, 1)) { - channel.pipeline().fireChannelWritabilityChanged(); - } + if (newWriteBufferSize == 0 || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) { + setWritable(); } } @@ -438,8 +440,94 @@ public final class ChannelOutboundBuffer { return nioBufferSize; } - boolean isWritable() { - return writable != 0; + /** + * Returns {@code true} if and only if {@linkplain #totalPendingWriteBytes() the total number of pending bytes} did + * not exceed the write watermark of the {@link Channel} and + * no {@linkplain #setUserDefinedWritability(int, boolean) user-defined writability flag} has been set to + * {@code false}. + */ + public boolean isWritable() { + return unwritable == 0; + } + + /** + * Returns {@code true} if and only if the user-defined writability flag at the specified index is set to + * {@code true}. + */ + public boolean getUserDefinedWritability(int index) { + return (unwritable & writabilityMask(index)) == 0; + } + + /** + * Sets a user-defined writability flag at the specified index. + */ + public void setUserDefinedWritability(int index, boolean writable) { + if (writable) { + setUserDefinedWritability(index); + } else { + clearUserDefinedWritability(index); + } + } + + private void setUserDefinedWritability(int index) { + final int mask = ~writabilityMask(index); + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue & mask; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue != 0 && newValue == 0) { + channel.pipeline().fireChannelWritabilityChanged(); + } + break; + } + } + } + + private void clearUserDefinedWritability(int index) { + final int mask = writabilityMask(index); + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue | mask; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue == 0 && newValue != 0) { + channel.pipeline().fireChannelWritabilityChanged(); + } + break; + } + } + } + + private static int writabilityMask(int index) { + if (index < 1 || index > 31) { + throw new IllegalArgumentException("index: " + index + " (expected: 1~31)"); + } + return 1 << index; + } + + private void setWritable() { + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue & ~1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue != 0 && newValue == 0) { + channel.pipeline().fireChannelWritabilityChanged(); + } + break; + } + } + } + + private void setUnwritable() { + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue | 1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue == 0 && newValue != 0) { + channel.pipeline().fireChannelWritabilityChanged(); + } + break; + } + } } /** diff --git a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java index 296a0e7d13..e283a109ed 100644 --- a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java @@ -17,6 +17,7 @@ package io.netty.channel; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.CharsetUtil; import org.junit.Test; @@ -24,6 +25,7 @@ import java.net.SocketAddress; import java.nio.ByteBuffer; import static io.netty.buffer.Unpooled.*; +import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; public class ChannelOutboundBufferTest { @@ -203,4 +205,152 @@ public class ChannelOutboundBufferTest { } } } + + @Test + public void testWritability() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + // Ensure exceeding the low watermark does not make channel unwritable. + ch.write(buffer().writeZero(128)); + assertThat(buf.toString(), is("")); + + ch.unsafe().outboundBuffer().addFlush(); + + // Ensure exceeding the high watermark makes channel unwritable. + ch.write(buffer().writeZero(128)); + assertThat(buf.toString(), is("false ")); + + // Ensure going down to the low watermark makes channel writable again by flushing the first write. + assertThat(ch.unsafe().outboundBuffer().remove(), is(true)); + assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), is(128L)); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + public void testUserDefinedWritability() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer(); + + // Ensure that the default value of a user-defined writability flag is true. + for (int i = 1; i <= 30; i ++) { + assertThat(cob.getUserDefinedWritability(i), is(true)); + } + + // Ensure that setting a user-defined writability flag to false affects channel.isWritable(); + cob.setUserDefinedWritability(1, false); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting a user-defined writability flag to true affects channel.isWritable(); + cob.setUserDefinedWritability(1, true); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + public void testUserDefinedWritability2() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer(); + + // Ensure that setting a user-defined writability flag to false affects channel.isWritable() + cob.setUserDefinedWritability(1, false); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting another user-defined writability flag to false does not trigger + // channelWritabilityChanged. + cob.setUserDefinedWritability(2, false); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting only one user-defined writability flag to true does not affect channel.isWritable() + cob.setUserDefinedWritability(1, true); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting all user-defined writability flags to true affects channel.isWritable() + cob.setUserDefinedWritability(2, true); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + public void testMixedWritability() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer(); + + // Trigger channelWritabilityChanged() by writing a lot. + ch.write(buffer().writeZero(256)); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged() + cob.setUserDefinedWritability(1, false); + assertThat(buf.toString(), is("false ")); + + // Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged() + // because of the user-defined writability flag. + ch.flush(); + assertThat(cob.totalPendingWriteBytes(), is(0L)); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting the user-defined writability flag to true triggers channelWritabilityChanged() + cob.setUserDefinedWritability(1, true); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + private static void safeClose(EmbeddedChannel ch) { + ch.finish(); + for (;;) { + ByteBuf m = ch.readOutbound(); + if (m == null) { + break; + } + m.release(); + } + } }