Tighten up contract of PromiseCombiner and so make it more safe to use (#8886)

Motivation:

PromiseCombiner is not thread-safe and even assumes all added Futures are using the same EventExecutor. This is kind of fragile as we do not enforce this. We need to enforce this contract to ensure it's safe to use and easy to spot concurrency problems.

Modifications:

- Add new contructor to PromiseCombiner that takes an EventExecutor and deprecate the old non-arg constructor.
- Check if methods are called from within the EventExecutor thread and if not fail
- Correctly dispatch on the right EventExecutor if the Future uses a different EventExecutor to eliminate concurrency issues.

Result:

More safe use of PromiseCombiner + enforce correct usage / contract.
This commit is contained in:
Norman Maurer 2019-02-28 20:32:04 +01:00
parent 89139aa3f8
commit f58d074caf
6 changed files with 83 additions and 11 deletions

View File

@ -108,7 +108,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
return promise; return promise;
} }
PromiseCombiner combiner = new PromiseCombiner(); PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
for (;;) { for (;;) {
ByteBuf nextBuf = nextReadableBuf(channel); ByteBuf nextBuf = nextReadableBuf(channel);
boolean compressedEndOfStream = nextBuf == null && endOfStream; boolean compressedEndOfStream = nextBuf == null && endOfStream;

View File

@ -132,7 +132,7 @@ public abstract class MessageToMessageEncoder<I> extends ChannelOutboundHandlerA
} }
private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) { private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) {
final PromiseCombiner combiner = new PromiseCombiner(); final PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
for (int i = 0; i < out.size(); i++) { for (int i = 0; i < out.size(); i++) {
combiner.add(ctx.write(out.getUnsafe(i))); combiner.add(ctx.write(out.getUnsafe(i)));
} }

View File

@ -22,7 +22,7 @@ import java.util.LinkedHashSet;
import java.util.Set; import java.util.Set;
/** /**
* @deprecated Use {@link PromiseCombiner} * @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}.
* *
* {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s * {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s
* into one, by listening to individual {@link Future}s and producing an aggregated result * into one, by listening to individual {@link Future}s and producing an aggregated result

View File

@ -29,26 +29,57 @@ import static java.util.Objects.requireNonNull;
* {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be * {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be
* combined have been added, callers must provide an aggregate promise to be notified when all combined promises have * combined have been added, callers must provide an aggregate promise to be notified when all combined promises have
* finished via the {@link PromiseCombiner#finish(Promise)} method.</p> * finished via the {@link PromiseCombiner#finish(Promise)} method.</p>
*
* <p>This implementation is <strong>NOT</strong> thread-safe and all methods must be called
* from the {@link EventExecutor} thread.</p>
*/ */
public final class PromiseCombiner { public final class PromiseCombiner {
private int expectedCount; private int expectedCount;
private int doneCount; private int doneCount;
private boolean doneAdding;
private Promise<Void> aggregatePromise; private Promise<Void> aggregatePromise;
private Throwable cause; private Throwable cause;
private final GenericFutureListener<Future<?>> listener = new GenericFutureListener<Future<?>>() { private final GenericFutureListener<Future<?>> listener = new GenericFutureListener<Future<?>>() {
@Override @Override
public void operationComplete(Future<?> future) throws Exception { public void operationComplete(final Future<?> future) {
if (executor.inEventLoop()) {
operationComplete0(future);
} else {
executor.execute(() -> operationComplete0(future));
}
}
private void operationComplete0(Future<?> future) {
assert executor.inEventLoop();
++doneCount; ++doneCount;
if (!future.isSuccess() && cause == null) { if (!future.isSuccess() && cause == null) {
cause = future.cause(); cause = future.cause();
} }
if (doneCount == expectedCount && doneAdding) { if (doneCount == expectedCount && aggregatePromise != null) {
tryPromise(); tryPromise();
} }
} }
}; };
private final EventExecutor executor;
/**
* Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}.
*/
@Deprecated
public PromiseCombiner() {
this(ImmediateEventExecutor.INSTANCE);
}
/**
* The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])}
* and {@link #finish(Promise)} from within the {@link EventExecutor} thread.
*
* @param executor the {@link EventExecutor} to use for notifications.
*/
public PromiseCombiner(EventExecutor executor) {
this.executor = requireNonNull(executor, "executor");
}
/** /**
* Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the * Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the
* {@link PromiseCombiner#finish(Promise)} method. * {@link PromiseCombiner#finish(Promise)} method.
@ -71,6 +102,7 @@ public final class PromiseCombiner {
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
public void add(Future future) { public void add(Future future) {
checkAddAllowed(); checkAddAllowed();
checkInEventLoop();
++expectedCount; ++expectedCount;
future.addListener(listener); future.addListener(listener);
} }
@ -114,22 +146,28 @@ public final class PromiseCombiner {
*/ */
public void finish(Promise<Void> aggregatePromise) { public void finish(Promise<Void> aggregatePromise) {
requireNonNull(aggregatePromise, "aggregatePromise"); requireNonNull(aggregatePromise, "aggregatePromise");
if (doneAdding) { checkInEventLoop();
if (this.aggregatePromise != null) {
throw new IllegalStateException("Already finished"); throw new IllegalStateException("Already finished");
} }
doneAdding = true;
this.aggregatePromise = aggregatePromise; this.aggregatePromise = aggregatePromise;
if (doneCount == expectedCount) { if (doneCount == expectedCount) {
tryPromise(); tryPromise();
} }
} }
private void checkInEventLoop() {
if (!executor.inEventLoop()) {
throw new IllegalStateException("Must be called from EventExecutor thread");
}
}
private boolean tryPromise() { private boolean tryPromise() {
return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause); return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause);
} }
private void checkAddAllowed() { private void checkAddAllowed() {
if (doneAdding) { if (aggregatePromise != null) {
throw new IllegalStateException("Adding promises is not allowed after finished adding"); throw new IllegalStateException("Adding promises is not allowed after finished adding");
} }
} }

View File

@ -19,6 +19,7 @@ import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
@ -26,6 +27,7 @@ import org.mockito.stubbing.Answer;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -56,7 +58,7 @@ public class PromiseCombinerTest {
@Before @Before
public void setup() { public void setup() {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
combiner = new PromiseCombiner(); combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE);
} }
@Test @Test
@ -161,6 +163,38 @@ public class PromiseCombinerTest {
verifyFail(p3, e1); verifyFail(p3, e1);
} }
@Test
public void testEventExecutor() {
EventExecutor executor = mock(EventExecutor.class);
when(executor.inEventLoop()).thenReturn(false);
combiner = new PromiseCombiner(executor);
Future<?> future = mock(Future.class);
try {
combiner.add(future);
Assert.fail();
} catch (IllegalStateException expected) {
// expected
}
try {
combiner.addAll(future);
Assert.fail();
} catch (IllegalStateException expected) {
// expected
}
@SuppressWarnings("unchecked")
Promise<Void> promise = (Promise<Void>) mock(Promise.class);
try {
combiner.finish(promise);
Assert.fail();
} catch (IllegalStateException expected) {
// expected
}
}
private static void verifyFail(Promise<Void> p, Throwable cause) { private static void verifyFail(Promise<Void> p, Throwable cause) {
verify(p).tryFailure(eq(cause)); verify(p).tryFailure(eq(cause));
} }

View File

@ -127,7 +127,7 @@ public final class PendingWriteQueue {
} }
ChannelPromise p = ctx.newPromise(); ChannelPromise p = ctx.newPromise();
PromiseCombiner combiner = new PromiseCombiner(); PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
try { try {
// It is possible for some of the written promises to trigger more writes. The new writes // It is possible for some of the written promises to trigger more writes. The new writes
// will "revive" the queue, so we need to write them up until the queue is empty. // will "revive" the queue, so we need to write them up until the queue is empty.