Make use of gathering writes if a MessageList which only contains ByteBuf msgs is written to a NioSocketChannel

This commit is contained in:
Norman Maurer 2013-06-12 09:45:33 +02:00
parent 2320a13a4e
commit d1a3806ebd
3 changed files with 140 additions and 1 deletions

View File

@ -18,6 +18,7 @@ package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
@ -47,6 +48,19 @@ public class SocketGatheringWriteTest extends AbstractSocketTest {
}
public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testGatheringWrite0(sb, cb, false);
}
@Test(timeout = 30000)
public void testGatheringWriteWithComposite() throws Throwable {
run();
}
public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testGatheringWrite0(sb, cb, true);
}
private static void testGatheringWrite0(ServerBootstrap sb, Bootstrap cb, boolean composite) throws Throwable {
final TestHandler sh = new TestHandler();
final TestHandler ch = new TestHandler();
@ -61,7 +75,18 @@ public class SocketGatheringWriteTest extends AbstractSocketTest {
for (int i = 0; i < data.length;) {
int length = Math.min(random.nextInt(1024 * 64), data.length - i);
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
if (composite && i % 2 == 0) {
int split = buf.readableBytes() / 2;
int size = buf.readableBytes() - split;
int oldIndex = buf.writerIndex();
buf.writerIndex(split);
ByteBuf buf2 = Unpooled.buffer(size).writeBytes(buf, split, oldIndex - split);
CompositeByteBuf comp = Unpooled.compositeBuffer();
comp.addComponent(buf).addComponent(buf2).writerIndex(length);
messages.add(comp);
} else {
messages.add(buf);
}
i += length;
}
assertNotEquals(cc.voidPromise(), cc.write(messages).sync());

View File

@ -16,6 +16,7 @@
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.util.Recycler;
import io.netty.util.Recycler.Handle;
@ -142,6 +143,7 @@ public final class MessageList<T> implements Iterable<T> {
private T[] elements;
private int size;
private int modifications;
private boolean byteBufsOnly = true;
MessageList(Handle handle) {
this(handle, DEFAULT_INITIAL_CAPACITY);
@ -195,6 +197,9 @@ public final class MessageList<T> implements Iterable<T> {
ensureCapacity(newSize);
elements[oldSize] = value;
size = newSize;
if (byteBufsOnly && !(value instanceof ByteBuf)) {
byteBufsOnly = false;
}
return this;
}
@ -221,6 +226,15 @@ public final class MessageList<T> implements Iterable<T> {
ensureCapacity(newSize);
System.arraycopy(src, srcIdx, elements, oldSize, srcLen);
size = newSize;
if (byteBufsOnly) {
for (int i = srcIdx; i < srcIdx; i++) {
if (!(src[i] instanceof ByteBuf)) {
byteBufsOnly = false;
break;
}
}
}
return this;
}
@ -245,6 +259,7 @@ public final class MessageList<T> implements Iterable<T> {
public MessageList<T> clear() {
modifications++;
Arrays.fill(elements, 0, size, null);
byteBufsOnly = true;
size = 0;
return this;
}
@ -325,6 +340,22 @@ public final class MessageList<T> implements Iterable<T> {
return new MessageListIterator();
}
/**
* Returns {@code true} if all messages contained in this {@link MessageList} are assignment-compatible with the
* object represented by this {@link Class}.
*/
public boolean containsOnly(Class<?> clazz) {
if (clazz == ByteBuf.class) {
return byteBufsOnly;
}
for (int i = 0; i < size; i++) {
if (!clazz.isInstance(elements[i])) {
return false;
}
}
return true;
}
private void ensureCapacity(int capacity) {
if (elements.length >= capacity) {
return;

View File

@ -23,6 +23,7 @@ import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.FileRegion;
import io.netty.channel.MessageList;
import io.netty.channel.nio.AbstractNioByteChannel;
import io.netty.channel.socket.DefaultSocketChannelConfig;
import io.netty.channel.socket.ServerSocketChannel;
@ -33,8 +34,12 @@ import io.netty.util.internal.logging.InternalLoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
@ -43,6 +48,14 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
// Buffers to use for Gathering writes
private static final ThreadLocal<ByteBuffer[]> BUFFERS = new ThreadLocal<ByteBuffer[]>() {
@Override
protected ByteBuffer[] initialValue() {
return new ByteBuffer[128];
}
};
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
private final SocketChannelConfig config;
@ -242,4 +255,74 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
updateOpWrite(expectedWrittenBytes, writtenBytes, lastSpin);
return writtenBytes;
}
@Override
protected int doWrite(MessageList<Object> msgs, int index) throws Exception {
int size = msgs.size();
// Check if this can be optimized via gathering writes
if (size > 1 && msgs.containsOnly(ByteBuf.class)) {
MessageList<ByteBuf> bufs = msgs.cast();
List<ByteBuffer> bufferList = new ArrayList<ByteBuffer>(size);
long expectedWrittenBytes = 0;
long writtenBytes = 0;
for (int i = index; i < size; i++) {
ByteBuf buf = bufs.get(i);
int count = buf.nioBufferCount();
if (count == 1) {
bufferList.add(buf.nioBuffer());
} else {
ByteBuffer[] nioBufs = buf.nioBuffers();
// use Arrays.asList(..) as it may be more efficient then looping. The only downside
// is that it will create one more object to gc
bufferList.addAll(Arrays.asList(nioBufs));
}
expectedWrittenBytes += buf.readableBytes();
}
ByteBuffer[] bufArray = bufferList.toArray(BUFFERS.get());
boolean done = false;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
final long localWrittenBytes = javaChannel().write(bufArray, 0, bufferList.size());
updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0);
if (localWrittenBytes == 0) {
break;
}
expectedWrittenBytes -= localWrittenBytes;
writtenBytes += localWrittenBytes;
if (expectedWrittenBytes == 0) {
done = true;
break;
}
}
int writtenBufs = 0;
if (done) {
// release buffers
for (int i = index; i < size; i++) {
ByteBuf buf = bufs.get(i);
buf.release();
writtenBufs++;
}
} else {
// not complete written all buffers so release those which was written and update the readerIndex
// of the partial written buffer
for (int i = index; i < size; i++) {
ByteBuf buf = bufs.get(i);
int readable = buf.readableBytes();
if (readable <= writtenBytes) {
writtenBufs++;
buf.release();
writtenBytes -= readable;
} else {
// not completly written so adjust readerindex break the loop
buf.readerIndex(buf.readerIndex() + (int) writtenBytes);
break;
}
}
}
return writtenBufs;
}
return super.doWrite(msgs, index);
}
}