From f58d074cafad6da630922e421ca82cef54592b20 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 28 Feb 2019 20:32:04 +0100 Subject: [PATCH] Tighten up contract of PromiseCombiner and so make it more safe to use (#8886) Motivation: PromiseCombiner is not thread-safe and even assumes all added Futures are using the same EventExecutor. This is kind of fragile as we do not enforce this. We need to enforce this contract to ensure it's safe to use and easy to spot concurrency problems. Modifications: - Add new contructor to PromiseCombiner that takes an EventExecutor and deprecate the old non-arg constructor. - Check if methods are called from within the EventExecutor thread and if not fail - Correctly dispatch on the right EventExecutor if the Future uses a different EventExecutor to eliminate concurrency issues. Result: More safe use of PromiseCombiner + enforce correct usage / contract. --- .../CompressorHttp2ConnectionEncoder.java | 2 +- .../codec/MessageToMessageEncoder.java | 2 +- .../util/concurrent/PromiseAggregator.java | 2 +- .../util/concurrent/PromiseCombiner.java | 50 ++++++++++++++++--- .../util/concurrent/PromiseCombinerTest.java | 36 ++++++++++++- .../io/netty/channel/PendingWriteQueue.java | 2 +- 6 files changed, 83 insertions(+), 11 deletions(-) diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java index 3137da21ed..d0ba8646c7 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java @@ -108,7 +108,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE return promise; } - PromiseCombiner combiner = new PromiseCombiner(); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); for (;;) { ByteBuf nextBuf = nextReadableBuf(channel); boolean compressedEndOfStream = nextBuf == null && endOfStream; diff --git a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java index 6b47b6f8a2..439dc8cb19 100644 --- a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java @@ -132,7 +132,7 @@ public abstract class MessageToMessageEncoder extends ChannelOutboundHandlerA } private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) { - final PromiseCombiner combiner = new PromiseCombiner(); + final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); for (int i = 0; i < out.size(); i++) { combiner.add(ctx.write(out.getUnsafe(i))); } diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java index 7a66fa9027..fe54f52c12 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java @@ -22,7 +22,7 @@ import java.util.LinkedHashSet; import java.util.Set; /** - * @deprecated Use {@link PromiseCombiner} + * @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. * * {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s * into one, by listening to individual {@link Future}s and producing an aggregated result diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java index d53f7b2294..9a7b3a6311 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -29,26 +29,57 @@ import static java.util.Objects.requireNonNull; * {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be * combined have been added, callers must provide an aggregate promise to be notified when all combined promises have * finished via the {@link PromiseCombiner#finish(Promise)} method.

+ * + *

This implementation is NOT thread-safe and all methods must be called + * from the {@link EventExecutor} thread.

*/ public final class PromiseCombiner { private int expectedCount; private int doneCount; - private boolean doneAdding; private Promise aggregatePromise; private Throwable cause; private final GenericFutureListener> listener = new GenericFutureListener>() { @Override - public void operationComplete(Future future) throws Exception { + public void operationComplete(final Future future) { + if (executor.inEventLoop()) { + operationComplete0(future); + } else { + executor.execute(() -> operationComplete0(future)); + } + } + + private void operationComplete0(Future future) { + assert executor.inEventLoop(); ++doneCount; if (!future.isSuccess() && cause == null) { cause = future.cause(); } - if (doneCount == expectedCount && doneAdding) { + if (doneCount == expectedCount && aggregatePromise != null) { tryPromise(); } } }; + private final EventExecutor executor; + + /** + * Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. + */ + @Deprecated + public PromiseCombiner() { + this(ImmediateEventExecutor.INSTANCE); + } + + /** + * The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])} + * and {@link #finish(Promise)} from within the {@link EventExecutor} thread. + * + * @param executor the {@link EventExecutor} to use for notifications. + */ + public PromiseCombiner(EventExecutor executor) { + this.executor = requireNonNull(executor, "executor"); + } + /** * Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the * {@link PromiseCombiner#finish(Promise)} method. @@ -71,6 +102,7 @@ public final class PromiseCombiner { @SuppressWarnings({ "unchecked", "rawtypes" }) public void add(Future future) { checkAddAllowed(); + checkInEventLoop(); ++expectedCount; future.addListener(listener); } @@ -114,22 +146,28 @@ public final class PromiseCombiner { */ public void finish(Promise aggregatePromise) { requireNonNull(aggregatePromise, "aggregatePromise"); - if (doneAdding) { + checkInEventLoop(); + if (this.aggregatePromise != null) { throw new IllegalStateException("Already finished"); } - doneAdding = true; this.aggregatePromise = aggregatePromise; if (doneCount == expectedCount) { tryPromise(); } } + private void checkInEventLoop() { + if (!executor.inEventLoop()) { + throw new IllegalStateException("Must be called from EventExecutor thread"); + } + } + private boolean tryPromise() { return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause); } private void checkAddAllowed() { - if (doneAdding) { + if (aggregatePromise != null) { throw new IllegalStateException("Adding promises is not allowed after finished adding"); } } diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java index b528f92f99..a8d5b04922 100644 --- a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java +++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -19,6 +19,7 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -26,6 +27,7 @@ import org.mockito.stubbing.Answer; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -56,7 +58,7 @@ public class PromiseCombinerTest { @Before public void setup() { MockitoAnnotations.initMocks(this); - combiner = new PromiseCombiner(); + combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE); } @Test @@ -161,6 +163,38 @@ public class PromiseCombinerTest { verifyFail(p3, e1); } + @Test + public void testEventExecutor() { + EventExecutor executor = mock(EventExecutor.class); + when(executor.inEventLoop()).thenReturn(false); + combiner = new PromiseCombiner(executor); + + Future future = mock(Future.class); + + try { + combiner.add(future); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + + try { + combiner.addAll(future); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + + @SuppressWarnings("unchecked") + Promise promise = (Promise) mock(Promise.class); + try { + combiner.finish(promise); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + } + private static void verifyFail(Promise p, Throwable cause) { verify(p).tryFailure(eq(cause)); } diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java index ea04eb70d7..f58991af80 100644 --- a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -127,7 +127,7 @@ public final class PendingWriteQueue { } ChannelPromise p = ctx.newPromise(); - PromiseCombiner combiner = new PromiseCombiner(); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); try { // It is possible for some of the written promises to trigger more writes. The new writes // will "revive" the queue, so we need to write them up until the queue is empty.