diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java new file mode 100644 index 0000000000..d84cfc3e74 --- /dev/null +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketGatheringWriteTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.MessageList; +import org.junit.Test; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +public class SocketGatheringWriteTest extends AbstractSocketTest { + + private static final Random random = new Random(); + static final byte[] data = new byte[1048576]; + + static { + random.nextBytes(data); + } + + @Test(timeout = 30000) + public void testGatheringWrite() throws Throwable { + run(); + } + + public void testGatheringWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final TestHandler sh = new TestHandler(); + final TestHandler ch = new TestHandler(); + + cb.handler(ch); + sb.childHandler(sh); + + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect().sync().channel(); + + MessageList messages = MessageList.newInstance(); + + for (int i = 0; i < data.length;) { + int length = Math.min(random.nextInt(1024 * 64), data.length - i); + ByteBuf buf = Unpooled.wrappedBuffer(data, i, length); + messages.add(buf); + i += length; + } + assertNotEquals(cc.voidPromise(), cc.write(messages).sync()); + + while (sh.counter < data.length) { + if (sh.exception.get() != null) { + break; + } + if (ch.exception.get() != null) { + break; + } + try { + Thread.sleep(50); + } catch (InterruptedException e) { + // Ignore. + } + } + assertEquals(Unpooled.wrappedBuffer(data), sh.received); + sh.channel.close().sync(); + ch.channel.close().sync(); + sc.close().sync(); + + if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { + throw sh.exception.get(); + } + if (sh.exception.get() != null) { + throw sh.exception.get(); + } + if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { + throw ch.exception.get(); + } + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + assertEquals(0, ch.counter); + } + + private static class TestHandler extends ChannelInboundHandlerAdapter { + volatile Channel channel; + final AtomicReference exception = new AtomicReference(); + volatile int counter; + final ByteBuf received = Unpooled.buffer(); + @Override + public void channelActive(ChannelHandlerContext ctx) + throws Exception { + channel = ctx.channel(); + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageList msgs) throws Exception { + for (int j = 0; j < msgs.size(); j ++) { + ByteBuf in = (ByteBuf) msgs.get(j); + counter += in.readableBytes(); + received.writeBytes(in); + } + msgs.releaseAllAndRecycle(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (exception.compareAndSet(null, cause)) { + ctx.close(); + } + } + } +} diff --git a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java index 5cadab9270..6deb088835 100755 --- a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java +++ b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java @@ -121,53 +121,64 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel { @Override protected int doWrite(MessageList msgs, int index) throws Exception { - Object msg = msgs.get(index); - - if (msg instanceof ByteBuf) { - ByteBuf buf = (ByteBuf) msg; - if (!buf.isReadable()) { - buf.release(); - return 1; + int size = msgs.size(); + int writeIndex = index; + for (;;) { + if (writeIndex >= size) { + break; } - boolean done = false; - for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { - int localFlushedAmount = doWriteBytes(buf, i == 0); - if (localFlushedAmount == 0) { - break; - } + Object msg = msgs.get(writeIndex); + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; if (!buf.isReadable()) { - done = true; - break; + buf.release(); + writeIndex++; + continue; } - } - // We may could optimize this to write multiple buffers at once (scattering) - if (done) { - buf.release(); - return 1; - } - } else if (msg instanceof FileRegion) { - FileRegion region = (FileRegion) msg; - boolean done = false; - for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { - long localFlushedAmount = doWriteFileRegion(region, i == 0); - if (localFlushedAmount == 0) { - break; - } - if (region.transfered() >= region.count()) { - done = true; - break; - } - } + boolean done = false; + for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { + int localFlushedAmount = doWriteBytes(buf, i == 0); + if (localFlushedAmount == 0) { + break; + } - if (done) { - region.release(); - return 1; + if (!buf.isReadable()) { + done = true; + break; + } + } + + if (done) { + buf.release(); + writeIndex++; + } else { + break; + } + } else if (msg instanceof FileRegion) { + FileRegion region = (FileRegion) msg; + boolean done = false; + for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { + long localFlushedAmount = doWriteFileRegion(region, i == 0); + if (localFlushedAmount == 0) { + break; + } + if (region.transfered() >= region.count()) { + done = true; + break; + } + } + + if (done) { + region.release(); + writeIndex++; + } else { + break; + } + } else { + throw new UnsupportedOperationException("unsupported message type: " + StringUtil.simpleClassName(msg)); } - } else { - throw new UnsupportedOperationException("unsupported message type: " + StringUtil.simpleClassName(msg)); } - - return 0; + return writeIndex - index; } /** @@ -194,6 +205,10 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel { */ protected abstract int doWriteBytes(ByteBuf buf, boolean lastSpin) throws Exception; + protected long doWriteBytes(MessageList bufs, int index, boolean lastSpin) throws Exception { + return doWriteBytes(bufs.get(index), lastSpin); + } + protected void updateOpWrite(long expectedWrittenBytes, long writtenBytes, boolean lastSpin) { if (writtenBytes >= expectedWrittenBytes) { final SelectionKey key = selectionKey(); diff --git a/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java b/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java index ef8bcb9534..70dbf61bd2 100755 --- a/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelPipeline; import io.netty.channel.FileRegion; import io.netty.channel.MessageList; import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.util.internal.StringUtil; import java.io.IOException; @@ -156,22 +157,30 @@ public abstract class AbstractOioByteChannel extends AbstractOioChannel { @Override protected int doWrite(MessageList msgs, int index) throws Exception { - Object msg = msgs.get(index); - if (msg instanceof ByteBuf) { - ByteBuf buf = (ByteBuf) msg; - while (buf.isReadable()) { - doWriteBytes(buf); + int size = msgs.size(); + int writeIndex = index; + for (;;) { + if (writeIndex >= size) { + break; + } + Object msg = msgs.get(writeIndex); + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + while (buf.isReadable()) { + doWriteBytes(buf); + } + buf.release(); + writeIndex++; + } else if (msg instanceof FileRegion) { + FileRegion region = (FileRegion) msg; + doWriteFileRegion(region); + region.release(); + writeIndex++; + } else { + throw new UnsupportedOperationException("unsupported message type: " + StringUtil.simpleClassName(msg)); } - buf.release(); - return 1; - } else if (msg instanceof FileRegion) { - FileRegion region = (FileRegion) msg; - doWriteFileRegion(region); - region.release(); - return 1; - } else { - throw new UnsupportedOperationException(); } + return writeIndex - index; } /**