Fix a bug in NioSocketChannel.doWrite() where flush() triggered from a ChannelFutureListener is ignored
- Fixes #1679
This commit is contained in:
parent
ca29f1a37d
commit
c79a3cdefe
|
@ -142,17 +142,13 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
|
|||
|
||||
@Override
|
||||
protected void doWrite(ChannelOutboundBuffer in) throws Exception {
|
||||
final SelectionKey key = selectionKey();
|
||||
final int interestOps = key.interestOps();
|
||||
int writeSpinCount = -1;
|
||||
|
||||
for (;;) {
|
||||
Object msg = in.current();
|
||||
if (msg == null) {
|
||||
// Wrote all messages.
|
||||
if ((interestOps & SelectionKey.OP_WRITE) != 0) {
|
||||
key.interestOps(interestOps & ~SelectionKey.OP_WRITE);
|
||||
}
|
||||
clearOpWrite();
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -186,9 +182,7 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
|
|||
} else {
|
||||
// Did not write completely.
|
||||
in.progress(flushedAmount);
|
||||
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
|
||||
key.interestOps(interestOps | SelectionKey.OP_WRITE);
|
||||
}
|
||||
setOpWrite();
|
||||
break;
|
||||
}
|
||||
} else if (msg instanceof FileRegion) {
|
||||
|
@ -216,9 +210,7 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
|
|||
} else {
|
||||
// Did not write completely.
|
||||
in.progress(flushedAmount);
|
||||
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
|
||||
key.interestOps(interestOps | SelectionKey.OP_WRITE);
|
||||
}
|
||||
setOpWrite();
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
|
@ -247,26 +239,20 @@ public abstract class AbstractNioByteChannel extends AbstractNioChannel {
|
|||
*/
|
||||
protected abstract int doWriteBytes(ByteBuf buf) throws Exception;
|
||||
|
||||
protected final void updateOpWrite(long expectedWrittenBytes, long writtenBytes, boolean lastSpin) {
|
||||
if (writtenBytes >= expectedWrittenBytes) {
|
||||
final SelectionKey key = selectionKey();
|
||||
final int interestOps = key.interestOps();
|
||||
// Wrote the outbound buffer completely - clear OP_WRITE.
|
||||
if ((interestOps & SelectionKey.OP_WRITE) != 0) {
|
||||
key.interestOps(interestOps & ~SelectionKey.OP_WRITE);
|
||||
}
|
||||
} else {
|
||||
// 1) Wrote nothing: buffer is full obviously - set OP_WRITE
|
||||
// 2) Wrote partial data:
|
||||
// a) lastSpin is false: no need to set OP_WRITE because the caller will try again immediately.
|
||||
// b) lastSpin is true: set OP_WRITE because the caller will not try again.
|
||||
if (writtenBytes == 0 || lastSpin) {
|
||||
final SelectionKey key = selectionKey();
|
||||
final int interestOps = key.interestOps();
|
||||
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
|
||||
key.interestOps(interestOps | SelectionKey.OP_WRITE);
|
||||
}
|
||||
}
|
||||
|
||||
protected final void setOpWrite() {
|
||||
final SelectionKey key = selectionKey();
|
||||
final int interestOps = key.interestOps();
|
||||
if ((interestOps & SelectionKey.OP_WRITE) == 0) {
|
||||
key.interestOps(interestOps | SelectionKey.OP_WRITE);
|
||||
}
|
||||
}
|
||||
|
||||
protected final void clearOpWrite() {
|
||||
final SelectionKey key = selectionKey();
|
||||
final int interestOps = key.interestOps();
|
||||
if ((interestOps & SelectionKey.OP_WRITE) != 0) {
|
||||
key.interestOps(interestOps & ~SelectionKey.OP_WRITE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -243,65 +243,75 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
|
|||
|
||||
@Override
|
||||
protected void doWrite(ChannelOutboundBuffer in) throws Exception {
|
||||
// Do non-gathering write for a single buffer case.
|
||||
final int msgCount = in.size();
|
||||
if (msgCount <= 1) {
|
||||
super.doWrite(in);
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure the pending writes are made of ByteBufs only.
|
||||
ByteBuffer[] nioBuffers = in.nioBuffers();
|
||||
if (nioBuffers == null) {
|
||||
super.doWrite(in);
|
||||
return;
|
||||
}
|
||||
|
||||
int nioBufferCnt = in.nioBufferCount();
|
||||
long expectedWrittenBytes = in.nioBufferSize();
|
||||
|
||||
final SocketChannel ch = javaChannel();
|
||||
long writtenBytes = 0;
|
||||
boolean done = false;
|
||||
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
|
||||
final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
|
||||
updateOpWrite(expectedWrittenBytes, localWrittenBytes, i == 0);
|
||||
if (localWrittenBytes == 0) {
|
||||
break;
|
||||
for (;;) {
|
||||
// Do non-gathering write for a single buffer case.
|
||||
final int msgCount = in.size();
|
||||
if (msgCount <= 1) {
|
||||
super.doWrite(in);
|
||||
return;
|
||||
}
|
||||
expectedWrittenBytes -= localWrittenBytes;
|
||||
writtenBytes += localWrittenBytes;
|
||||
if (expectedWrittenBytes == 0) {
|
||||
done = true;
|
||||
break;
|
||||
|
||||
// Ensure the pending writes are made of ByteBufs only.
|
||||
ByteBuffer[] nioBuffers = in.nioBuffers();
|
||||
if (nioBuffers == null) {
|
||||
super.doWrite(in);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (done) {
|
||||
// Release all buffers
|
||||
for (int i = msgCount; i > 0; i --) {
|
||||
in.remove();
|
||||
}
|
||||
} else {
|
||||
// Did not write all buffers completely.
|
||||
// Release the fully written buffers and update the indexes of the partially written buffer.
|
||||
int nioBufferCnt = in.nioBufferCount();
|
||||
long expectedWrittenBytes = in.nioBufferSize();
|
||||
|
||||
for (int i = msgCount; i > 0; i --) {
|
||||
final ByteBuf buf = (ByteBuf) in.current();
|
||||
final int readerIndex = buf.readerIndex();
|
||||
final int readableBytes = buf.writerIndex() - readerIndex;
|
||||
|
||||
if (readableBytes < writtenBytes) {
|
||||
in.remove();
|
||||
writtenBytes -= readableBytes;
|
||||
} else if (readableBytes > writtenBytes) {
|
||||
buf.readerIndex(readerIndex + (int) writtenBytes);
|
||||
in.progress(writtenBytes);
|
||||
break;
|
||||
} else { // readable == writtenBytes
|
||||
in.remove();
|
||||
final SocketChannel ch = javaChannel();
|
||||
long writtenBytes = 0;
|
||||
boolean done = false;
|
||||
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
|
||||
final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
|
||||
if (localWrittenBytes == 0) {
|
||||
break;
|
||||
}
|
||||
expectedWrittenBytes -= localWrittenBytes;
|
||||
writtenBytes += localWrittenBytes;
|
||||
if (expectedWrittenBytes == 0) {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (done) {
|
||||
// Release all buffers
|
||||
for (int i = msgCount; i > 0; i --) {
|
||||
in.remove();
|
||||
}
|
||||
|
||||
// Finish the write loop if no new messages were flushed by in.remove().
|
||||
if (in.isEmpty()) {
|
||||
clearOpWrite();
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Did not write all buffers completely.
|
||||
// Release the fully written buffers and update the indexes of the partially written buffer.
|
||||
|
||||
for (int i = msgCount; i > 0; i --) {
|
||||
final ByteBuf buf = (ByteBuf) in.current();
|
||||
final int readerIndex = buf.readerIndex();
|
||||
final int readableBytes = buf.writerIndex() - readerIndex;
|
||||
|
||||
if (readableBytes < writtenBytes) {
|
||||
in.remove();
|
||||
writtenBytes -= readableBytes;
|
||||
} else if (readableBytes > writtenBytes) {
|
||||
buf.readerIndex(readerIndex + (int) writtenBytes);
|
||||
in.progress(writtenBytes);
|
||||
break;
|
||||
} else { // readable == writtenBytes
|
||||
in.remove();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
setOpWrite();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,13 +16,17 @@
|
|||
package io.netty.channel.nio;
|
||||
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.socket.nio.NioServerSocketChannel;
|
||||
import io.netty.util.CharsetUtil;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.DataInput;
|
||||
import java.io.DataInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.net.Socket;
|
||||
import java.net.SocketAddress;
|
||||
|
@ -37,7 +41,7 @@ import static org.junit.Assert.*;
|
|||
public class NioSocketChannelTest {
|
||||
|
||||
/**
|
||||
* Test try to reproduce issue #1600
|
||||
* Reproduces the issue #1600
|
||||
*/
|
||||
@Test
|
||||
public void testFlushCloseReentrance() throws Exception {
|
||||
|
@ -92,4 +96,47 @@ public class NioSocketChannelTest {
|
|||
group.shutdownGracefully().sync();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reproduces the issue #1679
|
||||
*/
|
||||
@Test
|
||||
public void testFlushAfterGatheredFlush() throws Exception {
|
||||
NioEventLoopGroup group = new NioEventLoopGroup(1);
|
||||
try {
|
||||
ServerBootstrap sb = new ServerBootstrap();
|
||||
sb.group(group).channel(NioServerSocketChannel.class);
|
||||
sb.childHandler(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
|
||||
// Trigger a gathering write by writing two buffers.
|
||||
ctx.write(Unpooled.wrappedBuffer(new byte[] { 'a' }));
|
||||
ChannelFuture f = ctx.write(Unpooled.wrappedBuffer(new byte[] { 'b' }));
|
||||
f.addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) throws Exception {
|
||||
// This message must be flushed
|
||||
ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{'c'}));
|
||||
}
|
||||
});
|
||||
ctx.flush();
|
||||
}
|
||||
});
|
||||
|
||||
SocketAddress address = sb.bind(0).sync().channel().localAddress();
|
||||
|
||||
Socket s = new Socket();
|
||||
s.connect(address);
|
||||
|
||||
DataInput in = new DataInputStream(s.getInputStream());
|
||||
byte[] buf = new byte[3];
|
||||
in.readFully(buf);
|
||||
|
||||
assertThat(new String(buf, CharsetUtil.US_ASCII), is("abc"));
|
||||
|
||||
s.close();
|
||||
} finally {
|
||||
group.shutdownGracefully().sync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user