diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java index 6624f05db1..38b9de14cb 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -112,11 +112,12 @@ public final class PromiseCombiner { * @param aggregatePromise the promise to notify when all combined futures have finished */ public void finish(Promise aggregatePromise) { + ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); if (doneAdding) { throw new IllegalStateException("Already finished"); } doneAdding = true; - this.aggregatePromise = ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); + this.aggregatePromise = aggregatePromise; if (doneCount == expectedCount) { tryPromise(); } diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java index b46fa41024..d77aab70ce 100644 --- a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java +++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -15,6 +15,7 @@ */ package io.netty.util.concurrent; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -58,6 +59,18 @@ public class PromiseCombinerTest { combiner = new PromiseCombiner(); } + @Test + public void testNullArgument() { + try { + combiner.finish(null); + Assert.fail(); + } catch (NullPointerException expected) { + // expected + } + combiner.finish(p1); + verify(p1).trySuccess(null); + } + @Test public void testNullAggregatePromise() { combiner.finish(p1);