diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java index 730a9fe205..d53f7b2294 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -113,11 +113,12 @@ public final class PromiseCombiner { * @param aggregatePromise the promise to notify when all combined futures have finished */ public void finish(Promise aggregatePromise) { + requireNonNull(aggregatePromise, "aggregatePromise"); if (doneAdding) { throw new IllegalStateException("Already finished"); } doneAdding = true; - this.aggregatePromise = requireNonNull(aggregatePromise, "aggregatePromise"); + this.aggregatePromise = aggregatePromise; if (doneCount == expectedCount) { tryPromise(); } diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java index 1625a7f3b3..b528f92f99 100644 --- a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java +++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -15,6 +15,7 @@ */ package io.netty.util.concurrent; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -58,6 +59,18 @@ public class PromiseCombinerTest { combiner = new PromiseCombiner(); } + @Test + public void testNullArgument() { + try { + combiner.finish(null); + Assert.fail(); + } catch (NullPointerException expected) { + // expected + } + combiner.finish(p1); + verify(p1).trySuccess(null); + } + @Test public void testNullAggregatePromise() { combiner.finish(p1);