From d1a3806ebdf8b1272b42fc1c684a5e290bee6a24 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 12 Jun 2013 09:45:33 +0200 Subject: [PATCH] Make use of gathering writes if a MessageList which only contains ByteBuf msgs is written to a NioSocketChannel --- .../socket/SocketGatheringWriteTest.java | 27 +++++- .../java/io/netty/channel/MessageList.java | 31 +++++++ .../channel/socket/nio/NioSocketChannel.java | 83 +++++++++++++++++++ 3 files changed, 140 insertions(+), 1 deletion(-) diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java index d84cfc3e74..e044d58150 100644 --- a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java @@ -18,6 +18,7 @@ package io.netty.testsuite.transport.socket; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; @@ -47,6 +48,19 @@ public class SocketGatheringWriteTest extends AbstractSocketTest { } public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable { + testGatheringWrite0(sb, cb, false); + } + + @Test(timeout = 30000) + public void testGatheringWriteWithComposite() throws Throwable { + run(); + } + + public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable { + testGatheringWrite0(sb, cb, true); + } + + private static void testGatheringWrite0(ServerBootstrap sb, Bootstrap cb, boolean composite) throws Throwable { final TestHandler sh = new TestHandler(); final TestHandler ch = new TestHandler(); @@ -61,7 +75,18 @@ public class SocketGatheringWriteTest extends AbstractSocketTest { for (int i = 0; i < data.length;) { int length = Math.min(random.nextInt(1024 * 64), data.length - i); ByteBuf buf = Unpooled.wrappedBuffer(data, i, length); - messages.add(buf); + if (composite && i % 2 == 0) { + int split = buf.readableBytes() / 2; + int size = buf.readableBytes() - split; + int oldIndex = buf.writerIndex(); + buf.writerIndex(split); + ByteBuf buf2 = Unpooled.buffer(size).writeBytes(buf, split, oldIndex - split); + CompositeByteBuf comp = Unpooled.compositeBuffer(); + comp.addComponent(buf).addComponent(buf2).writerIndex(length); + messages.add(comp); + } else { + messages.add(buf); + } i += length; } assertNotEquals(cc.voidPromise(), cc.write(messages).sync()); diff --git a/transport/src/main/java/io/netty/channel/MessageList.java b/transport/src/main/java/io/netty/channel/MessageList.java index 92e0c3e8eb..d6ba0b5287 100644 --- a/transport/src/main/java/io/netty/channel/MessageList.java +++ b/transport/src/main/java/io/netty/channel/MessageList.java @@ -16,6 +16,7 @@ package io.netty.channel; +import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; @@ -142,6 +143,7 @@ public final class MessageList implements Iterable { private T[] elements; private int size; private int modifications; + private boolean byteBufsOnly = true; MessageList(Handle handle) { this(handle, DEFAULT_INITIAL_CAPACITY); @@ -195,6 +197,9 @@ public final class MessageList implements Iterable { ensureCapacity(newSize); elements[oldSize] = value; size = newSize; + if (byteBufsOnly && !(value instanceof ByteBuf)) { + byteBufsOnly = false; + } return this; } @@ -221,6 +226,15 @@ public final class MessageList implements Iterable { ensureCapacity(newSize); System.arraycopy(src, srcIdx, elements, oldSize, srcLen); size = newSize; + if (byteBufsOnly) { + for (int i = srcIdx; i < srcIdx; i++) { + if (!(src[i] instanceof ByteBuf)) { + byteBufsOnly = false; + break; + } + } + } + return this; } @@ -245,6 +259,7 @@ public final class MessageList implements Iterable { public MessageList clear() { modifications++; Arrays.fill(elements, 0, size, null); + byteBufsOnly = true; size = 0; return this; } @@ -325,6 +340,22 @@ public final class MessageList implements Iterable { return new MessageListIterator(); } + /** + * Returns {@code true} if all messages contained in this {@link MessageList} are assignment-compatible with the + * object represented by this {@link Class}. + */ + public boolean containsOnly(Class clazz) { + if (clazz == ByteBuf.class) { + return byteBufsOnly; + } + for (int i = 0; i < size; i++) { + if (!clazz.isInstance(elements[i])) { + return false; + } + } + return true; + } + private void ensureCapacity(int capacity) { if (elements.length >= capacity) { return; diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index 0f2541c8e8..ff000f5dc6 100755 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelMetadata; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; import io.netty.channel.FileRegion; +import io.netty.channel.MessageList; import io.netty.channel.nio.AbstractNioByteChannel; import io.netty.channel.socket.DefaultSocketChannelConfig; import io.netty.channel.socket.ServerSocketChannel; @@ -33,8 +34,12 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; /** * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation. @@ -43,6 +48,14 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty private static final ChannelMetadata METADATA = new ChannelMetadata(false); + // Buffers to use for Gathering writes + private static final ThreadLocal BUFFERS = new ThreadLocal() { + @Override + protected ByteBuffer[] initialValue() { + return new ByteBuffer[128]; + } + }; + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); private final SocketChannelConfig config; @@ -242,4 +255,74 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty updateOpWrite(expectedWrittenBytes, writtenBytes, lastSpin); return writtenBytes; } + + @Override + protected int doWrite(MessageList msgs, int index) throws Exception { + int size = msgs.size(); + // Check if this can be optimized via gathering writes + if (size > 1 && msgs.containsOnly(ByteBuf.class)) { + MessageList bufs = msgs.cast(); + + List bufferList = new ArrayList(size); + long expectedWrittenBytes = 0; + long writtenBytes = 0; + for (int i = index; i < size; i++) { + ByteBuf buf = bufs.get(i); + int count = buf.nioBufferCount(); + if (count == 1) { + bufferList.add(buf.nioBuffer()); + } else { + ByteBuffer[] nioBufs = buf.nioBuffers(); + // use Arrays.asList(..) as it may be more efficient then looping. The only downside + // is that it will create one more object to gc + bufferList.addAll(Arrays.asList(nioBufs)); + } + expectedWrittenBytes += buf.readableBytes(); + } + + ByteBuffer[] bufArray = bufferList.toArray(BUFFERS.get()); + boolean done = false; + for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { + final long localWrittenBytes = javaChannel().write(bufArray, 0, bufferList.size()); + updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0); + if (localWrittenBytes == 0) { + break; + } + expectedWrittenBytes -= localWrittenBytes; + writtenBytes += localWrittenBytes; + if (expectedWrittenBytes == 0) { + done = true; + break; + } + } + int writtenBufs = 0; + + if (done) { + // release buffers + for (int i = index; i < size; i++) { + ByteBuf buf = bufs.get(i); + buf.release(); + writtenBufs++; + } + } else { + // not complete written all buffers so release those which was written and update the readerIndex + // of the partial written buffer + for (int i = index; i < size; i++) { + ByteBuf buf = bufs.get(i); + int readable = buf.readableBytes(); + if (readable <= writtenBytes) { + writtenBufs++; + buf.release(); + writtenBytes -= readable; + } else { + // not completly written so adjust readerindex break the loop + buf.readerIndex(buf.readerIndex() + (int) writtenBytes); + break; + } + } + } + return writtenBufs; + } + return super.doWrite(msgs, index); + } }