Tighten up contract of PromiseCombiner and so make it more safe to use (#8886)
Motivation: PromiseCombiner is not thread-safe and even assumes all added Futures are using the same EventExecutor. This is kind of fragile as we do not enforce this. We need to enforce this contract to ensure it's safe to use and easy to spot concurrency problems. Modifications: - Add new contructor to PromiseCombiner that takes an EventExecutor and deprecate the old non-arg constructor. - Check if methods are called from within the EventExecutor thread and if not fail - Correctly dispatch on the right EventExecutor if the Future uses a different EventExecutor to eliminate concurrency issues. Result: More safe use of PromiseCombiner + enforce correct usage / contract.
This commit is contained in:
parent
89139aa3f8
commit
f58d074caf
@ -108,7 +108,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE
|
|||||||
return promise;
|
return promise;
|
||||||
}
|
}
|
||||||
|
|
||||||
PromiseCombiner combiner = new PromiseCombiner();
|
PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
|
||||||
for (;;) {
|
for (;;) {
|
||||||
ByteBuf nextBuf = nextReadableBuf(channel);
|
ByteBuf nextBuf = nextReadableBuf(channel);
|
||||||
boolean compressedEndOfStream = nextBuf == null && endOfStream;
|
boolean compressedEndOfStream = nextBuf == null && endOfStream;
|
||||||
|
@ -132,7 +132,7 @@ public abstract class MessageToMessageEncoder<I> extends ChannelOutboundHandlerA
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) {
|
private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) {
|
||||||
final PromiseCombiner combiner = new PromiseCombiner();
|
final PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
|
||||||
for (int i = 0; i < out.size(); i++) {
|
for (int i = 0; i < out.size(); i++) {
|
||||||
combiner.add(ctx.write(out.getUnsafe(i)));
|
combiner.add(ctx.write(out.getUnsafe(i)));
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ import java.util.LinkedHashSet;
|
|||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated Use {@link PromiseCombiner}
|
* @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}.
|
||||||
*
|
*
|
||||||
* {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s
|
* {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s
|
||||||
* into one, by listening to individual {@link Future}s and producing an aggregated result
|
* into one, by listening to individual {@link Future}s and producing an aggregated result
|
||||||
|
@ -29,26 +29,57 @@ import static java.util.Objects.requireNonNull;
|
|||||||
* {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be
|
* {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be
|
||||||
* combined have been added, callers must provide an aggregate promise to be notified when all combined promises have
|
* combined have been added, callers must provide an aggregate promise to be notified when all combined promises have
|
||||||
* finished via the {@link PromiseCombiner#finish(Promise)} method.</p>
|
* finished via the {@link PromiseCombiner#finish(Promise)} method.</p>
|
||||||
|
*
|
||||||
|
* <p>This implementation is <strong>NOT</strong> thread-safe and all methods must be called
|
||||||
|
* from the {@link EventExecutor} thread.</p>
|
||||||
*/
|
*/
|
||||||
public final class PromiseCombiner {
|
public final class PromiseCombiner {
|
||||||
private int expectedCount;
|
private int expectedCount;
|
||||||
private int doneCount;
|
private int doneCount;
|
||||||
private boolean doneAdding;
|
|
||||||
private Promise<Void> aggregatePromise;
|
private Promise<Void> aggregatePromise;
|
||||||
private Throwable cause;
|
private Throwable cause;
|
||||||
private final GenericFutureListener<Future<?>> listener = new GenericFutureListener<Future<?>>() {
|
private final GenericFutureListener<Future<?>> listener = new GenericFutureListener<Future<?>>() {
|
||||||
@Override
|
@Override
|
||||||
public void operationComplete(Future<?> future) throws Exception {
|
public void operationComplete(final Future<?> future) {
|
||||||
|
if (executor.inEventLoop()) {
|
||||||
|
operationComplete0(future);
|
||||||
|
} else {
|
||||||
|
executor.execute(() -> operationComplete0(future));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void operationComplete0(Future<?> future) {
|
||||||
|
assert executor.inEventLoop();
|
||||||
++doneCount;
|
++doneCount;
|
||||||
if (!future.isSuccess() && cause == null) {
|
if (!future.isSuccess() && cause == null) {
|
||||||
cause = future.cause();
|
cause = future.cause();
|
||||||
}
|
}
|
||||||
if (doneCount == expectedCount && doneAdding) {
|
if (doneCount == expectedCount && aggregatePromise != null) {
|
||||||
tryPromise();
|
tryPromise();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
private final EventExecutor executor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}.
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
|
public PromiseCombiner() {
|
||||||
|
this(ImmediateEventExecutor.INSTANCE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])}
|
||||||
|
* and {@link #finish(Promise)} from within the {@link EventExecutor} thread.
|
||||||
|
*
|
||||||
|
* @param executor the {@link EventExecutor} to use for notifications.
|
||||||
|
*/
|
||||||
|
public PromiseCombiner(EventExecutor executor) {
|
||||||
|
this.executor = requireNonNull(executor, "executor");
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the
|
* Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the
|
||||||
* {@link PromiseCombiner#finish(Promise)} method.
|
* {@link PromiseCombiner#finish(Promise)} method.
|
||||||
@ -71,6 +102,7 @@ public final class PromiseCombiner {
|
|||||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||||
public void add(Future future) {
|
public void add(Future future) {
|
||||||
checkAddAllowed();
|
checkAddAllowed();
|
||||||
|
checkInEventLoop();
|
||||||
++expectedCount;
|
++expectedCount;
|
||||||
future.addListener(listener);
|
future.addListener(listener);
|
||||||
}
|
}
|
||||||
@ -114,22 +146,28 @@ public final class PromiseCombiner {
|
|||||||
*/
|
*/
|
||||||
public void finish(Promise<Void> aggregatePromise) {
|
public void finish(Promise<Void> aggregatePromise) {
|
||||||
requireNonNull(aggregatePromise, "aggregatePromise");
|
requireNonNull(aggregatePromise, "aggregatePromise");
|
||||||
if (doneAdding) {
|
checkInEventLoop();
|
||||||
|
if (this.aggregatePromise != null) {
|
||||||
throw new IllegalStateException("Already finished");
|
throw new IllegalStateException("Already finished");
|
||||||
}
|
}
|
||||||
doneAdding = true;
|
|
||||||
this.aggregatePromise = aggregatePromise;
|
this.aggregatePromise = aggregatePromise;
|
||||||
if (doneCount == expectedCount) {
|
if (doneCount == expectedCount) {
|
||||||
tryPromise();
|
tryPromise();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void checkInEventLoop() {
|
||||||
|
if (!executor.inEventLoop()) {
|
||||||
|
throw new IllegalStateException("Must be called from EventExecutor thread");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private boolean tryPromise() {
|
private boolean tryPromise() {
|
||||||
return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause);
|
return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void checkAddAllowed() {
|
private void checkAddAllowed() {
|
||||||
if (doneAdding) {
|
if (aggregatePromise != null) {
|
||||||
throw new IllegalStateException("Adding promises is not allowed after finished adding");
|
throw new IllegalStateException("Adding promises is not allowed after finished adding");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ import org.junit.Assert;
|
|||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.Mockito;
|
||||||
import org.mockito.MockitoAnnotations;
|
import org.mockito.MockitoAnnotations;
|
||||||
import org.mockito.invocation.InvocationOnMock;
|
import org.mockito.invocation.InvocationOnMock;
|
||||||
import org.mockito.stubbing.Answer;
|
import org.mockito.stubbing.Answer;
|
||||||
@ -26,6 +27,7 @@ import org.mockito.stubbing.Answer;
|
|||||||
import static org.mockito.Mockito.any;
|
import static org.mockito.Mockito.any;
|
||||||
import static org.mockito.Mockito.eq;
|
import static org.mockito.Mockito.eq;
|
||||||
import static org.mockito.Mockito.doAnswer;
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.never;
|
||||||
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.verify;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
@ -56,7 +58,7 @@ public class PromiseCombinerTest {
|
|||||||
@Before
|
@Before
|
||||||
public void setup() {
|
public void setup() {
|
||||||
MockitoAnnotations.initMocks(this);
|
MockitoAnnotations.initMocks(this);
|
||||||
combiner = new PromiseCombiner();
|
combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -161,6 +163,38 @@ public class PromiseCombinerTest {
|
|||||||
verifyFail(p3, e1);
|
verifyFail(p3, e1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEventExecutor() {
|
||||||
|
EventExecutor executor = mock(EventExecutor.class);
|
||||||
|
when(executor.inEventLoop()).thenReturn(false);
|
||||||
|
combiner = new PromiseCombiner(executor);
|
||||||
|
|
||||||
|
Future<?> future = mock(Future.class);
|
||||||
|
|
||||||
|
try {
|
||||||
|
combiner.add(future);
|
||||||
|
Assert.fail();
|
||||||
|
} catch (IllegalStateException expected) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
combiner.addAll(future);
|
||||||
|
Assert.fail();
|
||||||
|
} catch (IllegalStateException expected) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Promise<Void> promise = (Promise<Void>) mock(Promise.class);
|
||||||
|
try {
|
||||||
|
combiner.finish(promise);
|
||||||
|
Assert.fail();
|
||||||
|
} catch (IllegalStateException expected) {
|
||||||
|
// expected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void verifyFail(Promise<Void> p, Throwable cause) {
|
private static void verifyFail(Promise<Void> p, Throwable cause) {
|
||||||
verify(p).tryFailure(eq(cause));
|
verify(p).tryFailure(eq(cause));
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ public final class PendingWriteQueue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ChannelPromise p = ctx.newPromise();
|
ChannelPromise p = ctx.newPromise();
|
||||||
PromiseCombiner combiner = new PromiseCombiner();
|
PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
|
||||||
try {
|
try {
|
||||||
// It is possible for some of the written promises to trigger more writes. The new writes
|
// It is possible for some of the written promises to trigger more writes. The new writes
|
||||||
// will "revive" the queue, so we need to write them up until the queue is empty.
|
// will "revive" the queue, so we need to write them up until the queue is empty.
|
||||||
|
Loading…
Reference in New Issue
Block a user