Allow to cancel non-flushed writes
This commit is contained in:
parent
d52dc3b740
commit
7041a9238e
@ -0,0 +1,122 @@
|
||||
/*
|
||||
* Copyright 2014 The Netty Project
|
||||
*
|
||||
* The Netty Project licenses this file to you under the Apache License,
|
||||
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at:
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package io.netty.testsuite.transport.socket;
|
||||
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class SocketCancelWriteTest extends AbstractSocketTest {
|
||||
|
||||
@Test(timeout = 30000)
|
||||
public void testCancelWrite() throws Throwable {
|
||||
run();
|
||||
}
|
||||
|
||||
public void testCancelWrite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
|
||||
final TestHandler sh = new TestHandler();
|
||||
final TestHandler ch = new TestHandler();
|
||||
final ByteBuf a = Unpooled.buffer().writeByte('a');
|
||||
final ByteBuf b = Unpooled.buffer().writeByte('b');
|
||||
final ByteBuf c = Unpooled.buffer().writeByte('c');
|
||||
final ByteBuf d = Unpooled.buffer().writeByte('d');
|
||||
final ByteBuf e = Unpooled.buffer().writeByte('e');
|
||||
|
||||
cb.handler(ch);
|
||||
sb.childHandler(sh);
|
||||
|
||||
Channel sc = sb.bind().sync().channel();
|
||||
Channel cc = cb.connect().sync().channel();
|
||||
|
||||
ChannelFuture f = cc.write(a);
|
||||
assertTrue(f.cancel(false));
|
||||
cc.writeAndFlush(b);
|
||||
cc.write(c);
|
||||
ChannelFuture f2 = cc.write(d);
|
||||
assertTrue(f2.cancel(false));
|
||||
cc.writeAndFlush(e);
|
||||
|
||||
while (sh.counter < 3) {
|
||||
if (sh.exception.get() != null) {
|
||||
break;
|
||||
}
|
||||
if (ch.exception.get() != null) {
|
||||
break;
|
||||
}
|
||||
try {
|
||||
Thread.sleep(50);
|
||||
} catch (InterruptedException ignore) {
|
||||
// Ignore.
|
||||
}
|
||||
}
|
||||
sh.channel.close().sync();
|
||||
ch.channel.close().sync();
|
||||
sc.close().sync();
|
||||
|
||||
if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
|
||||
throw sh.exception.get();
|
||||
}
|
||||
if (sh.exception.get() != null) {
|
||||
throw sh.exception.get();
|
||||
}
|
||||
if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
|
||||
throw ch.exception.get();
|
||||
}
|
||||
if (ch.exception.get() != null) {
|
||||
throw ch.exception.get();
|
||||
}
|
||||
assertEquals(0, ch.counter);
|
||||
assertEquals(Unpooled.wrappedBuffer(new byte[]{'b', 'c', 'e'}), sh.received);
|
||||
}
|
||||
|
||||
private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
|
||||
volatile Channel channel;
|
||||
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
|
||||
volatile int counter;
|
||||
final ByteBuf received = Unpooled.buffer();
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx)
|
||||
throws Exception {
|
||||
channel = ctx.channel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
|
||||
counter += in.readableBytes();
|
||||
received.writeBytes(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx,
|
||||
Throwable cause) throws Exception {
|
||||
if (exception.compareAndSet(null, cause)) {
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -22,6 +22,7 @@ package io.netty.channel;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.buffer.ByteBufHolder;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.socket.nio.NioSocketChannel;
|
||||
import io.netty.util.Recycler;
|
||||
import io.netty.util.Recycler.Handle;
|
||||
@ -238,7 +239,14 @@ public final class ChannelOutboundBuffer {
|
||||
if (isEmpty()) {
|
||||
return null;
|
||||
} else {
|
||||
return buffer[flushed].msg;
|
||||
Entry entry = buffer[flushed];
|
||||
if (!entry.cancelled && !entry.promise.setUncancellable()) {
|
||||
// Was cancelled so make sure we free up memory and notify about the freed bytes
|
||||
int pending = entry.cancel();
|
||||
decrementPendingOutboundBytes(pending);
|
||||
}
|
||||
|
||||
return entry.msg;
|
||||
}
|
||||
}
|
||||
|
||||
@ -280,9 +288,12 @@ public final class ChannelOutboundBuffer {
|
||||
|
||||
flushed = flushed + 1 & buffer.length - 1;
|
||||
|
||||
safeRelease(msg);
|
||||
safeSuccess(promise);
|
||||
decrementPendingOutboundBytes(size);
|
||||
if (!e.cancelled) {
|
||||
// only release message, notify and decrement if it was not canceled before.
|
||||
safeRelease(msg);
|
||||
safeSuccess(promise);
|
||||
decrementPendingOutboundBytes(size);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -305,10 +316,13 @@ public final class ChannelOutboundBuffer {
|
||||
|
||||
flushed = flushed + 1 & buffer.length - 1;
|
||||
|
||||
safeRelease(msg);
|
||||
if (!e.cancelled) {
|
||||
// only release message, fail and decrement if it was not canceled before.
|
||||
safeRelease(msg);
|
||||
|
||||
safeFail(promise, cause);
|
||||
decrementPendingOutboundBytes(size);
|
||||
safeFail(promise, cause);
|
||||
decrementPendingOutboundBytes(size);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -340,41 +354,51 @@ public final class ChannelOutboundBuffer {
|
||||
}
|
||||
|
||||
Entry entry = buffer[i];
|
||||
ByteBuf buf = (ByteBuf) m;
|
||||
final int readerIndex = buf.readerIndex();
|
||||
final int readableBytes = buf.writerIndex() - readerIndex;
|
||||
|
||||
if (readableBytes > 0) {
|
||||
nioBufferSize += readableBytes;
|
||||
int count = entry.count;
|
||||
if (count == -1) {
|
||||
//noinspection ConstantValueVariableUse
|
||||
entry.count = count = buf.nioBufferCount();
|
||||
}
|
||||
int neededSpace = nioBufferCount + count;
|
||||
if (neededSpace > nioBuffers.length) {
|
||||
this.nioBuffers = nioBuffers = expandNioBufferArray(nioBuffers, neededSpace, nioBufferCount);
|
||||
}
|
||||
if (buf.isDirect() || !alloc.isDirectBufferPooled()) {
|
||||
if (count == 1) {
|
||||
ByteBuffer nioBuf = entry.buf;
|
||||
if (nioBuf == null) {
|
||||
// cache ByteBuffer as it may need to create a new ByteBuffer instance if its a
|
||||
// derived buffer
|
||||
entry.buf = nioBuf = buf.internalNioBuffer(readerIndex, readableBytes);
|
||||
}
|
||||
nioBuffers[nioBufferCount ++] = nioBuf;
|
||||
} else {
|
||||
ByteBuffer[] nioBufs = entry.buffers;
|
||||
if (nioBufs == null) {
|
||||
// cached ByteBuffers as they may be expensive to create in terms of Object allocation
|
||||
entry.buffers = nioBufs = buf.nioBuffers();
|
||||
}
|
||||
nioBufferCount = fillBufferArray(nioBufs, nioBuffers, nioBufferCount);
|
||||
}
|
||||
if (!entry.cancelled) {
|
||||
if (!entry.promise.setUncancellable()) {
|
||||
// Was cancelled so make sure we free up memory and notify about the freed bytes
|
||||
int pending = entry.cancel();
|
||||
decrementPendingOutboundBytes(pending);
|
||||
} else {
|
||||
nioBufferCount = fillBufferArrayNonDirect(entry, buf, readerIndex,
|
||||
readableBytes, alloc, nioBuffers, nioBufferCount);
|
||||
ByteBuf buf = (ByteBuf) m;
|
||||
final int readerIndex = buf.readerIndex();
|
||||
final int readableBytes = buf.writerIndex() - readerIndex;
|
||||
|
||||
if (readableBytes > 0) {
|
||||
nioBufferSize += readableBytes;
|
||||
int count = entry.count;
|
||||
if (count == -1) {
|
||||
//noinspection ConstantValueVariableUse
|
||||
entry.count = count = buf.nioBufferCount();
|
||||
}
|
||||
int neededSpace = nioBufferCount + count;
|
||||
if (neededSpace > nioBuffers.length) {
|
||||
this.nioBuffers = nioBuffers =
|
||||
expandNioBufferArray(nioBuffers, neededSpace, nioBufferCount);
|
||||
}
|
||||
if (buf.isDirect() || !alloc.isDirectBufferPooled()) {
|
||||
if (count == 1) {
|
||||
ByteBuffer nioBuf = entry.buf;
|
||||
if (nioBuf == null) {
|
||||
// cache ByteBuffer as it may need to create a new ByteBuffer instance if its a
|
||||
// derived buffer
|
||||
entry.buf = nioBuf = buf.internalNioBuffer(readerIndex, readableBytes);
|
||||
}
|
||||
nioBuffers[nioBufferCount ++] = nioBuf;
|
||||
} else {
|
||||
ByteBuffer[] nioBufs = entry.buffers;
|
||||
if (nioBufs == null) {
|
||||
// cached ByteBuffers as they may be expensive to create in terms
|
||||
// of Object allocation
|
||||
entry.buffers = nioBufs = buf.nioBuffers();
|
||||
}
|
||||
nioBufferCount = fillBufferArray(nioBufs, nioBuffers, nioBufferCount);
|
||||
}
|
||||
} else {
|
||||
nioBufferCount = fillBufferArrayNonDirect(entry, buf, readerIndex,
|
||||
readableBytes, alloc, nioBuffers, nioBufferCount);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i = i + 1 & mask;
|
||||
@ -495,10 +519,6 @@ public final class ChannelOutboundBuffer {
|
||||
try {
|
||||
for (int i = 0; i < unflushedCount; i++) {
|
||||
Entry e = buffer[unflushed + i & buffer.length - 1];
|
||||
safeRelease(e.msg);
|
||||
e.msg = null;
|
||||
safeFail(e.promise, cause);
|
||||
e.promise = null;
|
||||
|
||||
// Just decrease; do not trigger any events via decrementPendingOutboundBytes()
|
||||
int size = e.pendingSize;
|
||||
@ -510,6 +530,12 @@ public final class ChannelOutboundBuffer {
|
||||
}
|
||||
|
||||
e.pendingSize = 0;
|
||||
if (!e.cancelled) {
|
||||
safeRelease(e.msg);
|
||||
safeFail(e.promise, cause);
|
||||
}
|
||||
e.msg = null;
|
||||
e.promise = null;
|
||||
}
|
||||
} finally {
|
||||
tail = unflushed;
|
||||
@ -579,6 +605,26 @@ public final class ChannelOutboundBuffer {
|
||||
long total;
|
||||
int pendingSize;
|
||||
int count = -1;
|
||||
boolean cancelled;
|
||||
|
||||
public int cancel() {
|
||||
if (!cancelled) {
|
||||
cancelled = true;
|
||||
int pSize = pendingSize;
|
||||
|
||||
// release message and replace with an empty buffer
|
||||
safeRelease(msg);
|
||||
msg = Unpooled.EMPTY_BUFFER;
|
||||
|
||||
pendingSize = 0;
|
||||
total = 0;
|
||||
progress = 0;
|
||||
buffers = null;
|
||||
buf = null;
|
||||
return pSize;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public void clear() {
|
||||
buffers = null;
|
||||
@ -589,6 +635,7 @@ public final class ChannelOutboundBuffer {
|
||||
total = 0;
|
||||
pendingSize = 0;
|
||||
count = -1;
|
||||
cancelled = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user