diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 7a380266de..40ef0ce563 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -28,6 +28,7 @@ import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.PendingWriteQueue; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.EventExecutor; @@ -35,7 +36,6 @@ import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.ImmediateExecutor; import io.netty.util.internal.EmptyArrays; -import io.netty.util.internal.PendingWrite; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -51,9 +51,7 @@ import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.DatagramChannel; import java.nio.channels.SocketChannel; -import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Deque; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -207,9 +205,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private final boolean startTls; private boolean sentFirstMessage; private boolean flushedBeforeHandshakeDone; + private PendingWriteQueue pendingUnencryptedWrites; + private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); - private final Deque pendingUnencryptedWrites = new ArrayDeque(); /** * Set by wrap*() methods when something is produced. @@ -370,12 +369,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { - for (;;) { - PendingWrite write = pendingUnencryptedWrites.poll(); - if (write == null) { - break; - } - write.failAndRecycle(new ChannelException("Pending write on removal of SslHandler")); + if (!pendingUnencryptedWrites.isEmpty()) { + // Check if queue is not empty first because create a new ChannelException is expensive + pendingUnencryptedWrites.removeAndFailAll(new ChannelException("Pending write on removal of SslHandler")); } } @@ -414,7 +410,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - pendingUnencryptedWrites.add(PendingWrite.newInstance(msg, promise)); + pendingUnencryptedWrites.add(msg, promise); } @Override @@ -423,18 +419,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // created with startTLS flag turned on. if (startTls && !sentFirstMessage) { sentFirstMessage = true; - for (;;) { - PendingWrite pendingWrite = pendingUnencryptedWrites.poll(); - if (pendingWrite == null) { - break; - } - ctx.write(pendingWrite.msg(), (ChannelPromise) pendingWrite.recycleAndGet()); - } + pendingUnencryptedWrites.removeAndWriteAll(); ctx.flush(); return; } if (pendingUnencryptedWrites.isEmpty()) { - pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null)); + pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, ctx.voidPromise()); } if (!handshakePromise.isDone()) { flushedBeforeHandshakeDone = true; @@ -448,18 +438,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ChannelPromise promise = null; try { for (;;) { - PendingWrite pending = pendingUnencryptedWrites.peek(); - if (pending == null) { + Object msg = pendingUnencryptedWrites.current(); + if (msg == null) { break; } - if (!(pending.msg() instanceof ByteBuf)) { - ctx.write(pending.msg(), (ChannelPromise) pending.recycleAndGet()); - pendingUnencryptedWrites.remove(); + if (!(msg instanceof ByteBuf)) { + pendingUnencryptedWrites.removeAndWrite(); continue; } - ByteBuf buf = (ByteBuf) pending.msg(); + ByteBuf buf = (ByteBuf) msg; if (out == null) { out = allocateOutNetBuf(ctx, buf.readableBytes()); } @@ -467,9 +456,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH SSLEngineResult result = wrap(engine, buf, out); if (!buf.isReadable()) { - buf.release(); - promise = (ChannelPromise) pending.recycleAndGet(); - pendingUnencryptedWrites.remove(); + promise = pendingUnencryptedWrites.remove(); } else { promise = null; } @@ -477,13 +464,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH if (result.getStatus() == Status.CLOSED) { // SSLEngine has been closed already. // Any further write attempts should be denied. - for (;;) { - PendingWrite w = pendingUnencryptedWrites.poll(); - if (w == null) { - break; - } - w.failAndRecycle(SSLENGINE_CLOSED); - } + pendingUnencryptedWrites.removeAndFailAll(SSLENGINE_CLOSED); return; } else { switch (result.getHandshakeStatus()) { @@ -1134,13 +1115,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } } notifyHandshakeFailure(cause); - for (;;) { - PendingWrite write = pendingUnencryptedWrites.poll(); - if (write == null) { - break; - } - write.failAndRecycle(cause); - } + pendingUnencryptedWrites.removeAndFailAll(cause); } private void notifyHandshakeFailure(Throwable cause) { @@ -1172,6 +1147,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH @Override public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; + pendingUnencryptedWrites = new PendingWriteQueue(ctx); if (ctx.channel().isActive() && engine.getUseClientMode()) { // channelActive() event has been fired already, which means this.channelActive() will diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index f289a86243..cdfeb7aaf7 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -155,7 +155,7 @@ public final class ChannelOutboundBuffer { * Increment the pending bytes which will be written at some point. * This method is thread-safe! */ - void incrementPendingOutboundBytes(int size) { + void incrementPendingOutboundBytes(long size) { if (size == 0) { return; } @@ -172,7 +172,7 @@ public final class ChannelOutboundBuffer { * Decrement the pending bytes which will be written at some point. * This method is thread-safe! */ - void decrementPendingOutboundBytes(int size) { + void decrementPendingOutboundBytes(long size) { if (size == 0) { return; } diff --git a/transport/src/main/java/io/netty/channel/PendingWriteQueue.java b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java new file mode 100644 index 0000000000..a1a0e56a8c --- /dev/null +++ b/transport/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -0,0 +1,278 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.Recycler; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * A queue of write operations which are pending for later execution. It also updates the + * {@linkplain Channel#isWritable() writability} of the associated {@link Channel}, so that + * the pending write operations are also considered to determine the writability. + */ +public final class PendingWriteQueue { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PendingWriteQueue.class); + + private final ChannelHandlerContext ctx; + private final ChannelOutboundBuffer buffer; + private final MessageSizeEstimator.Handle estimatorHandle; + + // head and tail pointers for the linked-list structure. If empty head and tail are null. + private PendingWrite head; + private PendingWrite tail; + private int size; + + public PendingWriteQueue(ChannelHandlerContext ctx) { + if (ctx == null) { + throw new NullPointerException("ctx"); + } + this.ctx = ctx; + buffer = ctx.channel().unsafe().outboundBuffer(); + estimatorHandle = ctx.channel().config().getMessageSizeEstimator().newHandle(); + } + + /** + * Returns {@code true} if there are no pending write operations left in this queue. + */ + public boolean isEmpty() { + assert ctx.executor().inEventLoop(); + return head == null; + } + + /** + * Returns the number of pending write operations. + */ + public int size() { + assert ctx.executor().inEventLoop(); + return size; + } + + /** + * Add the given {@code msg} and {@link ChannelPromise}. + */ + public void add(Object msg, ChannelPromise promise) { + assert ctx.executor().inEventLoop(); + if (msg == null) { + throw new NullPointerException("msg"); + } + if (promise == null) { + throw new NullPointerException("promise"); + } + int messageSize = estimatorHandle.size(msg); + if (messageSize < 0) { + // Size may be unknow so just use 0 + messageSize = 0; + } + PendingWrite write = PendingWrite.newInstance(msg, messageSize, promise); + PendingWrite currentTail = tail; + if (currentTail == null) { + tail = head = write; + } else { + currentTail.next = write; + tail = write; + } + size ++; + buffer.incrementPendingOutboundBytes(write.size); + } + + /** + * Remove all pending write operation and fail them with the given {@link Throwable}. The message will be released + * via {@link ReferenceCountUtil#safeRelease(Object)}. + */ + public void removeAndFailAll(Throwable cause) { + assert ctx.executor().inEventLoop(); + if (cause == null) { + throw new NullPointerException("cause"); + } + PendingWrite write = head; + while (write != null) { + PendingWrite next = write.next; + ReferenceCountUtil.safeRelease(write.msg); + ChannelPromise promise = write.promise; + recycle(write); + safeFail(promise, cause); + write = next; + } + assertEmpty(); + } + + /** + * Remove a pending write operation and fail it with the given {@link Throwable}. The message will be released via + * {@link ReferenceCountUtil#safeRelease(Object)}. + */ + public void removeAndFail(Throwable cause) { + assert ctx.executor().inEventLoop(); + if (cause == null) { + throw new NullPointerException("cause"); + } + PendingWrite write = head; + if (write == null) { + return; + } + ReferenceCountUtil.safeRelease(write.msg); + ChannelPromise promise = write.promise; + safeFail(promise, cause); + recycle(write); + } + + /** + * Remove all pending write operation and performs them via + * {@link ChannelHandlerContext#write(Object, ChannelPromise)}. + * + * @return {@link ChannelFuture} if something was written and {@code null} + * if the {@link PendingWriteQueue} is empty. + */ + public ChannelFuture removeAndWriteAll() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + // empty so just return null + return null; + } + if (size == 1) { + // No need to use ChannelPromiseAggregator for this case. + return removeAndWrite(); + } + ChannelPromise p = ctx.newPromise(); + ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(p); + while (write != null) { + PendingWrite next = write.next; + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write); + ctx.write(msg, promise); + aggregator.add(promise); + write = next; + } + assertEmpty(); + return p; + } + + private void assertEmpty() { + assert tail == null && head == null && size == 0; + } + + /** + * Removes a pending write operation and performs it via + * {@link ChannelHandlerContext#write(Object, ChannelPromise)}. + * + * @return {@link ChannelFuture} if something was written and {@code null} + * if the {@link PendingWriteQueue} is empty. + */ + public ChannelFuture removeAndWrite() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write); + return ctx.write(msg, promise); + } + + /** + * Removes a pending write operation and release it's message via {@link ReferenceCountUtil#safeRelease(Object)}. + * + * @return {@link ChannelPromise} of the pending write or {@code null} if the queue is empty. + * + */ + public ChannelPromise remove() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + ChannelPromise promise = write.promise; + ReferenceCountUtil.safeRelease(write.msg); + recycle(write); + return promise; + } + + /** + * Return the current message or {@code null} if empty. + */ + public Object current() { + assert ctx.executor().inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + return write.msg; + } + + private void recycle(PendingWrite write) { + PendingWrite next = write.next; + + buffer.decrementPendingOutboundBytes(write.size); + write.recycle(); + size --; + if (next == null) { + // Handled last PendingWrite so rest head and tail + head = tail = null; + assert size == 0; + } else { + head = next; + assert size > 0; + } + } + + private static void safeFail(ChannelPromise promise, Throwable cause) { + if (!(promise instanceof VoidChannelPromise) && !promise.tryFailure(cause)) { + logger.warn("Failed to mark a promise as failure because it's done already: {}", promise, cause); + } + } + + /** + * Holds all meta-data and construct the linked-list structure. + */ + static final class PendingWrite { + private static final Recycler RECYCLER = new Recycler() { + @Override + protected PendingWrite newObject(Handle handle) { + return new PendingWrite(handle); + } + }; + + private final Recycler.Handle handle; + private PendingWrite next; + private long size; + private ChannelPromise promise; + private Object msg; + + private PendingWrite(Recycler.Handle handle) { + this.handle = handle; + } + + static PendingWrite newInstance(Object msg, int size, ChannelPromise promise) { + PendingWrite write = RECYCLER.get(); + write.size = size; + write.msg = msg; + write.promise = promise; + return write; + } + + private void recycle() { + size = 0; + next = null; + msg = null; + promise = null; + RECYCLER.recycle(this, handle); + } + } +} diff --git a/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java b/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java new file mode 100644 index 0000000000..c6d82d09a4 --- /dev/null +++ b/transport/src/test/java/io/netty/channel/PendingWriteQueueTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.Assert; +import org.junit.Test; + +public class PendingWriteQueueTest { + + @Test + public void testRemoveAndWrite() { + assertWrite(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable()); + + ChannelFuture future = queue.removeAndWrite(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertQueueEmpty(queue); + } + }); + super.flush(ctx); + } + }, 1); + } + + @Test + public void testRemoveAndWriteAll() { + assertWrite(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable()); + + ChannelFuture future = queue.removeAndWriteAll(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertQueueEmpty(queue); + } + }); + super.flush(ctx); + } + }, 3); + } + + @Test + public void testRemoveAndFail() { + assertWriteFails(new TestHandler() { + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + queue.removeAndFail(new TestException()); + super.flush(ctx); + } + }, 1); + } + + @Test + public void testRemoveAndFailAll() { + assertWriteFails(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + queue.removeAndFailAll(new TestException()); + super.flush(ctx); + } + }, 3); + } + + private static void assertWrite(ChannelHandler handler, int count) { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.config().setWriteBufferLowWaterMark(1); + channel.config().setWriteBufferHighWaterMark(3); + + ByteBuf[] buffers = new ByteBuf[count]; + for (int i = 0; i < buffers.length; i++) { + buffers[i] = buffer.duplicate().retain(); + } + Assert.assertTrue(channel.writeOutbound(buffers)); + Assert.assertTrue(channel.finish()); + channel.closeFuture().syncUninterruptibly(); + + for (int i = 0; i < buffers.length; i++) { + assertBuffer(channel, buffer); + } + buffer.release(); + Assert.assertNull(channel.readOutbound()); + } + + private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) { + ByteBuf written = (ByteBuf) channel.readOutbound(); + Assert.assertEquals(buffer, written); + written.release(); + } + + private static void assertQueueEmpty(PendingWriteQueue queue) { + Assert.assertTrue(queue.isEmpty()); + Assert.assertEquals(0, queue.size()); + Assert.assertNull(queue.current()); + Assert.assertNull(queue.removeAndWrite()); + Assert.assertNull(queue.removeAndWriteAll()); + } + + private static void assertWriteFails(ChannelHandler handler, int count) { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + ByteBuf[] buffers = new ByteBuf[count]; + for (int i = 0; i < buffers.length; i++) { + buffers[i] = buffer.duplicate().retain(); + } + try { + Assert.assertFalse(channel.writeOutbound(buffers)); + Assert.fail(); + } catch (Exception e) { + Assert.assertTrue(e instanceof TestException); + } + Assert.assertFalse(channel.finish()); + channel.closeFuture().syncUninterruptibly(); + + buffer.release(); + Assert.assertNull(channel.readOutbound()); + } + + private static class TestHandler extends ChannelDuplexHandler { + protected PendingWriteQueue queue; + private int expectedSize; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + assertQueueEmpty(queue); + Assert.assertTrue("Should be writable", ctx.channel().isWritable()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + queue.add(msg, promise); + Assert.assertFalse(queue.isEmpty()); + Assert.assertEquals(++ expectedSize, queue.size()); + Assert.assertNotNull(queue.current()); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + queue = new PendingWriteQueue(ctx); + } + } + + private static final class TestException extends Exception { } +}