Correctly update the ByteBuffers position on write and writev

This commit is contained in:
Norman Maurer 2014-02-17 07:18:39 +01:00
parent 0b1acf35c9
commit e0669522a3
2 changed files with 41 additions and 9 deletions

View File

@ -34,6 +34,7 @@
extern int accept4(int sockFd, struct sockaddr *addr, socklen_t *addrlen, int flags) __attribute__((weak)); extern int accept4(int sockFd, struct sockaddr *addr, socklen_t *addrlen, int flags) __attribute__((weak));
// Those are initialized in the init(...) method and cached for performance reasons // Those are initialized in the init(...) method and cached for performance reasons
jmethodID updatePosId = NULL;
jmethodID posId = NULL; jmethodID posId = NULL;
jmethodID limitId = NULL; jmethodID limitId = NULL;
jfieldID posFieldId = NULL; jfieldID posFieldId = NULL;
@ -257,7 +258,12 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
throwRuntimeException(env, "Unable to find method ByteBuffer.limit()"); throwRuntimeException(env, "Unable to find method ByteBuffer.limit()");
return JNI_ERR; return JNI_ERR;
} }
updatePosId = (*env)->GetMethodID(env, cls, "position", "(I)Ljava/nio/Buffer;");
if (updatePosId == NULL) {
// position method was not found.. something is wrong so bail out
throwRuntimeException(env, "Unable to find method ByteBuffer.position(int)");
return JNI_ERR;
}
// Try to get the ids of the position and limit fields. We later then check if we was able // Try to get the ids of the position and limit fields. We later then check if we was able
// to find them and if so use them get the position and limit of the buffer. This is // to find them and if so use them get the position and limit of the buffer. This is
// much faster then call back into java via (*env)->CallIntMethod(...). // much faster then call back into java via (*env)->CallIntMethod(...).
@ -481,9 +487,26 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_write(JNIEnv * env, jc
throwIOException(env, exceptionMessage("Error while write(...): ", err)); throwIOException(env, exceptionMessage("Error while write(...): ", err));
return -1; return -1;
} }
if (posFieldId == NULL) {
(*env)->CallObjectMethod(env, jbuffer, updatePosId, pos + res);
} else {
(*env)->SetIntField(env, jbuffer, posFieldId, pos + res);
}
return (jint) res; return (jint) res;
} }
void incrementPosition(JNIEnv * env, jobject bufObj, int written) {
// Get the current position using the (*env)->GetIntField if possible and fallback
// to slower (*env)->CallIntMethod(...) if needed
if (posFieldId == NULL) {
jint pos = (*env)->CallIntMethod(env, bufObj, posId, NULL);
(*env)->CallObjectMethod(env, bufObj, updatePosId, pos + written);
} else {
jint pos = (*env)->GetIntField(env, bufObj, posFieldId);
(*env)->SetIntField(env, bufObj, posFieldId, pos + written);
}
}
JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_writev(JNIEnv * env, jclass clazz, jint fd, jobjectArray buffers, jint offset, jint length) { JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_writev(JNIEnv * env, jclass clazz, jint fd, jobjectArray buffers, jint offset, jint length) {
struct iovec iov[length]; struct iovec iov[length];
int i; int i;
@ -532,9 +555,25 @@ JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_writev(JNIEnv * env,
throwClosedChannelException(env); throwClosedChannelException(env);
return -1; return -1;
} }
throwIOException(env, exceptionMessage("Error while write(...): ", err)); throwIOException(env, exceptionMessage("Error while writev(...): ", err));
return -1; return -1;
} }
// update the position of the written buffers
int written = res;
int a;
for (a = 0; a < length; a++) {
int pos;
int len = iov[a].iov_len;
jobject bufObj = (*env)->GetObjectArrayElement(env, buffers, a + offset);
if (len >= written) {
incrementPosition(env, bufObj, written);
break;
} else {
incrementPosition(env, bufObj, len);
written -= len;
}
}
return res; return res;
} }
@ -856,4 +895,3 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_getTrafficClass(JNIEnv
} }
return optval; return optval;
} }

View File

@ -160,7 +160,6 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So
long localWrittenBytes = Native.writev(fd, nioBuffers, 0, nioBufferCnt); long localWrittenBytes = Native.writev(fd, nioBuffers, 0, nioBufferCnt);
if (localWrittenBytes < expectedWrittenBytes) { if (localWrittenBytes < expectedWrittenBytes) {
int nioBufIndex = 0;
setEpollOut(); setEpollOut();
// Did not write all buffers completely. // Did not write all buffers completely.
@ -171,17 +170,12 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So
final int readableBytes = buf.writerIndex() - readerIndex; final int readableBytes = buf.writerIndex() - readerIndex;
if (readableBytes < localWrittenBytes) { if (readableBytes < localWrittenBytes) {
nioBufIndex += buf.nioBufferCount();
in.remove(); in.remove();
localWrittenBytes -= readableBytes; localWrittenBytes -= readableBytes;
} else if (readableBytes > localWrittenBytes) { } else if (readableBytes > localWrittenBytes) {
buf.readerIndex(readerIndex + (int) localWrittenBytes); buf.readerIndex(readerIndex + (int) localWrittenBytes);
in.progress(localWrittenBytes); in.progress(localWrittenBytes);
// update position in ByteBuffer as we not do this in the native methods
ByteBuffer bb = nioBuffers[nioBufIndex];
bb.position(bb.position() + (int) localWrittenBytes);
break; break;
} else { // readable == writtenBytes } else { // readable == writtenBytes
in.remove(); in.remove();