diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java index 3137da21ed..d0ba8646c7 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java @@ -108,7 +108,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE return promise; } - PromiseCombiner combiner = new PromiseCombiner(); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); for (;;) { ByteBuf nextBuf = nextReadableBuf(channel); boolean compressedEndOfStream = nextBuf == null && endOfStream; diff --git a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java index 6b47b6f8a2..439dc8cb19 100644 --- a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java @@ -132,7 +132,7 @@ public abstract class MessageToMessageEncoder extends ChannelOutboundHandlerA } private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) { - final PromiseCombiner combiner = new PromiseCombiner(); + final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); for (int i = 0; i < out.size(); i++) { combiner.add(ctx.write(out.getUnsafe(i))); } diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java index 7a66fa9027..fe54f52c12 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java @@ -22,7 +22,7 @@ import java.util.LinkedHashSet; import java.util.Set; /** - * @deprecated Use {@link PromiseCombiner} + * @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. * * {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s * into one, by listening to individual {@link Future}s and producing an aggregated result diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java index d53f7b2294..9a7b3a6311 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -29,26 +29,57 @@ import static java.util.Objects.requireNonNull; * {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be * combined have been added, callers must provide an aggregate promise to be notified when all combined promises have * finished via the {@link PromiseCombiner#finish(Promise)} method.

+ * + *

This implementation is NOT thread-safe and all methods must be called + * from the {@link EventExecutor} thread.

*/ public final class PromiseCombiner { private int expectedCount; private int doneCount; - private boolean doneAdding; private Promise aggregatePromise; private Throwable cause; private final GenericFutureListener> listener = new GenericFutureListener>() { @Override - public void operationComplete(Future future) throws Exception { + public void operationComplete(final Future future) { + if (executor.inEventLoop()) { + operationComplete0(future); + } else { + executor.execute(() -> operationComplete0(future)); + } + } + + private void operationComplete0(Future future) { + assert executor.inEventLoop(); ++doneCount; if (!future.isSuccess() && cause == null) { cause = future.cause(); } - if (doneCount == expectedCount && doneAdding) { + if (doneCount == expectedCount && aggregatePromise != null) { tryPromise(); } } }; + private final EventExecutor executor; + + /** + * Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. + */ + @Deprecated + public PromiseCombiner() { + this(ImmediateEventExecutor.INSTANCE); + } + + /** + * The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])} + * and {@link #finish(Promise)} from within the {@link EventExecutor} thread. + * + * @param executor the {@link EventExecutor} to use for notifications. + */ + public PromiseCombiner(EventExecutor executor) { + this.executor = requireNonNull(executor, "executor"); + } + /** * Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the * {@link PromiseCombiner#finish(Promise)} method. @@ -71,6 +102,7 @@ public final class PromiseCombiner { @SuppressWarnings({ "unchecked", "rawtypes" }) public void add(Future future) { checkAddAllowed(); + checkInEventLoop(); ++expectedCount; future.addListener(listener); } @@ -114,22 +146,28 @@ public final class PromiseCombiner { */ public void finish(Promise aggregatePromise) { requireNonNull(aggregatePromise, "aggregatePromise"); - if (doneAdding) { + checkInEventLoop(); + if (this.aggregatePromise != null) { throw new IllegalStateException("Already finished"); } - doneAdding = true; this.aggregatePromise = aggregatePromise; if (doneCount == expectedCount) { tryPromise(); } } + private void checkInEventLoop() { + if (!executor.inEventLoop()) { + throw new IllegalStateException("Must be called from EventExecutor thread"); + } + } + private boolean tryPromise() { return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause); } private void checkAddAllowed() { - if (doneAdding) { + if (aggregatePromise != null) { throw new IllegalStateException("Adding promises is not allowed after finished adding"); } } diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java index b528f92f99..a8d5b04922 100644 --- a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java +++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -19,6 +19,7 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -26,6 +27,7 @@ import org.mockito.stubbing.Answer; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -56,7 +58,7 @@ public class PromiseCombinerTest { @Before public void setup() { MockitoAnnotations.initMocks(this); - combiner = new PromiseCombiner(); + combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE); } @Test @@ -161,6 +163,38 @@ public class PromiseCombinerTest { verifyFail(p3, e1); } + @Test + public void testEventExecutor() { + EventExecutor executor = mock(EventExecutor.class); + when(executor.inEventLoop()).thenReturn(false); + combiner = new PromiseCombiner(executor); + + Future future = mock(Future.class); + + try { + combiner.add(future); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + + try { + combiner.addAll(future); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + + @SuppressWarnings("unchecked") + Promise promise = (Promise) mock(Promise.class); + try { + combiner.finish(promise); + Assert.fail(); + } catch (IllegalStateException expected) { + // expected + } + } + private static void verifyFail(Promise p, Throwable cause) { verify(p).tryFailure(eq(cause)); } diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java index ea04eb70d7..f58991af80 100644 --- a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -127,7 +127,7 @@ public final class PendingWriteQueue { } ChannelPromise p = ctx.newPromise(); - PromiseCombiner combiner = new PromiseCombiner(); + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); try { // It is possible for some of the written promises to trigger more writes. The new writes // will "revive" the queue, so we need to write them up until the queue is empty.