Fix data corruption in FileRegion transfer with epoll transport

Related issue: #2764

Motivation:

EpollSocketChannel.writeFileRegion() does not handle the case where the
position of a FileRegion is non-zero properly.

Modifications:

- Improve SocketFileRegionTest so that it tests the cases where the file
  transfer begins from the middle of the file
- Add another jlong parameter named 'base_off' so that we can take the
  position of a FileRegion into account

Result:

Improved test passes. Corruption is gone.
This commit is contained in:
Trustin Lee 2014-08-13 16:50:18 -07:00
parent af625f2274
commit 061d5bc261
5 changed files with 38 additions and 14 deletions

View File

@ -24,6 +24,7 @@ import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.DefaultFileRegion; import io.netty.channel.DefaultFileRegion;
import io.netty.channel.FileRegion; import io.netty.channel.FileRegion;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test; import org.junit.Test;
import java.io.File; import java.io.File;
@ -33,15 +34,15 @@ import java.io.IOException;
import java.util.Random; import java.util.Random;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class SocketFileRegionTest extends AbstractSocketTest { public class SocketFileRegionTest extends AbstractSocketTest {
private static final Random random = new Random();
static final byte[] data = new byte[1048576 * 10]; static final byte[] data = new byte[1048576 * 10];
static { static {
random.nextBytes(data); ThreadLocalRandom.current().nextBytes(data);
} }
@Test @Test
@ -82,11 +83,26 @@ public class SocketFileRegionTest extends AbstractSocketTest {
private static void testFileRegion0( private static void testFileRegion0(
ServerBootstrap sb, Bootstrap cb, boolean voidPromise, final boolean autoRead) throws Throwable { ServerBootstrap sb, Bootstrap cb, boolean voidPromise, final boolean autoRead) throws Throwable {
File file = File.createTempFile("netty-", ".tmp"); final File file = File.createTempFile("netty-", ".tmp");
file.deleteOnExit(); file.deleteOnExit();
FileOutputStream out = new FileOutputStream(file); final FileOutputStream out = new FileOutputStream(file);
final Random random = ThreadLocalRandom.current();
// Prepend random data which will not be transferred, so that we can test non-zero start offset
final int startOffset = random.nextInt(8192);
for (int i = 0; i < startOffset; i ++) {
out.write(random.nextInt());
}
// .. and here comes the real data to transfer.
out.write(data); out.write(data);
// .. and then some extra data which is not supposed to be transferred.
for (int i = random.nextInt(8192); i > 0; i --) {
out.write(random.nextInt());
}
out.close(); out.close();
ChannelInboundHandler ch = new SimpleChannelInboundHandler<Object>() { ChannelInboundHandler ch = new SimpleChannelInboundHandler<Object>() {
@ -114,7 +130,7 @@ public class SocketFileRegionTest extends AbstractSocketTest {
Channel sc = sb.bind().sync().channel(); Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel(); Channel cc = cb.connect().sync().channel();
FileRegion region = new DefaultFileRegion(new FileInputStream(file).getChannel(), 0L, file.length()); FileRegion region = new DefaultFileRegion(new FileInputStream(file).getChannel(), startOffset, data.length);
if (voidPromise) { if (voidPromise) {
assertEquals(cc.voidPromise(), cc.writeAndFlush(region, cc.voidPromise())); assertEquals(cc.voidPromise(), cc.writeAndFlush(region, cc.voidPromise()));
} else { } else {
@ -143,6 +159,9 @@ public class SocketFileRegionTest extends AbstractSocketTest {
if (sh.exception.get() != null) { if (sh.exception.get() != null) {
throw sh.exception.get(); throw sh.exception.get();
} }
// Make sure we did not receive more than we expected.
assertThat(sh.counter, is(data.length));
} }
private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> { private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {

View File

@ -933,7 +933,7 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept(JNIEnv * env, j
return socketFd; return socketFd;
} }
JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong off, jlong len) { JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len) {
jobject fileChannel = (*env)->GetObjectField(env, fileRegion, fileChannelFieldId); jobject fileChannel = (*env)->GetObjectField(env, fileRegion, fileChannelFieldId);
if (fileChannel == NULL) { if (fileChannel == NULL) {
throwRuntimeException(env, "Unable to obtain FileChannel from FileRegion"); throwRuntimeException(env, "Unable to obtain FileChannel from FileRegion");
@ -950,7 +950,7 @@ JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env,
return -1; return -1;
} }
ssize_t res; ssize_t res;
off_t offset = off; off_t offset = base_off + off;
int err; int err;
do { do {
res = sendfile(fd, srcFd, &offset, (size_t) len); res = sendfile(fd, srcFd, &offset, (size_t) len);

View File

@ -62,7 +62,7 @@ void Java_io_netty_channel_epoll_Native_listen(JNIEnv * env, jclass clazz, jint
jboolean Java_io_netty_channel_epoll_Native_connect(JNIEnv * env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port); jboolean Java_io_netty_channel_epoll_Native_connect(JNIEnv * env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
jboolean Java_io_netty_channel_epoll_Native_finishConnect(JNIEnv * env, jclass clazz, jint fd); jboolean Java_io_netty_channel_epoll_Native_finishConnect(JNIEnv * env, jclass clazz, jint fd);
jint Java_io_netty_channel_epoll_Native_accept(JNIEnv * env, jclass clazz, jint fd); jint Java_io_netty_channel_epoll_Native_accept(JNIEnv * env, jclass clazz, jint fd);
jlong Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong off, jlong len); jlong Java_io_netty_channel_epoll_Native_sendfile(JNIEnv *env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
jobject Java_io_netty_channel_epoll_Native_remoteAddress(JNIEnv * env, jclass clazz, jint fd); jobject Java_io_netty_channel_epoll_Native_remoteAddress(JNIEnv * env, jclass clazz, jint fd);
jobject Java_io_netty_channel_epoll_Native_localAddress(JNIEnv * env, jclass clazz, jint fd); jobject Java_io_netty_channel_epoll_Native_localAddress(JNIEnv * env, jclass clazz, jint fd);
void Java_io_netty_channel_epoll_Native_setReuseAddress(JNIEnv * env, jclass clazz, jint fd, jint optval); void Java_io_netty_channel_epoll_Native_setReuseAddress(JNIEnv * env, jclass clazz, jint fd, jint optval);

View File

@ -265,17 +265,19 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So
* @return amount the amount of written bytes * @return amount the amount of written bytes
*/ */
private boolean writeFileRegion(ChannelOutboundBuffer in, DefaultFileRegion region) throws Exception { private boolean writeFileRegion(ChannelOutboundBuffer in, DefaultFileRegion region) throws Exception {
if (region.transfered() >= region.count()) { final long regionCount = region.count();
if (region.transfered() >= regionCount) {
in.remove(); in.remove();
return true; return true;
} }
final long baseOffset = region.position();
boolean done = false; boolean done = false;
long flushedAmount = 0; long flushedAmount = 0;
for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) {
long expected = region.count() - region.position(); final long offset = region.transfered();
long localFlushedAmount = Native.sendfile(fd, region, region.transfered(), expected); final long localFlushedAmount = Native.sendfile(fd, region, baseOffset, offset, regionCount - offset);
if (localFlushedAmount == 0) { if (localFlushedAmount == 0) {
// Returned EAGAIN need to set EPOLLOUT // Returned EAGAIN need to set EPOLLOUT
setEpollOut(); setEpollOut();
@ -283,13 +285,15 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So
} }
flushedAmount += localFlushedAmount; flushedAmount += localFlushedAmount;
if (region.transfered() >= region.count()) { if (region.transfered() >= regionCount) {
done = true; done = true;
break; break;
} }
} }
in.progress(flushedAmount); if (flushedAmount > 0) {
in.progress(flushedAmount);
}
if (done) { if (done) {
in.remove(); in.remove();

View File

@ -75,7 +75,8 @@ final class Native {
public static native int read(int fd, ByteBuffer buf, int pos, int limit) throws IOException; public static native int read(int fd, ByteBuffer buf, int pos, int limit) throws IOException;
public static native int readAddress(int fd, long address, int pos, int limit) throws IOException; public static native int readAddress(int fd, long address, int pos, int limit) throws IOException;
public static native long sendfile(int dest, DefaultFileRegion src, long offset, long length) throws IOException; public static native long sendfile(
int dest, DefaultFileRegion src, long baseOffset, long offset, long length) throws IOException;
public static int sendTo( public static int sendTo(
int fd, ByteBuffer buf, int pos, int limit, InetAddress addr, int port) throws IOException { int fd, ByteBuffer buf, int pos, int limit, InetAddress addr, int port) throws IOException {