diff --git a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java index 0017cd5f87..37d37ca435 100644 --- a/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java @@ -721,11 +721,16 @@ public class CompositeByteBuf extends AbstractReferenceCountedByteBuf { @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { - long writtenBytes = out.write(nioBuffers(index, length)); - if (writtenBytes > Integer.MAX_VALUE) { - return Integer.MAX_VALUE; + int count = nioBufferCount(); + if (count == 1) { + return out.write(internalNioBuffer(index, length)); } else { - return (int) writtenBytes; + long writtenBytes = out.write(nioBuffers(index, length)); + if (writtenBytes > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) writtenBytes; + } } } diff --git a/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java index 8b4ac0c952..f5cb599ec5 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java @@ -17,8 +17,13 @@ package io.netty.buffer; import org.junit.Test; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.WritableByteChannel; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -598,9 +603,237 @@ public abstract class AbstractCompositeByteBufTest extends AbstractByteBufTest { assertEquals(1, buf.numComponents()); } + @Test + public void testGatheringWritesHeap() throws Exception { + testGatheringWrites(buffer().order(order), buffer().order(order)); + } + + @Test + public void testGatheringWritesDirect() throws Exception { + testGatheringWrites(directBuffer().order(order), directBuffer().order(order)); + } + + @Test + public void testGatheringWritesMixes() throws Exception { + testGatheringWrites(buffer().order(order), directBuffer().order(order)); + } + + @Test + public void testGatheringWritesHeapPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.heapBuffer().order(order)); + } + + @Test + public void testGatheringWritesDirectPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.directBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order)); + } + + @Test + public void testGatheringWritesMixesPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order)); + } + + private static void testGatheringWrites(ByteBuf buf1, ByteBuf buf2) throws Exception { + CompositeByteBuf buf = freeLater(compositeBuffer()); + buf.addComponent(buf1.writeBytes(new byte[]{1, 2})); + buf.addComponent(buf2.writeBytes(new byte[]{1, 2})); + buf.writerIndex(3); + buf.readerIndex(1); + + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + + buf.readBytes(channel, 2); + + byte[] data = new byte[2]; + buf.getBytes(1, data); + assertArrayEquals(data, channel.writtenBytes()); + } + + @Test + public void testGatheringWritesPartialHeap() throws Exception { + testGatheringWritesPartial(buffer().order(order), buffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialDirect() throws Exception { + testGatheringWritesPartial(directBuffer().order(order), directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialMixes() throws Exception { + testGatheringWritesPartial(buffer().order(order), directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialHeapSlice() throws Exception { + testGatheringWritesPartial(buffer().order(order), buffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialDirectSlice() throws Exception { + testGatheringWritesPartial(directBuffer().order(order), directBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialMixesSlice() throws Exception { + testGatheringWritesPartial(buffer().order(order), directBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialHeapPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialDirectPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.directBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialMixesPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialHeapPooledSliced() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialDirectPooledSliced() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.directBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialMixesPooledSliced() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), true); + } + + private static void testGatheringWritesPartial(ByteBuf buf1, ByteBuf buf2, boolean slice) throws Exception { + CompositeByteBuf buf = freeLater(compositeBuffer()); + buf1.writeBytes(new byte[]{1, 2, 3, 4}); + buf2.writeBytes(new byte[]{1, 2, 3, 4}); + if (slice) { + buf1 = buf1.readerIndex(1).slice(); + buf2 = buf2.writerIndex(3).slice(); + buf.addComponent(buf1); + buf.addComponent(buf2); + buf.writerIndex(6); + } else { + buf.addComponent(buf1); + buf.addComponent(buf2); + buf.writerIndex(7); + buf.readerIndex(1); + } + + TestGatheringByteChannel channel = new TestGatheringByteChannel(1); + + while (buf.isReadable()) { + System.out.println(buf.nioBuffers().length); + buf.readBytes(channel, buf.readableBytes()); + } + + byte[] data = new byte[6]; + + if (slice) { + buf.getBytes(0, data); + } else { + buf.getBytes(1, data); + } + assertArrayEquals(data, channel.writtenBytes()); + } + + @Test + public void testGatheringWritesSingleHeap() throws Exception { + testGatheringWritesSingleBuf(buffer().order(order)); + } + + @Test + public void testGatheringWritesSingleDirect() throws Exception { + testGatheringWritesSingleBuf(directBuffer().order(order)); + } + + private static void testGatheringWritesSingleBuf(ByteBuf buf1) throws Exception { + CompositeByteBuf buf = freeLater(compositeBuffer()); + buf.addComponent(buf1.writeBytes(new byte[]{1, 2, 3, 4})); + buf.writerIndex(3); + buf.readerIndex(1); + + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + buf.readBytes(channel, 2); + + byte[] data = new byte[2]; + buf.getBytes(1, data); + assertArrayEquals(data, channel.writtenBytes()); + } + @Override @Test public void testInternalNioBuffer() { // ignore } + + private static final class TestGatheringByteChannel implements GatheringByteChannel { + private final ByteArrayOutputStream out = new ByteArrayOutputStream(); + private final WritableByteChannel channel = Channels.newChannel(out); + private final int limit; + TestGatheringByteChannel(int limit) { + this.limit = limit; + } + + TestGatheringByteChannel() { + this(Integer.MAX_VALUE); + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + long written = 0; + for (; offset < length; offset++) { + written += write(srcs[offset]); + if (written >= limit) { + break; + } + } + return written; + } + + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return write(srcs, 0, srcs.length); + } + + @Override + public int write(ByteBuffer src) throws IOException { + int oldLimit = src.limit(); + if (limit < src.remaining()) { + src.limit(src.position() + limit); + } + int w = channel.write(src); + src.limit(oldLimit); + return w; + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } + + @Override + public void close() throws IOException { + channel.close(); + } + + public byte[] writtenBytes() { + return out.toByteArray(); + } + } }