From 2c4a7a253912458908d45dc653b52ff3a3aaab99 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Sun, 11 Sep 2016 09:27:58 +0200 Subject: [PATCH] [#5800] Support any FileRegion implementation when using epoll transport Motivation: At the moment only DefaultFileRegion is supported when using the native epoll transport. Modification: - Add support for any FileRegion implementation - Add test case Result: Also custom FileRegion implementation are supported when using the epoll transport. --- .../socket/SocketFileRegionTest.java | 98 +++++++++++++++- .../epoll/AbstractEpollStreamChannel.java | 110 +++++++++++++++++- 2 files changed, 197 insertions(+), 11 deletions(-) diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java index 4a44c62171..6b7ad3ca0e 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketFileRegionTest.java @@ -33,6 +33,7 @@ import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import java.nio.channels.WritableByteChannel; import java.util.Random; import java.util.concurrent.atomic.AtomicReference; @@ -52,6 +53,11 @@ public class SocketFileRegionTest extends AbstractSocketTest { run(); } + @Test + public void testCustomFileRegion() throws Throwable { + run(); + } + @Test public void testFileRegionNotAutoRead() throws Throwable { run(); @@ -68,23 +74,28 @@ public class SocketFileRegionTest extends AbstractSocketTest { } public void testFileRegion(ServerBootstrap sb, Bootstrap cb) throws Throwable { - testFileRegion0(sb, cb, false, true); + testFileRegion0(sb, cb, false, true, true); + } + + public void testCustomFileRegion(ServerBootstrap sb, Bootstrap cb) throws Throwable { + testFileRegion0(sb, cb, false, true, false); } public void testFileRegionVoidPromise(ServerBootstrap sb, Bootstrap cb) throws Throwable { - testFileRegion0(sb, cb, true, true); + testFileRegion0(sb, cb, true, true, true); } public void testFileRegionNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable { - testFileRegion0(sb, cb, false, false); + testFileRegion0(sb, cb, false, false, true); } public void testFileRegionVoidPromiseNotAutoRead(ServerBootstrap sb, Bootstrap cb) throws Throwable { - testFileRegion0(sb, cb, true, false); + testFileRegion0(sb, cb, true, false, true); } private static void testFileRegion0( - ServerBootstrap sb, Bootstrap cb, boolean voidPromise, final boolean autoRead) throws Throwable { + ServerBootstrap sb, Bootstrap cb, boolean voidPromise, final boolean autoRead, boolean defaultFileRegion) + throws Throwable { sb.childOption(ChannelOption.AUTO_READ, autoRead); cb.option(ChannelOption.AUTO_READ, autoRead); @@ -140,6 +151,10 @@ public class SocketFileRegionTest extends AbstractSocketTest { new FileInputStream(file).getChannel(), startOffset, data.length - bufferSize); FileRegion emptyRegion = new DefaultFileRegion(new FileInputStream(file).getChannel(), 0, 0); + if (!defaultFileRegion) { + region = new FileRegionWrapper(region); + emptyRegion = new FileRegionWrapper(emptyRegion); + } // Do write ByteBuf and then FileRegion to ensure that mixed writes work // Also, write an empty FileRegion to test if writing an empty FileRegion does not cause any issues. // @@ -229,4 +244,77 @@ public class SocketFileRegionTest extends AbstractSocketTest { } } } + + private static final class FileRegionWrapper implements FileRegion { + private final FileRegion region; + + FileRegionWrapper(FileRegion region) { + this.region = region; + } + + @Override + public int refCnt() { + return region.refCnt(); + } + + @Override + public long position() { + return region.position(); + } + + @Override + @Deprecated + public long transfered() { + return region.transfered(); + } + + @Override + public boolean release() { + return region.release(); + } + + @Override + public long transferred() { + return region.transferred(); + } + + @Override + public long count() { + return region.count(); + } + + @Override + public boolean release(int decrement) { + return region.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + return region.transferTo(target, position); + } + + @Override + public FileRegion retain() { + region.retain(); + return this; + } + + @Override + public FileRegion retain(int increment) { + region.retain(increment); + return this; + } + + @Override + public FileRegion touch() { + region.touch(); + return this; + } + + @Override + public FileRegion touch(Object hint) { + region.touch(hint); + return this; + } + } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index f18e4c7430..f34d23eeb9 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -17,7 +17,9 @@ package io.netty.channel.epoll; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelFuture; @@ -29,6 +31,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.ConnectTimeoutException; import io.netty.channel.DefaultFileRegion; import io.netty.channel.EventLoop; +import io.netty.channel.FileRegion; import io.netty.channel.RecvByteBufAllocator; import io.netty.channel.socket.DuplexChannel; import io.netty.channel.unix.FileDescriptor; @@ -44,6 +47,7 @@ import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.ConnectionPendingException; +import java.nio.channels.WritableByteChannel; import java.util.Queue; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledFuture; @@ -83,6 +87,8 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im private FileDescriptor pipeIn; private FileDescriptor pipeOut; + private WritableByteChannel byteChannel; + /** * @deprecated Use {@link #AbstractEpollStreamChannel(Channel, Socket)}. */ @@ -372,7 +378,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im * @param region the {@link DefaultFileRegion} from which the bytes should be written * @return amount the amount of written bytes */ - private boolean writeFileRegion( + private boolean writeDefaultFileRegion( ChannelOutboundBuffer in, DefaultFileRegion region, int writeSpinCount) throws Exception { final long regionCount = region.count(); if (region.transferred() >= regionCount) { @@ -409,6 +415,42 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im return done; } + private boolean writeFileRegion( + ChannelOutboundBuffer in, FileRegion region, final int writeSpinCount) throws Exception { + if (region.transferred() >= region.count()) { + in.remove(); + return true; + } + + boolean done = false; + long flushedAmount = 0; + + if (byteChannel == null) { + byteChannel = new SocketWritableByteChannel(); + } + for (int i = writeSpinCount - 1; i >= 0; i--) { + final long localFlushedAmount = region.transferTo(byteChannel, region.transferred()); + if (localFlushedAmount == 0) { + break; + } + + flushedAmount += localFlushedAmount; + if (region.transferred() >= region.count()) { + done = true; + break; + } + } + + if (flushedAmount > 0) { + in.progress(flushedAmount); + } + + if (done) { + in.remove(); + } + return done; + } + @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { int writeSpinCount = config().getWriteSpinCount(); @@ -448,15 +490,19 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im // The outbound buffer contains only one message or it contains a file region. Object msg = in.current(); if (msg instanceof ByteBuf) { - ByteBuf buf = (ByteBuf) msg; - if (!writeBytes(in, buf, writeSpinCount)) { + if (!writeBytes(in, (ByteBuf) msg, writeSpinCount)) { // was not able to write everything so break here we will get notified later again once // the network stack can handle more writes. return false; } } else if (msg instanceof DefaultFileRegion) { - DefaultFileRegion region = (DefaultFileRegion) msg; - if (!writeFileRegion(in, region, writeSpinCount)) { + if (!writeDefaultFileRegion(in, (DefaultFileRegion) msg, writeSpinCount)) { + // was not able to write everything so break here we will get notified later again once + // the network stack can handle more writes. + return false; + } + } else if (msg instanceof FileRegion) { + if (!writeFileRegion(in, (FileRegion) msg, writeSpinCount)) { // was not able to write everything so break here we will get notified later again once // the network stack can handle more writes. return false; @@ -533,7 +579,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im return buf; } - if (msg instanceof DefaultFileRegion || msg instanceof SpliceOutTask) { + if (msg instanceof FileRegion || msg instanceof SpliceOutTask) { return msg; } @@ -1211,4 +1257,56 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im } } } + + private final class SocketWritableByteChannel implements WritableByteChannel { + + @Override + public int write(ByteBuffer src) throws IOException { + final int written; + int position = src.position(); + int limit = src.limit(); + if (src.isDirect()) { + written = fd().write(src, position, src.limit()); + } else { + final int readableBytes = limit - position; + ByteBuf buffer = null; + try { + if (readableBytes == 0) { + buffer = Unpooled.EMPTY_BUFFER; + } else { + final ByteBufAllocator alloc = alloc(); + if (alloc.isDirectBufferPooled()) { + buffer = alloc.directBuffer(readableBytes); + } else { + buffer = ByteBufUtil.threadLocalDirectBuffer(); + if (buffer == null) { + buffer = Unpooled.directBuffer(readableBytes); + } + } + } + buffer.writeBytes(src.duplicate()); + ByteBuffer nioBuffer = buffer.internalNioBuffer(buffer.readerIndex(), readableBytes); + written = fd().write(nioBuffer, nioBuffer.position(), nioBuffer.limit()); + } finally { + if (buffer != null) { + buffer.release(); + } + } + } + if (written > 0) { + src.position(position + written); + } + return written; + } + + @Override + public boolean isOpen() { + return fd().isOpen(); + } + + @Override + public void close() throws IOException { + fd().close(); + } + } }