From 7764b0e3de3280e164826945f5c3dc670d672e00 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Sun, 10 Aug 2014 13:40:41 +0200 Subject: [PATCH] [#2752] Add PendingWriteQueue for queue up writes Motivation: Sometimes ChannelHandler need to queue writes to some point and then process these. We currently have no datastructure for this so the user will use an Queue or something like this. The problem is with this Channel.isWritable() will not work as expected and so the user risk to write to fast. That's exactly what happened in our SslHandler. For this purpose we need to add a special datastructure which will also take care of update the Channel and so be sure that Channel.isWritable() works as expected. Modifications: - Add PendingWriteQueue which can be used for this purpose - Make use of PendingWriteQueue in SslHandler Result: It is now possible to queue writes in a ChannelHandler and still have Channel.isWritable() working as expected. This also fixes #2752. --- .../java/io/netty/handler/ssl/SslHandler.java | 60 ++-- .../netty/channel/ChannelOutboundBuffer.java | 4 +- .../io/netty/channel/PendingWriteQueue.java | 278 ++++++++++++++++++ .../netty/channel/PendingWriteQueueTest.java | 170 +++++++++++ 4 files changed, 468 insertions(+), 44 deletions(-) create mode 100644 transport/src/main/java/io/netty/channel/PendingWriteQueue.java create mode 100644 transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index b9277da171..c9bd6756c3 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -27,13 +27,13 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.PendingWriteQueue; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.internal.EmptyArrays; -import io.netty.util.internal.PendingWrite; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -48,8 +48,6 @@ import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.DatagramChannel; import java.nio.channels.SocketChannel; -import java.util.ArrayDeque; -import java.util.Deque; import java.util.List; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; @@ -201,9 +199,10 @@ public class SslHandler extends ByteToMessageDecoder { private final boolean startTls; private boolean sentFirstMessage; private boolean flushedBeforeHandshakeDone; + private PendingWriteQueue pendingUnencryptedWrites; + private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); - private final Deque pendingUnencryptedWrites = new ArrayDeque(); /** * Set by wrap*() methods when something is produced. @@ -343,12 +342,9 @@ public class SslHandler extends ByteToMessageDecoder { @Override public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - for (;;) { - PendingWrite write = pendingUnencryptedWrites.poll(); - if (write == null) { - break; - } - write.failAndRecycle(new ChannelException("Pending write on removal of SslHandler")); + if (!pendingUnencryptedWrites.isEmpty()) { + // Check if queue is not empty first because create a new ChannelException is expensive + pendingUnencryptedWrites.removeAndFailAll(new ChannelException("Pending write on removal of SslHandler")); } } @@ -366,7 +362,7 @@ public class SslHandler extends ByteToMessageDecoder { @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - pendingUnencryptedWrites.add(PendingWrite.newInstance(msg, promise)); + pendingUnencryptedWrites.add(msg, promise); } @Override @@ -375,18 +371,12 @@ public class SslHandler extends ByteToMessageDecoder { // created with startTLS flag turned on. if (startTls && !sentFirstMessage) { sentFirstMessage = true; - for (;;) { - PendingWrite pendingWrite = pendingUnencryptedWrites.poll(); - if (pendingWrite == null) { - break; - } - ctx.write(pendingWrite.msg(), (ChannelPromise) pendingWrite.recycleAndGet()); - } + pendingUnencryptedWrites.removeAndWriteAll(); ctx.flush(); return; } if (pendingUnencryptedWrites.isEmpty()) { - pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null)); + pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, ctx.voidPromise()); } if (!handshakePromise.isDone()) { flushedBeforeHandshakeDone = true; @@ -400,18 +390,17 @@ public class SslHandler extends ByteToMessageDecoder { ChannelPromise promise = null; try { for (;;) { - PendingWrite pending = pendingUnencryptedWrites.peek(); - if (pending == null) { + Object msg = pendingUnencryptedWrites.current(); + if (msg == null) { break; } - if (!(pending.msg() instanceof ByteBuf)) { - ctx.write(pending.msg(), (ChannelPromise) pending.recycleAndGet()); - pendingUnencryptedWrites.remove(); + if (!(msg instanceof ByteBuf)) { + pendingUnencryptedWrites.removeAndWrite(); continue; } - ByteBuf buf = (ByteBuf) pending.msg(); + ByteBuf buf = (ByteBuf) msg; if (out == null) { out = allocateOutNetBuf(ctx, buf.readableBytes()); } @@ -419,9 +408,7 @@ public class SslHandler extends ByteToMessageDecoder { SSLEngineResult result = wrap(engine, buf, out); if (!buf.isReadable()) { - buf.release(); - promise = (ChannelPromise) pending.recycleAndGet(); - pendingUnencryptedWrites.remove(); + promise = pendingUnencryptedWrites.remove(); } else { promise = null; } @@ -429,13 +416,7 @@ public class SslHandler extends ByteToMessageDecoder { if (result.getStatus() == Status.CLOSED) { // SSLEngine has been closed already. // Any further write attempts should be denied. - for (;;) { - PendingWrite w = pendingUnencryptedWrites.poll(); - if (w == null) { - break; - } - w.failAndRecycle(SSLENGINE_CLOSED); - } + pendingUnencryptedWrites.removeAndFailAll(SSLENGINE_CLOSED); return; } else { switch (result.getHandshakeStatus()) { @@ -1037,13 +1018,7 @@ public class SslHandler extends ByteToMessageDecoder { } } notifyHandshakeFailure(cause); - for (;;) { - PendingWrite write = pendingUnencryptedWrites.poll(); - if (write == null) { - break; - } - write.failAndRecycle(cause); - } + pendingUnencryptedWrites.removeAndFailAll(cause); } private void notifyHandshakeFailure(Throwable cause) { @@ -1075,6 +1050,7 @@ public class SslHandler extends ByteToMessageDecoder { @Override public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; + pendingUnencryptedWrites = new PendingWriteQueue(ctx); if (ctx.channel().isActive() && engine.getUseClientMode()) { // channelActive() event has been fired already, which means this.channelActive() will diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index f289a86243..cdfeb7aaf7 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -155,7 +155,7 @@ public final class ChannelOutboundBuffer { * Increment the pending bytes which will be written at some point. * This method is thread-safe! */ - void incrementPendingOutboundBytes(int size) { + void incrementPendingOutboundBytes(long size) { if (size == 0) { return; } @@ -172,7 +172,7 @@ public final class ChannelOutboundBuffer { * Decrement the pending bytes which will be written at some point. * This method is thread-safe! */ - void decrementPendingOutboundBytes(int size) { + void decrementPendingOutboundBytes(long size) { if (size == 0) { return; } diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java new file mode 100644 index 0000000000..a1a0e56a8c --- /dev/null +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -0,0 +1,278 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.Recycler; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * A queue of write operations which are pending for later execution. It also updates the + * {@linkplain Channel#isWritable() writability} of the associated {@link Channel}, so that + * the pending write operations are also considered to determine the writability. + */ +public final class PendingWriteQueue { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PendingWriteQueue.class); + + private final ChannelHandlerContext ctx; + private final ChannelOutboundBuffer buffer; + private final MessageSizeEstimator.Handle estimatorHandle; + + // head and tail pointers for the linked-list structure. If empty head and tail are null. + private PendingWrite head; + private PendingWrite tail; + private int size; + + public PendingWriteQueue(ChannelHandlerContext ctx) { + if (ctx == null) { + throw new NullPointerException("ctx"); + } + this.ctx = ctx; + buffer = ctx.channel().unsafe().outboundBuffer(); + estimatorHandle = ctx.channel().config().getMessageSizeEstimator().newHandle(); + } + + /** + * Returns {@code true} if there are no pending write operations left in this queue. + */ + public boolean isEmpty() { + assert ctx.executor().inEventLoop(); + return head == null; + } + + /** + * Returns the number of pending write operations. + */ + public int size() { + assert ctx.executor().inEventLoop(); + return size; + } + + /** + * Add the given {@code msg} and {@link ChannelPromise}. + */ + public void add(Object msg, ChannelPromise promise) { + assert ctx.executor().inEventLoop(); + if (msg == null) { + throw new NullPointerException("msg"); + } + if (promise == null) { + throw new NullPointerException("promise"); + } + int messageSize = estimatorHandle.size(msg); + if (messageSize < 0) { + // Size may be unknow so just use 0 + messageSize = 0; + } + PendingWrite write = PendingWrite.newInstance(msg, messageSize, promise); + PendingWrite currentTail = tail; + if (currentTail == null) { + tail = head = write; + } else { + currentTail.next = write; + tail = write; + } + size ++; + buffer.incrementPendingOutboundBytes(write.size); + } + + /** + * Remove all pending write operation and fail them with the given {@link Throwable}. The message will be released + * via {@link ReferenceCountUtil#safeRelease(Object)}. + */ + public void removeAndFailAll(Throwable cause) { + assert ctx.executor().inEventLoop(); + if (cause == null) { + throw new NullPointerException("cause"); + } + PendingWrite write = head; + while (write != null) { + PendingWrite next = write.next; + ReferenceCountUtil.safeRelease(write.msg); + ChannelPromise promise = write.promise; + recycle(write); + safeFail(promise, cause); + write = next; + } + assertEmpty(); + } + + /** + * Remove a pending write operation and fail it with the given {@link Throwable}. The message will be released via + * {@link ReferenceCountUtil#safeRelease(Object)}. + */ + public void removeAndFail(Throwable cause) { + assert ctx.executor().inEventLoop(); + if (cause == null) { + throw new NullPointerException("cause"); + } + PendingWrite write = head; + if (write == null) { + return; + } + ReferenceCountUtil.safeRelease(write.msg); + ChannelPromise promise = write.promise; + safeFail(promise, cause); + recycle(write); + } + + /** + * Remove all pending write operation and performs them via + * {@link ChannelHandlerContext#write(Object, ChannelPromise)}. + * + * @return {@link ChannelFuture} if something was written and {@code null} + * if the {@link PendingWriteQueue} is empty. + */ + public ChannelFuture removeAndWriteAll() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + // empty so just return null + return null; + } + if (size == 1) { + // No need to use ChannelPromiseAggregator for this case. + return removeAndWrite(); + } + ChannelPromise p = ctx.newPromise(); + ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(p); + while (write != null) { + PendingWrite next = write.next; + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write); + ctx.write(msg, promise); + aggregator.add(promise); + write = next; + } + assertEmpty(); + return p; + } + + private void assertEmpty() { + assert tail == null && head == null && size == 0; + } + + /** + * Removes a pending write operation and performs it via + * {@link ChannelHandlerContext#write(Object, ChannelPromise)}. + * + * @return {@link ChannelFuture} if something was written and {@code null} + * if the {@link PendingWriteQueue} is empty. + */ + public ChannelFuture removeAndWrite() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write); + return ctx.write(msg, promise); + } + + /** + * Removes a pending write operation and release it's message via {@link ReferenceCountUtil#safeRelease(Object)}. + * + * @return {@link ChannelPromise} of the pending write or {@code null} if the queue is empty. + * + */ + public ChannelPromise remove() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + ChannelPromise promise = write.promise; + ReferenceCountUtil.safeRelease(write.msg); + recycle(write); + return promise; + } + + /** + * Return the current message or {@code null} if empty. + */ + public Object current() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + return write.msg; + } + + private void recycle(PendingWrite write) { + PendingWrite next = write.next; + + buffer.decrementPendingOutboundBytes(write.size); + write.recycle(); + size --; + if (next == null) { + // Handled last PendingWrite so rest head and tail + head = tail = null; + assert size == 0; + } else { + head = next; + assert size > 0; + } + } + + private static void safeFail(ChannelPromise promise, Throwable cause) { + if (!(promise instanceof VoidChannelPromise) && !promise.tryFailure(cause)) { + logger.warn("Failed to mark a promise as failure because it's done already: {}", promise, cause); + } + } + + /** + * Holds all meta-data and construct the linked-list structure. + */ + static final class PendingWrite { + private static final Recycler RECYCLER = new Recycler() { + @Override + protected PendingWrite newObject(Handle handle) { + return new PendingWrite(handle); + } + }; + + private final Recycler.Handle handle; + private PendingWrite next; + private long size; + private ChannelPromise promise; + private Object msg; + + private PendingWrite(Recycler.Handle handle) { + this.handle = handle; + } + + static PendingWrite newInstance(Object msg, int size, ChannelPromise promise) { + PendingWrite write = RECYCLER.get(); + write.size = size; + write.msg = msg; + write.promise = promise; + return write; + } + + private void recycle() { + size = 0; + next = null; + msg = null; + promise = null; + RECYCLER.recycle(this, handle); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java b/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java new file mode 100644 index 0000000000..c3fdc1434e --- /dev/null +++ b/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.Assert; +import org.junit.Test; + +public class PendingWriteQueueTest { + + @Test + public void testRemoveAndWrite() { + assertWrite(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable()); + + ChannelFuture future = queue.removeAndWrite(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertQueueEmpty(queue); + } + }); + super.flush(ctx); + } + }, 1); + } + + @Test + public void testRemoveAndWriteAll() { + assertWrite(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable()); + + ChannelFuture future = queue.removeAndWriteAll(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertQueueEmpty(queue); + } + }); + super.flush(ctx); + } + }, 3); + } + + @Test + public void testRemoveAndFail() { + assertWriteFails(new TestHandler() { + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + queue.removeAndFail(new TestException()); + super.flush(ctx); + } + }, 1); + } + + @Test + public void testRemoveAndFailAll() { + assertWriteFails(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + queue.removeAndFailAll(new TestException()); + super.flush(ctx); + } + }, 3); + } + + private static void assertWrite(ChannelHandler handler, int count) { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.config().setWriteBufferLowWaterMark(1); + channel.config().setWriteBufferHighWaterMark(3); + + ByteBuf[] buffers = new ByteBuf[count]; + for (int i = 0; i < buffers.length; i++) { + buffers[i] = buffer.duplicate().retain(); + } + Assert.assertTrue(channel.writeOutbound(buffers)); + Assert.assertTrue(channel.finish()); + channel.closeFuture().syncUninterruptibly(); + + for (int i = 0; i < buffers.length; i++) { + assertBuffer(channel, buffer); + } + buffer.release(); + Assert.assertNull(channel.readOutbound()); + } + + private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) { + ByteBuf written = channel.readOutbound(); + Assert.assertEquals(buffer, written); + written.release(); + } + + private static void assertQueueEmpty(PendingWriteQueue queue) { + Assert.assertTrue(queue.isEmpty()); + Assert.assertEquals(0, queue.size()); + Assert.assertNull(queue.current()); + Assert.assertNull(queue.removeAndWrite()); + Assert.assertNull(queue.removeAndWriteAll()); + } + + private static void assertWriteFails(ChannelHandler handler, int count) { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + ByteBuf[] buffers = new ByteBuf[count]; + for (int i = 0; i < buffers.length; i++) { + buffers[i] = buffer.duplicate().retain(); + } + try { + Assert.assertFalse(channel.writeOutbound(buffers)); + Assert.fail(); + } catch (Exception e) { + Assert.assertTrue(e instanceof TestException); + } + Assert.assertFalse(channel.finish()); + channel.closeFuture().syncUninterruptibly(); + + buffer.release(); + Assert.assertNull(channel.readOutbound()); + } + + private static class TestHandler extends ChannelDuplexHandler { + protected PendingWriteQueue queue; + private int expectedSize; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + assertQueueEmpty(queue); + Assert.assertTrue("Should be writable", ctx.channel().isWritable()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + queue.add(msg, promise); + Assert.assertFalse(queue.isEmpty()); + Assert.assertEquals(++ expectedSize, queue.size()); + Assert.assertNotNull(queue.current()); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + queue = new PendingWriteQueue(ctx); + } + } + + private static final class TestException extends Exception { } +}