Fix a race condition in DefaultPromise

.. which occurs when a user adds a listener from different threads after the promise is done and the notifications for the listeners, that were added before the promise is done, is in progress.  For instance:

   Thread-1: p.addListener(listenerA);
   Thread-1: p.setSuccess(null);
   Thread-2: p.addListener(listenerB);
   Thread-2: p.executor.execute(taskNotifyListenerB);
   Thread-1: p.executor.execute(taskNotifyListenerA);

taskNotifyListenerB should not really notify listenerB until taskNotifyListenerA is finished.

To fix this issue:

- Change the semantic of (listeners == null) to determine if the early
  listeners [1] were notified
- If a late listener is added before the early listeners are notified,
  the notification of the late listener is deferred until the early
  listeners are notified (i.e. until listeners == null)
- The late listeners with deferred notifications are stored in a lazily
  instantiated queue to preserve ordering, and then are notified once
  the early listeners are notified.

[1] the listeners that were added before the promise is done
[2] the listeners that were added after the promise is done
This commit is contained in:
Trustin Lee 2014-02-06 16:06:30 -08:00
parent 2b769c6daf
commit 4628d9a672
2 changed files with 186 additions and 14 deletions

View File

@ -22,16 +22,15 @@ import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.util.ArrayDeque;
import java.util.concurrent.CancellationException;
import java.util.concurrent.TimeUnit;
import static java.util.concurrent.TimeUnit.*;
public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(DefaultPromise.class);
private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultPromise.class);
private static final int MAX_LISTENER_STACK_DEPTH = 8;
private static final ThreadLocal<Integer> LISTENER_STACK_DEPTH = new ThreadLocal<Integer>() {
@ -51,7 +50,19 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
private final EventExecutor executor;
private volatile Object result;
private Object listeners; // Can be ChannelFutureListener or DefaultFutureListeners
/**
* One or more listeners. Can be a {@link GenericFutureListener} or a {@link DefaultFutureListeners}.
* If {@code null}, it means either 1) no listeners were added yet or 2) all listeners were notified.
*/
private Object listeners;
/**
* The list of the listeners that were added after the promise is done. Initially {@code null} and lazily
* instantiated when the late listener is scheduled to be notified later. Also used as a cached {@link Runnable}
* that performs the notification of the listeners it contains.
*/
private LateListeners lateListeners;
private short waiters;
@ -127,7 +138,7 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
}
if (isDone()) {
notifyListener(executor(), this, listener);
notifyLateListener(listener);
return this;
}
@ -149,7 +160,7 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
}
}
notifyListener(executor(), this, listener);
notifyLateListener(listener);
return this;
}
@ -541,8 +552,6 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
return;
}
this.listeners = null;
EventExecutor executor = executor();
if (executor.inEventLoop()) {
final Integer stackDepth = LISTENER_STACK_DEPTH.get();
@ -556,6 +565,7 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
final GenericFutureListener<? extends Future<V>> l =
(GenericFutureListener<? extends Future<V>>) listeners;
notifyListener0(this, l);
this.listeners = null;
}
} finally {
LISTENER_STACK_DEPTH.set(stackDepth);
@ -571,6 +581,7 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
@Override
public void run() {
notifyListeners0(DefaultPromise.this, dfl);
DefaultPromise.this.listeners = null;
}
});
} else {
@ -581,6 +592,7 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
@Override
public void run() {
notifyListener0(DefaultPromise.this, l);
DefaultPromise.this.listeners = null;
}
});
}
@ -597,6 +609,40 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
}
}
/**
* Notifies the specified listener which were added after this promise is already done.
* This method ensures that the specified listener is not notified until {@link #listeners} becomes {@code null}
* to avoid the case where the late listeners are notified even before the early listeners are notified.
*/
private void notifyLateListener(final GenericFutureListener<?> l) {
final EventExecutor executor = executor();
if (executor.inEventLoop()) {
if (listeners == null && lateListeners == null) {
final Integer stackDepth = LISTENER_STACK_DEPTH.get();
if (stackDepth < MAX_LISTENER_STACK_DEPTH) {
LISTENER_STACK_DEPTH.set(stackDepth + 1);
try {
notifyListener0(this, l);
} finally {
LISTENER_STACK_DEPTH.set(stackDepth);
}
}
} else {
LateListeners lateListeners = this.lateListeners;
if (lateListeners == null) {
this.lateListeners = lateListeners = new LateListeners();
}
lateListeners.add(l);
execute(executor, lateListeners);
}
} else {
// Add the late listener to lateListeners in the executor thread for thread safety.
// We could just make LateListeners extend ConcurrentLinkedQueue, but it's an overkill considering
// that most asynchronous applications won't execute this code path.
execute(executor, new LateListenerNotifier(l));
}
}
protected static void notifyListener(
final EventExecutor eventExecutor, final Future<?> future, final GenericFutureListener<?> l) {
@ -613,13 +659,17 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
}
}
execute(eventExecutor, new Runnable() {
@Override
public void run() {
notifyListener0(future, l);
}
});
}
private static void execute(EventExecutor executor, Runnable task) {
try {
eventExecutor.execute(new Runnable() {
@Override
public void run() {
notifyListener0(future, l);
}
});
executor.execute(task);
} catch (Throwable t) {
logger.error("Failed to notify a listener. Event loop shut down?", t);
}
@ -780,4 +830,52 @@ public class DefaultPromise<V> extends AbstractFuture<V> implements Promise<V> {
}
return buf;
}
private final class LateListeners extends ArrayDeque<GenericFutureListener<?>> implements Runnable {
private static final long serialVersionUID = -687137418080392244L;
LateListeners() {
super(2);
}
@Override
public void run() {
if (listeners == null) {
for (;;) {
GenericFutureListener<?> l = poll();
if (l == null) {
break;
}
notifyListener0(DefaultPromise.this, l);
}
} else {
// Reschedule until the initial notification is done to avoid the race condition
// where the notification is made in an incorrect order.
executor().execute(this);
}
}
}
private final class LateListenerNotifier implements Runnable {
private GenericFutureListener<?> l;
LateListenerNotifier(GenericFutureListener<?> l) {
this.l = l;
}
@Override
public void run() {
LateListeners lateListeners = DefaultPromise.this.lateListeners;
if (l != null) {
if (lateListeners == null) {
DefaultPromise.this.lateListeners = lateListeners = new LateListeners();
}
lateListeners.add(l);
l = null;
}
lateListeners.run();
}
}
}

View File

@ -18,6 +18,10 @@ package io.netty.util.concurrent;
import org.junit.Test;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
@ -74,4 +78,74 @@ public class DefaultPromiseTest {
assertThat(a.isSuccess(), is(true));
}
}
@Test
public void testListenerNotifyOrder() throws Exception {
SingleThreadEventExecutor executor =
new SingleThreadEventExecutor(null, Executors.defaultThreadFactory(), true) {
@Override
protected void run() {
for (;;) {
Runnable task = takeTask();
if (task != null) {
task.run();
updateLastExecutionTime();
}
if (confirmShutdown()) {
break;
}
}
}
};
final BlockingQueue<FutureListener<Void>> listeners = new LinkedBlockingQueue<FutureListener<Void>>();
int runs = 20000;
for (int i = 0; i < runs; i++) {
final Promise<Void> promise = new DefaultPromise<Void>(executor);
final FutureListener<Void> listener1 = new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
listeners.add(this);
}
};
final FutureListener<Void> listener2 = new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
listeners.add(this);
}
};
final FutureListener<Void> listener4 = new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
listeners.add(this);
}
};
final FutureListener<Void> listener3 = new FutureListener<Void>() {
@Override
public void operationComplete(Future<Void> future) throws Exception {
// Ensure listener4 is notified *after* this method returns to maintain the order.
future.addListener(listener4);
listeners.add(this);
}
};
GlobalEventExecutor.INSTANCE.execute(new Runnable() {
@Override
public void run() {
promise.setSuccess(null);
}
});
promise.addListener(listener1).addListener(listener2).addListener(listener3);
assertSame("Fail during run " + i + " / " + runs, listener1, listeners.take());
assertSame("Fail during run " + i + " / " + runs, listener2, listeners.take());
assertSame("Fail during run " + i + " / " + runs, listener3, listeners.take());
assertSame("Fail during run " + i + " / " + runs, listener4, listeners.take());
assertTrue("Fail during run " + i + " / " + runs, listeners.isEmpty());
}
executor.shutdownGracefully().sync();
}
}