Fix a bug in NioSocketChannel.doWrite() where flush() triggered from a ChannelFutureListener is ignored

- Fixes #1679
This commit is contained in:
Trustin Lee 2013-07-31 19:07:38 +09:00
parent ca29f1a37d
commit c79a3cdefe
3 changed files with 128 additions and 85 deletions

View File

@ -142,17 +142,13 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
@Override
protected void doWrite(ChannelOutboundBuffer in) throws Exception {
final SelectionKey key = selectionKey();
final int interestOps = key.interestOps();
int writeSpinCount = -1;
for (;;) {
Object msg = in.current();
if (msg == null) {
// Wrote all messages.
if ((interestOps & SelectionKey.OP_WRITE) != 0) {
key.interestOps(interestOps & ~SelectionKey.OP_WRITE);
}
clearOpWrite();
break;
}
@ -186,9 +182,7 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
} else {
// Did not write completely.
in.progress(flushedAmount);
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
key.interestOps(interestOps | SelectionKey.OP_WRITE);
}
setOpWrite();
break;
}
} else if (msg instanceof FileRegion) {
@ -216,9 +210,7 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
} else {
// Did not write completely.
in.progress(flushedAmount);
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
key.interestOps(interestOps | SelectionKey.OP_WRITE);
}
setOpWrite();
break;
}
} else {
@ -247,26 +239,20 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
*/
protected abstract int doWriteBytes(ByteBuf buf) throws Exception;
protected final void updateOpWrite(long expectedWrittenBytes, long writtenBytes, boolean lastSpin) {
if (writtenBytes >= expectedWrittenBytes) {
final SelectionKey key = selectionKey();
final int interestOps = key.interestOps();
// Wrote the outbound buffer completely - clear OP_WRITE.
if ((interestOps & SelectionKey.OP_WRITE) != 0) {
key.interestOps(interestOps & ~SelectionKey.OP_WRITE);
}
} else {
// 1) Wrote nothing: buffer is full obviously - set OP_WRITE
// 2) Wrote partial data:
// a) lastSpin is false: no need to set OP_WRITE because the caller will try again immediately.
// b) lastSpin is true: set OP_WRITE because the caller will not try again.
if (writtenBytes == 0 || lastSpin) {
final SelectionKey key = selectionKey();
final int interestOps = key.interestOps();
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
key.interestOps(interestOps | SelectionKey.OP_WRITE);
}
}
protected final void setOpWrite() {
final SelectionKey key = selectionKey();
final int interestOps = key.interestOps();
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
key.interestOps(interestOps | SelectionKey.OP_WRITE);
}
}
protected final void clearOpWrite() {
final SelectionKey key = selectionKey();
final int interestOps = key.interestOps();
if ((interestOps & SelectionKey.OP_WRITE) != 0) {
key.interestOps(interestOps & ~SelectionKey.OP_WRITE);
}
}
}

View File

@ -243,65 +243,75 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
@Override
protected void doWrite(ChannelOutboundBuffer in) throws Exception {
// Do non-gathering write for a single buffer case.
final int msgCount = in.size();
if (msgCount <= 1) {
super.doWrite(in);
return;
}
// Ensure the pending writes are made of ByteBufs only.
ByteBuffer[] nioBuffers = in.nioBuffers();
if (nioBuffers == null) {
super.doWrite(in);
return;
}
int nioBufferCnt = in.nioBufferCount();
long expectedWrittenBytes = in.nioBufferSize();
final SocketChannel ch = javaChannel();
long writtenBytes = 0;
boolean done = false;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0);
if (localWrittenBytes == 0) {
break;
for (;;) {
// Do non-gathering write for a single buffer case.
final int msgCount = in.size();
if (msgCount <= 1) {
super.doWrite(in);
return;
}
expectedWrittenBytes -= localWrittenBytes;
writtenBytes += localWrittenBytes;
if (expectedWrittenBytes == 0) {
done = true;
break;
// Ensure the pending writes are made of ByteBufs only.
ByteBuffer[] nioBuffers = in.nioBuffers();
if (nioBuffers == null) {
super.doWrite(in);
return;
}
}
if (done) {
// Release all buffers
for (int i = msgCount; i > 0; i --) {
in.remove();
}
} else {
// Did not write all buffers completely.
// Release the fully written buffers and update the indexes of the partially written buffer.
int nioBufferCnt = in.nioBufferCount();
long expectedWrittenBytes = in.nioBufferSize();
for (int i = msgCount; i > 0; i --) {
final ByteBuf buf = (ByteBuf) in.current();
final int readerIndex = buf.readerIndex();
final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes < writtenBytes) {
in.remove();
writtenBytes -= readableBytes;
} else if (readableBytes > writtenBytes) {
buf.readerIndex(readerIndex + (int) writtenBytes);
in.progress(writtenBytes);
break;
} else { // readable == writtenBytes
in.remove();
final SocketChannel ch = javaChannel();
long writtenBytes = 0;
boolean done = false;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
if (localWrittenBytes == 0) {
break;
}
expectedWrittenBytes -= localWrittenBytes;
writtenBytes += localWrittenBytes;
if (expectedWrittenBytes == 0) {
done = true;
break;
}
}
if (done) {
// Release all buffers
for (int i = msgCount; i > 0; i --) {
in.remove();
}
// Finish the write loop if no new messages were flushed by in.remove().
if (in.isEmpty()) {
clearOpWrite();
break;
}
} else {
// Did not write all buffers completely.
// Release the fully written buffers and update the indexes of the partially written buffer.
for (int i = msgCount; i > 0; i --) {
final ByteBuf buf = (ByteBuf) in.current();
final int readerIndex = buf.readerIndex();
final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes < writtenBytes) {
in.remove();
writtenBytes -= readableBytes;
} else if (readableBytes > writtenBytes) {
buf.readerIndex(readerIndex + (int) writtenBytes);
in.progress(writtenBytes);
break;
} else { // readable == writtenBytes
in.remove();
break;
}
}
setOpWrite();
break;
}
}
}

View File

@ -16,13 +16,17 @@
package io.netty.channel.nio;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.CharsetUtil;
import org.junit.Test;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.InputStream;
import java.net.Socket;
import java.net.SocketAddress;
@ -37,7 +41,7 @@ import static org.junit.Assert.*;
public class NioSocketChannelTest {
/**
* Test try to reproduce issue #1600
* Reproduces the issue #1600
*/
@Test
public void testFlushCloseReentrance() throws Exception {
@ -92,4 +96,47 @@ public class NioSocketChannelTest {
group.shutdownGracefully().sync();
}
}
/**
* Reproduces the issue #1679
*/
@Test
public void testFlushAfterGatheredFlush() throws Exception {
NioEventLoopGroup group = new NioEventLoopGroup(1);
try {
ServerBootstrap sb = new ServerBootstrap();
sb.group(group).channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
// Trigger a gathering write by writing two buffers.
ctx.write(Unpooled.wrappedBuffer(new byte[] { 'a' }));
ChannelFuture f = ctx.write(Unpooled.wrappedBuffer(new byte[] { 'b' }));
f.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
// This message must be flushed
ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{'c'}));
}
});
ctx.flush();
}
});
SocketAddress address = sb.bind(0).sync().channel().localAddress();
Socket s = new Socket();
s.connect(address);
DataInput in = new DataInputStream(s.getInputStream());
byte[] buf = new byte[3];
in.readFully(buf);
assertThat(new String(buf, CharsetUtil.US_ASCII), is("abc"));
s.close();
} finally {
group.shutdownGracefully().sync();
}
}
}