Fix data corruption in FileRegion transfer with epoll transport

Related issue: #2764

Motivation:

EpollSocketChannel.writeFileRegion() does not handle the case where the
position of a FileRegion is non-zero properly.

Modifications:

- Improve SocketFileRegionTest so that it tests the cases where the file
  transfer begins from the middle of the file
- Add another jlong parameter named 'base_off' so that we can take the
  position of a FileRegion into account

Result:

Improved test passes. Corruption is gone.
This commit is contained in:
Trustin Lee 2014-08-13 16:50:18 -07:00
parent af625f2274
commit 061d5bc261
5 changed files with 38 additions and 14 deletions

View File

@ -24,6 +24,7 @@ import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.DefaultFileRegion;
import io.netty.channel.FileRegion;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test;
import java.io.File;
@ -33,15 +34,15 @@ import java.io.IOException;
import java.util.Random;
import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.*;
public class SocketFileRegionTest extends AbstractSocketTest {
private static final Random random = new Random();
static final byte[] data = new byte[1048576 * 10];
static {
random.nextBytes(data);
ThreadLocalRandom.current().nextBytes(data);
}
@Test
@ -82,11 +83,26 @@ public class SocketFileRegionTest extends AbstractSocketTest {
private static void testFileRegion0(
ServerBootstrap sb, Bootstrap cb, boolean voidPromise, final boolean autoRead) throws Throwable {
File file = File.createTempFile("netty-", ".tmp");
final File file = File.createTempFile("netty-", ".tmp");
file.deleteOnExit();
FileOutputStream out = new FileOutputStream(file);
final FileOutputStream out = new FileOutputStream(file);
final Random random = ThreadLocalRandom.current();
// Prepend random data which will not be transferred, so that we can test non-zero start offset
final int startOffset = random.nextInt(8192);
for (int i = 0; i < startOffset; i ++) {
out.write(random.nextInt());
}
// .. and here comes the real data to transfer.
out.write(data);
// .. and then some extra data which is not supposed to be transferred.
for (int i = random.nextInt(8192); i > 0; i --) {
out.write(random.nextInt());
}
out.close();
ChannelInboundHandler ch = new SimpleChannelInboundHandler<Object>() {
@ -114,7 +130,7 @@ public class SocketFileRegionTest extends AbstractSocketTest {
Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel();
FileRegion region = new DefaultFileRegion(new FileInputStream(file).getChannel(), 0L, file.length());
FileRegion region = new DefaultFileRegion(new FileInputStream(file).getChannel(), startOffset, data.length);
if (voidPromise) {
assertEquals(cc.voidPromise(), cc.writeAndFlush(region, cc.voidPromise()));
} else {
@ -143,6 +159,9 @@ public class SocketFileRegionTest extends AbstractSocketTest {
if (sh.exception.get() != null) {
throw sh.exception.get();
}
// Make sure we did not receive more than we expected.
assertThat(sh.counter, is(data.length));
}
private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {

View File

@ -933,7 +933,7 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept(JNIEnv * env, j
return socketFd;
}
JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong off, jlong len) {
JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len) {
jobject fileChannel = (*env)->GetObjectField(env, fileRegion, fileChannelFieldId);
if (fileChannel == NULL) {
throwRuntimeException(env, "Unable to obtain FileChannel from FileRegion");
@ -950,7 +950,7 @@ JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env,
return -1;
}
ssize_t res;
off_t offset = off;
off_t offset = base_off + off;
int err;
do {
res = sendfile(fd, srcFd, &offset, (size_t) len);

View File

@ -62,7 +62,7 @@ void Java_io_netty_channel_epoll_Native_listen(JNIEnv * env, jclass clazz, jint
jboolean Java_io_netty_channel_epoll_Native_connect(JNIEnv * env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
jboolean Java_io_netty_channel_epoll_Native_finishConnect(JNIEnv * env, jclass clazz, jint fd);
jint Java_io_netty_channel_epoll_Native_accept(JNIEnv * env, jclass clazz, jint fd);
jlong Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong off, jlong len);
jlong Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
jobject Java_io_netty_channel_epoll_Native_remoteAddress(JNIEnv * env, jclass clazz, jint fd);
jobject Java_io_netty_channel_epoll_Native_localAddress(JNIEnv * env, jclass clazz, jint fd);
void Java_io_netty_channel_epoll_Native_setReuseAddress(JNIEnv * env, jclass clazz, jint fd, jint optval);

View File

@ -265,17 +265,19 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So
* @return amount the amount of written bytes
*/
private boolean writeFileRegion(ChannelOutboundBuffer in, DefaultFileRegion region) throws Exception {
if (region.transfered() >= region.count()) {
final long regionCount = region.count();
if (region.transfered() >= regionCount) {
in.remove();
return true;
}
final long baseOffset = region.position();
boolean done = false;
long flushedAmount = 0;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
long expected = region.count() - region.position();
long localFlushedAmount = Native.sendfile(fd, region, region.transfered(), expected);
final long offset = region.transfered();
final long localFlushedAmount = Native.sendfile(fd, region, baseOffset, offset, regionCount - offset);
if (localFlushedAmount == 0) {
// Returned EAGAIN need to set EPOLLOUT
setEpollOut();
@ -283,13 +285,15 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So
}
flushedAmount += localFlushedAmount;
if (region.transfered() >= region.count()) {
if (region.transfered() >= regionCount) {
done = true;
break;
}
}
in.progress(flushedAmount);
if (flushedAmount > 0) {
in.progress(flushedAmount);
}
if (done) {
in.remove();

View File

@ -75,7 +75,8 @@ final class Native {
public static native int read(int fd, ByteBuffer buf, int pos, int limit) throws IOException;
public static native int readAddress(int fd, long address, int pos, int limit) throws IOException;
public static native long sendfile(int dest, DefaultFileRegion src, long offset, long length) throws IOException;
public static native long sendfile(
int dest, DefaultFileRegion src, long baseOffset, long offset, long length) throws IOException;
public static int sendTo(
int fd, ByteBuffer buf, int pos, int limit, InetAddress addr, int port) throws IOException {