Implement gathering writes in NioSocketChannel

- Add some support methods in ChannelOutputBuffer
This commit is contained in:
Trustin Lee 2013-07-18 23:14:39 +09:00
parent 4f0a952241
commit 46ea0d4e7b
2 changed files with 110 additions and 81 deletions

View File

@ -19,10 +19,12 @@
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
@ -42,6 +44,10 @@ public final class ChannelOutboundBuffer {
private int head;
private int tail;
private ByteBuffer[] nioBuffers;
private int nioBufferCount;
private long nioBufferSize;
// Unflushed messages are stored in an array list.
private Object[] unflushed;
private ChannelPromise[] unflushedPromises;
@ -90,6 +96,8 @@ public final class ChannelOutboundBuffer {
flushedProgresses = new long[initialCapacity];
flushedTotals = new long[initialCapacity];
nioBuffers = new ByteBuffer[initialCapacity];
unflushed = new Object[initialCapacity];
unflushedPromises = new ChannelPromise[initialCapacity];
unflushedTotals = new long[initialCapacity];
@ -293,6 +301,90 @@ public final class ChannelOutboundBuffer {
return true;
}
public ByteBuffer[] nioBuffers() {
ByteBuffer[] nioBuffers = this.nioBuffers;
long nioBufferSize = 0;
int nioBufferCount = 0;
final int mask = flushed.length - 1;
Object m;
int i = head;
while ((m = flushed[i]) != null) {
if (!(m instanceof ByteBuf)) {
this.nioBufferCount = 0;
this.nioBufferSize = 0;
return null;
}
ByteBuf buf = (ByteBuf) m;
final int readerIndex = buf.readerIndex();
final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes > 0) {
nioBufferSize += readableBytes;
if (buf.isDirect()) {
int count = buf.nioBufferCount();
if (count == 1) {
if (nioBufferCount == nioBuffers.length) {
this.nioBuffers = nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCount);
}
nioBuffers[nioBufferCount ++] = buf.internalNioBuffer(readerIndex, readableBytes);
} else {
ByteBuffer[] nioBufs = buf.nioBuffers();
if (nioBufferCount + nioBufs.length == nioBuffers.length + 1) {
this.nioBuffers = nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCount);
}
for (ByteBuffer nioBuf: nioBufs) {
if (nioBuf == null) {
break;
}
nioBuffers[nioBufferCount ++] = nioBuf;
}
}
} else {
ByteBuf directBuf = channel.alloc().directBuffer(readableBytes);
directBuf.writeBytes(buf, readerIndex, readableBytes);
buf.release();
flushed[i] = directBuf;
if (nioBufferCount == nioBuffers.length) {
nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCount);
}
nioBuffers[nioBufferCount ++] = directBuf.internalNioBuffer(0, readableBytes);
}
}
i = i + 1 & mask;
}
this.nioBufferCount = nioBufferCount;
this.nioBufferSize = nioBufferSize;
return nioBuffers;
}
private static ByteBuffer[] doubleNioBufferArray(ByteBuffer[] array, int size) {
int newCapacity = array.length << 1;
if (newCapacity < 0) {
throw new IllegalStateException();
}
ByteBuffer[] newArray = new ByteBuffer[newCapacity];
System.arraycopy(array, 0, newArray, 0, size);
return newArray;
}
public int nioBufferCount() {
return nioBufferCount;
}
public long nioBufferSize() {
return nioBufferSize;
}
boolean getWritable() {
return WRITABLE_UPDATER.get(this) == 1;
}

View File

@ -47,25 +47,6 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
// Buffers to use for Gathering writes
private static final ThreadLocal<ByteBuffer[]> BUFFERS = new ThreadLocal<ByteBuffer[]>() {
@Override
protected ByteBuffer[] initialValue() {
return new ByteBuffer[128];
}
};
private static ByteBuffer[] getNioBufferArray() {
return BUFFERS.get();
}
private static ByteBuffer[] doubleNioBufferArray(ByteBuffer[] array, int size) {
ByteBuffer[] newArray = new ByteBuffer[array.length << 1];
System.arraycopy(array, 0, newArray, 0, size);
BUFFERS.set(newArray);
return newArray;
}
private static SocketChannel newSocket() {
try {
return SocketChannel.open();
@ -262,62 +243,23 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
@Override
protected void doWrite(ChannelOutboundBuffer in) throws Exception {
// FIXME: Re-enable gathering write.
super.doWrite(in);
/*
// Do non-gathering write for a single buffer case.
if (in.size() <= 1) {
final int msgCount = in.size();
if (msgCount <= 1) {
super.doWrite(in);
return;
}
ByteBuffer[] nioBuffers = getNioBufferArray();
int nioBufferCnt = 0;
long expectedWrittenBytes = 0;
for (int i = startIndex; i < msgsLength; i++) {
Object m = msgs[i];
if (!(m instanceof ByteBuf)) {
return super.doWrite(msgs, msgsLength, startIndex);
}
ByteBuf buf = (ByteBuf) m;
int readerIndex = buf.readerIndex();
int readableBytes = buf.readableBytes();
expectedWrittenBytes += readableBytes;
if (buf.isDirect()) {
int count = buf.nioBufferCount();
if (count == 1) {
if (nioBufferCnt == nioBuffers.length) {
nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt);
}
nioBuffers[nioBufferCnt ++] = buf.internalNioBuffer(readerIndex, readableBytes);
} else {
ByteBuffer[] nioBufs = buf.nioBuffers();
if (nioBufferCnt + nioBufs.length == nioBuffers.length + 1) {
nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt);
}
for (ByteBuffer nioBuf: nioBufs) {
if (nioBuf == null) {
break;
}
nioBuffers[nioBufferCnt ++] = nioBuf;
}
}
} else {
ByteBuf directBuf = alloc().directBuffer(readableBytes);
directBuf.writeBytes(buf, readerIndex, readableBytes);
buf.release();
msgs[i] = directBuf;
if (nioBufferCnt == nioBuffers.length) {
nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt);
}
nioBuffers[nioBufferCnt ++] = directBuf.internalNioBuffer(0, readableBytes);
}
// Ensure the pending writes are made of ByteBufs only.
ByteBuffer[] nioBuffers = in.nioBuffers();
if (nioBuffers == null) {
super.doWrite(in);
return;
}
int nioBufferCnt = in.nioBufferCount();
long expectedWrittenBytes = in.nioBufferSize();
final SocketChannel ch = javaChannel();
long writtenBytes = 0;
boolean done = false;
@ -336,35 +278,30 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
}
if (done) {
// release buffers
for (int i = startIndex; i < msgsLength; i++) {
((ReferenceCounted) msgs[i]).release();
// Release all buffers
for (int i = msgCount; i > 0; i --) {
in.remove();
}
return msgsLength - startIndex;
} else {
// Did not write all buffers completely.
// Release the fully written buffers and update the indexes of the partially written buffer.
int writtenBufs = 0;
for (int i = startIndex; i < msgsLength; i++) {
final ByteBuf buf = (ByteBuf) msgs[i];
for (int i = msgCount; i > 0; i --) {
final ByteBuf buf = (ByteBuf) in.current();
final int readerIndex = buf.readerIndex();
final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes < writtenBytes) {
writtenBufs ++;
buf.release();
in.remove();
writtenBytes -= readableBytes;
} else if (readableBytes > writtenBytes) {
buf.readerIndex(readerIndex + (int) writtenBytes);
break;
} else { // readable == writtenBytes
writtenBufs ++;
buf.release();
in.remove();
break;
}
}
return writtenBufs;
}
*/
}
}