Allow to cancel non-flushed writes

This commit is contained in:
Norman Maurer 2014-02-07 20:52:37 +01:00
parent d52dc3b740
commit 7041a9238e
2 changed files with 214 additions and 45 deletions

View File

@ -0,0 +1,122 @@
/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import org.junit.Test;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class SocketCancelWriteTest extends AbstractSocketTest {
@Test(timeout = 30000)
public void testCancelWrite() throws Throwable {
run();
}
public void testCancelWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
final TestHandler sh = new TestHandler();
final TestHandler ch = new TestHandler();
final ByteBuf a = Unpooled.buffer().writeByte('a');
final ByteBuf b = Unpooled.buffer().writeByte('b');
final ByteBuf c = Unpooled.buffer().writeByte('c');
final ByteBuf d = Unpooled.buffer().writeByte('d');
final ByteBuf e = Unpooled.buffer().writeByte('e');
cb.handler(ch);
sb.childHandler(sh);
Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel();
ChannelFuture f = cc.write(a);
assertTrue(f.cancel(false));
cc.writeAndFlush(b);
cc.write(c);
ChannelFuture f2 = cc.write(d);
assertTrue(f2.cancel(false));
cc.writeAndFlush(e);
while (sh.counter < 3) {
if (sh.exception.get() != null) {
break;
}
if (ch.exception.get() != null) {
break;
}
try {
Thread.sleep(50);
} catch (InterruptedException ignore) {
// Ignore.
}
}
sh.channel.close().sync();
ch.channel.close().sync();
sc.close().sync();
if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
throw sh.exception.get();
}
if (sh.exception.get() != null) {
throw sh.exception.get();
}
if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
throw ch.exception.get();
}
if (ch.exception.get() != null) {
throw ch.exception.get();
}
assertEquals(0, ch.counter);
assertEquals(Unpooled.wrappedBuffer(new byte[]{'b', 'c', 'e'}), sh.received);
}
private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
volatile Channel channel;
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
volatile int counter;
final ByteBuf received = Unpooled.buffer();
@Override
public void channelActive(ChannelHandlerContext ctx)
throws Exception {
channel = ctx.channel();
}
@Override
public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
counter += in.readableBytes();
received.writeBytes(in);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx,
Throwable cause) throws Exception {
if (exception.compareAndSet(null, cause)) {
ctx.close();
}
}
}
}

View File

@ -22,6 +22,7 @@ package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.Unpooled;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.Recycler;
import io.netty.util.Recycler.Handle;
@ -238,7 +239,14 @@ public final class ChannelOutboundBuffer {
if (isEmpty()) {
return null;
} else {
return buffer[flushed].msg;
Entry entry = buffer[flushed];
if (!entry.cancelled && !entry.promise.setUncancellable()) {
// Was cancelled so make sure we free up memory and notify about the freed bytes
int pending = entry.cancel();
decrementPendingOutboundBytes(pending);
}
return entry.msg;
}
}
@ -280,9 +288,12 @@ public final class ChannelOutboundBuffer {
flushed = flushed + 1 & buffer.length - 1;
safeRelease(msg);
safeSuccess(promise);
decrementPendingOutboundBytes(size);
if (!e.cancelled) {
// only release message, notify and decrement if it was not canceled before.
safeRelease(msg);
safeSuccess(promise);
decrementPendingOutboundBytes(size);
}
return true;
}
@ -305,10 +316,13 @@ public final class ChannelOutboundBuffer {
flushed = flushed + 1 & buffer.length - 1;
safeRelease(msg);
if (!e.cancelled) {
// only release message, fail and decrement if it was not canceled before.
safeRelease(msg);
safeFail(promise, cause);
decrementPendingOutboundBytes(size);
safeFail(promise, cause);
decrementPendingOutboundBytes(size);
}
return true;
}
@ -340,41 +354,51 @@ public final class ChannelOutboundBuffer {
}
Entry entry = buffer[i];
ByteBuf buf = (ByteBuf) m;
final int readerIndex = buf.readerIndex();
final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes > 0) {
nioBufferSize += readableBytes;
int count = entry.count;
if (count == -1) {
//noinspection ConstantValueVariableUse
entry.count = count = buf.nioBufferCount();
}
int neededSpace = nioBufferCount + count;
if (neededSpace > nioBuffers.length) {
this.nioBuffers = nioBuffers = expandNioBufferArray(nioBuffers, neededSpace, nioBufferCount);
}
if (buf.isDirect() || !alloc.isDirectBufferPooled()) {
if (count == 1) {
ByteBuffer nioBuf = entry.buf;
if (nioBuf == null) {
// cache ByteBuffer as it may need to create a new ByteBuffer instance if its a
// derived buffer
entry.buf = nioBuf = buf.internalNioBuffer(readerIndex, readableBytes);
}
nioBuffers[nioBufferCount ++] = nioBuf;
} else {
ByteBuffer[] nioBufs = entry.buffers;
if (nioBufs == null) {
// cached ByteBuffers as they may be expensive to create in terms of Object allocation
entry.buffers = nioBufs = buf.nioBuffers();
}
nioBufferCount = fillBufferArray(nioBufs, nioBuffers, nioBufferCount);
}
if (!entry.cancelled) {
if (!entry.promise.setUncancellable()) {
// Was cancelled so make sure we free up memory and notify about the freed bytes
int pending = entry.cancel();
decrementPendingOutboundBytes(pending);
} else {
nioBufferCount = fillBufferArrayNonDirect(entry, buf, readerIndex,
readableBytes, alloc, nioBuffers, nioBufferCount);
ByteBuf buf = (ByteBuf) m;
final int readerIndex = buf.readerIndex();
final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes > 0) {
nioBufferSize += readableBytes;
int count = entry.count;
if (count == -1) {
//noinspection ConstantValueVariableUse
entry.count = count = buf.nioBufferCount();
}
int neededSpace = nioBufferCount + count;
if (neededSpace > nioBuffers.length) {
this.nioBuffers = nioBuffers =
expandNioBufferArray(nioBuffers, neededSpace, nioBufferCount);
}
if (buf.isDirect() || !alloc.isDirectBufferPooled()) {
if (count == 1) {
ByteBuffer nioBuf = entry.buf;
if (nioBuf == null) {
// cache ByteBuffer as it may need to create a new ByteBuffer instance if its a
// derived buffer
entry.buf = nioBuf = buf.internalNioBuffer(readerIndex, readableBytes);
}
nioBuffers[nioBufferCount ++] = nioBuf;
} else {
ByteBuffer[] nioBufs = entry.buffers;
if (nioBufs == null) {
// cached ByteBuffers as they may be expensive to create in terms
// of Object allocation
entry.buffers = nioBufs = buf.nioBuffers();
}
nioBufferCount = fillBufferArray(nioBufs, nioBuffers, nioBufferCount);
}
} else {
nioBufferCount = fillBufferArrayNonDirect(entry, buf, readerIndex,
readableBytes, alloc, nioBuffers, nioBufferCount);
}
}
}
}
i = i + 1 & mask;
@ -495,10 +519,6 @@ public final class ChannelOutboundBuffer {
try {
for (int i = 0; i < unflushedCount; i++) {
Entry e = buffer[unflushed + i & buffer.length - 1];
safeRelease(e.msg);
e.msg = null;
safeFail(e.promise, cause);
e.promise = null;
// Just decrease; do not trigger any events via decrementPendingOutboundBytes()
int size = e.pendingSize;
@ -510,6 +530,12 @@ public final class ChannelOutboundBuffer {
}
e.pendingSize = 0;
if (!e.cancelled) {
safeRelease(e.msg);
safeFail(e.promise, cause);
}
e.msg = null;
e.promise = null;
}
} finally {
tail = unflushed;
@ -579,6 +605,26 @@ public final class ChannelOutboundBuffer {
long total;
int pendingSize;
int count = -1;
boolean cancelled;
public int cancel() {
if (!cancelled) {
cancelled = true;
int pSize = pendingSize;
// release message and replace with an empty buffer
safeRelease(msg);
msg = Unpooled.EMPTY_BUFFER;
pendingSize = 0;
total = 0;
progress = 0;
buffers = null;
buf = null;
return pSize;
}
return 0;
}
public void clear() {
buffers = null;
@ -589,6 +635,7 @@ public final class ChannelOutboundBuffer {
total = 0;
pendingSize = 0;
count = -1;
cancelled = false;
}
}