Generate less garbage when performing gathering writes

This commit is contained in:
Trustin Lee 2013-06-13 10:27:10 +09:00
parent 78c6925921
commit 2088d1b491

View File

@ -37,9 +37,6 @@ import java.net.SocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/** /**
* {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation. * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
@ -48,6 +45,8 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
private static final ChannelMetadata METADATA = new ChannelMetadata(false); private static final ChannelMetadata METADATA = new ChannelMetadata(false);
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
// Buffers to use for Gathering writes // Buffers to use for Gathering writes
private static final ThreadLocal<ByteBuffer[]> BUFFERS = new ThreadLocal<ByteBuffer[]>() { private static final ThreadLocal<ByteBuffer[]> BUFFERS = new ThreadLocal<ByteBuffer[]>() {
@Override @Override
@ -56,9 +55,16 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
} }
}; };
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); private static ByteBuffer[] getNioBufferArray() {
return BUFFERS.get();
}
private final SocketChannelConfig config; private static ByteBuffer[] doubleNioBufferArray(ByteBuffer[] array, int size) {
ByteBuffer[] newArray = new ByteBuffer[array.length << 1];
System.arraycopy(array, 0, newArray, 0, size);
BUFFERS.set(newArray);
return newArray;
}
private static SocketChannel newSocket() { private static SocketChannel newSocket() {
try { try {
@ -68,6 +74,8 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
} }
} }
private final SocketChannelConfig config;
/** /**
* Create a new instance * Create a new instance
*/ */
@ -258,71 +266,85 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
@Override @Override
protected int doWrite(MessageList<Object> msgs, int index) throws Exception { protected int doWrite(MessageList<Object> msgs, int index) throws Exception {
int size = msgs.size(); int size = msgs.size();
// Check if this can be optimized via gathering writes
if (size > 1 && msgs.containsOnly(ByteBuf.class)) {
MessageList<ByteBuf> bufs = msgs.cast();
List<ByteBuffer> bufferList = new ArrayList<ByteBuffer>(size); // Do non-gathering write for a single buffer case.
long expectedWrittenBytes = 0; if (size <= 1 || !msgs.containsOnly(ByteBuf.class)) {
long writtenBytes = 0; return super.doWrite(msgs, index);
for (int i = index; i < size; i++) { }
ByteBuf buf = bufs.get(i);
int count = buf.nioBufferCount();
if (count == 1) {
bufferList.add(buf.nioBuffer());
} else {
ByteBuffer[] nioBufs = buf.nioBuffers();
// use Arrays.asList(..) as it may be more efficient then looping. The only downside
// is that it will create one more object to gc
bufferList.addAll(Arrays.asList(nioBufs));
}
expectedWrittenBytes += buf.readableBytes();
}
ByteBuffer[] bufArray = bufferList.toArray(BUFFERS.get()); MessageList<ByteBuf> bufs = msgs.cast();
boolean done = false;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
final long localWrittenBytes = javaChannel().write(bufArray, 0, bufferList.size());
updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0);
if (localWrittenBytes == 0) {
break;
}
expectedWrittenBytes -= localWrittenBytes;
writtenBytes += localWrittenBytes;
if (expectedWrittenBytes == 0) {
done = true;
break;
}
}
int writtenBufs = 0;
if (done) { ByteBuffer[] nioBuffers = getNioBufferArray();
// release buffers int nioBufferCnt = 0;
for (int i = index; i < size; i++) { long expectedWrittenBytes = 0;
ByteBuf buf = bufs.get(i); for (int i = index; i < size; i++) {
buf.release(); ByteBuf buf = bufs.get(i);
writtenBufs++; int count = buf.nioBufferCount();
if (count == 1) {
if (nioBufferCnt == nioBuffers.length) {
nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt);
} }
nioBuffers[nioBufferCnt ++] = buf.nioBuffer();
} else { } else {
// not complete written all buffers so release those which was written and update the readerIndex ByteBuffer[] nioBufs = buf.nioBuffers();
// of the partial written buffer if (nioBufferCnt + nioBufs.length == nioBuffers.length + 1) {
for (int i = index; i < size; i++) { nioBuffers = doubleNioBufferArray(nioBuffers, nioBufferCnt);
ByteBuf buf = bufs.get(i); }
int readable = buf.readableBytes(); for (ByteBuffer nioBuf: nioBufs) {
if (readable <= writtenBytes) { if (nioBuf == null) {
writtenBufs++;
buf.release();
writtenBytes -= readable;
} else {
// not completly written so adjust readerindex break the loop
buf.readerIndex(buf.readerIndex() + (int) writtenBytes);
break; break;
} }
nioBuffers[nioBufferCnt ++] = nioBuf;
}
}
expectedWrittenBytes += buf.readableBytes();
}
long writtenBytes = 0;
boolean done = false;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
final long localWrittenBytes = javaChannel().write(nioBuffers, 0, nioBufferCnt);
updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0);
if (localWrittenBytes == 0) {
break;
}
expectedWrittenBytes -= localWrittenBytes;
writtenBytes += localWrittenBytes;
if (expectedWrittenBytes == 0) {
done = true;
break;
}
}
if (done) {
// release buffers
for (int i = index; i < size; i++) {
ByteBuf buf = bufs.get(i);
buf.release();
}
return size - index;
} else {
// Did not write all buffers completely.
// Release the fully written buffers and update the indexes of the partially written buffer.
int writtenBufs = 0;
for (int i = index; i < size; i++) {
ByteBuf buf = bufs.get(i);
int readable = buf.readableBytes();
if (readable < writtenBytes) {
writtenBufs ++;
buf.release();
writtenBytes -= readable;
} else if (readable > writtenBytes) {
buf.readerIndex(buf.readerIndex() + (int) writtenBytes);
break;
} else { // readable == writtenBytes
writtenBufs ++;
buf.release();
break;
} }
} }
return writtenBufs; return writtenBufs;
} }
return super.doWrite(msgs, index);
} }
} }