From 46ea0d4e7b8380cfc8c3f5508ac2a3896262e08e Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Thu, 18 Jul 2013 23:14:39 +0900 Subject: [PATCH] Implement gathering writes in NioSocketChannel - Add some support methods in ChannelOutputBuffer --- .../netty/channel/ChannelOutboundBuffer.java | 92 +++++++++++++++++ .../channel/socket/nio/NioSocketChannel.java | 99 ++++--------------- 2 files changed, 110 insertions(+), 81 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index 5beb91b518..4c6004554c 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -19,10 +19,12 @@ */ package io.netty.channel; +import io.netty.buffer.ByteBuf; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; @@ -42,6 +44,10 @@ public final class ChannelOutboundBuffer { private int head; private int tail; + private ByteBuffer[] nioBuffers; + private int nioBufferCount; + private long nioBufferSize; + // Unflushed messages are stored in an array list. private Object[] unflushed; private ChannelPromise[] unflushedPromises; @@ -90,6 +96,8 @@ public final class ChannelOutboundBuffer { flushedProgresses = new long[initialCapacity]; flushedTotals = new long[initialCapacity]; + nioBuffers = new ByteBuffer[initialCapacity]; + unflushed = new Object[initialCapacity]; unflushedPromises = new ChannelPromise[initialCapacity]; unflushedTotals = new long[initialCapacity]; @@ -293,6 +301,90 @@ public final class ChannelOutboundBuffer { return true; } + public ByteBuffer[] nioBuffers() { + ByteBuffer[] nioBuffers = this.nioBuffers; + long nioBufferSize = 0; + int nioBufferCount = 0; + + final int mask = flushed.length - 1; + + Object m; + int i = head; + while ((m = flushed[i]) != null) { + if (!(m instanceof ByteBuf)) { + this.nioBufferCount = 0; + this.nioBufferSize = 0; + return null; + } + + ByteBuf buf = (ByteBuf) m; + + final int readerIndex = buf.readerIndex(); + final int readableBytes = buf.writerIndex() - readerIndex; + + if (readableBytes > 0) { + nioBufferSize += readableBytes; + + if (buf.isDirect()) { + int count = buf.nioBufferCount(); + if (count == 1) { + if (nioBufferCount == nioBuffers.length) { + this.nioBuffers = nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCount); + } + nioBuffers[nioBufferCount ++] = buf.internalNioBuffer(readerIndex, readableBytes); + } else { + ByteBuffer[] nioBufs = buf.nioBuffers(); + if (nioBufferCount + nioBufs.length == nioBuffers.length + 1) { + this.nioBuffers = nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCount); + } + for (ByteBuffer nioBuf: nioBufs) { + if (nioBuf == null) { + break; + } + nioBuffers[nioBufferCount ++] = nioBuf; + } + } + } else { + ByteBuf directBuf = channel.alloc().directBuffer(readableBytes); + directBuf.writeBytes(buf, readerIndex, readableBytes); + buf.release(); + flushed[i] = directBuf; + if (nioBufferCount == nioBuffers.length) { + nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCount); + } + nioBuffers[nioBufferCount ++] = directBuf.internalNioBuffer(0, readableBytes); + } + } + + i = i + 1 & mask; + } + + this.nioBufferCount = nioBufferCount; + this.nioBufferSize = nioBufferSize; + + return nioBuffers; + } + + private static ByteBuffer[] doubleNioBufferArray(ByteBuffer[] array, int size) { + int newCapacity = array.length << 1; + if (newCapacity < 0) { + throw new IllegalStateException(); + } + + ByteBuffer[] newArray = new ByteBuffer[newCapacity]; + System.arraycopy(array, 0, newArray, 0, size); + + return newArray; + } + + public int nioBufferCount() { + return nioBufferCount; + } + + public long nioBufferSize() { + return nioBufferSize; + } + boolean getWritable() { return WRITABLE_UPDATER.get(this) == 1; } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index 18df56a0d6..5157ca5d3f 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -47,25 +47,6 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); - // Buffers to use for Gathering writes - private static final ThreadLocal BUFFERS = new ThreadLocal() { - @Override - protected ByteBuffer[] initialValue() { - return new ByteBuffer[128]; - } - }; - - private static ByteBuffer[] getNioBufferArray() { - return BUFFERS.get(); - } - - private static ByteBuffer[] doubleNioBufferArray(ByteBuffer[] array, int size) { - ByteBuffer[] newArray = new ByteBuffer[array.length << 1]; - System.arraycopy(array, 0, newArray, 0, size); - BUFFERS.set(newArray); - return newArray; - } - private static SocketChannel newSocket() { try { return SocketChannel.open(); @@ -262,62 +243,23 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { - // FIXME: Re-enable gathering write. - super.doWrite(in); - - /* // Do non-gathering write for a single buffer case. - if (in.size() <= 1) { + final int msgCount = in.size(); + if (msgCount <= 1) { super.doWrite(in); return; } - ByteBuffer[] nioBuffers = getNioBufferArray(); - int nioBufferCnt = 0; - long expectedWrittenBytes = 0; - for (int i = startIndex; i < msgsLength; i++) { - Object m = msgs[i]; - if (!(m instanceof ByteBuf)) { - return super.doWrite(msgs, msgsLength, startIndex); - } - - ByteBuf buf = (ByteBuf) m; - - int readerIndex = buf.readerIndex(); - int readableBytes = buf.readableBytes(); - expectedWrittenBytes += readableBytes; - - if (buf.isDirect()) { - int count = buf.nioBufferCount(); - if (count == 1) { - if (nioBufferCnt == nioBuffers.length) { - nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt); - } - nioBuffers[nioBufferCnt ++] = buf.internalNioBuffer(readerIndex, readableBytes); - } else { - ByteBuffer[] nioBufs = buf.nioBuffers(); - if (nioBufferCnt + nioBufs.length == nioBuffers.length + 1) { - nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt); - } - for (ByteBuffer nioBuf: nioBufs) { - if (nioBuf == null) { - break; - } - nioBuffers[nioBufferCnt ++] = nioBuf; - } - } - } else { - ByteBuf directBuf = alloc().directBuffer(readableBytes); - directBuf.writeBytes(buf, readerIndex, readableBytes); - buf.release(); - msgs[i] = directBuf; - if (nioBufferCnt == nioBuffers.length) { - nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt); - } - nioBuffers[nioBufferCnt ++] = directBuf.internalNioBuffer(0, readableBytes); - } + // Ensure the pending writes are made of ByteBufs only. + ByteBuffer[] nioBuffers = in.nioBuffers(); + if (nioBuffers == null) { + super.doWrite(in); + return; } + int nioBufferCnt = in.nioBufferCount(); + long expectedWrittenBytes = in.nioBufferSize(); + final SocketChannel ch = javaChannel(); long writtenBytes = 0; boolean done = false; @@ -336,35 +278,30 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty } if (done) { - // release buffers - for (int i = startIndex; i < msgsLength; i++) { - ((ReferenceCounted) msgs[i]).release(); + // Release all buffers + for (int i = msgCount; i > 0; i --) { + in.remove(); } - return msgsLength - startIndex; } else { // Did not write all buffers completely. // Release the fully written buffers and update the indexes of the partially written buffer. - int writtenBufs = 0; - for (int i = startIndex; i < msgsLength; i++) { - final ByteBuf buf = (ByteBuf) msgs[i]; + + for (int i = msgCount; i > 0; i --) { + final ByteBuf buf = (ByteBuf) in.current(); final int readerIndex = buf.readerIndex(); final int readableBytes = buf.writerIndex() - readerIndex; if (readableBytes < writtenBytes) { - writtenBufs ++; - buf.release(); + in.remove(); writtenBytes -= readableBytes; } else if (readableBytes > writtenBytes) { buf.readerIndex(readerIndex + (int) writtenBytes); break; } else { // readable == writtenBytes - writtenBufs ++; - buf.release(); + in.remove(); break; } } - return writtenBufs; } - */ } }