DefaultPromise StackOverflowError protection updates

Modifications:
DefaultPromise provides a ThreadLocal queue to protect against StackOverflowError because of executors which may immediately execute runnables instead of queue them (i.e. ImmediateEventExecutor). However this may be better addressed by fixing these executors to protect against StackOverflowError instead of just fixing for a single use case. Also the most commonly used executors already provide the desired behavior and don't need the additional overhead of a ThreadLocal queue in DefaultPromise.

Modifications:
- Remove ThreadLocal queue from DefaultPromise
- Change ImmediateEventExecutor so it maintains a queue of runnables if reentrant condition occurs

Result:
DefaultPromise StackOverflowError code is simpler, and ImmediateEventExecutor protects against StackOverflowError.
This commit is contained in:
Scott Mitchell 2016-05-25 12:52:09 -07:00
parent af7b0a04a0
commit f6ad9df8ac
3 changed files with 169 additions and 126 deletions

View File

@ -24,8 +24,6 @@ import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
@ -42,26 +40,6 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
private static final Signal SUCCESS = Signal.valueOf(DefaultPromise.class, "SUCCESS");
private static final Signal UNCANCELLABLE = Signal.valueOf(DefaultPromise.class, "UNCANCELLABLE");
private static final CauseHolder CANCELLATION_CAUSE_HOLDER = new CauseHolder(new CancellationException());
private static final FastThreadLocal<Queue<DefaultPromise<?>>> STACK_OVERFLOW_DELAYED_PROMISES =
new FastThreadLocal<Queue<DefaultPromise<?>>>() {
@Override
protected Queue<DefaultPromise<?>> initialValue() throws Exception {
return new ArrayDeque<DefaultPromise<?>>(2);
}
};
/**
* This queue will hold pairs of {@code <Future, GenericFutureListener>} which are always inserted sequentially.
* <p>
* This is only used for static utility method {@link #notifyListener(EventExecutor, Future, GenericFutureListener)}
* and not instances of {@link DefaultPromise}.
*/
private static final FastThreadLocal<Queue<Object>> STACK_OVERFLOW_DELAYED_FUTURES =
new FastThreadLocal<Queue<Object>>() {
@Override
protected Queue<Object> initialValue() throws Exception {
return new ArrayDeque<Object>(2);
}
};
static {
AtomicReferenceFieldUpdater<DefaultPromise, Object> updater =
@ -97,12 +75,19 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
* It is preferable to use {@link EventExecutor#newPromise()} to create a new promise
*
* @param executor
* the {@link EventExecutor} which is used to notify the promise once it is complete
* the {@link EventExecutor} which is used to notify the promise once it is complete.
* It is assumed this executor will protect against {@link StackOverflowError} exceptions.
* The executor may be used to avoid {@link StackOverflowError} by executing a {@link Runnable} if the stack
* depth exceeds a threshold.
*
*/
public DefaultPromise(EventExecutor executor) {
this.executor = checkNotNull(executor, "executor");
}
/**
* See {@link #executor()} for expectations of the executor.
*/
protected DefaultPromise() {
// only for subclasses
executor = null;
@ -391,6 +376,14 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
return buf;
}
/**
* Get the executor used to notify listeners when this promise is complete.
* <p>
* It is assumed this executor will protect against {@link StackOverflowError} exceptions.
* The executor may be used to avoid {@link StackOverflowError} by executing a {@link Runnable} if the stack
* depth exceeds a threshold.
* @return The executor used to notify listeners when this promise is complete.
*/
protected EventExecutor executor() {
return executor;
}
@ -413,18 +406,10 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
*/
protected static void notifyListener(
EventExecutor eventExecutor, final Future<?> future, final GenericFutureListener<?> listener) {
checkNotNull(eventExecutor, "eventExecutor");
checkNotNull(future, "future");
checkNotNull(listener, "listener");
if (eventExecutor.inEventLoop()) {
notifyListenerWithStackOverFlowProtection(future, listener);
} else {
safeExecute(eventExecutor, new OneTimeTask() {
@Override
public void run() {
notifyListenerWithStackOverFlowProtection(future, listener);
}
});
}
notifyListenerWithStackOverFlowProtection(eventExecutor, future, listener);
}
private void notifyListeners() {
@ -432,87 +417,61 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
if (listeners == null) {
return;
}
EventExecutor executor = executor();
if (executor.inEventLoop()) {
notifyListenersWithStackOverFlowProtection();
} else {
safeExecute(executor, new OneTimeTask() {
@Override
public void run() {
notifyListenersWithStackOverFlowProtection();
}
});
}
notifyListenersWithStackOverFlowProtection();
}
private void notifyListenersWithStackOverFlowProtection() {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get();
final int stackDepth = threadLocals.futureListenerStackDepth();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListenersNow();
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
if (stackDepth == 0) {
// We want to force all notifications to occur at the same depth on the stack as the initial method
// invocation. If we leave the stackDepth at 0 then notifyListeners could occur from a
// delayedPromise's stack which could lead to a stack overflow. So force the stack depth to 1 here,
// and later set it back to 0 after all delayedPromises have been notified.
threadLocals.setFutureListenerStackDepth(1);
try {
Queue<DefaultPromise<?>> delayedPromiseQueue = STACK_OVERFLOW_DELAYED_PROMISES.get();
DefaultPromise<?> delayedPromise;
while ((delayedPromise = delayedPromiseQueue.poll()) != null) {
delayedPromise.notifyListenersWithStackOverFlowProtection();
}
} finally {
threadLocals.setFutureListenerStackDepth(0);
}
EventExecutor executor = executor();
if (executor.inEventLoop()) {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get();
final int stackDepth = threadLocals.futureListenerStackDepth();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListenersNow();
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
}
return;
}
} else {
STACK_OVERFLOW_DELAYED_PROMISES.get().add(this);
}
safeExecute(executor, new OneTimeTask() {
@Override
public void run() {
notifyListenersNow();
}
});
}
/**
* The logic in this method should be identical to {@link #notifyListenersWithStackOverFlowProtection()} but
* cannot share code because the listener(s) can not be queued for an instance of {@link DefaultPromise} since the
* cannot share code because the listener(s) cannot be cached for an instance of {@link DefaultPromise} since the
* listener(s) may be changed and is protected by a synchronized operation.
*/
private static void notifyListenerWithStackOverFlowProtection(Future<?> future, GenericFutureListener<?> listener) {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get();
final int stackDepth = threadLocals.futureListenerStackDepth();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListener0(future, listener);
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
if (stackDepth == 0) {
// We want to force all notifications to occur at the same depth on the stack as the initial method
// invocation. If we leave the stackDepth at 0 then notifyListeners could occur from a
// delayedPromise's stack which could lead to a stack overflow. So force the stack depth to 1 here,
// and later set it back to 0 after all delayedPromises have been notified.
threadLocals.setFutureListenerStackDepth(1);
try {
Queue<Object> delayedFutureQueue = STACK_OVERFLOW_DELAYED_FUTURES.get();
Object delayedFuture;
while ((delayedFuture = delayedFutureQueue.poll()) != null) {
notifyListenerWithStackOverFlowProtection((Future<?>) delayedFuture,
(GenericFutureListener<?>) delayedFutureQueue.poll());
}
} finally {
threadLocals.setFutureListenerStackDepth(0);
}
private static void notifyListenerWithStackOverFlowProtection(final EventExecutor executor,
final Future<?> future,
final GenericFutureListener<?> listener) {
if (executor.inEventLoop()) {
final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get();
final int stackDepth = threadLocals.futureListenerStackDepth();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
threadLocals.setFutureListenerStackDepth(stackDepth + 1);
try {
notifyListener0(future, listener);
} finally {
threadLocals.setFutureListenerStackDepth(stackDepth);
}
return;
}
} else {
Queue<Object> delayedFutureQueue = STACK_OVERFLOW_DELAYED_FUTURES.get();
delayedFutureQueue.add(future);
delayedFutureQueue.add(listener);
}
safeExecute(executor, new OneTimeTask() {
@Override
public void run() {
notifyListener0(future, listener);
}
});
}
private void notifyListenersNow() {

View File

@ -15,21 +15,46 @@
*/
package io.netty.util.concurrent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
/**
* {@link AbstractEventExecutor} which execute tasks in the callers thread.
* Executes {@link Runnable} objects in the caller's thread. If the {@link #execute(Runnable)} is reentrant it will be
* queued until the original {@link Runnable} finishes execution.
* <p>
* All {@link Throwable} objects thrown from {@link #execute(Runnable)} will be swallowed and logged. This is to ensure
* that all queued {@link Runnable} objects have the chance to be run.
*/
public final class ImmediateEventExecutor extends AbstractEventExecutor {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(ImmediateEventExecutor.class);
public static final ImmediateEventExecutor INSTANCE = new ImmediateEventExecutor();
/**
* A Runnable will be queued if we are executing a Runnable. This is to prevent a {@link StackOverflowError}.
*/
private static final FastThreadLocal<Queue<Runnable>> DELAYED_RUNNABLES = new FastThreadLocal<Queue<Runnable>>() {
@Override
protected Queue<Runnable> initialValue() throws Exception {
return new ArrayDeque<Runnable>();
}
};
/**
* Set to {@code true} if we are executing a runnable.
*/
private static final FastThreadLocal<Boolean> RUNNING = new FastThreadLocal<Boolean>() {
@Override
protected Boolean initialValue() throws Exception {
return false;
}
};
private final Future<?> terminationFuture = new FailedFuture<Object>(
GlobalEventExecutor.INSTANCE, new UnsupportedOperationException());
private ImmediateEventExecutor() {
// Singleton
}
private ImmediateEventExecutor() { }
@Override
public boolean inEventLoop() {
@ -80,7 +105,27 @@ public final class ImmediateEventExecutor extends AbstractEventExecutor {
if (command == null) {
throw new NullPointerException("command");
}
command.run();
if (!RUNNING.get()) {
RUNNING.set(true);
try {
command.run();
} catch (Throwable cause) {
logger.info("Throwable caught while executing Runnable {}", command, cause);
} finally {
Queue<Runnable> delayedRunnables = DELAYED_RUNNABLES.get();
Runnable runnable;
while ((runnable = delayedRunnables.poll()) != null) {
try {
runnable.run();
} catch (Throwable cause) {
logger.info("Throwable caught while executing Runnable {}", runnable, cause);
}
}
RUNNING.set(false);
}
} else {
DELAYED_RUNNABLES.get().add(command);
}
}
@Override

View File

@ -32,6 +32,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import static java.lang.Math.max;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
@ -56,6 +57,10 @@ public class DefaultPromiseTest {
findStackOverflowDepth();
}
private static int stackOverflowTestDepth() {
return max(stackOverflowDepth << 1, stackOverflowDepth);
}
@Test(expected = CancellationException.class)
public void testCancellationExceptionIsThrownWhenBlockingGet() throws InterruptedException, ExecutionException {
final Promise<Void> promise = new DefaultPromise<Void>(ImmediateEventExecutor.INSTANCE);
@ -72,18 +77,19 @@ public class DefaultPromiseTest {
}
@Test
public void testNoStackOverflowErrorWithImmediateEventExecutorA() throws Exception {
testStackOverFlowErrorChainedFuturesA(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE),
ImmediateEventExecutor.INSTANCE);
public void testStackOverflowWithImmediateEventExecutorA() throws Exception {
testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, true);
testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, false);
}
@Test
public void testNoStackOverflowErrorWithDefaultEventExecutorA() throws Exception {
public void testNoStackOverflowWithDefaultEventExecutorA() throws Exception {
ExecutorService executorService = Executors.newSingleThreadExecutor();
try {
EventExecutor executor = new DefaultEventExecutor(executorService);
try {
testStackOverFlowErrorChainedFuturesA(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE), executor);
testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), executor, true);
testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), executor, false);
} finally {
executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
}
@ -93,18 +99,19 @@ public class DefaultPromiseTest {
}
@Test
public void testNoStackOverflowErrorWithImmediateEventExecutorB() throws Exception {
testStackOverFlowErrorChainedFuturesB(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE),
ImmediateEventExecutor.INSTANCE);
public void testNoStackOverflowWithImmediateEventExecutorB() throws Exception {
testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, true);
testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, false);
}
@Test
public void testNoStackOverflowErrorWithDefaultEventExecutorB() throws Exception {
public void testNoStackOverflowWithDefaultEventExecutorB() throws Exception {
ExecutorService executorService = Executors.newSingleThreadExecutor();
try {
EventExecutor executor = new DefaultEventExecutor(executorService);
try {
testStackOverFlowErrorChainedFuturesB(Math.min(stackOverflowDepth << 1, Integer.MAX_VALUE), executor);
testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), executor, true);
testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), executor, false);
} finally {
executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
}
@ -197,10 +204,31 @@ public class DefaultPromiseTest {
testLateListenerIsOrderedCorrectly(fakeException());
}
private void testStackOverFlowErrorChainedFuturesA(int promiseChainLength, EventExecutor executor)
private void testStackOverFlowChainedFuturesA(int promiseChainLength, final EventExecutor executor,
boolean runTestInExecutorThread)
throws InterruptedException {
final Promise<Void>[] p = new DefaultPromise[promiseChainLength];
final CountDownLatch latch = new CountDownLatch(promiseChainLength);
if (runTestInExecutorThread) {
executor.execute(new Runnable() {
@Override
public void run() {
testStackOverFlowChainedFuturesA(executor, p, latch);
}
});
} else {
testStackOverFlowChainedFuturesA(executor, p, latch);
}
assertTrue(latch.await(2, TimeUnit.SECONDS));
for (int i = 0; i < p.length; ++i) {
assertTrue("index " + i, p[i].isSuccess());
}
}
private void testStackOverFlowChainedFuturesA(EventExecutor executor, final Promise<Void>[] p,
final CountDownLatch latch) {
for (int i = 0; i < p.length; i ++) {
final int finalI = i;
p[i] = new DefaultPromise<Void>(executor);
@ -216,17 +244,33 @@ public class DefaultPromiseTest {
}
p[0].setSuccess(null);
}
latch.await(2, TimeUnit.SECONDS);
private void testStackOverFlowChainedFuturesB(int promiseChainLength, final EventExecutor executor,
boolean runTestInExecutorThread)
throws InterruptedException {
final Promise<Void>[] p = new DefaultPromise[promiseChainLength];
final CountDownLatch latch = new CountDownLatch(promiseChainLength);
if (runTestInExecutorThread) {
executor.execute(new Runnable() {
@Override
public void run() {
testStackOverFlowChainedFuturesA(executor, p, latch);
}
});
} else {
testStackOverFlowChainedFuturesA(executor, p, latch);
}
assertTrue(latch.await(2, TimeUnit.SECONDS));
for (int i = 0; i < p.length; ++i) {
assertTrue("index " + i, p[i].isSuccess());
}
}
private void testStackOverFlowErrorChainedFuturesB(int promiseChainLength, EventExecutor executor)
throws InterruptedException {
final Promise<Void>[] p = new DefaultPromise[promiseChainLength];
final CountDownLatch latch = new CountDownLatch(promiseChainLength);
private void testStackOverFlowChainedFuturesB(EventExecutor executor, final Promise<Void>[] p,
final CountDownLatch latch) {
for (int i = 0; i < p.length; i ++) {
final int finalI = i;
p[i] = new DefaultPromise<Void>(executor);
@ -247,11 +291,6 @@ public class DefaultPromiseTest {
}
p[0].setSuccess(null);
latch.await(2, TimeUnit.SECONDS);
for (int i = 0; i < p.length; ++i) {
assertTrue("index " + i, p[i].isSuccess());
}
}
/**