From ee9cbda9f00dd616076f044270ca075265c9c068 Mon Sep 17 00:00:00 2001
From: Trustin Lee
Date: Wed, 22 Oct 2014 17:45:28 +0900
Subject: [PATCH] Implement user-defined writability flags
Related: #2945
Motivation:
Some special handlers such as TrafficShapingHandler need to override the
writability of a Channel to throttle the outbound traffic.
Modifications:
Add a new indexed property called 'user-defined writability flag' to
ChannelOutboundBuffer so that a handler can override the writability of
a Channel easily.
Result:
A handler can override the writability of a Channel using an unsafe API.
For example:
Channel ch = ...;
ch.unsafe().outboundBuffer().setUserDefinedWritability(1, false);
---
.../netty/channel/ChannelOutboundBuffer.java | 130 ++++++++++++---
.../channel/ChannelOutboundBufferTest.java | 150 ++++++++++++++++++
2 files changed, 259 insertions(+), 21 deletions(-)
diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java
index e41b3724f6..d9f5447ede 100644
--- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java
+++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java
@@ -36,8 +36,14 @@ import java.util.concurrent.atomic.AtomicLongFieldUpdater;
/**
* (Transport implementors only) an internal data structure used by {@link AbstractChannel} to store its pending
* outbound write requests.
- *
- * All the methods should only be called by the {@link EventLoop} of the {@link Channel}.
+ *
+ * All methods must be called by a transport implementation from an I/O thread, except the following ones:
+ *
+ * - {@link #size()} and {@link #isEmpty()}
+ * - {@link #isWritable()}
+ * - {@link #getUserDefinedWritability(int)} and {@link #setUserDefinedWritability(int, boolean)}
+ *
+ *
*/
public final class ChannelOutboundBuffer {
@@ -70,21 +76,21 @@ public final class ChannelOutboundBuffer {
private static final AtomicLongFieldUpdater TOTAL_PENDING_SIZE_UPDATER;
- @SuppressWarnings("unused")
+ @SuppressWarnings("UnusedDeclaration")
private volatile long totalPendingSize;
- private static final AtomicIntegerFieldUpdater WRITABLE_UPDATER;
+ private static final AtomicIntegerFieldUpdater UNWRITABLE_UPDATER;
- @SuppressWarnings("FieldMayBeFinal")
- private volatile int writable = 1;
+ @SuppressWarnings("UnusedDeclaration")
+ private volatile int unwritable;
static {
- AtomicIntegerFieldUpdater writableUpdater =
- PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "writable");
- if (writableUpdater == null) {
- writableUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "writable");
+ AtomicIntegerFieldUpdater unwritableUpdater =
+ PlatformDependent.newAtomicIntegerFieldUpdater(ChannelOutboundBuffer.class, "unwritable");
+ if (unwritableUpdater == null) {
+ unwritableUpdater = AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "unwritable");
}
- WRITABLE_UPDATER = writableUpdater;
+ UNWRITABLE_UPDATER = unwritableUpdater;
AtomicLongFieldUpdater pendingSizeUpdater =
PlatformDependent.newAtomicLongFieldUpdater(ChannelOutboundBuffer.class, "totalPendingSize");
@@ -161,10 +167,8 @@ public final class ChannelOutboundBuffer {
}
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size);
- if (newWriteBufferSize > channel.config().getWriteBufferHighWaterMark()) {
- if (WRITABLE_UPDATER.compareAndSet(this, 1, 0)) {
- channel.pipeline().fireChannelWritabilityChanged();
- }
+ if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) {
+ setUnwritable();
}
}
@@ -178,10 +182,8 @@ public final class ChannelOutboundBuffer {
}
long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size);
- if (newWriteBufferSize == 0 || newWriteBufferSize < channel.config().getWriteBufferLowWaterMark()) {
- if (WRITABLE_UPDATER.compareAndSet(this, 0, 1)) {
- channel.pipeline().fireChannelWritabilityChanged();
- }
+ if (newWriteBufferSize == 0 || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) {
+ setWritable();
}
}
@@ -438,8 +440,94 @@ public final class ChannelOutboundBuffer {
return nioBufferSize;
}
- boolean isWritable() {
- return writable != 0;
+ /**
+ * Returns {@code true} if and only if {@linkplain #totalPendingWriteBytes() the total number of pending bytes} did
+ * not exceed the write watermark of the {@link Channel} and
+ * no {@linkplain #setUserDefinedWritability(int, boolean) user-defined writability flag} has been set to
+ * {@code false}.
+ */
+ public boolean isWritable() {
+ return unwritable == 0;
+ }
+
+ /**
+ * Returns {@code true} if and only if the user-defined writability flag at the specified index is set to
+ * {@code true}.
+ */
+ public boolean getUserDefinedWritability(int index) {
+ return (unwritable & writabilityMask(index)) == 0;
+ }
+
+ /**
+ * Sets a user-defined writability flag at the specified index.
+ */
+ public void setUserDefinedWritability(int index, boolean writable) {
+ if (writable) {
+ setUserDefinedWritability(index);
+ } else {
+ clearUserDefinedWritability(index);
+ }
+ }
+
+ private void setUserDefinedWritability(int index) {
+ final int mask = ~writabilityMask(index);
+ for (;;) {
+ final int oldValue = unwritable;
+ final int newValue = oldValue & mask;
+ if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
+ if (oldValue != 0 && newValue == 0) {
+ channel.pipeline().fireChannelWritabilityChanged();
+ }
+ break;
+ }
+ }
+ }
+
+ private void clearUserDefinedWritability(int index) {
+ final int mask = writabilityMask(index);
+ for (;;) {
+ final int oldValue = unwritable;
+ final int newValue = oldValue | mask;
+ if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
+ if (oldValue == 0 && newValue != 0) {
+ channel.pipeline().fireChannelWritabilityChanged();
+ }
+ break;
+ }
+ }
+ }
+
+ private static int writabilityMask(int index) {
+ if (index < 1 || index > 31) {
+ throw new IllegalArgumentException("index: " + index + " (expected: 1~31)");
+ }
+ return 1 << index;
+ }
+
+ private void setWritable() {
+ for (;;) {
+ final int oldValue = unwritable;
+ final int newValue = oldValue & ~1;
+ if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
+ if (oldValue != 0 && newValue == 0) {
+ channel.pipeline().fireChannelWritabilityChanged();
+ }
+ break;
+ }
+ }
+ }
+
+ private void setUnwritable() {
+ for (;;) {
+ final int oldValue = unwritable;
+ final int newValue = oldValue | 1;
+ if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) {
+ if (oldValue == 0 && newValue != 0) {
+ channel.pipeline().fireChannelWritabilityChanged();
+ }
+ break;
+ }
+ }
}
/**
diff --git a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java
index 296a0e7d13..e283a109ed 100644
--- a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java
+++ b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java
@@ -17,6 +17,7 @@ package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
+import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import org.junit.Test;
@@ -24,6 +25,7 @@ import java.net.SocketAddress;
import java.nio.ByteBuffer;
import static io.netty.buffer.Unpooled.*;
+import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
public class ChannelOutboundBufferTest {
@@ -203,4 +205,152 @@ public class ChannelOutboundBufferTest {
}
}
}
+
+ @Test
+ public void testWritability() {
+ final StringBuilder buf = new StringBuilder();
+ EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
+ @Override
+ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
+ buf.append(ctx.channel().isWritable());
+ buf.append(' ');
+ }
+ });
+
+ ch.config().setWriteBufferLowWaterMark(128);
+ ch.config().setWriteBufferHighWaterMark(256);
+
+ // Ensure exceeding the low watermark does not make channel unwritable.
+ ch.write(buffer().writeZero(128));
+ assertThat(buf.toString(), is(""));
+
+ ch.unsafe().outboundBuffer().addFlush();
+
+ // Ensure exceeding the high watermark makes channel unwritable.
+ ch.write(buffer().writeZero(128));
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure going down to the low watermark makes channel writable again by flushing the first write.
+ assertThat(ch.unsafe().outboundBuffer().remove(), is(true));
+ assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), is(128L));
+ assertThat(buf.toString(), is("false true "));
+
+ safeClose(ch);
+ }
+
+ @Test
+ public void testUserDefinedWritability() {
+ final StringBuilder buf = new StringBuilder();
+ EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
+ @Override
+ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
+ buf.append(ctx.channel().isWritable());
+ buf.append(' ');
+ }
+ });
+
+ ch.config().setWriteBufferLowWaterMark(128);
+ ch.config().setWriteBufferHighWaterMark(256);
+
+ ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer();
+
+ // Ensure that the default value of a user-defined writability flag is true.
+ for (int i = 1; i <= 30; i ++) {
+ assertThat(cob.getUserDefinedWritability(i), is(true));
+ }
+
+ // Ensure that setting a user-defined writability flag to false affects channel.isWritable();
+ cob.setUserDefinedWritability(1, false);
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure that setting a user-defined writability flag to true affects channel.isWritable();
+ cob.setUserDefinedWritability(1, true);
+ assertThat(buf.toString(), is("false true "));
+
+ safeClose(ch);
+ }
+
+ @Test
+ public void testUserDefinedWritability2() {
+ final StringBuilder buf = new StringBuilder();
+ EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
+ @Override
+ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
+ buf.append(ctx.channel().isWritable());
+ buf.append(' ');
+ }
+ });
+
+ ch.config().setWriteBufferLowWaterMark(128);
+ ch.config().setWriteBufferHighWaterMark(256);
+
+ ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer();
+
+ // Ensure that setting a user-defined writability flag to false affects channel.isWritable()
+ cob.setUserDefinedWritability(1, false);
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure that setting another user-defined writability flag to false does not trigger
+ // channelWritabilityChanged.
+ cob.setUserDefinedWritability(2, false);
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure that setting only one user-defined writability flag to true does not affect channel.isWritable()
+ cob.setUserDefinedWritability(1, true);
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure that setting all user-defined writability flags to true affects channel.isWritable()
+ cob.setUserDefinedWritability(2, true);
+ assertThat(buf.toString(), is("false true "));
+
+ safeClose(ch);
+ }
+
+ @Test
+ public void testMixedWritability() {
+ final StringBuilder buf = new StringBuilder();
+ EmbeddedChannel ch = new EmbeddedChannel(new ChannelHandlerAdapter() {
+ @Override
+ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
+ buf.append(ctx.channel().isWritable());
+ buf.append(' ');
+ }
+ });
+
+ ch.config().setWriteBufferLowWaterMark(128);
+ ch.config().setWriteBufferHighWaterMark(256);
+
+ ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer();
+
+ // Trigger channelWritabilityChanged() by writing a lot.
+ ch.write(buffer().writeZero(256));
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged()
+ cob.setUserDefinedWritability(1, false);
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChannged()
+ // because of the user-defined writability flag.
+ ch.flush();
+ assertThat(cob.totalPendingWriteBytes(), is(0L));
+ assertThat(buf.toString(), is("false "));
+
+ // Ensure that setting the user-defined writability flag to true triggers channelWritabilityChanged()
+ cob.setUserDefinedWritability(1, true);
+ assertThat(buf.toString(), is("false true "));
+
+ safeClose(ch);
+ }
+
+ private static void safeClose(EmbeddedChannel ch) {
+ ch.finish();
+ for (;;) {
+ ByteBuf m = ch.readOutbound();
+ if (m == null) {
+ break;
+ }
+ m.release();
+ }
+ }
}