Fix a bug in AioSocketChannel where recursive doBeginRead() is allowed unexpectedly

This commit is contained in:
Trustin Lee 2013-01-01 00:08:58 +09:00
parent e0a6dc0ac3
commit 1e9652b47a

View File

@ -20,12 +20,12 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFlushPromiseNotifier;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.FileRegion;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.channel.socket.SocketChannel;
import java.io.IOException;
@ -69,6 +69,8 @@ public class AioSocketChannel extends AbstractAioChannel implements SocketChanne
private volatile boolean outputShutdown;
private boolean readInProgress;
private boolean inDoBeginRead;
private boolean readAgain;
private boolean writeInProgress;
private boolean inDoFlushByteBuffer;
@ -311,33 +313,58 @@ public class AioSocketChannel extends AbstractAioChannel implements SocketChanne
@Override
protected void doBeginRead() {
if (inDoBeginRead) {
readAgain = true;
return;
}
if (readInProgress || inputShutdown) {
return;
}
ByteBuf byteBuf = pipeline().inboundByteBuffer();
if (!byteBuf.readable()) {
byteBuf.discardReadBytes();
}
inDoBeginRead = true;
try {
for (;;) {
ByteBuf byteBuf = pipeline().inboundByteBuffer();
if (!byteBuf.readable()) {
byteBuf.discardReadBytes();
}
expandReadBuffer(byteBuf);
expandReadBuffer(byteBuf);
readInProgress = true;
if (byteBuf.nioBufferCount() == 1) {
// Get a ByteBuffer view on the ByteBuf
ByteBuffer buffer = byteBuf.nioBuffer(byteBuf.writerIndex(), byteBuf.writableBytes());
javaChannel().read(
buffer, config.getReadTimeout(), TimeUnit.MILLISECONDS, this, READ_HANDLER);
} else {
ByteBuffer[] buffers = byteBuf.nioBuffers(byteBuf.writerIndex(), byteBuf.writableBytes());
if (buffers.length == 1) {
javaChannel().read(
buffers[0], config.getReadTimeout(), TimeUnit.MILLISECONDS, this, READ_HANDLER);
} else {
javaChannel().read(
buffers, 0, buffers.length, config.getReadTimeout(), TimeUnit.MILLISECONDS,
this, SCATTERING_READ_HANDLER);
readInProgress = true;
if (byteBuf.nioBufferCount() == 1) {
// Get a ByteBuffer view on the ByteBuf
ByteBuffer buffer = byteBuf.nioBuffer(byteBuf.writerIndex(), byteBuf.writableBytes());
javaChannel().read(
buffer, config.getReadTimeout(), TimeUnit.MILLISECONDS, this, READ_HANDLER);
} else {
ByteBuffer[] buffers = byteBuf.nioBuffers(byteBuf.writerIndex(), byteBuf.writableBytes());
if (buffers.length == 1) {
javaChannel().read(
buffers[0], config.getReadTimeout(), TimeUnit.MILLISECONDS, this, READ_HANDLER);
} else {
javaChannel().read(
buffers, 0, buffers.length, config.getReadTimeout(), TimeUnit.MILLISECONDS,
this, SCATTERING_READ_HANDLER);
}
}
if (readInProgress) {
// JDK decided to read data (or notify handler) later.
break;
}
if (readAgain) {
// User requested the read operation.
readAgain = false;
continue;
}
break;
}
} finally {
inDoBeginRead = false;
}
}
@ -362,8 +389,8 @@ public class AioSocketChannel extends AbstractAioChannel implements SocketChanne
return;
}
// Notify flush futures only when the handler is called outside of unsafe().flushNow()
// because flushNow() will do that for us.
// Update the write counter and notify flush futures only when the handler is called outside of
// unsafe().flushNow() because flushNow() will do that for us.
ChannelFlushPromiseNotifier notifier = channel.flushFutureNotifier;
notifier.increaseWriteCounter(writtenBytes);
notifier.notifyFlushFutures();