diff --git a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c index 99c77e260d..c52ad62b48 100644 --- a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c +++ b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c @@ -34,6 +34,7 @@ extern int accept4(int sockFd, struct sockaddr *addr, socklen_t *addrlen, int flags) __attribute__((weak)); // Those are initialized in the init(...) method and cached for performance reasons +jmethodID updatePosId = NULL; jmethodID posId = NULL; jmethodID limitId = NULL; jfieldID posFieldId = NULL; @@ -257,7 +258,12 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { throwRuntimeException(env, "Unable to find method ByteBuffer.limit()"); return JNI_ERR; } - + updatePosId = (*env)->GetMethodID(env, cls, "position", "(I)Ljava/nio/Buffer;"); + if (updatePosId == NULL) { + // position method was not found.. something is wrong so bail out + throwRuntimeException(env, "Unable to find method ByteBuffer.position(int)"); + return JNI_ERR; + } // Try to get the ids of the position and limit fields. We later then check if we was able // to find them and if so use them get the position and limit of the buffer. This is // much faster then call back into java via (*env)->CallIntMethod(...). @@ -481,9 +487,26 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_write(JNIEnv * env, jc throwIOException(env, exceptionMessage("Error while write(...): ", err)); return -1; } + if (posFieldId == NULL) { + (*env)->CallObjectMethod(env, jbuffer, updatePosId, pos + res); + } else { + (*env)->SetIntField(env, jbuffer, posFieldId, pos + res); + } return (jint) res; } +void incrementPosition(JNIEnv * env, jobject bufObj, int written) { + // Get the current position using the (*env)->GetIntField if possible and fallback + // to slower (*env)->CallIntMethod(...) if needed + if (posFieldId == NULL) { + jint pos = (*env)->CallIntMethod(env, bufObj, posId, NULL); + (*env)->CallObjectMethod(env, bufObj, updatePosId, pos + written); + } else { + jint pos = (*env)->GetIntField(env, bufObj, posFieldId); + (*env)->SetIntField(env, bufObj, posFieldId, pos + written); + } +} + JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_writev(JNIEnv * env, jclass clazz, jint fd, jobjectArray buffers, jint offset, jint length) { struct iovec iov[length]; int i; @@ -532,9 +555,25 @@ JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_writev(JNIEnv * env, throwClosedChannelException(env); return -1; } - throwIOException(env, exceptionMessage("Error while write(...): ", err)); + throwIOException(env, exceptionMessage("Error while writev(...): ", err)); return -1; } + + // update the position of the written buffers + int written = res; + int a; + for (a = 0; a < length; a++) { + int pos; + int len = iov[a].iov_len; + jobject bufObj = (*env)->GetObjectArrayElement(env, buffers, a + offset); + if (len >= written) { + incrementPosition(env, bufObj, written); + break; + } else { + incrementPosition(env, bufObj, len); + written -= len; + } + } return res; } @@ -856,4 +895,3 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_getTrafficClass(JNIEnv } return optval; } - diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java index d7f03898e6..314cb99af2 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java @@ -160,7 +160,6 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So long localWrittenBytes = Native.writev(fd, nioBuffers, 0, nioBufferCnt); if (localWrittenBytes < expectedWrittenBytes) { - int nioBufIndex = 0; setEpollOut(); // Did not write all buffers completely. @@ -171,17 +170,12 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So final int readableBytes = buf.writerIndex() - readerIndex; if (readableBytes < localWrittenBytes) { - nioBufIndex += buf.nioBufferCount(); in.remove(); localWrittenBytes -= readableBytes; } else if (readableBytes > localWrittenBytes) { buf.readerIndex(readerIndex + (int) localWrittenBytes); in.progress(localWrittenBytes); - - // update position in ByteBuffer as we not do this in the native methods - ByteBuffer bb = nioBuffers[nioBufIndex]; - bb.position(bb.position() + (int) localWrittenBytes); break; } else { // readable == writtenBytes in.remove();