diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java
index 3137da21ed..d0ba8646c7 100644
--- a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java
+++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java
@@ -108,7 +108,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
return promise;
}
- PromiseCombiner combiner = new PromiseCombiner();
+ PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
for (;;) {
ByteBuf nextBuf = nextReadableBuf(channel);
boolean compressedEndOfStream = nextBuf == null && endOfStream;
diff --git a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java
index 6b47b6f8a2..439dc8cb19 100644
--- a/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java
+++ b/codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java
@@ -132,7 +132,7 @@ public abstract class MessageToMessageEncoder extends ChannelOutboundHandlerA
}
private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) {
- final PromiseCombiner combiner = new PromiseCombiner();
+ final PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
for (int i = 0; i < out.size(); i++) {
combiner.add(ctx.write(out.getUnsafe(i)));
}
diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java
index 7a66fa9027..fe54f52c12 100644
--- a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java
+++ b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java
@@ -22,7 +22,7 @@ import java.util.LinkedHashSet;
import java.util.Set;
/**
- * @deprecated Use {@link PromiseCombiner}
+ * @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}.
*
* {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s
* into one, by listening to individual {@link Future}s and producing an aggregated result
diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java
index d53f7b2294..9a7b3a6311 100644
--- a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java
+++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java
@@ -29,26 +29,57 @@ import static java.util.Objects.requireNonNull;
* {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be
* combined have been added, callers must provide an aggregate promise to be notified when all combined promises have
* finished via the {@link PromiseCombiner#finish(Promise)} method.
+ *
+ * This implementation is NOT thread-safe and all methods must be called
+ * from the {@link EventExecutor} thread.
*/
public final class PromiseCombiner {
private int expectedCount;
private int doneCount;
- private boolean doneAdding;
private Promise aggregatePromise;
private Throwable cause;
private final GenericFutureListener> listener = new GenericFutureListener>() {
@Override
- public void operationComplete(Future> future) throws Exception {
+ public void operationComplete(final Future> future) {
+ if (executor.inEventLoop()) {
+ operationComplete0(future);
+ } else {
+ executor.execute(() -> operationComplete0(future));
+ }
+ }
+
+ private void operationComplete0(Future> future) {
+ assert executor.inEventLoop();
++doneCount;
if (!future.isSuccess() && cause == null) {
cause = future.cause();
}
- if (doneCount == expectedCount && doneAdding) {
+ if (doneCount == expectedCount && aggregatePromise != null) {
tryPromise();
}
}
};
+ private final EventExecutor executor;
+
+ /**
+ * Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}.
+ */
+ @Deprecated
+ public PromiseCombiner() {
+ this(ImmediateEventExecutor.INSTANCE);
+ }
+
+ /**
+ * The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])}
+ * and {@link #finish(Promise)} from within the {@link EventExecutor} thread.
+ *
+ * @param executor the {@link EventExecutor} to use for notifications.
+ */
+ public PromiseCombiner(EventExecutor executor) {
+ this.executor = requireNonNull(executor, "executor");
+ }
+
/**
* Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the
* {@link PromiseCombiner#finish(Promise)} method.
@@ -71,6 +102,7 @@ public final class PromiseCombiner {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void add(Future future) {
checkAddAllowed();
+ checkInEventLoop();
++expectedCount;
future.addListener(listener);
}
@@ -114,22 +146,28 @@ public final class PromiseCombiner {
*/
public void finish(Promise aggregatePromise) {
requireNonNull(aggregatePromise, "aggregatePromise");
- if (doneAdding) {
+ checkInEventLoop();
+ if (this.aggregatePromise != null) {
throw new IllegalStateException("Already finished");
}
- doneAdding = true;
this.aggregatePromise = aggregatePromise;
if (doneCount == expectedCount) {
tryPromise();
}
}
+ private void checkInEventLoop() {
+ if (!executor.inEventLoop()) {
+ throw new IllegalStateException("Must be called from EventExecutor thread");
+ }
+ }
+
private boolean tryPromise() {
return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause);
}
private void checkAddAllowed() {
- if (doneAdding) {
+ if (aggregatePromise != null) {
throw new IllegalStateException("Adding promises is not allowed after finished adding");
}
}
diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java
index b528f92f99..a8d5b04922 100644
--- a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java
+++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java
@@ -19,6 +19,7 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@@ -26,6 +27,7 @@ import org.mockito.stubbing.Answer;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -56,7 +58,7 @@ public class PromiseCombinerTest {
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
- combiner = new PromiseCombiner();
+ combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE);
}
@Test
@@ -161,6 +163,38 @@ public class PromiseCombinerTest {
verifyFail(p3, e1);
}
+ @Test
+ public void testEventExecutor() {
+ EventExecutor executor = mock(EventExecutor.class);
+ when(executor.inEventLoop()).thenReturn(false);
+ combiner = new PromiseCombiner(executor);
+
+ Future> future = mock(Future.class);
+
+ try {
+ combiner.add(future);
+ Assert.fail();
+ } catch (IllegalStateException expected) {
+ // expected
+ }
+
+ try {
+ combiner.addAll(future);
+ Assert.fail();
+ } catch (IllegalStateException expected) {
+ // expected
+ }
+
+ @SuppressWarnings("unchecked")
+ Promise promise = (Promise) mock(Promise.class);
+ try {
+ combiner.finish(promise);
+ Assert.fail();
+ } catch (IllegalStateException expected) {
+ // expected
+ }
+ }
+
private static void verifyFail(Promise p, Throwable cause) {
verify(p).tryFailure(eq(cause));
}
diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java
index ea04eb70d7..f58991af80 100644
--- a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java
+++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java
@@ -127,7 +127,7 @@ public final class PendingWriteQueue {
}
ChannelPromise p = ctx.newPromise();
- PromiseCombiner combiner = new PromiseCombiner();
+ PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
try {
// It is possible for some of the written promises to trigger more writes. The new writes
// will "revive" the queue, so we need to write them up until the queue is empty.