diff --git a/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java b/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java index df7547545d..a798563c34 100644 --- a/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java +++ b/common/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java @@ -23,17 +23,21 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import java.lang.Thread.State; import java.util.ArrayList; +import java.util.Collection; import java.util.LinkedHashSet; import java.util.List; import java.util.Queue; import java.util.Set; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -739,6 +743,39 @@ public abstract class SingleThreadEventExecutor extends AbstractScheduledEventEx } } + @Override + public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + throwIfInEventLoop("invokeAny"); + return super.invokeAny(tasks); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + throwIfInEventLoop("invokeAny"); + return super.invokeAny(tasks, timeout, unit); + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + throwIfInEventLoop("invokeAll"); + return super.invokeAll(tasks); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { + throwIfInEventLoop("invokeAll"); + return super.invokeAll(tasks, timeout, unit); + } + + private void throwIfInEventLoop(String method) { + if (inEventLoop()) { + throw new RejectedExecutionException("Calling " + method + " from within the EventLoop is not allowed"); + } + } + /** * Returns the {@link ThreadProperties} of the {@link Thread} that powers the {@link SingleThreadEventExecutor}. * If the {@link SingleThreadEventExecutor} is not started yet, this operation will start it and block until the diff --git a/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java b/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java index 7a9e20b9ea..3be7dc8f83 100644 --- a/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java +++ b/common/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java @@ -18,6 +18,12 @@ package io.netty.util.concurrent; import org.junit.Assert; import org.junit.Test; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; public class SingleThreadEventExecutorTest { @@ -49,4 +55,75 @@ public class SingleThreadEventExecutorTest { Assert.assertTrue(threadProperties.stackTrace().length > 0); executor.shutdownGracefully(); } + + @Test(expected = RejectedExecutionException.class, timeout = 3000) + public void testInvokeAnyInEventLoop() { + testInvokeInEventLoop(true, false); + } + + @Test(expected = RejectedExecutionException.class, timeout = 3000) + public void testInvokeAnyInEventLoopWithTimeout() { + testInvokeInEventLoop(true, true); + } + + @Test(expected = RejectedExecutionException.class, timeout = 3000) + public void testInvokeAllInEventLoop() { + testInvokeInEventLoop(false, false); + } + + @Test(expected = RejectedExecutionException.class, timeout = 3000) + public void testInvokeAllInEventLoopWithTimeout() { + testInvokeInEventLoop(false, true); + } + + private static void testInvokeInEventLoop(final boolean any, final boolean timeout) { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor(null, + Executors.defaultThreadFactory(), false) { + @Override + protected void run() { + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + }; + try { + final Promise promise = executor.newPromise(); + executor.execute(new Runnable() { + @Override + public void run() { + try { + Set> set = Collections.>singleton(new Callable() { + @Override + public Boolean call() throws Exception { + promise.setFailure(new AssertionError("Should never execute the Callable")); + return Boolean.TRUE; + } + }); + if (any) { + if (timeout) { + executor.invokeAny(set, 10, TimeUnit.SECONDS); + } else { + executor.invokeAny(set); + } + } else { + if (timeout) { + executor.invokeAll(set, 10, TimeUnit.SECONDS); + } else { + executor.invokeAll(set); + } + } + promise.setFailure(new AssertionError("Should never reach here")); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }); + promise.syncUninterruptibly(); + } finally { + executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + } }