DefaultPromise StackOverflowError protection

Motivation:
f2ed3e6ce8039d142e4c047fcc9cf09409105243 removed the previous mechanism for StackOverflowError because it didn't work in all cases (i.e. ImmediateExecutor). However if a chain of listeners which complete other promises is formed there is still a possibility of a StackOverflowError.

Modifications:
- Use a ThreadLocal to save any DefaultPromises which could not be notified due to the stack being too large. After the first DefaultPromise on the stack completes notification this ThreadLocal should be used to notify any DefaultPromises which have not yet been notified.

Result:
DefaultPromise has StackOverflowError protection that works with all EventExecutor types.
This commit is contained in:
Scott Mitchell 2016-05-23 18:36:35 -07:00
parent f5d58e2e1a
commit b3c56d5f69
2 changed files with 221 additions and 70 deletions

View File

@ -24,6 +24,8 @@ import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
@ -40,6 +42,26 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
private static final Signal SUCCESS = Signal.valueOf(DefaultPromise.class, "SUCCESS"); private static final Signal SUCCESS = Signal.valueOf(DefaultPromise.class, "SUCCESS");
private static final Signal UNCANCELLABLE = Signal.valueOf(DefaultPromise.class, "UNCANCELLABLE"); private static final Signal UNCANCELLABLE = Signal.valueOf(DefaultPromise.class, "UNCANCELLABLE");
private static final CauseHolder CANCELLATION_CAUSE_HOLDER = new CauseHolder(new CancellationException()); private static final CauseHolder CANCELLATION_CAUSE_HOLDER = new CauseHolder(new CancellationException());
private static final FastThreadLocal<Queue<DefaultPromise<?>>> STACK_OVERFLOW_DELAYED_PROMISES =
new FastThreadLocal<Queue<DefaultPromise<?>>>() {
@Override
protected Queue<DefaultPromise<?>> initialValue() throws Exception {
return new ArrayDeque<DefaultPromise<?>>(2);
}
};
/**
* This queue will hold pairs of {@code <Future, GenericFutureListener>} which are always inserted sequentially.
* <p>
* This is only used for static utility method {@link #notifyListener(EventExecutor, Future, GenericFutureListener)}
* and not instances of {@link DefaultPromise}.
*/
private static final FastThreadLocal<Queue<Object>> STACK_OVERFLOW_DELAYED_FUTURES =
new FastThreadLocal<Queue<Object>>() {
@Override
protected Queue<Object> initialValue() throws Exception {
return new ArrayDeque<Object>(2);
}
};
static { static {
AtomicReferenceFieldUpdater<DefaultPromise, Object> updater = AtomicReferenceFieldUpdater<DefaultPromise, Object> updater =
@ -385,33 +407,25 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
* <p> * <p>
* This method has a fixed depth of {@link #MAX_LISTENER_STACK_DEPTH} that will limit recursion to prevent * This method has a fixed depth of {@link #MAX_LISTENER_STACK_DEPTH} that will limit recursion to prevent
* {@link StackOverflowError} and will stop notifying listeners added after this threshold is exceeded. * {@link StackOverflowError} and will stop notifying listeners added after this threshold is exceeded.
* @param eventExecutor the executor to use to notify the listener {@code l}. * @param eventExecutor the executor to use to notify the listener {@code listener}.
* @param future the future that is complete. * @param future the future that is complete.
* @param l the listener to notify. * @param listener the listener to notify.
*/ */
protected static void notifyListener( protected static void notifyListener(
EventExecutor eventExecutor, final Future<?> future, final GenericFutureListener<?> l) { EventExecutor eventExecutor, final Future<?> future, final GenericFutureListener<?> listener) {
checkNotNull(future, "future");
checkNotNull(listener, "listener");
if (eventExecutor.inEventLoop()) { if (eventExecutor.inEventLoop()) {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get(); notifyListenerWithStackOverFlowProtection(future, listener);
final int stackDepth = threadLocals.futureListenerStackDepth(); } else {
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListener0(future, l);
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
}
return;
}
}
safeExecute(eventExecutor, new OneTimeTask() { safeExecute(eventExecutor, new OneTimeTask() {
@Override @Override
public void run() { public void run() {
notifyListener0(future, l); notifyListenerWithStackOverFlowProtection(future, listener);
} }
}); });
} }
}
private void notifyListeners() { private void notifyListeners() {
// Modifications to listeners should be done in a synchronized block before this, and should be visible here. // Modifications to listeners should be done in a synchronized block before this, and should be visible here.
@ -420,18 +434,88 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
} }
EventExecutor executor = executor(); EventExecutor executor = executor();
if (executor.inEventLoop()) { if (executor.inEventLoop()) {
notifyListeners0(); notifyListenersWithStackOverFlowProtection();
return; } else {
}
safeExecute(executor, new OneTimeTask() { safeExecute(executor, new OneTimeTask() {
@Override @Override
public void run() { public void run() {
notifyListeners0(); notifyListenersWithStackOverFlowProtection();
} }
}); });
} }
}
private void notifyListeners0() { private void notifyListenersWithStackOverFlowProtection() {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get();
final int stackDepth = threadLocals.futureListenerStackDepth();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListenersNow();
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
if (stackDepth == 0) {
// We want to force all notifications to occur at the same depth on the stack as the initial method
// invocation. If we leave the stackDepth at 0 then notifyListeners could occur from a
// delayedPromise's stack which could lead to a stack overflow. So force the stack depth to 1 here,
// and later set it back to 0 after all delayedPromises have been notified.
threadLocals.setFutureListenerStackDepth(1);
try {
Queue<DefaultPromise<?>> delayedPromiseQueue = STACK_OVERFLOW_DELAYED_PROMISES.get();
DefaultPromise<?> delayedPromise;
while ((delayedPromise = delayedPromiseQueue.poll()) != null) {
delayedPromise.notifyListenersWithStackOverFlowProtection();
}
} finally {
threadLocals.setFutureListenerStackDepth(0);
}
}
}
} else {
STACK_OVERFLOW_DELAYED_PROMISES.get().add(this);
}
}
/**
* The logic in this method should be identical to {@link #notifyListenersWithStackOverFlowProtection()} but
* cannot share code because the listener(s) can not be queued for an instance of {@link DefaultPromise} since the
* listener(s) may be changed and is protected by a synchronized operation.
*/
private static void notifyListenerWithStackOverFlowProtection(Future<?> future, GenericFutureListener<?> listener) {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get();
final int stackDepth = threadLocals.futureListenerStackDepth();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListener0(future, listener);
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
if (stackDepth == 0) {
// We want to force all notifications to occur at the same depth on the stack as the initial method
// invocation. If we leave the stackDepth at 0 then notifyListeners could occur from a
// delayedPromise's stack which could lead to a stack overflow. So force the stack depth to 1 here,
// and later set it back to 0 after all delayedPromises have been notified.
threadLocals.setFutureListenerStackDepth(1);
try {
Queue<Object> delayedFutureQueue = STACK_OVERFLOW_DELAYED_FUTURES.get();
Object delayedFuture;
while ((delayedFuture = delayedFutureQueue.poll()) != null) {
notifyListenerWithStackOverFlowProtection((Future<?>) delayedFuture,
(GenericFutureListener<?>) delayedFutureQueue.poll());
}
} finally {
threadLocals.setFutureListenerStackDepth(0);
}
}
}
} else {
Queue<Object> delayedFutureQueue = STACK_OVERFLOW_DELAYED_FUTURES.get();
delayedFutureQueue.add(future);
delayedFutureQueue.add(listener);
}
}
private void notifyListenersNow() {
Object listeners; Object listeners;
synchronized (this) { synchronized (this) {
// Only proceed if there are listeners to notify and we are not already notifying listeners. // Only proceed if there are listeners to notify and we are not already notifying listeners.

View File

@ -16,26 +16,45 @@
package io.netty.util.concurrent; package io.netty.util.concurrent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.BeforeClass;
import org.junit.Test; import org.junit.Test;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public class DefaultPromiseTest { public class DefaultPromiseTest {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultPromiseTest.class);
private static int stackOverflowDepth;
@BeforeClass
public static void beforeClass() {
try {
findStackOverflowDepth();
throw new IllegalStateException("Expected StackOverflowError but didn't get it?!");
} catch (StackOverflowError e) {
logger.debug("StackOverflowError depth: {}", stackOverflowDepth);
}
}
private static void findStackOverflowDepth() {
++stackOverflowDepth;
findStackOverflowDepth();
}
@Test(expected = CancellationException.class) @Test(expected = CancellationException.class)
public void testCancellationExceptionIsThrownWhenBlockingGet() throws InterruptedException, ExecutionException { public void testCancellationExceptionIsThrownWhenBlockingGet() throws InterruptedException, ExecutionException {
@ -54,52 +73,43 @@ public class DefaultPromiseTest {
@Test @Test
public void testNoStackOverflowErrorWithImmediateEventExecutorA() throws Exception { public void testNoStackOverflowErrorWithImmediateEventExecutorA() throws Exception {
final Promise<Void>[] p = new DefaultPromise[128]; testStackOverFlowErrorChainedFuturesA(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE),
for (int i = 0; i < p.length; i ++) { ImmediateEventExecutor.INSTANCE);
final int finalI = i;
p[i] = new DefaultPromise<Void>(ImmediateEventExecutor.INSTANCE);
p[i].addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
if (finalI + 1 < p.length) {
p[finalI + 1].setSuccess(null);
}
}
});
} }
p[0].setSuccess(null); @Test
public void testNoStackOverflowErrorWithDefaultEventExecutorA() throws Exception {
for (Promise<Void> a: p) { ExecutorService executorService = Executors.newSingleThreadExecutor();
assertThat(a.isSuccess(), is(true)); try {
EventExecutor executor = new DefaultEventExecutor(executorService);
try {
testStackOverFlowErrorChainedFuturesA(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE), executor);
} finally {
executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
}
} finally {
executorService.shutdown();
} }
} }
@Test @Test
public void testNoStackOverflowErrorWithImmediateEventExecutorB() throws Exception { public void testNoStackOverflowErrorWithImmediateEventExecutorB() throws Exception {
final Promise<Void>[] p = new DefaultPromise[128]; testStackOverFlowErrorChainedFuturesB(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE),
for (int i = 0; i < p.length; i ++) { ImmediateEventExecutor.INSTANCE);
final int finalI = i;
p[i] = new DefaultPromise<Void>(ImmediateEventExecutor.INSTANCE);
p[i].addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
future.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
if (finalI + 1 < p.length) {
p[finalI + 1].setSuccess(null);
}
}
});
}
});
} }
p[0].setSuccess(null); @Test
public void testNoStackOverflowErrorWithDefaultEventExecutorB() throws Exception {
for (Promise<Void> a: p) { ExecutorService executorService = Executors.newSingleThreadExecutor();
assertThat(a.isSuccess(), is(true)); try {
EventExecutor executor = new DefaultEventExecutor(executorService);
try {
testStackOverFlowErrorChainedFuturesB(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE), executor);
} finally {
executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
}
} finally {
executorService.shutdown();
} }
} }
@ -187,6 +197,63 @@ public class DefaultPromiseTest {
testLateListenerIsOrderedCorrectly(fakeException()); testLateListenerIsOrderedCorrectly(fakeException());
} }
private void testStackOverFlowErrorChainedFuturesA(int promiseChainLength, EventExecutor executor)
throws InterruptedException {
final Promise<Void>[] p = new DefaultPromise[promiseChainLength];
final CountDownLatch latch = new CountDownLatch(promiseChainLength);
for (int i = 0; i < p.length; i ++) {
final int finalI = i;
p[i] = new DefaultPromise<Void>(executor);
p[i].addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
if (finalI + 1 < p.length) {
p[finalI + 1].setSuccess(null);
}
latch.countDown();
}
});
}
p[0].setSuccess(null);
latch.await(2, TimeUnit.SECONDS);
for (int i = 0; i < p.length; ++i) {
assertTrue("index " + i, p[i].isSuccess());
}
}
private void testStackOverFlowErrorChainedFuturesB(int promiseChainLength, EventExecutor executor)
throws InterruptedException {
final Promise<Void>[] p = new DefaultPromise[promiseChainLength];
final CountDownLatch latch = new CountDownLatch(promiseChainLength);
for (int i = 0; i < p.length; i ++) {
final int finalI = i;
p[i] = new DefaultPromise<Void>(executor);
p[i].addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
future.addListener(new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
if (finalI + 1 < p.length) {
p[finalI + 1].setSuccess(null);
}
latch.countDown();
}
});
}
});
}
p[0].setSuccess(null);
latch.await(2, TimeUnit.SECONDS);
for (int i = 0; i < p.length; ++i) {
assertTrue("index " + i, p[i].isSuccess());
}
}
/** /**
* This test is mean to simulate the following sequence of events, which all take place on the I/O thread: * This test is mean to simulate the following sequence of events, which all take place on the I/O thread:
* <ol> * <ol>