diff --git a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java index 6e68f6a841..5316f2125a 100644 --- a/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java +++ b/codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java @@ -19,11 +19,11 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; -import io.netty.channel.ChannelPromiseAggregator; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.compression.ZlibCodecFactory; import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.util.concurrent.PromiseCombiner; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_ENCODING; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; @@ -106,9 +106,7 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE return promise; } - ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(promise); - ChannelPromise bufPromise = ctx.newPromise(); - aggregator.add(bufPromise); + PromiseCombiner combiner = new PromiseCombiner(); for (;;) { ByteBuf nextBuf = nextReadableBuf(channel); boolean compressedEndOfStream = nextBuf == null && endOfStream; @@ -117,16 +115,8 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE compressedEndOfStream = nextBuf == null; } - final ChannelPromise nextPromise; - if (nextBuf != null) { - // We have to add the nextPromise to the aggregator before doing the current write. This is so - // completing the current write before the next write is done won't complete the aggregate promise - nextPromise = ctx.newPromise(); - aggregator.add(nextPromise); - } else { - nextPromise = null; - } - + ChannelPromise bufPromise = ctx.newPromise(); + combiner.add(bufPromise); super.writeData(ctx, streamId, buf, padding, compressedEndOfStream, bufPromise); if (nextBuf == null) { break; @@ -134,14 +124,16 @@ public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionE padding = 0; // Padding is only communicated once on the first iteration buf = nextBuf; - bufPromise = nextPromise; } - return promise; + combiner.finish(promise); + } catch (Throwable cause) { + promise.tryFailure(cause); } finally { if (endOfStream) { cleanup(stream, channel); } } + return promise; } @Override diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java index 8efd0bcae8..7d908c0379 100644 --- a/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java +++ b/common/src/main/java/io/netty/util/concurrent/PromiseAggregator.java @@ -20,6 +20,8 @@ import java.util.LinkedHashSet; import java.util.Set; /** + * @deprecated Use {@link PromiseCombiner} + * * {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s * into one, by listening to individual {@link Future}s and producing an aggregated result * (success/failure) when all {@link Future}s have completed. @@ -27,6 +29,7 @@ import java.util.Set; * @param the type of value returned by the {@link Future} * @param the type of {@link Future} */ +@Deprecated public class PromiseAggregator> implements GenericFutureListener { private final Promise aggregatePromise; diff --git a/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java new file mode 100644 index 0000000000..58ccae0077 --- /dev/null +++ b/common/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -0,0 +1,75 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; + +public final class PromiseCombiner { + private int expectedCount; + private int doneCount; + private boolean doneAdding; + private Promise aggregatePromise; + private Throwable cause; + private final GenericFutureListener> listener = new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + ++doneCount; + if (!future.isSuccess() && cause == null) { + cause = future.cause(); + } + if (doneCount == expectedCount && doneAdding) { + tryPromise(); + } + } + }; + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void add(Promise promise) { + checkAddAllowed(); + ++expectedCount; + promise.addListener(listener); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void addAll(Promise... promises) { + checkAddAllowed(); + expectedCount += promises.length; + for (Promise promise : promises) { + promise.addListener(listener); + } + } + + public void finish(Promise aggregatePromise) { + if (doneAdding) { + throw new IllegalStateException("Already finished"); + } + doneAdding = true; + this.aggregatePromise = ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); + if (doneCount == expectedCount) { + tryPromise(); + } + } + + private boolean tryPromise() { + return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause); + } + + private void checkAddAllowed() { + if (doneAdding) { + throw new IllegalStateException("Adding promises is not allowed after finished adding"); + } + } +} diff --git a/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java new file mode 100644 index 0000000000..b392585f59 --- /dev/null +++ b/common/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -0,0 +1,194 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class PromiseCombinerTest { + @Mock + private Promise p1; + private GenericFutureListener> l1; + private GenericFutureListenerConsumer l1Consumer = new GenericFutureListenerConsumer() { + @Override + public void accept(GenericFutureListener> listener) { + l1 = listener; + } + }; + @Mock + private Promise p2; + private GenericFutureListener> l2; + private GenericFutureListenerConsumer l2Consumer = new GenericFutureListenerConsumer() { + @Override + public void accept(GenericFutureListener> listener) { + l2 = listener; + } + }; + @Mock + private Promise p3; + private PromiseCombiner combiner; + + @Before + public void setup() { + MockitoAnnotations.initMocks(this); + combiner = new PromiseCombiner(); + } + + @Test + public void testNullAggregatePromise() { + combiner.finish(p1); + verify(p1).trySuccess(any(Void.class)); + } + + @Test(expected = NullPointerException.class) + public void testAddNullPromise() { + combiner.add(null); + } + + @Test(expected = NullPointerException.class) + public void testAddAllNullPromise() { + combiner.addAll(null); + } + + @Test(expected = IllegalStateException.class) + public void testAddAfterFinish() { + combiner.finish(p1); + combiner.add(p2); + } + + @SuppressWarnings("unchecked") + @Test(expected = IllegalStateException.class) + public void testAddAllAfterFinish() { + combiner.finish(p1); + combiner.addAll(p2); + } + + @SuppressWarnings("unchecked") + @Test(expected = IllegalStateException.class) + public void testFinishCalledTwiceThrows() { + combiner.finish(p1); + combiner.finish(p1); + } + + @Test + public void testAddAllSuccess() throws Exception { + mockSuccessPromise(p1, l1Consumer); + mockSuccessPromise(p2, l2Consumer); + combiner.addAll(p1, p2); + combiner.finish(p3); + l1.operationComplete(p1); + verifyNotCompleted(p3); + l2.operationComplete(p2); + verifySuccess(p3); + } + + @Test + public void testAddSuccess() throws Exception { + mockSuccessPromise(p1, l1Consumer); + mockSuccessPromise(p2, l2Consumer); + combiner.add(p1); + l1.operationComplete(p1); + combiner.add(p2); + l2.operationComplete(p2); + verifyNotCompleted(p3); + combiner.finish(p3); + verifySuccess(p3); + } + + @Test + public void testAddAllFail() throws Exception { + RuntimeException e1 = new RuntimeException("fake exception 1"); + RuntimeException e2 = new RuntimeException("fake exception 2"); + mockFailedPromise(p1, e1, l1Consumer); + mockFailedPromise(p2, e2, l2Consumer); + combiner.addAll(p1, p2); + combiner.finish(p3); + l1.operationComplete(p1); + verifyNotCompleted(p3); + l2.operationComplete(p2); + verifyFail(p3, e1); + } + + @Test + public void testAddFail() throws Exception { + RuntimeException e1 = new RuntimeException("fake exception 1"); + RuntimeException e2 = new RuntimeException("fake exception 2"); + mockFailedPromise(p1, e1, l1Consumer); + mockFailedPromise(p2, e2, l2Consumer); + combiner.add(p1); + l1.operationComplete(p1); + combiner.add(p2); + l2.operationComplete(p2); + verifyNotCompleted(p3); + combiner.finish(p3); + verifyFail(p3, e1); + } + + private void verifyFail(Promise p, Throwable cause) { + verify(p).tryFailure(eq(cause)); + } + + private void verifySuccess(Promise p) { + verify(p).trySuccess(any(Void.class)); + } + + private void verifyNotCompleted(Promise p) { + verify(p, never()).trySuccess(any(Void.class)); + verify(p, never()).tryFailure(any(Throwable.class)); + verify(p, never()).setSuccess(any(Void.class)); + verify(p, never()).setFailure(any(Throwable.class)); + } + + private void mockSuccessPromise(Promise p, GenericFutureListenerConsumer consumer) { + when(p.isDone()).thenReturn(true); + when(p.isSuccess()).thenReturn(true); + mockListener(p, consumer); + } + + private void mockFailedPromise(Promise p, Throwable cause, GenericFutureListenerConsumer consumer) { + when(p.isDone()).thenReturn(true); + when(p.isSuccess()).thenReturn(false); + when(p.cause()).thenReturn(cause); + mockListener(p, consumer); + } + + @SuppressWarnings("unchecked") + private void mockListener(final Promise p, final GenericFutureListenerConsumer consumer) { + doAnswer(new Answer>() { + @SuppressWarnings("unchecked") + @Override + public Promise answer(InvocationOnMock invocation) throws Throwable { + consumer.accept(invocation.getArgumentAt(0, GenericFutureListener.class)); + return p; + } + }).when(p).addListener(any(GenericFutureListener.class)); + } + + interface GenericFutureListenerConsumer { + void accept(GenericFutureListener> listener); + } +} diff --git a/transport/src/main/java/io/netty/channel/ChannelPromiseAggregator.java b/transport/src/main/java/io/netty/channel/ChannelPromiseAggregator.java index b882e7dbee..5ee15092fb 100644 --- a/transport/src/main/java/io/netty/channel/ChannelPromiseAggregator.java +++ b/transport/src/main/java/io/netty/channel/ChannelPromiseAggregator.java @@ -17,12 +17,16 @@ package io.netty.channel; import io.netty.util.concurrent.PromiseAggregator; +import io.netty.util.concurrent.PromiseCombiner; /** + * @deprecated Use {@link PromiseCombiner} + * * Class which is used to consolidate multiple channel futures into one, by * listening to the individual futures and producing an aggregated result * (success/failure) when all futures have completed. */ +@Deprecated public final class ChannelPromiseAggregator extends PromiseAggregator implements ChannelFutureListener { diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java index 7b20b2bdf8..3a49baaaa3 100644 --- a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -17,6 +17,7 @@ package io.netty.channel; import io.netty.util.Recycler; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.PromiseCombiner; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -169,17 +170,22 @@ public final class PendingWriteQueue { size = 0; ChannelPromise p = ctx.newPromise(); - ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(p); - while (write != null) { - PendingWrite next = write.next; - Object msg = write.msg; - ChannelPromise promise = write.promise; - recycle(write, false); - ctx.write(msg, promise); - aggregator.add(promise); - write = next; + PromiseCombiner combiner = new PromiseCombiner(); + try { + while (write != null) { + PendingWrite next = write.next; + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write, false); + combiner.add(promise); + ctx.write(msg, promise); + write = next; + } + assertEmpty(); + combiner.finish(p); + } catch (Throwable cause) { + p.setFailure(cause); } - assertEmpty(); return p; }