Respect MAX_MESSAGES_PER_READ in LocalChannel / LocalServerChannel. (#7885)

Motivation:

LocalChannel / LocalServerChannel did not respect read limits and just always read all of the messages.

Modifications:

- Correct respect MAX_MESSAGES_PER_READ settings
- Add unit tests

Result:

Fixes https://github.com/netty/netty/issues/7880.
This commit is contained in:
Norman Maurer 2018-04-26 07:58:56 +02:00 committed by GitHub
parent eaf1771336
commit f4d7e8de14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 219 additions and 48 deletions

View File

@ -25,6 +25,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop;
import io.netty.channel.PreferHeapByteBufAllocator;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
@ -67,17 +68,13 @@ public class LocalChannel extends AbstractChannel {
private final Runnable readTask = new Runnable() {
@Override
public void run() {
ChannelPipeline pipeline = pipeline();
for (;;) {
Object m = inboundBuffer.poll();
if (m == null) {
break;
}
pipeline.fireChannelRead(m);
// Ensure the inboundBuffer is not empty as readInbound() will always call fireChannelReadComplete()
if (!inboundBuffer.isEmpty()) {
readInbound();
}
pipeline.fireChannelReadComplete();
}
};
private final Runnable shutdownHook = new Runnable() {
@Override
public void run() {
@ -295,13 +292,27 @@ public class LocalChannel extends AbstractChannel {
((SingleThreadEventExecutor) eventLoop()).removeShutdownHook(shutdownHook);
}
private void readInbound() {
RecvByteBufAllocator.Handle handle = unsafe().recvBufAllocHandle();
handle.reset(config());
ChannelPipeline pipeline = pipeline();
do {
Object received = inboundBuffer.poll();
if (received == null) {
break;
}
pipeline.fireChannelRead(received);
} while (handle.continueReading());
pipeline.fireChannelReadComplete();
}
@Override
protected void doBeginRead() throws Exception {
if (readInProgress) {
return;
}
ChannelPipeline pipeline = pipeline();
Queue<Object> inboundBuffer = this.inboundBuffer;
if (inboundBuffer.isEmpty()) {
readInProgress = true;
@ -313,14 +324,7 @@ public class LocalChannel extends AbstractChannel {
if (stackDepth < MAX_READER_STACK_DEPTH) {
threadLocals.setLocalChannelReaderStackDepth(stackDepth + 1);
try {
for (;;) {
Object received = inboundBuffer.poll();
if (received == null) {
break;
}
pipeline.fireChannelRead(received);
}
pipeline.fireChannelReadComplete();
readInbound();
} finally {
threadLocals.setLocalChannelReaderStackDepth(stackDepth);
}
@ -435,19 +439,11 @@ public class LocalChannel extends AbstractChannel {
FINISH_READ_FUTURE_UPDATER.compareAndSet(peer, peerFinishReadFuture, null);
}
}
ChannelPipeline peerPipeline = peer.pipeline();
// We should only set readInProgress to false if there is any data that was read as otherwise we may miss to
// forward data later on.
if (peer.readInProgress && !peer.inboundBuffer.isEmpty()) {
peer.readInProgress = false;
for (;;) {
Object received = peer.inboundBuffer.poll();
if (received == null) {
break;
}
peerPipeline.fireChannelRead(received);
}
peerPipeline.fireChannelReadComplete();
peer.readInbound();
}
}

View File

@ -21,6 +21,7 @@ import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop;
import io.netty.channel.PreferHeapByteBufAllocator;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.ServerChannel;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.util.concurrent.SingleThreadEventExecutor;
@ -126,15 +127,7 @@ public class LocalServerChannel extends AbstractServerChannel {
return;
}
ChannelPipeline pipeline = pipeline();
for (;;) {
Object m = inboundBuffer.poll();
if (m == null) {
break;
}
pipeline.fireChannelRead(m);
}
pipeline.fireChannelReadComplete();
readInbound();
}
LocalChannel serve(final LocalChannel peer) {
@ -143,15 +136,30 @@ public class LocalServerChannel extends AbstractServerChannel {
serve0(child);
} else {
eventLoop().execute(new Runnable() {
@Override
public void run() {
serve0(child);
}
@Override
public void run() {
serve0(child);
}
});
}
return child;
}
private void readInbound() {
RecvByteBufAllocator.Handle handle = unsafe().recvBufAllocHandle();
handle.reset(config());
ChannelPipeline pipeline = pipeline();
do {
Object m = inboundBuffer.poll();
if (m == null) {
break;
}
pipeline.fireChannelRead(m);
} while (handle.continueReading());
pipeline.fireChannelReadComplete();
}
/**
* A factory method for {@link LocalChannel}s. Users may override it
* to create custom instances of {@link LocalChannel}s.
@ -164,15 +172,8 @@ public class LocalServerChannel extends AbstractServerChannel {
inboundBuffer.add(child);
if (acceptInProgress) {
acceptInProgress = false;
ChannelPipeline pipeline = pipeline();
for (;;) {
Object m = inboundBuffer.poll();
if (m == null) {
break;
}
pipeline.fireChannelRead(m);
}
pipeline.fireChannelReadComplete();
readInbound();
}
}
}

View File

@ -46,11 +46,13 @@ import java.net.ConnectException;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
@ -1020,4 +1022,176 @@ public class LocalChannelTest {
closeChannel(sc);
}
}
@Test(timeout = 5000)
public void testMaxMessagesPerReadRespectedWithAutoReadSharedGroup() throws Exception {
testMaxMessagesPerReadRespected(sharedGroup, sharedGroup, true);
}
@Test(timeout = 5000)
public void testMaxMessagesPerReadRespectedWithoutAutoReadSharedGroup() throws Exception {
testMaxMessagesPerReadRespected(sharedGroup, sharedGroup, false);
}
@Test(timeout = 5000)
public void testMaxMessagesPerReadRespectedWithAutoReadDifferentGroup() throws Exception {
testMaxMessagesPerReadRespected(group1, group2, true);
}
@Test(timeout = 5000)
public void testMaxMessagesPerReadRespectedWithoutAutoReadDifferentGroup() throws Exception {
testMaxMessagesPerReadRespected(group1, group2, false);
}
private static void testMaxMessagesPerReadRespected(
EventLoopGroup serverGroup, EventLoopGroup clientGroup, final boolean autoRead) throws Exception {
final CountDownLatch countDownLatch = new CountDownLatch(5);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
cb.group(serverGroup)
.channel(LocalChannel.class)
.option(ChannelOption.AUTO_READ, autoRead)
.option(ChannelOption.MAX_MESSAGES_PER_READ, 1)
.handler(new ChannelReadHandler(countDownLatch, autoRead));
sb.group(clientGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(final ChannelHandlerContext ctx) {
for (int i = 0; i < 10; i++) {
ctx.write(i);
}
ctx.flush();
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(TEST_ADDRESS).sync().channel();
cc = cb.connect(TEST_ADDRESS).sync().channel();
countDownLatch.await();
} finally {
closeChannel(cc);
closeChannel(sc);
}
}
@Test(timeout = 5000)
public void testServerMaxMessagesPerReadRespectedWithAutoReadSharedGroup() throws Exception {
testServerMaxMessagesPerReadRespected(sharedGroup, sharedGroup, true);
}
@Test(timeout = 5000)
public void testServerMaxMessagesPerReadRespectedWithoutAutoReadSharedGroup() throws Exception {
testServerMaxMessagesPerReadRespected(sharedGroup, sharedGroup, false);
}
@Test(timeout = 5000)
public void testServerMaxMessagesPerReadRespectedWithAutoReadDifferentGroup() throws Exception {
testServerMaxMessagesPerReadRespected(group1, group2, true);
}
@Test(timeout = 5000)
public void testServerMaxMessagesPerReadRespectedWithoutAutoReadDifferentGroup() throws Exception {
testServerMaxMessagesPerReadRespected(group1, group2, false);
}
private void testServerMaxMessagesPerReadRespected(
EventLoopGroup serverGroup, EventLoopGroup clientGroup, final boolean autoRead) throws Exception {
final CountDownLatch countDownLatch = new CountDownLatch(5);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
cb.group(serverGroup)
.channel(LocalChannel.class)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
// NOOP
}
});
sb.group(clientGroup)
.channel(LocalServerChannel.class)
.option(ChannelOption.AUTO_READ, autoRead)
.option(ChannelOption.MAX_MESSAGES_PER_READ, 1)
.handler(new ChannelReadHandler(countDownLatch, autoRead))
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
// NOOP
}
});
Channel sc = null;
Channel cc = null;
try {
// Start server
sc = sb.bind(TEST_ADDRESS).sync().channel();
for (int i = 0; i < 5; i++) {
try {
cc = cb.connect(TEST_ADDRESS).sync().channel();
} finally {
closeChannel(cc);
}
}
countDownLatch.await();
} finally {
closeChannel(sc);
}
}
private static final class ChannelReadHandler extends ChannelInboundHandlerAdapter {
private final CountDownLatch latch;
private final boolean autoRead;
private int read;
ChannelReadHandler(CountDownLatch latch, boolean autoRead) {
this.latch = latch;
this.autoRead = autoRead;
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
if (!autoRead) {
ctx.read();
}
ctx.fireChannelActive();
}
@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) {
assertEquals(0, read);
read++;
ctx.fireChannelRead(msg);
}
@Override
public void channelReadComplete(final ChannelHandlerContext ctx) {
assertEquals(1, read);
latch.countDown();
if (latch.getCount() > 0) {
if (!autoRead) {
// The read will be scheduled 100ms in the future to ensure we not receive any
// channelRead calls in the meantime.
ctx.executor().schedule(new Runnable() {
@Override
public void run() {
read = 0;
ctx.read();
}
}, 100, TimeUnit.MILLISECONDS);
} else {
read = 0;
}
}
ctx.fireChannelReadComplete();
}
}
}