diff --git a/src/main/java/org/jboss/netty/channel/socket/nio/NioSocketChannel.java b/src/main/java/org/jboss/netty/channel/socket/nio/NioSocketChannel.java index e638931e23..0e692285f9 100644 --- a/src/main/java/org/jboss/netty/channel/socket/nio/NioSocketChannel.java +++ b/src/main/java/org/jboss/netty/channel/socket/nio/NioSocketChannel.java @@ -22,6 +22,8 @@ */ package org.jboss.netty.channel.socket.nio; +import static org.jboss.netty.channel.Channels.*; + import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.channels.SocketChannel; @@ -55,23 +57,12 @@ abstract class NioSocketChannel extends AbstractChannel final Object interestOpsLock = new Object(); final Object writeLock = new Object(); - final AtomicBoolean writeTaskInTaskQueue = new AtomicBoolean(); final Runnable writeTask = new WriteTask(); - final AtomicInteger writeBufferSize = new AtomicInteger(); - final Queue writeBuffer = new WriteBuffer(); + final AtomicBoolean writeTaskInTaskQueue = new AtomicBoolean(); - /** Previous return value of isWritable() */ - boolean oldWritable; - /** - * Set to true if the amount of data in the writeBuffer exceeds - * the high water mark, as specified in NioSocketChannelConfig. - */ - boolean exceededHighWaterMark; - /** - * true if and only if NioWorker is firing an event which might cause - * infinite recursion. - */ - boolean firingEvent; + final Queue writeBuffer = new WriteBuffer(); + final AtomicInteger writeBufferSize = new AtomicInteger(); + final AtomicInteger highWaterMarkCounter = new AtomicInteger(); MessageEvent currentWriteEvent; int currentWriteIndex; @@ -116,11 +107,27 @@ abstract class NioSocketChannel extends AbstractChannel } int interestOps = getRawInterestOps(); - if (writeBufferSize.get() >= getConfig().getWriteBufferHighWaterMark()) { - interestOps |= Channel.OP_WRITE; + int writeBufferSize = this.writeBufferSize.get(); + if (writeBufferSize != 0) { + if (highWaterMarkCounter.get() > 0) { + int lowWaterMark = getConfig().getWriteBufferLowWaterMark(); + if (writeBufferSize >= lowWaterMark) { + interestOps |= Channel.OP_WRITE; + } else { + interestOps &= ~Channel.OP_WRITE; + } + } else { + int highWaterMark = getConfig().getWriteBufferHighWaterMark(); + if (writeBufferSize >= highWaterMark) { + interestOps |= Channel.OP_WRITE; + } else { + interestOps &= ~Channel.OP_WRITE; + } + } } else { interestOps &= ~Channel.OP_WRITE; } + return interestOps; } @@ -152,6 +159,14 @@ abstract class NioSocketChannel extends AbstractChannel } private final class WriteBuffer extends LinkedTransferQueue { + + private final ThreadLocal notifying = new ThreadLocal() { + @Override + protected Boolean initialValue() { + return Boolean.FALSE; + } + }; + WriteBuffer() { super(); } @@ -160,8 +175,22 @@ abstract class NioSocketChannel extends AbstractChannel public boolean offer(MessageEvent e) { boolean success = super.offer(e); assert success; - writeBufferSize.addAndGet( - ((ChannelBuffer) e.getMessage()).readableBytes()); + + int messageSize = ((ChannelBuffer) e.getMessage()).readableBytes(); + int newWriteBufferSize = writeBufferSize.addAndGet(messageSize); + int highWaterMark = getConfig().getWriteBufferHighWaterMark(); + + if (newWriteBufferSize >= highWaterMark) { + if (newWriteBufferSize - messageSize < highWaterMark) { + highWaterMarkCounter.incrementAndGet(); + if (!notifying.get()) { + notifying.set(Boolean.TRUE); + fireChannelInterestChanged( + NioSocketChannel.this, getRawInterestOps() | OP_WRITE); + notifying.set(Boolean.FALSE); + } + } + } return true; } @@ -169,11 +198,20 @@ abstract class NioSocketChannel extends AbstractChannel public MessageEvent poll() { MessageEvent e = super.poll(); if (e != null) { - int newWriteBufferSize = writeBufferSize.addAndGet( - -((ChannelBuffer) e.getMessage()).readableBytes()); - if (newWriteBufferSize == 0 || - newWriteBufferSize < getConfig().getWriteBufferLowWaterMark()) { - exceededHighWaterMark = true; + int messageSize = ((ChannelBuffer) e.getMessage()).readableBytes(); + int newWriteBufferSize = writeBufferSize.addAndGet(-messageSize); + int lowWaterMark = getConfig().getWriteBufferLowWaterMark(); + + if (newWriteBufferSize == 0 || newWriteBufferSize < lowWaterMark) { + if (newWriteBufferSize + messageSize >= lowWaterMark) { + highWaterMarkCounter.decrementAndGet(); + if (!notifying.get()) { + notifying.set(Boolean.TRUE); + fireChannelInterestChanged( + NioSocketChannel.this, getRawInterestOps() & ~OP_WRITE); + notifying.set(Boolean.FALSE); + } + } } } return e; diff --git a/src/main/java/org/jboss/netty/channel/socket/nio/NioWorker.java b/src/main/java/org/jboss/netty/channel/socket/nio/NioWorker.java index c530acb562..2130e844bb 100644 --- a/src/main/java/org/jboss/netty/channel/socket/nio/NioWorker.java +++ b/src/main/java/org/jboss/netty/channel/socket/nio/NioWorker.java @@ -366,6 +366,7 @@ class NioWorker implements Runnable { removeOpWrite = true; break; } + buf = (ChannelBuffer) evt.getMessage(); bufIdx = buf.readerIndex(); } else { @@ -417,42 +418,6 @@ class NioWorker implements Runnable { } else if (removeOpWrite) { setOpWrite(channel, false, mightNeedWakeup); } - - fireChannelInterestChangedIfNecessary(channel, open); - } - } - - private static void fireChannelInterestChangedIfNecessary( - NioSocketChannel channel, boolean open) { - if (channel.firingEvent) { - // Prevent StackOverflowError. - return; - } - - channel.firingEvent = true; - try { - int interestOps = channel.getRawInterestOps(); - boolean wasWritable = channel.oldWritable; - boolean writable = channel.oldWritable = open? channel.isWritable() : false; - if (wasWritable) { - if (writable) { - if (channel.exceededHighWaterMark) { - channel.exceededHighWaterMark = false; - fireChannelInterestChanged(channel, interestOps | Channel.OP_WRITE); - fireChannelInterestChanged(channel, interestOps & ~Channel.OP_WRITE); - } - } else { - fireChannelInterestChanged(channel, interestOps | Channel.OP_WRITE); - } - } else { - if (writable) { - fireChannelInterestChanged(channel, interestOps & ~Channel.OP_WRITE); - } else { - fireChannelInterestChanged(channel, interestOps | Channel.OP_WRITE); - } - } - } finally { - channel.firingEvent = false; } } @@ -584,7 +549,6 @@ class NioWorker implements Runnable { if (changed) { channel.setRawInterestOpsNow(interestOps); - //fireChannelInterestChanged(channel, interestOps); } }