diff --git a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java index 6444d30fbc..7dbebc0bf3 100644 --- a/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java +++ b/transport/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java @@ -142,17 +142,13 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel { @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { - final SelectionKey key = selectionKey(); - final int interestOps = key.interestOps(); int writeSpinCount = -1; for (;;) { Object msg = in.current(); if (msg == null) { // Wrote all messages. - if ((interestOps & SelectionKey.OP_WRITE) != 0) { - key.interestOps(interestOps & ~SelectionKey.OP_WRITE); - } + clearOpWrite(); break; } @@ -186,9 +182,7 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel { } else { // Did not write completely. in.progress(flushedAmount); - if ((interestOps & SelectionKey.OP_WRITE) == 0) { - key.interestOps(interestOps | SelectionKey.OP_WRITE); - } + setOpWrite(); break; } } else if (msg instanceof FileRegion) { @@ -216,9 +210,7 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel { } else { // Did not write completely. in.progress(flushedAmount); - if ((interestOps & SelectionKey.OP_WRITE) == 0) { - key.interestOps(interestOps | SelectionKey.OP_WRITE); - } + setOpWrite(); break; } } else { @@ -247,26 +239,20 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel { */ protected abstract int doWriteBytes(ByteBuf buf) throws Exception; - protected final void updateOpWrite(long expectedWrittenBytes, long writtenBytes, boolean lastSpin) { - if (writtenBytes >= expectedWrittenBytes) { - final SelectionKey key = selectionKey(); - final int interestOps = key.interestOps(); - // Wrote the outbound buffer completely - clear OP_WRITE. - if ((interestOps & SelectionKey.OP_WRITE) != 0) { - key.interestOps(interestOps & ~SelectionKey.OP_WRITE); - } - } else { - // 1) Wrote nothing: buffer is full obviously - set OP_WRITE - // 2) Wrote partial data: - // a) lastSpin is false: no need to set OP_WRITE because the caller will try again immediately. - // b) lastSpin is true: set OP_WRITE because the caller will not try again. - if (writtenBytes == 0 || lastSpin) { - final SelectionKey key = selectionKey(); - final int interestOps = key.interestOps(); - if ((interestOps & SelectionKey.OP_WRITE) == 0) { - key.interestOps(interestOps | SelectionKey.OP_WRITE); - } - } + + protected final void setOpWrite() { + final SelectionKey key = selectionKey(); + final int interestOps = key.interestOps(); + if ((interestOps & SelectionKey.OP_WRITE) == 0) { + key.interestOps(interestOps | SelectionKey.OP_WRITE); + } + } + + protected final void clearOpWrite() { + final SelectionKey key = selectionKey(); + final int interestOps = key.interestOps(); + if ((interestOps & SelectionKey.OP_WRITE) != 0) { + key.interestOps(interestOps & ~SelectionKey.OP_WRITE); } } } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index 4531b24348..71630e5603 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -243,65 +243,75 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { - // Do non-gathering write for a single buffer case. - final int msgCount = in.size(); - if (msgCount <= 1) { - super.doWrite(in); - return; - } - - // Ensure the pending writes are made of ByteBufs only. - ByteBuffer[] nioBuffers = in.nioBuffers(); - if (nioBuffers == null) { - super.doWrite(in); - return; - } - - int nioBufferCnt = in.nioBufferCount(); - long expectedWrittenBytes = in.nioBufferSize(); - - final SocketChannel ch = javaChannel(); - long writtenBytes = 0; - boolean done = false; - for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { - final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt); - updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0); - if (localWrittenBytes == 0) { - break; + for (;;) { + // Do non-gathering write for a single buffer case. + final int msgCount = in.size(); + if (msgCount <= 1) { + super.doWrite(in); + return; } - expectedWrittenBytes -= localWrittenBytes; - writtenBytes += localWrittenBytes; - if (expectedWrittenBytes == 0) { - done = true; - break; + + // Ensure the pending writes are made of ByteBufs only. + ByteBuffer[] nioBuffers = in.nioBuffers(); + if (nioBuffers == null) { + super.doWrite(in); + return; } - } - if (done) { - // Release all buffers - for (int i = msgCount; i > 0; i --) { - in.remove(); - } - } else { - // Did not write all buffers completely. - // Release the fully written buffers and update the indexes of the partially written buffer. + int nioBufferCnt = in.nioBufferCount(); + long expectedWrittenBytes = in.nioBufferSize(); - for (int i = msgCount; i > 0; i --) { - final ByteBuf buf = (ByteBuf) in.current(); - final int readerIndex = buf.readerIndex(); - final int readableBytes = buf.writerIndex() - readerIndex; - - if (readableBytes < writtenBytes) { - in.remove(); - writtenBytes -= readableBytes; - } else if (readableBytes > writtenBytes) { - buf.readerIndex(readerIndex + (int) writtenBytes); - in.progress(writtenBytes); - break; - } else { // readable == writtenBytes - in.remove(); + final SocketChannel ch = javaChannel(); + long writtenBytes = 0; + boolean done = false; + for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { + final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt); + if (localWrittenBytes == 0) { break; } + expectedWrittenBytes -= localWrittenBytes; + writtenBytes += localWrittenBytes; + if (expectedWrittenBytes == 0) { + done = true; + break; + } + } + + if (done) { + // Release all buffers + for (int i = msgCount; i > 0; i --) { + in.remove(); + } + + // Finish the write loop if no new messages were flushed by in.remove(). + if (in.isEmpty()) { + clearOpWrite(); + break; + } + } else { + // Did not write all buffers completely. + // Release the fully written buffers and update the indexes of the partially written buffer. + + for (int i = msgCount; i > 0; i --) { + final ByteBuf buf = (ByteBuf) in.current(); + final int readerIndex = buf.readerIndex(); + final int readableBytes = buf.writerIndex() - readerIndex; + + if (readableBytes < writtenBytes) { + in.remove(); + writtenBytes -= readableBytes; + } else if (readableBytes > writtenBytes) { + buf.readerIndex(readerIndex + (int) writtenBytes); + in.progress(writtenBytes); + break; + } else { // readable == writtenBytes + in.remove(); + break; + } + } + + setOpWrite(); + break; } } } diff --git a/transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java b/transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java index d9682f386d..0844167b8c 100644 --- a/transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java +++ b/transport/src/test/java/io/netty/channel/nio/NioSocketChannelTest.java @@ -16,13 +16,17 @@ package io.netty.channel.nio; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.CharsetUtil; import org.junit.Test; +import java.io.DataInput; +import java.io.DataInputStream; import java.io.InputStream; import java.net.Socket; import java.net.SocketAddress; @@ -37,7 +41,7 @@ import static org.junit.Assert.*; public class NioSocketChannelTest { /** - * Test try to reproduce issue #1600 + * Reproduces the issue #1600 */ @Test public void testFlushCloseReentrance() throws Exception { @@ -92,4 +96,47 @@ public class NioSocketChannelTest { group.shutdownGracefully().sync(); } } + + /** + * Reproduces the issue #1679 + */ + @Test + public void testFlushAfterGatheredFlush() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(1); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group).channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + // Trigger a gathering write by writing two buffers. + ctx.write(Unpooled.wrappedBuffer(new byte[] { 'a' })); + ChannelFuture f = ctx.write(Unpooled.wrappedBuffer(new byte[] { 'b' })); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // This message must be flushed + ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{'c'})); + } + }); + ctx.flush(); + } + }); + + SocketAddress address = sb.bind(0).sync().channel().localAddress(); + + Socket s = new Socket(); + s.connect(address); + + DataInput in = new DataInputStream(s.getInputStream()); + byte[] buf = new byte[3]; + in.readFully(buf); + + assertThat(new String(buf, CharsetUtil.US_ASCII), is("abc")); + + s.close(); + } finally { + group.shutdownGracefully().sync(); + } + } }