Make use of gathering writes if a MessageList which only contains ByteBuf msgs is written to a NioSocketChannel
This commit is contained in:
parent
2320a13a4e
commit
d1a3806ebd
@ -18,6 +18,7 @@ package io.netty.testsuite.transport.socket;
|
|||||||
import io.netty.bootstrap.Bootstrap;
|
import io.netty.bootstrap.Bootstrap;
|
||||||
import io.netty.bootstrap.ServerBootstrap;
|
import io.netty.bootstrap.ServerBootstrap;
|
||||||
import io.netty.buffer.ByteBuf;
|
import io.netty.buffer.ByteBuf;
|
||||||
|
import io.netty.buffer.CompositeByteBuf;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
@ -47,6 +48,19 @@ public class SocketGatheringWriteTest extends AbstractSocketTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
|
public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
|
||||||
|
testGatheringWrite0(sb, cb, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout = 30000)
|
||||||
|
public void testGatheringWriteWithComposite() throws Throwable {
|
||||||
|
run();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGatheringWriteWithComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
|
||||||
|
testGatheringWrite0(sb, cb, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void testGatheringWrite0(ServerBootstrap sb, Bootstrap cb, boolean composite) throws Throwable {
|
||||||
final TestHandler sh = new TestHandler();
|
final TestHandler sh = new TestHandler();
|
||||||
final TestHandler ch = new TestHandler();
|
final TestHandler ch = new TestHandler();
|
||||||
|
|
||||||
@ -61,7 +75,18 @@ public class SocketGatheringWriteTest extends AbstractSocketTest {
|
|||||||
for (int i = 0; i < data.length;) {
|
for (int i = 0; i < data.length;) {
|
||||||
int length = Math.min(random.nextInt(1024 * 64), data.length - i);
|
int length = Math.min(random.nextInt(1024 * 64), data.length - i);
|
||||||
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
|
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
|
||||||
messages.add(buf);
|
if (composite && i % 2 == 0) {
|
||||||
|
int split = buf.readableBytes() / 2;
|
||||||
|
int size = buf.readableBytes() - split;
|
||||||
|
int oldIndex = buf.writerIndex();
|
||||||
|
buf.writerIndex(split);
|
||||||
|
ByteBuf buf2 = Unpooled.buffer(size).writeBytes(buf, split, oldIndex - split);
|
||||||
|
CompositeByteBuf comp = Unpooled.compositeBuffer();
|
||||||
|
comp.addComponent(buf).addComponent(buf2).writerIndex(length);
|
||||||
|
messages.add(comp);
|
||||||
|
} else {
|
||||||
|
messages.add(buf);
|
||||||
|
}
|
||||||
i += length;
|
i += length;
|
||||||
}
|
}
|
||||||
assertNotEquals(cc.voidPromise(), cc.write(messages).sync());
|
assertNotEquals(cc.voidPromise(), cc.write(messages).sync());
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
package io.netty.channel;
|
package io.netty.channel;
|
||||||
|
|
||||||
|
import io.netty.buffer.ByteBuf;
|
||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.util.Recycler;
|
import io.netty.util.Recycler;
|
||||||
import io.netty.util.Recycler.Handle;
|
import io.netty.util.Recycler.Handle;
|
||||||
@ -142,6 +143,7 @@ public final class MessageList<T> implements Iterable<T> {
|
|||||||
private T[] elements;
|
private T[] elements;
|
||||||
private int size;
|
private int size;
|
||||||
private int modifications;
|
private int modifications;
|
||||||
|
private boolean byteBufsOnly = true;
|
||||||
|
|
||||||
MessageList(Handle handle) {
|
MessageList(Handle handle) {
|
||||||
this(handle, DEFAULT_INITIAL_CAPACITY);
|
this(handle, DEFAULT_INITIAL_CAPACITY);
|
||||||
@ -195,6 +197,9 @@ public final class MessageList<T> implements Iterable<T> {
|
|||||||
ensureCapacity(newSize);
|
ensureCapacity(newSize);
|
||||||
elements[oldSize] = value;
|
elements[oldSize] = value;
|
||||||
size = newSize;
|
size = newSize;
|
||||||
|
if (byteBufsOnly && !(value instanceof ByteBuf)) {
|
||||||
|
byteBufsOnly = false;
|
||||||
|
}
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,6 +226,15 @@ public final class MessageList<T> implements Iterable<T> {
|
|||||||
ensureCapacity(newSize);
|
ensureCapacity(newSize);
|
||||||
System.arraycopy(src, srcIdx, elements, oldSize, srcLen);
|
System.arraycopy(src, srcIdx, elements, oldSize, srcLen);
|
||||||
size = newSize;
|
size = newSize;
|
||||||
|
if (byteBufsOnly) {
|
||||||
|
for (int i = srcIdx; i < srcIdx; i++) {
|
||||||
|
if (!(src[i] instanceof ByteBuf)) {
|
||||||
|
byteBufsOnly = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,6 +259,7 @@ public final class MessageList<T> implements Iterable<T> {
|
|||||||
public MessageList<T> clear() {
|
public MessageList<T> clear() {
|
||||||
modifications++;
|
modifications++;
|
||||||
Arrays.fill(elements, 0, size, null);
|
Arrays.fill(elements, 0, size, null);
|
||||||
|
byteBufsOnly = true;
|
||||||
size = 0;
|
size = 0;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
@ -325,6 +340,22 @@ public final class MessageList<T> implements Iterable<T> {
|
|||||||
return new MessageListIterator();
|
return new MessageListIterator();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns {@code true} if all messages contained in this {@link MessageList} are assignment-compatible with the
|
||||||
|
* object represented by this {@link Class}.
|
||||||
|
*/
|
||||||
|
public boolean containsOnly(Class<?> clazz) {
|
||||||
|
if (clazz == ByteBuf.class) {
|
||||||
|
return byteBufsOnly;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
if (!clazz.isInstance(elements[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private void ensureCapacity(int capacity) {
|
private void ensureCapacity(int capacity) {
|
||||||
if (elements.length >= capacity) {
|
if (elements.length >= capacity) {
|
||||||
return;
|
return;
|
||||||
|
@ -23,6 +23,7 @@ import io.netty.channel.ChannelMetadata;
|
|||||||
import io.netty.channel.ChannelPromise;
|
import io.netty.channel.ChannelPromise;
|
||||||
import io.netty.channel.EventLoop;
|
import io.netty.channel.EventLoop;
|
||||||
import io.netty.channel.FileRegion;
|
import io.netty.channel.FileRegion;
|
||||||
|
import io.netty.channel.MessageList;
|
||||||
import io.netty.channel.nio.AbstractNioByteChannel;
|
import io.netty.channel.nio.AbstractNioByteChannel;
|
||||||
import io.netty.channel.socket.DefaultSocketChannelConfig;
|
import io.netty.channel.socket.DefaultSocketChannelConfig;
|
||||||
import io.netty.channel.socket.ServerSocketChannel;
|
import io.netty.channel.socket.ServerSocketChannel;
|
||||||
@ -33,8 +34,12 @@ import io.netty.util.internal.logging.InternalLoggerFactory;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.InetSocketAddress;
|
import java.net.InetSocketAddress;
|
||||||
import java.net.SocketAddress;
|
import java.net.SocketAddress;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.channels.SelectionKey;
|
import java.nio.channels.SelectionKey;
|
||||||
import java.nio.channels.SocketChannel;
|
import java.nio.channels.SocketChannel;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
|
* {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
|
||||||
@ -43,6 +48,14 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
|
|||||||
|
|
||||||
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
|
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
|
||||||
|
|
||||||
|
// Buffers to use for Gathering writes
|
||||||
|
private static final ThreadLocal<ByteBuffer[]> BUFFERS = new ThreadLocal<ByteBuffer[]>() {
|
||||||
|
@Override
|
||||||
|
protected ByteBuffer[] initialValue() {
|
||||||
|
return new ByteBuffer[128];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
|
private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
|
||||||
|
|
||||||
private final SocketChannelConfig config;
|
private final SocketChannelConfig config;
|
||||||
@ -242,4 +255,74 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
|
|||||||
updateOpWrite(expectedWrittenBytes, writtenBytes, lastSpin);
|
updateOpWrite(expectedWrittenBytes, writtenBytes, lastSpin);
|
||||||
return writtenBytes;
|
return writtenBytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int doWrite(MessageList<Object> msgs, int index) throws Exception {
|
||||||
|
int size = msgs.size();
|
||||||
|
// Check if this can be optimized via gathering writes
|
||||||
|
if (size > 1 && msgs.containsOnly(ByteBuf.class)) {
|
||||||
|
MessageList<ByteBuf> bufs = msgs.cast();
|
||||||
|
|
||||||
|
List<ByteBuffer> bufferList = new ArrayList<ByteBuffer>(size);
|
||||||
|
long expectedWrittenBytes = 0;
|
||||||
|
long writtenBytes = 0;
|
||||||
|
for (int i = index; i < size; i++) {
|
||||||
|
ByteBuf buf = bufs.get(i);
|
||||||
|
int count = buf.nioBufferCount();
|
||||||
|
if (count == 1) {
|
||||||
|
bufferList.add(buf.nioBuffer());
|
||||||
|
} else {
|
||||||
|
ByteBuffer[] nioBufs = buf.nioBuffers();
|
||||||
|
// use Arrays.asList(..) as it may be more efficient then looping. The only downside
|
||||||
|
// is that it will create one more object to gc
|
||||||
|
bufferList.addAll(Arrays.asList(nioBufs));
|
||||||
|
}
|
||||||
|
expectedWrittenBytes += buf.readableBytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
ByteBuffer[] bufArray = bufferList.toArray(BUFFERS.get());
|
||||||
|
boolean done = false;
|
||||||
|
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
|
||||||
|
final long localWrittenBytes = javaChannel().write(bufArray, 0, bufferList.size());
|
||||||
|
updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0);
|
||||||
|
if (localWrittenBytes == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
expectedWrittenBytes -= localWrittenBytes;
|
||||||
|
writtenBytes += localWrittenBytes;
|
||||||
|
if (expectedWrittenBytes == 0) {
|
||||||
|
done = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int writtenBufs = 0;
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
// release buffers
|
||||||
|
for (int i = index; i < size; i++) {
|
||||||
|
ByteBuf buf = bufs.get(i);
|
||||||
|
buf.release();
|
||||||
|
writtenBufs++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// not complete written all buffers so release those which was written and update the readerIndex
|
||||||
|
// of the partial written buffer
|
||||||
|
for (int i = index; i < size; i++) {
|
||||||
|
ByteBuf buf = bufs.get(i);
|
||||||
|
int readable = buf.readableBytes();
|
||||||
|
if (readable <= writtenBytes) {
|
||||||
|
writtenBufs++;
|
||||||
|
buf.release();
|
||||||
|
writtenBytes -= readable;
|
||||||
|
} else {
|
||||||
|
// not completly written so adjust readerindex break the loop
|
||||||
|
buf.readerIndex(buf.readerIndex() + (int) writtenBytes);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return writtenBufs;
|
||||||
|
}
|
||||||
|
return super.doWrite(msgs, index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user