diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index ff000f5dc6..251c5aa319 100755 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -37,9 +37,6 @@ import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; /** * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation. @@ -48,6 +45,8 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); + // Buffers to use for Gathering writes private static final ThreadLocal BUFFERS = new ThreadLocal() { @Override @@ -56,9 +55,16 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty } }; - private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); + private static ByteBuffer[] getNioBufferArray() { + return BUFFERS.get(); + } - private final SocketChannelConfig config; + private static ByteBuffer[] doubleNioBufferArray(ByteBuffer[] array, int size) { + ByteBuffer[] newArray = new ByteBuffer[array.length << 1]; + System.arraycopy(array, 0, newArray, 0, size); + BUFFERS.set(newArray); + return newArray; + } private static SocketChannel newSocket() { try { @@ -68,6 +74,8 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty } } + private final SocketChannelConfig config; + /** * Create a new instance */ @@ -258,71 +266,85 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty @Override protected int doWrite(MessageList msgs, int index) throws Exception { - int size = msgs.size(); - // Check if this can be optimized via gathering writes - if (size > 1 && msgs.containsOnly(ByteBuf.class)) { - MessageList bufs = msgs.cast(); + int size = msgs.size(); - List bufferList = new ArrayList(size); - long expectedWrittenBytes = 0; - long writtenBytes = 0; - for (int i = index; i < size; i++) { - ByteBuf buf = bufs.get(i); - int count = buf.nioBufferCount(); - if (count == 1) { - bufferList.add(buf.nioBuffer()); - } else { - ByteBuffer[] nioBufs = buf.nioBuffers(); - // use Arrays.asList(..) as it may be more efficient then looping. The only downside - // is that it will create one more object to gc - bufferList.addAll(Arrays.asList(nioBufs)); - } - expectedWrittenBytes += buf.readableBytes(); - } + // Do non-gathering write for a single buffer case. + if (size <= 1 || !msgs.containsOnly(ByteBuf.class)) { + return super.doWrite(msgs, index); + } - ByteBuffer[] bufArray = bufferList.toArray(BUFFERS.get()); - boolean done = false; - for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { - final long localWrittenBytes = javaChannel().write(bufArray, 0, bufferList.size()); - updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0); - if (localWrittenBytes == 0) { - break; - } - expectedWrittenBytes -= localWrittenBytes; - writtenBytes += localWrittenBytes; - if (expectedWrittenBytes == 0) { - done = true; - break; - } - } - int writtenBufs = 0; + MessageList bufs = msgs.cast(); - if (done) { - // release buffers - for (int i = index; i < size; i++) { - ByteBuf buf = bufs.get(i); - buf.release(); - writtenBufs++; + ByteBuffer[] nioBuffers = getNioBufferArray(); + int nioBufferCnt = 0; + long expectedWrittenBytes = 0; + for (int i = index; i < size; i++) { + ByteBuf buf = bufs.get(i); + int count = buf.nioBufferCount(); + if (count == 1) { + if (nioBufferCnt == nioBuffers.length) { + nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt); } + nioBuffers[nioBufferCnt ++] = buf.nioBuffer(); } else { - // not complete written all buffers so release those which was written and update the readerIndex - // of the partial written buffer - for (int i = index; i < size; i++) { - ByteBuf buf = bufs.get(i); - int readable = buf.readableBytes(); - if (readable <= writtenBytes) { - writtenBufs++; - buf.release(); - writtenBytes -= readable; - } else { - // not completly written so adjust readerindex break the loop - buf.readerIndex(buf.readerIndex() + (int) writtenBytes); + ByteBuffer[] nioBufs = buf.nioBuffers(); + if (nioBufferCnt + nioBufs.length == nioBuffers.length + 1) { + nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt); + } + for (ByteBuffer nioBuf: nioBufs) { + if (nioBuf == null) { break; } + nioBuffers[nioBufferCnt ++] = nioBuf; + } + } + expectedWrittenBytes += buf.readableBytes(); + } + + long writtenBytes = 0; + boolean done = false; + for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { + final long localWrittenBytes = javaChannel().write(nioBuffers, 0, nioBufferCnt); + updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0); + if (localWrittenBytes == 0) { + break; + } + expectedWrittenBytes -= localWrittenBytes; + writtenBytes += localWrittenBytes; + if (expectedWrittenBytes == 0) { + done = true; + break; + } + } + + if (done) { + // release buffers + for (int i = index; i < size; i++) { + ByteBuf buf = bufs.get(i); + buf.release(); + } + return size - index; + } else { + // Did not write all buffers completely. + // Release the fully written buffers and update the indexes of the partially written buffer. + int writtenBufs = 0; + for (int i = index; i < size; i++) { + ByteBuf buf = bufs.get(i); + int readable = buf.readableBytes(); + if (readable < writtenBytes) { + writtenBufs ++; + buf.release(); + writtenBytes -= readable; + } else if (readable > writtenBytes) { + buf.readerIndex(buf.readerIndex() + (int) writtenBytes); + break; + } else { // readable == writtenBytes + writtenBufs ++; + buf.release(); + break; } } return writtenBufs; } - return super.doWrite(msgs, index); } }