From d0943dcd30b08eb4043aeb88fd983bcebf8c3432 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 1 Apr 2016 11:45:43 +0200 Subject: [PATCH] [#5028] Fix re-entrance issue with channelWritabilityChanged(...) and write(...) Motivation: When always triggered fireChannelWritabilityChanged() directly when the update the pending bytes in the ChannelOutboundBuffer was made from within the EventLoop. This is problematic as this can cause some re-entrance issue if the user has a custom ChannelOutboundHandler that does multiple writes from within the write(...) method and also has a handler that will intercept the channelWritabilityChanged event and trigger another write when the Channel is writable. This can also easily happen if the user just use a MessageToMessageEncoder subclass and triggers a write from channelWritabilityChanged(). Beside this we also triggered fireChannelWritabilityChanged() too often when a user did a write from outside the EventLoop. In this case we increased the pending bytes of the outboundbuffer before scheduled the actual write and decreased again before the write then takes place. Both of this may trigger a fireChannelWritabilityChanged() event which then may be re-triggered once the actual write ends again in the ChannelOutboundBuffer. The third gotcha was that a user may get multiple events even if the writability of the channel not changed. Modification: - Always invoke the fireChannelWritabilityChanged() later on the EventLoop. - Only trigger the fireChannelWritabilityChanged() if the channel is still active and if the writability of the channel changed. No need to cause events that were already triggered without a real writability change. - when write(...) is called from outside the EventLoop we only increase the pending bytes in the outbound buffer (so that Channel.isWritable() is updated directly) but not cause a fireChannelWritabilityChanged(). The fireChannelWritabilityChanged() is then triggered once the task is picked up by the EventLoop as usual. Result: No more re-entrance possible because of writes from within channelWritabilityChanged(...) method and no events without a real writability change. --- .../netty/channel/ChannelOutboundBuffer.java | 90 +++---- .../channel/DefaultChannelHandlerInvoker.java | 8 +- .../io/netty/channel/PendingWriteQueue.java | 4 +- .../io/netty/channel/BaseChannelTest.java | 90 ------- .../channel/ChannelOutboundBufferTest.java | 9 + .../netty/channel/ReentrantChannelTest.java | 249 +++++++++++------- 6 files changed, 214 insertions(+), 236 deletions(-) delete mode 100644 transport/src/test/java/io/netty/channel/BaseChannelTest.java diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index 5b9a3f5911..071cbb03f1 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -89,7 +89,7 @@ public final class ChannelOutboundBuffer { @SuppressWarnings("UnusedDeclaration") private volatile int unwritable; - private volatile Runnable fireChannelWritabilityChangedTask; + private final Runnable fireChannelWritabilityChangedTask; static { AtomicIntegerFieldUpdater unwritableUpdater = @@ -107,8 +107,9 @@ public final class ChannelOutboundBuffer { TOTAL_PENDING_SIZE_UPDATER = pendingSizeUpdater; } - ChannelOutboundBuffer(AbstractChannel channel) { + ChannelOutboundBuffer(final AbstractChannel channel) { this.channel = channel; + fireChannelWritabilityChangedTask = new ChannelWritabilityChangedTask(channel); } /** @@ -131,7 +132,7 @@ public final class ChannelOutboundBuffer { // increment pending bytes after adding message to the unflushed arrays. // See https://github.com/netty/netty/issues/1619 - incrementPendingOutboundBytes(size, false); + incrementPendingOutboundBytes(size, true); } /** @@ -154,7 +155,7 @@ public final class ChannelOutboundBuffer { if (!entry.promise.setUncancellable()) { // Was cancelled so make sure we free up memory and notify about the freed bytes int pending = entry.cancel(); - decrementPendingOutboundBytes(pending, false, true); + decrementPendingOutboundBytes(pending, true); } entry = entry.next; } while (entry != null); @@ -168,18 +169,14 @@ public final class ChannelOutboundBuffer { * Increment the pending bytes which will be written at some point. * This method is thread-safe! */ - void incrementPendingOutboundBytes(long size) { - incrementPendingOutboundBytes(size, true); - } - - private void incrementPendingOutboundBytes(long size, boolean invokeLater) { + void incrementPendingOutboundBytes(long size, boolean notifyWritability) { if (size == 0) { return; } long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); if (newWriteBufferSize >= channel.config().getWriteBufferHighWaterMark()) { - setUnwritable(invokeLater); + setUnwritable(notifyWritability); } } @@ -187,19 +184,15 @@ public final class ChannelOutboundBuffer { * Decrement the pending bytes which will be written at some point. * This method is thread-safe! */ - void decrementPendingOutboundBytes(long size) { - decrementPendingOutboundBytes(size, true, true); - } - - private void decrementPendingOutboundBytes(long size, boolean invokeLater, boolean notifyWritability) { + void decrementPendingOutboundBytes(long size, boolean notifyWritability) { if (size == 0) { return; } long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); - if (notifyWritability && (newWriteBufferSize == 0 - || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark())) { - setWritable(invokeLater); + if (newWriteBufferSize == 0 + || newWriteBufferSize <= channel.config().getWriteBufferLowWaterMark()) { + setWritable(notifyWritability); } } @@ -264,7 +257,7 @@ public final class ChannelOutboundBuffer { // only release message, notify and decrement if it was not canceled before. ReferenceCountUtil.safeRelease(msg); safeSuccess(promise); - decrementPendingOutboundBytes(size, false, true); + decrementPendingOutboundBytes(size, true); } // recycle the entry @@ -300,7 +293,7 @@ public final class ChannelOutboundBuffer { ReferenceCountUtil.safeRelease(msg); safeFail(promise, cause); - decrementPendingOutboundBytes(size, false, notifyWritability); + decrementPendingOutboundBytes(size, notifyWritability); } // recycle the entry @@ -522,7 +515,7 @@ public final class ChannelOutboundBuffer { final int newValue = oldValue & mask; if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (oldValue != 0 && newValue == 0) { - fireChannelWritabilityChanged(true); + fireChannelWritabilityChanged(); } break; } @@ -536,7 +529,7 @@ public final class ChannelOutboundBuffer { final int newValue = oldValue | mask; if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { if (oldValue == 0 && newValue != 0) { - fireChannelWritabilityChanged(true); + fireChannelWritabilityChanged(); } break; } @@ -550,48 +543,36 @@ public final class ChannelOutboundBuffer { return 1 << index; } - private void setWritable(boolean invokeLater) { + private void setWritable(boolean notify) { for (;;) { final int oldValue = unwritable; final int newValue = oldValue & ~1; if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { - if (oldValue != 0 && newValue == 0) { - fireChannelWritabilityChanged(invokeLater); + if (notify && oldValue != 0 && newValue == 0) { + fireChannelWritabilityChanged(); } break; } } } - private void setUnwritable(boolean invokeLater) { + private void setUnwritable(boolean notify) { for (;;) { final int oldValue = unwritable; final int newValue = oldValue | 1; if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { - if (oldValue == 0 && newValue != 0) { - fireChannelWritabilityChanged(invokeLater); + if (notify && oldValue == 0 && newValue != 0) { + fireChannelWritabilityChanged(); } break; } } } - private void fireChannelWritabilityChanged(boolean invokeLater) { - final ChannelPipeline pipeline = channel.pipeline(); - if (invokeLater) { - Runnable task = fireChannelWritabilityChangedTask; - if (task == null) { - fireChannelWritabilityChangedTask = task = new Runnable() { - @Override - public void run() { - pipeline.fireChannelWritabilityChanged(); - } - }; - } - channel.eventLoop().execute(task); - } else { - pipeline.fireChannelWritabilityChanged(); - } + private void fireChannelWritabilityChanged() { + // Always invoke it later to prevent re-entrance bug. + // See https://github.com/netty/netty/issues/5028 + channel.eventLoop().execute(fireChannelWritabilityChangedTask); } /** @@ -862,4 +843,25 @@ public final class ChannelOutboundBuffer { return next; } } + + private static final class ChannelWritabilityChangedTask implements Runnable { + private final Channel channel; + private boolean writable = true; + + ChannelWritabilityChangedTask(Channel channel) { + this.channel = channel; + } + + @Override + public void run() { + if (channel.isActive()) { + boolean newWritable = channel.isWritable(); + + if (writable != newWritable) { + writable = newWritable; + channel.pipeline().fireChannelWritabilityChanged(); + } + } + } + } } diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java index 9a72c10e4c..bee3a9e0d5 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java @@ -470,7 +470,9 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { // Check for null as it may be set to null if the channel is closed already if (buffer != null) { task.size = ((AbstractChannel) ctx.channel()).estimatorHandle().size(msg) + WRITE_TASK_OVERHEAD; - buffer.incrementPendingOutboundBytes(task.size); + // We increment the pending bytes but NOT call fireChannelWritabilityChanged() because this + // will be done automaticaly once we add the message to the ChannelOutboundBuffer. + buffer.incrementPendingOutboundBytes(task.size, false); } else { task.size = 0; } @@ -491,7 +493,9 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { ChannelOutboundBuffer buffer = ctx.channel().unsafe().outboundBuffer(); // Check for null as it may be set to null if the channel is closed already if (ESTIMATE_TASK_SIZE_ON_SUBMIT && buffer != null) { - buffer.decrementPendingOutboundBytes(size); + // We decrement the pending bytes but NOT call fireChannelWritabilityChanged() because this + // will be done automaticaly once we pick up the messages out of the buffer to actually write these. + buffer.decrementPendingOutboundBytes(size, false); } invokeWriteNow(ctx, msg, promise); } finally { diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java index 3a49baaaa3..18dc49b4e0 100644 --- a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -94,7 +94,7 @@ public final class PendingWriteQueue { // if the channel was already closed when constructing the PendingWriteQueue. // See https://github.com/netty/netty/issues/3967 if (buffer != null) { - buffer.incrementPendingOutboundBytes(write.size); + buffer.incrementPendingOutboundBytes(write.size, true); } } @@ -264,7 +264,7 @@ public final class PendingWriteQueue { // if the channel was already closed when constructing the PendingWriteQueue. // See https://github.com/netty/netty/issues/3967 if (buffer != null) { - buffer.decrementPendingOutboundBytes(writeSize); + buffer.decrementPendingOutboundBytes(writeSize, true); } } diff --git a/transport/src/test/java/io/netty/channel/BaseChannelTest.java b/transport/src/test/java/io/netty/channel/BaseChannelTest.java deleted file mode 100644 index 46d37eecaa..0000000000 --- a/transport/src/test/java/io/netty/channel/BaseChannelTest.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2012 The Netty Project - * - * The Netty Project licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ -package io.netty.channel; - - -import io.netty.bootstrap.Bootstrap; -import io.netty.bootstrap.ServerBootstrap; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.local.LocalChannel; -import io.netty.channel.local.LocalServerChannel; - -import static org.junit.Assert.*; - -class BaseChannelTest { - - private final LoggingHandler loggingHandler; - - BaseChannelTest() { - loggingHandler = new LoggingHandler(); - } - - ServerBootstrap getLocalServerBootstrap() { - EventLoopGroup serverGroup = new DefaultEventLoopGroup(); - ServerBootstrap sb = new ServerBootstrap(); - sb.group(serverGroup); - sb.channel(LocalServerChannel.class); - sb.childHandler(new ChannelInitializer() { - @Override - public void initChannel(LocalChannel ch) throws Exception { - } - }); - - return sb; - } - - Bootstrap getLocalClientBootstrap() { - EventLoopGroup clientGroup = new DefaultEventLoopGroup(); - Bootstrap cb = new Bootstrap(); - cb.channel(LocalChannel.class); - cb.group(clientGroup); - - cb.handler(loggingHandler); - - return cb; - } - - static ByteBuf createTestBuf(int len) { - ByteBuf buf = Unpooled.buffer(len, len); - buf.setIndex(0, len); - return buf; - } - - void assertLog(String firstExpected, String... otherExpected) { - String actual = loggingHandler.getLog(); - if (firstExpected.equals(actual)) { - return; - } - for (String e: otherExpected) { - if (e.equals(actual)) { - return; - } - } - - // Let the comparison fail with the first expectation. - assertEquals(firstExpected, actual); - } - - void clearLog() { - loggingHandler.clear(); - } - - void setInterest(LoggingHandler.Event... events) { - loggingHandler.setInterest(events); - } - -} diff --git a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java index 23ee48ab23..c2751be694 100644 --- a/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java +++ b/transport/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java @@ -229,11 +229,17 @@ public class ChannelOutboundBufferTest { // Ensure exceeding the high watermark makes channel unwritable. ch.write(buffer().writeZero(128)); + assertThat(buf.toString(), is("")); + + ch.runPendingTasks(); assertThat(buf.toString(), is("false ")); // Ensure going down to the low watermark makes channel writable again by flushing the first write. assertThat(ch.unsafe().outboundBuffer().remove(), is(true)); assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), is(128L)); + assertThat(buf.toString(), is("false ")); + + ch.runPendingTasks(); assertThat(buf.toString(), is("false true ")); safeClose(ch); @@ -331,6 +337,9 @@ public class ChannelOutboundBufferTest { // Trigger channelWritabilityChanged() by writing a lot. ch.write(buffer().writeZero(256)); + assertThat(buf.toString(), is("")); + + ch.runPendingTasks(); assertThat(buf.toString(), is("false ")); // Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged() diff --git a/transport/src/test/java/io/netty/channel/ReentrantChannelTest.java b/transport/src/test/java/io/netty/channel/ReentrantChannelTest.java index 0152eb392a..8b5ab33202 100644 --- a/transport/src/test/java/io/netty/channel/ReentrantChannelTest.java +++ b/transport/src/test/java/io/netty/channel/ReentrantChannelTest.java @@ -17,97 +17,100 @@ package io.netty.channel; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.LoggingHandler.Event; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import java.nio.channels.ClosedChannelException; +import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.*; -public class ReentrantChannelTest extends BaseChannelTest { +public class ReentrantChannelTest { - @Test - public void testWritabilityChanged() throws Exception { + private EventLoopGroup clientGroup; + private EventLoopGroup serverGroup; + private LoggingHandler loggingHandler; - LocalAddress addr = new LocalAddress("testWritabilityChanged"); + private ServerBootstrap getLocalServerBootstrap() { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(serverGroup); + sb.channel(LocalServerChannel.class); + sb.childHandler(new ChannelInboundHandlerAdapter()); - ServerBootstrap sb = getLocalServerBootstrap(); - sb.bind(addr).sync().channel(); - - Bootstrap cb = getLocalClientBootstrap(); - - setInterest(Event.WRITE, Event.FLUSH, Event.WRITABILITY); - - Channel clientChannel = cb.connect(addr).sync().channel(); - clientChannel.config().setWriteBufferLowWaterMark(512); - clientChannel.config().setWriteBufferHighWaterMark(1024); - - // What is supposed to happen from this point: - // - // 1. Because this write attempt has been made from a non-I/O thread, - // ChannelOutboundBuffer.pendingWriteBytes will be increased before - // write() event is really evaluated. - // -> channelWritabilityChanged() will be triggered, - // because the Channel became unwritable. - // - // 2. The write() event is handled by the pipeline in an I/O thread. - // -> write() will be triggered. - // - // 3. Once the write() event is handled, ChannelOutboundBuffer.pendingWriteBytes - // will be decreased. - // -> channelWritabilityChanged() will be triggered, - // because the Channel became writable again. - // - // 4. The message is added to the ChannelOutboundBuffer and thus - // pendingWriteBytes will be increased again. - // -> channelWritabilityChanged() will be triggered. - // - // 5. The flush() event causes the write request in theChannelOutboundBuffer - // to be removed. - // -> flush() and channelWritabilityChanged() will be triggered. - // - // Note that the channelWritabilityChanged() in the step 4 can occur between - // the flush() and the channelWritabilityChanged() in the stap 5, because - // the flush() is invoked from a non-I/O thread while the other are from - // an I/O thread. - - ChannelFuture future = clientChannel.write(createTestBuf(2000)); - - clientChannel.flush(); - future.sync(); - - clientChannel.close().sync(); - - assertLog( - // Case 1: - "WRITABILITY: writable=false\n" + - "WRITE\n" + - "WRITABILITY: writable=false\n" + - "WRITABILITY: writable=false\n" + - "FLUSH\n" + - "WRITABILITY: writable=true\n", - // Case 2: - "WRITABILITY: writable=false\n" + - "WRITE\n" + - "WRITABILITY: writable=false\n" + - "FLUSH\n" + - "WRITABILITY: writable=true\n" + - "WRITABILITY: writable=true\n"); + return sb; + } + + private Bootstrap getLocalClientBootstrap() { + Bootstrap cb = new Bootstrap(); + cb.group(clientGroup); + cb.channel(LocalChannel.class); + cb.handler(loggingHandler); + + return cb; + } + + private static ByteBuf createTestBuf(int len) { + ByteBuf buf = Unpooled.buffer(len, len); + buf.setIndex(0, len); + return buf; + } + + private void assertLog(String firstExpected, String... otherExpected) { + String actual = loggingHandler.getLog(); + if (firstExpected.equals(actual)) { + return; + } + for (String e: otherExpected) { + if (e.equals(actual)) { + return; + } + } + + // Let the comparison fail with the first expectation. + assertEquals(firstExpected, actual); + } + + private void setInterest(LoggingHandler.Event... events) { + loggingHandler.setInterest(events); + } + + @Before + public void setup() { + loggingHandler = new LoggingHandler(); + clientGroup = new DefaultEventLoop(); + serverGroup = new DefaultEventLoop(); + } + + @After + public void teardown() { + clientGroup.shutdownGracefully(); + serverGroup.shutdownGracefully(); } - /** - * Similar to {@link #testWritabilityChanged()} with slight variation. - */ @Test - public void testFlushInWritabilityChanged() throws Exception { + public void testWritabilityChangedWriteAndFlush() throws Exception { + testWritabilityChanged0(true); + } + @Test + public void testWritabilityChangedWriteThenFlush() throws Exception { + testWritabilityChanged0(false); + } + + private void testWritabilityChanged0(boolean writeAndFlush) throws Exception { LocalAddress addr = new LocalAddress("testFlushInWritabilityChanged"); - ServerBootstrap sb = getLocalServerBootstrap(); - sb.bind(addr).sync().channel(); + Channel serverChannel = sb.bind(addr).sync().channel(); Bootstrap cb = getLocalClientBootstrap(); @@ -117,38 +120,37 @@ public class ReentrantChannelTest extends BaseChannelTest { clientChannel.config().setWriteBufferLowWaterMark(512); clientChannel.config().setWriteBufferHighWaterMark(1024); - clientChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() { - @Override - public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { - if (!ctx.channel().isWritable()) { - ctx.channel().flush(); - } - ctx.fireChannelWritabilityChanged(); - } - }); - assertTrue(clientChannel.isWritable()); - - clientChannel.write(createTestBuf(2000)).sync(); + if (writeAndFlush) { + clientChannel.writeAndFlush(createTestBuf(2000)).sync(); + } else { + ChannelFuture future = clientChannel.write(createTestBuf(2000)); + clientChannel.flush(); + future.sync(); + } clientChannel.close().sync(); + serverChannel.close().sync(); + // Because of timing of the scheduling we either should see: + // - WRITE, FLUSH + // - WRITE, WRITABILITY: writable=false, FLUSH + // - WRITE, WRITABILITY: writable=false, FLUSH, WRITABILITY: writable=true + // + // This is the case as between the write and flush the EventLoop may already pick up the pending writes and + // put these into the ChannelOutboundBuffer. Once the flush then happen from outside the EventLoop we may be + // able to flush everything and also schedule the writabilityChanged event before the actual close takes + // place which means we may see another writability changed event to inform the channel is writable again. assertLog( - // Case 1: - "WRITABILITY: writable=false\n" + - "FLUSH\n" + + "WRITE\n" + + "FLUSH\n", "WRITE\n" + "WRITABILITY: writable=false\n" + - "WRITABILITY: writable=false\n" + - "FLUSH\n" + - "WRITABILITY: writable=true\n", - // Case 2: - "WRITABILITY: writable=false\n" + - "FLUSH\n" + + "FLUSH\n", "WRITE\n" + "WRITABILITY: writable=false\n" + "FLUSH\n" + - "WRITABILITY: writable=true\n" + - "WRITABILITY: writable=true\n"); + "WRITABILITY: writable=true\n" + ); } @Test @@ -157,7 +159,7 @@ public class ReentrantChannelTest extends BaseChannelTest { LocalAddress addr = new LocalAddress("testWriteFlushPingPong"); ServerBootstrap sb = getLocalServerBootstrap(); - sb.bind(addr).sync().channel(); + Channel serverChannel = sb.bind(addr).sync().channel(); Bootstrap cb = getLocalClientBootstrap(); @@ -191,6 +193,7 @@ public class ReentrantChannelTest extends BaseChannelTest { clientChannel.writeAndFlush(createTestBuf(2000)); clientChannel.close().sync(); + serverChannel.close().sync(); assertLog( "WRITE\n" + @@ -214,7 +217,7 @@ public class ReentrantChannelTest extends BaseChannelTest { LocalAddress addr = new LocalAddress("testCloseInFlush"); ServerBootstrap sb = getLocalServerBootstrap(); - sb.bind(addr).sync().channel(); + Channel serverChannel = sb.bind(addr).sync().channel(); Bootstrap cb = getLocalClientBootstrap(); @@ -239,6 +242,7 @@ public class ReentrantChannelTest extends BaseChannelTest { clientChannel.write(createTestBuf(2000)).sync(); clientChannel.closeFuture().sync(); + serverChannel.close().sync(); assertLog("WRITE\nFLUSH\nCLOSE\n"); } @@ -249,7 +253,7 @@ public class ReentrantChannelTest extends BaseChannelTest { LocalAddress addr = new LocalAddress("testFlushFailure"); ServerBootstrap sb = getLocalServerBootstrap(); - sb.bind(addr).sync().channel(); + Channel serverChannel = sb.bind(addr).sync().channel(); Bootstrap cb = getLocalClientBootstrap(); @@ -279,7 +283,56 @@ public class ReentrantChannelTest extends BaseChannelTest { } clientChannel.closeFuture().sync(); + serverChannel.close().sync(); assertLog("WRITE\nCLOSE\n"); } + + // Test for https://github.com/netty/netty/issues/5028 + @Test + public void testWriteRentrance() { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelDuplexHandler() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + for (int i = 0; i < 3; i++) { + ctx.write(i); + } + ctx.write(3, promise); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isWritable()) { + ctx.channel().writeAndFlush(-1); + } + } + }); + + channel.config().setMessageSizeEstimator(new MessageSizeEstimator() { + @Override + public Handle newHandle() { + return new Handle() { + @Override + public int size(Object msg) { + // Each message will just increase the pending bytes by 1. + return 1; + } + }; + } + }); + channel.config().setWriteBufferLowWaterMark(3); + channel.config().setWriteBufferHighWaterMark(4); + channel.writeOutbound(-1); + assertTrue(channel.finish()); + assertSequenceOutbound(channel); + assertSequenceOutbound(channel); + assertNull(channel.readOutbound()); + } + + private static void assertSequenceOutbound(EmbeddedChannel channel) { + assertEquals(0, channel.readOutbound()); + assertEquals(1, channel.readOutbound()); + assertEquals(2, channel.readOutbound()); + assertEquals(3, channel.readOutbound()); + } }