[#2752] Add PendingWriteQueue for queue up writes

Motivation:

Sometimes ChannelHandler need to queue writes to some point and then process these. We currently have no datastructure for this so the user will use an Queue or something like this. The problem is with this Channel.isWritable() will not work as expected and so the user risk to write to fast. That's exactly what happened in our SslHandler. For this purpose we need to add a special datastructure which will also take care of update the Channel and so be sure that Channel.isWritable() works as expected.

Modifications:

- Add PendingWriteQueue which can be used for this purpose
- Make use of PendingWriteQueue in SslHandler

Result:

It is now possible to queue writes in a ChannelHandler and still have Channel.isWritable() working as expected. This also fixes #2752.
This commit is contained in:
Norman Maurer 2014-08-10 13:40:41 +02:00
parent 857713ad4c
commit 8f019ae4fa
4 changed files with 468 additions and 44 deletions

View File

@ -28,6 +28,7 @@ import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.PendingWriteQueue;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.concurrent.DefaultPromise; import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
@ -35,7 +36,6 @@ import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateExecutor; import io.netty.util.concurrent.ImmediateExecutor;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PendingWrite;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -51,9 +51,7 @@ import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel; import java.nio.channels.DatagramChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
@ -207,9 +205,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private final boolean startTls; private final boolean startTls;
private boolean sentFirstMessage; private boolean sentFirstMessage;
private boolean flushedBeforeHandshakeDone; private boolean flushedBeforeHandshakeDone;
private PendingWriteQueue pendingUnencryptedWrites;
private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise handshakePromise = new LazyChannelPromise();
private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise();
private final Deque<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
/** /**
* Set by wrap*() methods when something is produced. * Set by wrap*() methods when something is produced.
@ -370,12 +369,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
for (;;) { if (!pendingUnencryptedWrites.isEmpty()) {
PendingWrite write = pendingUnencryptedWrites.poll(); // Check if queue is not empty first because create a new ChannelException is expensive
if (write == null) { pendingUnencryptedWrites.removeAndFailAll(new ChannelException("Pending write on removal of SslHandler"));
break;
}
write.failAndRecycle(new ChannelException("Pending write on removal of SslHandler"));
} }
} }
@ -414,7 +410,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
pendingUnencryptedWrites.add(PendingWrite.newInstance(msg, promise)); pendingUnencryptedWrites.add(msg, promise);
} }
@Override @Override
@ -423,18 +419,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// created with startTLS flag turned on. // created with startTLS flag turned on.
if (startTls && !sentFirstMessage) { if (startTls && !sentFirstMessage) {
sentFirstMessage = true; sentFirstMessage = true;
for (;;) { pendingUnencryptedWrites.removeAndWriteAll();
PendingWrite pendingWrite = pendingUnencryptedWrites.poll();
if (pendingWrite == null) {
break;
}
ctx.write(pendingWrite.msg(), (ChannelPromise) pendingWrite.recycleAndGet());
}
ctx.flush(); ctx.flush();
return; return;
} }
if (pendingUnencryptedWrites.isEmpty()) { if (pendingUnencryptedWrites.isEmpty()) {
pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null)); pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, ctx.voidPromise());
} }
if (!handshakePromise.isDone()) { if (!handshakePromise.isDone()) {
flushedBeforeHandshakeDone = true; flushedBeforeHandshakeDone = true;
@ -448,18 +438,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ChannelPromise promise = null; ChannelPromise promise = null;
try { try {
for (;;) { for (;;) {
PendingWrite pending = pendingUnencryptedWrites.peek(); Object msg = pendingUnencryptedWrites.current();
if (pending == null) { if (msg == null) {
break; break;
} }
if (!(pending.msg() instanceof ByteBuf)) { if (!(msg instanceof ByteBuf)) {
ctx.write(pending.msg(), (ChannelPromise) pending.recycleAndGet()); pendingUnencryptedWrites.removeAndWrite();
pendingUnencryptedWrites.remove();
continue; continue;
} }
ByteBuf buf = (ByteBuf) pending.msg(); ByteBuf buf = (ByteBuf) msg;
if (out == null) { if (out == null) {
out = allocateOutNetBuf(ctx, buf.readableBytes()); out = allocateOutNetBuf(ctx, buf.readableBytes());
} }
@ -467,9 +456,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
SSLEngineResult result = wrap(engine, buf, out); SSLEngineResult result = wrap(engine, buf, out);
if (!buf.isReadable()) { if (!buf.isReadable()) {
buf.release(); promise = pendingUnencryptedWrites.remove();
promise = (ChannelPromise) pending.recycleAndGet();
pendingUnencryptedWrites.remove();
} else { } else {
promise = null; promise = null;
} }
@ -477,13 +464,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (result.getStatus() == Status.CLOSED) { if (result.getStatus() == Status.CLOSED) {
// SSLEngine has been closed already. // SSLEngine has been closed already.
// Any further write attempts should be denied. // Any further write attempts should be denied.
for (;;) { pendingUnencryptedWrites.removeAndFailAll(SSLENGINE_CLOSED);
PendingWrite w = pendingUnencryptedWrites.poll();
if (w == null) {
break;
}
w.failAndRecycle(SSLENGINE_CLOSED);
}
return; return;
} else { } else {
switch (result.getHandshakeStatus()) { switch (result.getHandshakeStatus()) {
@ -1134,13 +1115,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
} }
notifyHandshakeFailure(cause); notifyHandshakeFailure(cause);
for (;;) { pendingUnencryptedWrites.removeAndFailAll(cause);
PendingWrite write = pendingUnencryptedWrites.poll();
if (write == null) {
break;
}
write.failAndRecycle(cause);
}
} }
private void notifyHandshakeFailure(Throwable cause) { private void notifyHandshakeFailure(Throwable cause) {
@ -1172,6 +1147,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx; this.ctx = ctx;
pendingUnencryptedWrites = new PendingWriteQueue(ctx);
if (ctx.channel().isActive() && engine.getUseClientMode()) { if (ctx.channel().isActive() && engine.getUseClientMode()) {
// channelActive() event has been fired already, which means this.channelActive() will // channelActive() event has been fired already, which means this.channelActive() will

View File

@ -155,7 +155,7 @@ public final class ChannelOutboundBuffer {
* Increment the pending bytes which will be written at some point. * Increment the pending bytes which will be written at some point.
* This method is thread-safe! * This method is thread-safe!
*/ */
void incrementPendingOutboundBytes(int size) { void incrementPendingOutboundBytes(long size) {
if (size == 0) { if (size == 0) {
return; return;
} }
@ -172,7 +172,7 @@ public final class ChannelOutboundBuffer {
* Decrement the pending bytes which will be written at some point. * Decrement the pending bytes which will be written at some point.
* This method is thread-safe! * This method is thread-safe!
*/ */
void decrementPendingOutboundBytes(int size) { void decrementPendingOutboundBytes(long size) {
if (size == 0) { if (size == 0) {
return; return;
} }

View File

@ -0,0 +1,278 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel;
import io.netty.util.Recycler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
/**
* A queue of write operations which are pending for later execution. It also updates the
* {@linkplain Channel#isWritable() writability} of the associated {@link Channel}, so that
* the pending write operations are also considered to determine the writability.
*/
public final class PendingWriteQueue {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(PendingWriteQueue.class);
private final ChannelHandlerContext ctx;
private final ChannelOutboundBuffer buffer;
private final MessageSizeEstimator.Handle estimatorHandle;
// head and tail pointers for the linked-list structure. If empty head and tail are null.
private PendingWrite head;
private PendingWrite tail;
private int size;
public PendingWriteQueue(ChannelHandlerContext ctx) {
if (ctx == null) {
throw new NullPointerException("ctx");
}
this.ctx = ctx;
buffer = ctx.channel().unsafe().outboundBuffer();
estimatorHandle = ctx.channel().config().getMessageSizeEstimator().newHandle();
}
/**
* Returns {@code true} if there are no pending write operations left in this queue.
*/
public boolean isEmpty() {
assert ctx.executor().inEventLoop();
return head == null;
}
/**
* Returns the number of pending write operations.
*/
public int size() {
assert ctx.executor().inEventLoop();
return size;
}
/**
* Add the given {@code msg} and {@link ChannelPromise}.
*/
public void add(Object msg, ChannelPromise promise) {
assert ctx.executor().inEventLoop();
if (msg == null) {
throw new NullPointerException("msg");
}
if (promise == null) {
throw new NullPointerException("promise");
}
int messageSize = estimatorHandle.size(msg);
if (messageSize < 0) {
// Size may be unknow so just use 0
messageSize = 0;
}
PendingWrite write = PendingWrite.newInstance(msg, messageSize, promise);
PendingWrite currentTail = tail;
if (currentTail == null) {
tail = head = write;
} else {
currentTail.next = write;
tail = write;
}
size ++;
buffer.incrementPendingOutboundBytes(write.size);
}
/**
* Remove all pending write operation and fail them with the given {@link Throwable}. The message will be released
* via {@link ReferenceCountUtil#safeRelease(Object)}.
*/
public void removeAndFailAll(Throwable cause) {
assert ctx.executor().inEventLoop();
if (cause == null) {
throw new NullPointerException("cause");
}
PendingWrite write = head;
while (write != null) {
PendingWrite next = write.next;
ReferenceCountUtil.safeRelease(write.msg);
ChannelPromise promise = write.promise;
recycle(write);
safeFail(promise, cause);
write = next;
}
assertEmpty();
}
/**
* Remove a pending write operation and fail it with the given {@link Throwable}. The message will be released via
* {@link ReferenceCountUtil#safeRelease(Object)}.
*/
public void removeAndFail(Throwable cause) {
assert ctx.executor().inEventLoop();
if (cause == null) {
throw new NullPointerException("cause");
}
PendingWrite write = head;
if (write == null) {
return;
}
ReferenceCountUtil.safeRelease(write.msg);
ChannelPromise promise = write.promise;
safeFail(promise, cause);
recycle(write);
}
/**
* Remove all pending write operation and performs them via
* {@link ChannelHandlerContext#write(Object, ChannelPromise)}.
*
* @return {@link ChannelFuture} if something was written and {@code null}
* if the {@link PendingWriteQueue} is empty.
*/
public ChannelFuture removeAndWriteAll() {
assert ctx.executor().inEventLoop();
PendingWrite write = head;
if (write == null) {
// empty so just return null
return null;
}
if (size == 1) {
// No need to use ChannelPromiseAggregator for this case.
return removeAndWrite();
}
ChannelPromise p = ctx.newPromise();
ChannelPromiseAggregator aggregator = new ChannelPromiseAggregator(p);
while (write != null) {
PendingWrite next = write.next;
Object msg = write.msg;
ChannelPromise promise = write.promise;
recycle(write);
ctx.write(msg, promise);
aggregator.add(promise);
write = next;
}
assertEmpty();
return p;
}
private void assertEmpty() {
assert tail == null && head == null && size == 0;
}
/**
* Removes a pending write operation and performs it via
* {@link ChannelHandlerContext#write(Object, ChannelPromise)}.
*
* @return {@link ChannelFuture} if something was written and {@code null}
* if the {@link PendingWriteQueue} is empty.
*/
public ChannelFuture removeAndWrite() {
assert ctx.executor().inEventLoop();
PendingWrite write = head;
if (write == null) {
return null;
}
Object msg = write.msg;
ChannelPromise promise = write.promise;
recycle(write);
return ctx.write(msg, promise);
}
/**
* Removes a pending write operation and release it's message via {@link ReferenceCountUtil#safeRelease(Object)}.
*
* @return {@link ChannelPromise} of the pending write or {@code null} if the queue is empty.
*
*/
public ChannelPromise remove() {
assert ctx.executor().inEventLoop();
PendingWrite write = head;
if (write == null) {
return null;
}
ChannelPromise promise = write.promise;
ReferenceCountUtil.safeRelease(write.msg);
recycle(write);
return promise;
}
/**
* Return the current message or {@code null} if empty.
*/
public Object current() {
assert ctx.executor().inEventLoop();
PendingWrite write = head;
if (write == null) {
return null;
}
return write.msg;
}
private void recycle(PendingWrite write) {
PendingWrite next = write.next;
buffer.decrementPendingOutboundBytes(write.size);
write.recycle();
size --;
if (next == null) {
// Handled last PendingWrite so rest head and tail
head = tail = null;
assert size == 0;
} else {
head = next;
assert size > 0;
}
}
private static void safeFail(ChannelPromise promise, Throwable cause) {
if (!(promise instanceof VoidChannelPromise) && !promise.tryFailure(cause)) {
logger.warn("Failed to mark a promise as failure because it's done already: {}", promise, cause);
}
}
/**
* Holds all meta-data and construct the linked-list structure.
*/
static final class PendingWrite {
private static final Recycler<PendingWrite> RECYCLER = new Recycler<PendingWrite>() {
@Override
protected PendingWrite newObject(Handle handle) {
return new PendingWrite(handle);
}
};
private final Recycler.Handle handle;
private PendingWrite next;
private long size;
private ChannelPromise promise;
private Object msg;
private PendingWrite(Recycler.Handle handle) {
this.handle = handle;
}
static PendingWrite newInstance(Object msg, int size, ChannelPromise promise) {
PendingWrite write = RECYCLER.get();
write.size = size;
write.msg = msg;
write.promise = promise;
return write;
}
private void recycle() {
size = 0;
next = null;
msg = null;
promise = null;
RECYCLER.recycle(this, handle);
}
}
}

View File

@ -0,0 +1,170 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test;
public class PendingWriteQueueTest {
@Test
public void testRemoveAndWrite() {
assertWrite(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable());
ChannelFuture future = queue.removeAndWrite();
future.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
assertQueueEmpty(queue);
}
});
super.flush(ctx);
}
}, 1);
}
@Test
public void testRemoveAndWriteAll() {
assertWrite(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
Assert.assertFalse("Should not be writable anymore", ctx.channel().isWritable());
ChannelFuture future = queue.removeAndWriteAll();
future.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
assertQueueEmpty(queue);
}
});
super.flush(ctx);
}
}, 3);
}
@Test
public void testRemoveAndFail() {
assertWriteFails(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
queue.removeAndFail(new TestException());
super.flush(ctx);
}
}, 1);
}
@Test
public void testRemoveAndFailAll() {
assertWriteFails(new TestHandler() {
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
queue.removeAndFailAll(new TestException());
super.flush(ctx);
}
}, 3);
}
private static void assertWrite(ChannelHandler handler, int count) {
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(handler);
channel.config().setWriteBufferLowWaterMark(1);
channel.config().setWriteBufferHighWaterMark(3);
ByteBuf[] buffers = new ByteBuf[count];
for (int i = 0; i < buffers.length; i++) {
buffers[i] = buffer.duplicate().retain();
}
Assert.assertTrue(channel.writeOutbound(buffers));
Assert.assertTrue(channel.finish());
channel.closeFuture().syncUninterruptibly();
for (int i = 0; i < buffers.length; i++) {
assertBuffer(channel, buffer);
}
buffer.release();
Assert.assertNull(channel.readOutbound());
}
private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) {
ByteBuf written = (ByteBuf) channel.readOutbound();
Assert.assertEquals(buffer, written);
written.release();
}
private static void assertQueueEmpty(PendingWriteQueue queue) {
Assert.assertTrue(queue.isEmpty());
Assert.assertEquals(0, queue.size());
Assert.assertNull(queue.current());
Assert.assertNull(queue.removeAndWrite());
Assert.assertNull(queue.removeAndWriteAll());
}
private static void assertWriteFails(ChannelHandler handler, int count) {
final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII);
final EmbeddedChannel channel = new EmbeddedChannel(handler);
ByteBuf[] buffers = new ByteBuf[count];
for (int i = 0; i < buffers.length; i++) {
buffers[i] = buffer.duplicate().retain();
}
try {
Assert.assertFalse(channel.writeOutbound(buffers));
Assert.fail();
} catch (Exception e) {
Assert.assertTrue(e instanceof TestException);
}
Assert.assertFalse(channel.finish());
channel.closeFuture().syncUninterruptibly();
buffer.release();
Assert.assertNull(channel.readOutbound());
}
private static class TestHandler extends ChannelDuplexHandler {
protected PendingWriteQueue queue;
private int expectedSize;
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
assertQueueEmpty(queue);
Assert.assertTrue("Should be writable", ctx.channel().isWritable());
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
queue.add(msg, promise);
Assert.assertFalse(queue.isEmpty());
Assert.assertEquals(++ expectedSize, queue.size());
Assert.assertNotNull(queue.current());
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
queue = new PendingWriteQueue(ctx);
}
}
private static final class TestException extends Exception { }
}