From 309ee68c217b4d6a5e83a15dc0b9d7b713d3fd6c Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Thu, 6 Feb 2014 16:06:30 -0800 Subject: [PATCH] Fix a race condition in DefaultPromise .. which occurs when a user adds a listener from different threads after the promise is done and the notifications for the listeners, that were added before the promise is done, is in progress. For instance: Thread-1: p.addListener(listenerA); Thread-1: p.setSuccess(null); Thread-2: p.addListener(listenerB); Thread-2: p.executor.execute(taskNotifyListenerB); Thread-1: p.executor.execute(taskNotifyListenerA); taskNotifyListenerB should not really notify listenerB until taskNotifyListenerA is finished. To fix this issue: - Change the semantic of (listeners == null) to determine if the early listeners [1] were notified - If a late listener is added before the early listeners are notified, the notification of the late listener is deferred until the early listeners are notified (i.e. until listeners == null) - The late listeners with deferred notifications are stored in a lazily instantiated queue to preserve ordering, and then are notified once the early listeners are notified. [1] the listeners that were added before the promise is done [2] the listeners that were added after the promise is done --- .../netty/util/concurrent/DefaultPromise.java | 126 ++++++++++++++++-- .../util/concurrent/DefaultPromiseTest.java | 74 ++++++++++ 2 files changed, 186 insertions(+), 14 deletions(-) diff --git a/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java b/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java index 082583d665..11cfd52491 100644 --- a/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java +++ b/common/src/main/java/io/netty/util/concurrent/DefaultPromise.java @@ -22,16 +22,15 @@ import io.netty.util.internal.StringUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.ArrayDeque; import java.util.concurrent.CancellationException; import java.util.concurrent.TimeUnit; import static java.util.concurrent.TimeUnit.*; - public class DefaultPromise extends AbstractFuture implements Promise { - private static final InternalLogger logger = - InternalLoggerFactory.getInstance(DefaultPromise.class); + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultPromise.class); private static final int MAX_LISTENER_STACK_DEPTH = 8; private static final ThreadLocal LISTENER_STACK_DEPTH = new ThreadLocal() { @@ -51,7 +50,19 @@ public class DefaultPromise extends AbstractFuture implements Promise { private final EventExecutor executor; private volatile Object result; - private Object listeners; // Can be ChannelFutureListener or DefaultFutureListeners + + /** + * One or more listeners. Can be a {@link GenericFutureListener} or a {@link DefaultFutureListeners}. + * If {@code null}, it means either 1) no listeners were added yet or 2) all listeners were notified. + */ + private Object listeners; + + /** + * The list of the listeners that were added after the promise is done. Initially {@code null} and lazily + * instantiated when the late listener is scheduled to be notified later. Also used as a cached {@link Runnable} + * that performs the notification of the listeners it contains. + */ + private LateListeners lateListeners; private short waiters; @@ -127,7 +138,7 @@ public class DefaultPromise extends AbstractFuture implements Promise { } if (isDone()) { - notifyListener(executor(), this, listener); + notifyLateListener(listener); return this; } @@ -149,7 +160,7 @@ public class DefaultPromise extends AbstractFuture implements Promise { } } - notifyListener(executor(), this, listener); + notifyLateListener(listener); return this; } @@ -541,8 +552,6 @@ public class DefaultPromise extends AbstractFuture implements Promise { return; } - this.listeners = null; - EventExecutor executor = executor(); if (executor.inEventLoop()) { final Integer stackDepth = LISTENER_STACK_DEPTH.get(); @@ -556,6 +565,7 @@ public class DefaultPromise extends AbstractFuture implements Promise { final GenericFutureListener> l = (GenericFutureListener>) listeners; notifyListener0(this, l); + this.listeners = null; } } finally { LISTENER_STACK_DEPTH.set(stackDepth); @@ -571,6 +581,7 @@ public class DefaultPromise extends AbstractFuture implements Promise { @Override public void run() { notifyListeners0(DefaultPromise.this, dfl); + DefaultPromise.this.listeners = null; } }); } else { @@ -581,6 +592,7 @@ public class DefaultPromise extends AbstractFuture implements Promise { @Override public void run() { notifyListener0(DefaultPromise.this, l); + DefaultPromise.this.listeners = null; } }); } @@ -597,6 +609,40 @@ public class DefaultPromise extends AbstractFuture implements Promise { } } + /** + * Notifies the specified listener which were added after this promise is already done. + * This method ensures that the specified listener is not notified until {@link #listeners} becomes {@code null} + * to avoid the case where the late listeners are notified even before the early listeners are notified. + */ + private void notifyLateListener(final GenericFutureListener l) { + final EventExecutor executor = executor(); + if (executor.inEventLoop()) { + if (listeners == null && lateListeners == null) { + final Integer stackDepth = LISTENER_STACK_DEPTH.get(); + if (stackDepth < MAX_LISTENER_STACK_DEPTH) { + LISTENER_STACK_DEPTH.set(stackDepth + 1); + try { + notifyListener0(this, l); + } finally { + LISTENER_STACK_DEPTH.set(stackDepth); + } + } + } else { + LateListeners lateListeners = this.lateListeners; + if (lateListeners == null) { + this.lateListeners = lateListeners = new LateListeners(); + } + lateListeners.add(l); + execute(executor, lateListeners); + } + } else { + // Add the late listener to lateListeners in the executor thread for thread safety. + // We could just make LateListeners extend ConcurrentLinkedQueue, but it's an overkill considering + // that most asynchronous applications won't execute this code path. + execute(executor, new LateListenerNotifier(l)); + } + } + protected static void notifyListener( final EventExecutor eventExecutor, final Future future, final GenericFutureListener l) { @@ -613,13 +659,17 @@ public class DefaultPromise extends AbstractFuture implements Promise { } } + execute(eventExecutor, new Runnable() { + @Override + public void run() { + notifyListener0(future, l); + } + }); + } + + private static void execute(EventExecutor executor, Runnable task) { try { - eventExecutor.execute(new Runnable() { - @Override - public void run() { - notifyListener0(future, l); - } - }); + executor.execute(task); } catch (Throwable t) { logger.error("Failed to notify a listener. Event loop shut down?", t); } @@ -780,4 +830,52 @@ public class DefaultPromise extends AbstractFuture implements Promise { } return buf; } + + private final class LateListeners extends ArrayDeque> implements Runnable { + + private static final long serialVersionUID = -687137418080392244L; + + LateListeners() { + super(2); + } + + @Override + public void run() { + if (listeners == null) { + for (;;) { + GenericFutureListener l = poll(); + if (l == null) { + break; + } + notifyListener0(DefaultPromise.this, l); + } + } else { + // Reschedule until the initial notification is done to avoid the race condition + // where the notification is made in an incorrect order. + executor().execute(this); + } + } + } + + private final class LateListenerNotifier implements Runnable { + private GenericFutureListener l; + + LateListenerNotifier(GenericFutureListener l) { + this.l = l; + } + + @Override + public void run() { + LateListeners lateListeners = DefaultPromise.this.lateListeners; + if (l != null) { + if (lateListeners == null) { + DefaultPromise.this.lateListeners = lateListeners = new LateListeners(); + } + lateListeners.add(l); + l = null; + } + + lateListeners.run(); + } + } } diff --git a/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java b/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java index bb8060909f..519802c449 100644 --- a/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java +++ b/common/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java @@ -18,6 +18,10 @@ package io.netty.util.concurrent; import org.junit.Test; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; + import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; @@ -74,4 +78,74 @@ public class DefaultPromiseTest { assertThat(a.isSuccess(), is(true)); } } + + @Test + public void testListenerNotifyOrder() throws Exception { + SingleThreadEventExecutor executor = + new SingleThreadEventExecutor(null, Executors.defaultThreadFactory(), true) { + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + }; + + final BlockingQueue> listeners = new LinkedBlockingQueue>(); + int runs = 20000; + + for (int i = 0; i < runs; i++) { + final Promise promise = new DefaultPromise(executor); + final FutureListener listener1 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + } + }; + final FutureListener listener2 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + } + }; + final FutureListener listener4 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + } + }; + final FutureListener listener3 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + // Ensure listener4 is notified *after* this method returns to maintain the order. + future.addListener(listener4); + listeners.add(this); + } + }; + + GlobalEventExecutor.INSTANCE.execute(new Runnable() { + @Override + public void run() { + promise.setSuccess(null); + } + }); + + promise.addListener(listener1).addListener(listener2).addListener(listener3); + + assertSame("Fail during run " + i + " / " + runs, listener1, listeners.take()); + assertSame("Fail during run " + i + " / " + runs, listener2, listeners.take()); + assertSame("Fail during run " + i + " / " + runs, listener3, listeners.take()); + assertSame("Fail during run " + i + " / " + runs, listener4, listeners.take()); + assertTrue("Fail during run " + i + " / " + runs, listeners.isEmpty()); + } + executor.shutdownGracefully().sync(); + } }