diff --git a/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java b/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java index 80bf9098d3..c4198e7173 100644 --- a/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java +++ b/transport/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java @@ -140,6 +140,7 @@ public abstract class AbstractCoalescingBufferQueue { // Use isEmpty rather than readableBytes==0 as we may have a promise associated with an empty buffer. if (bufAndListenerPairs.isEmpty()) { + assert readableBytes == 0; return removeEmptyValue(); } bytes = Math.min(bytes, readableBytes); @@ -221,7 +222,6 @@ public abstract class AbstractCoalescingBufferQueue { * @param ctx The context to write all elements to. */ public final void writeAndRemoveAll(ChannelHandlerContext ctx) { - decrementReadableBytes(readableBytes); Throwable pending = null; ByteBuf previousBuf = null; for (;;) { @@ -229,6 +229,7 @@ public abstract class AbstractCoalescingBufferQueue { try { if (entry == null) { if (previousBuf != null) { + decrementReadableBytes(previousBuf.readableBytes()); ctx.write(previousBuf, ctx.voidPromise()); } break; @@ -236,13 +237,16 @@ public abstract class AbstractCoalescingBufferQueue { if (entry instanceof ByteBuf) { if (previousBuf != null) { + decrementReadableBytes(previousBuf.readableBytes()); ctx.write(previousBuf, ctx.voidPromise()); } previousBuf = (ByteBuf) entry; } else if (entry instanceof ChannelPromise) { + decrementReadableBytes(previousBuf.readableBytes()); ctx.write(previousBuf, (ChannelPromise) entry); previousBuf = null; } else { + decrementReadableBytes(previousBuf.readableBytes()); ctx.write(previousBuf).addListener((ChannelFutureListener) entry); previousBuf = null; } @@ -326,7 +330,6 @@ public abstract class AbstractCoalescingBufferQueue { } private void releaseAndCompleteAll(ChannelFuture future) { - decrementReadableBytes(readableBytes); Throwable pending = null; for (;;) { Object entry = bufAndListenerPairs.poll(); @@ -335,7 +338,9 @@ public abstract class AbstractCoalescingBufferQueue { } try { if (entry instanceof ByteBuf) { - safeRelease(entry); + ByteBuf buffer = (ByteBuf) entry; + decrementReadableBytes(buffer.readableBytes()); + safeRelease(buffer); } else { ((ChannelFutureListener) entry).operationComplete(future); } diff --git a/transport/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java b/transport/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java new file mode 100644 index 0000000000..122dbb0f26 --- /dev/null +++ b/transport/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 The Netty Project + + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCountUtil; +import org.junit.Test; + +import java.nio.channels.ClosedChannelException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class AbstractCoalescingBufferQueueTest { + + // See https://github.com/netty/netty/issues/10286 + @Test + public void testDecrementAllWhenWriteAndRemoveAll() { + testDecrementAll(true); + } + + // See https://github.com/netty/netty/issues/10286 + @Test + public void testDecrementAllWhenReleaseAndFailAll() { + testDecrementAll(false); + } + + private static void testDecrementAll(boolean write) { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.release(msg); + promise.setSuccess(); + } + }, new ChannelHandlerAdapter() { }); + final AbstractCoalescingBufferQueue queue = new AbstractCoalescingBufferQueue(channel, 128) { + @Override + protected ByteBuf compose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) { + return composeIntoComposite(alloc, cumulation, next); + } + + @Override + protected ByteBuf removeEmptyValue() { + return Unpooled.EMPTY_BUFFER; + } + }; + + final byte[] bytes = new byte[128]; + queue.add(Unpooled.wrappedBuffer(bytes), new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + queue.add(Unpooled.wrappedBuffer(bytes)); + assertEquals(bytes.length, queue.readableBytes()); + } + }); + + assertEquals(bytes.length, queue.readableBytes()); + + ChannelHandlerContext ctx = channel.pipeline().lastContext(); + if (write) { + queue.writeAndRemoveAll(ctx); + } else { + queue.releaseAndFailAll(ctx, new ClosedChannelException()); + } + ByteBuf buffer = queue.remove(channel.alloc(), 128, channel.newPromise()); + assertFalse(buffer.isReadable()); + buffer.release(); + + assertTrue(queue.isEmpty()); + assertEquals(0, queue.readableBytes()); + + assertFalse(channel.finish()); + } +}