diff --git a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c index ecccaeced7..2c13f91cc4 100644 --- a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c +++ b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c @@ -581,14 +581,7 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_write(JNIEnv * env, jc throwRuntimeException(env, "Unable to access address of buffer"); return -1; } - jint res = write0(env, clazz, fd, buffer, pos, limit); - if (res > 0) { - // Increment the pos of the ByteBuffer as it may be only partial written to prevent data-corruption later once we - // try to write the remaining data. - // See https://github.com/netty/netty/issues/2371 - incrementPosition(env, jbuffer, res); - } - return res; + return write0(env, clazz, fd, buffer, pos, limit); } JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_writeAddress(JNIEnv * env, jclass clazz, jint fd, jlong address, jint pos, jint limit) { @@ -675,18 +668,6 @@ JNIEXPORT jobject JNICALL Java_io_netty_channel_epoll_Native_recvFromAddress(JNI return recvFrom0(env, fd, (void*) address, pos, limit); } -void incrementPosition(JNIEnv * env, jobject bufObj, int written) { - // Get the current position using the (*env)->GetIntField if possible and fallback - // to slower (*env)->CallIntMethod(...) if needed - if (posFieldId == NULL) { - jint pos = (*env)->CallIntMethod(env, bufObj, posId, NULL); - (*env)->CallObjectMethod(env, bufObj, updatePosId, pos + written); - } else { - jint pos = (*env)->GetIntField(env, bufObj, posFieldId); - (*env)->SetIntField(env, bufObj, posFieldId, pos + written); - } -} - jlong writev0(JNIEnv * env, jclass clazz, jint fd, struct iovec iov[], jint length) { ssize_t res; int err; diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java index de3e009b9a..fce2a21d28 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java @@ -107,28 +107,43 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So * Write bytes form the given {@link ByteBuf} to the underlying {@link java.nio.channels.Channel}. * @param buf the {@link ByteBuf} from which the bytes should be written */ - private int doWriteBytes(ByteBuf buf) throws Exception { - int readerIndex = buf.readerIndex(); - int localFlushedAmount; + private boolean writeBytes(ChannelOutboundBuffer in, ByteBuf buf) throws Exception { + int readableBytes = buf.readableBytes(); + if (readableBytes == 0) { + in.remove(); + return true; + } + boolean setEpollOut = false; + boolean done = false; + long writtenBytes = 0; if (buf.nioBufferCount() == 1) { - if (buf.hasMemoryAddress()) { - localFlushedAmount = Native.writeAddress(fd, buf.memoryAddress(), readerIndex, buf.writerIndex()); - } else { - ByteBuffer nioBuf = buf.internalNioBuffer(readerIndex, buf.readableBytes()); - localFlushedAmount = Native.write(fd, nioBuf, nioBuf.position(), nioBuf.limit()); + int readerIndex = buf.readerIndex(); + ByteBuffer nioBuf = buf.internalNioBuffer(readerIndex, buf.readableBytes()); + for (;;) { + int pos = nioBuf.position(); + int limit = nioBuf.limit(); + int localFlushedAmount = Native.write(fd, nioBuf, pos, limit); + if (localFlushedAmount > 0) { + nioBuf.position(pos + localFlushedAmount); + writtenBytes += localFlushedAmount; + if (writtenBytes == readableBytes) { + done = true; + break; + } + } else { + setEpollOut = true; + break; + } } + updateOutboundBuffer(in, writtenBytes, 1, done, setEpollOut); + return done; } else { - // backed by more then one buffer, do a gathering write... - ByteBuffer[] nioBufs = buf.nioBuffers(); - localFlushedAmount = (int) Native.writev(fd, nioBufs, 0, nioBufs.length); + ByteBuffer[] nioBuffers = buf.nioBuffers(); + return writeBytesMultiple0(in, 1, nioBuffers, nioBuffers.length, readableBytes); } - if (localFlushedAmount > 0) { - buf.readerIndex(readerIndex + localFlushedAmount); - } - return localFlushedAmount; } - private void writeBytesMultiple( + private boolean writeBytesMultiple( EpollChannelOutboundBuffer in, int msgCount, AddressEntry[] addresses) throws IOException { int addressCnt = in.addressCount(); @@ -138,9 +153,8 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So long writtenBytes = 0; int offset = 0; int end = offset + addressCnt; - int spinCount = config.getWriteSpinCount(); loop: while (addressCnt > 0) { - for (int i = spinCount - 1; i >= 0; i --) { + for (;;) { int cnt = addressCnt > Native.IOV_MAX? Native.IOV_MAX : addressCnt; long localWrittenBytes = Native.writevAddresses(fd, addresses, offset, cnt); @@ -165,30 +179,33 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So localWrittenBytes -= bytes; } } - } - if (expectedWrittenBytes == 0) { - done = true; - break; + if (expectedWrittenBytes == 0) { + done = true; + break; + } } } updateOutboundBuffer(in, writtenBytes, msgCount, done, setEpollOut); + return done; } - private void writeBytesMultiple( + private boolean writeBytesMultiple( NioSocketChannelOutboundBuffer in, int msgCount, ByteBuffer[] nioBuffers) throws IOException { + return writeBytesMultiple0(in, msgCount, nioBuffers, in.nioBufferCount(), in.nioBufferSize()); + } - int nioBufferCnt = in.nioBufferCount(); - long expectedWrittenBytes = in.nioBufferSize(); + private boolean writeBytesMultiple0( + ChannelOutboundBuffer in, int msgCount, ByteBuffer[] nioBuffers, + int nioBufferCnt, long expectedWrittenBytes) throws IOException { boolean done = false; boolean setEpollOut = false; long writtenBytes = 0; int offset = 0; int end = offset + nioBufferCnt; - int spinCount = config.getWriteSpinCount(); loop: while (nioBufferCnt > 0) { - for (int i = spinCount - 1; i >= 0; i --) { + for (;;) { int cnt = nioBufferCnt > Native.IOV_MAX? Native.IOV_MAX : nioBufferCnt; long localWrittenBytes = Native.writev(fd, nioBuffers, offset, cnt); @@ -206,6 +223,9 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So if (bytes > localWrittenBytes) { buffer.position(pos + (int) localWrittenBytes); // incomplete write + + // As we use edge-triggered we need to set EPOLLOUT as otherwise we may not get notified again + setEpollOut(); break; } else { offset++; @@ -213,15 +233,15 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So localWrittenBytes -= bytes; } } - } - if (expectedWrittenBytes == 0) { - done = true; - break; + if (expectedWrittenBytes == 0) { + done = true; + break; + } } } - updateOutboundBuffer(in, writtenBytes, msgCount, done, setEpollOut); + return done; } private void updateOutboundBuffer(ChannelOutboundBuffer in, long writtenBytes, int msgCount, @@ -231,11 +251,7 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So for (int i = msgCount; i > 0; i --) { in.remove(); } - - // Finish the write loop if no new messages were flushed by in.remove(). - if (in.isEmpty()) { - clearEpollOut(); - } + in.progress(writtenBytes); } else { // Did not write all buffers completely. // Release the fully written buffers and update the indexes of the partially written buffer. @@ -273,7 +289,7 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So flushTask = this.flushTask = new Runnable() { @Override public void run() { - flush(); + unsafe().flush(); } }; } @@ -287,8 +303,37 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So * @param region the {@link DefaultFileRegion} from which the bytes should be written * @return amount the amount of written bytes */ - private long doWriteFileRegion(DefaultFileRegion region, long count) throws Exception { - return Native.sendfile(fd, region, region.transfered(), count); + private boolean writeFileRegion(ChannelOutboundBuffer in, DefaultFileRegion region) throws Exception { + boolean setOpWrite = false; + boolean done = false; + long flushedAmount = 0; + + for (int i = config().getWriteSpinCount() - 1; i >= 0; i --) { + long expected = region.count() - region.position(); + long localFlushedAmount = Native.sendfile(fd, region, region.transfered(), expected); + if (localFlushedAmount == 0) { + setOpWrite = true; + break; + } + + flushedAmount += localFlushedAmount; + if (region.transfered() >= region.count()) { + done = true; + break; + } else { + // As we use edge-triggered we need to set EPOLLOUT as otherwise we may not get notified again + setEpollOut(); + } + } + + in.progress(flushedAmount); + + if (done) { + in.remove(); + } else { + incompleteWrite(setOpWrite); + } + return done; } @Override @@ -312,7 +357,11 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So // Ensure the pending writes are made of memoryaddresses only. AddressEntry[] addresses = epollIn.memoryAddresses(); if (addresses != null) { - writeBytesMultiple(epollIn, msgCount, addresses); + if (!writeBytesMultiple(epollIn, msgCount, addresses)) { + // was not able to write everything so break here we will get notified later again once + // the network stack can handle more writes. + break; + } // We do not break the loop here even if the outbound buffer was flushed completely, // because a user might have triggered another write and flush when we notify his or her @@ -322,9 +371,13 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So } else { NioSocketChannelOutboundBuffer nioIn = (NioSocketChannelOutboundBuffer) in; // Ensure the pending writes are made of memoryaddresses only. - ByteBuffer[] buffers = nioIn.nioBuffers(); - if (buffers != null) { - writeBytesMultiple(nioIn, msgCount, buffers); + ByteBuffer[] nioBuffers = nioIn.nioBuffers(); + if (nioBuffers != null) { + if (!writeBytesMultiple(nioIn, msgCount, nioBuffers)) { + // was not able to write everything so break here we will get notified later again once + // the network stack can handle more writes. + break; + } // We do not break the loop here even if the outbound buffer was flushed completely, // because a user might have triggered another write and flush when we notify his or her @@ -338,52 +391,18 @@ public final class EpollSocketChannel extends AbstractEpollChannel implements So Object msg = in.current(); if (msg instanceof ByteBuf) { ByteBuf buf = (ByteBuf) msg; - int readableBytes = buf.readableBytes(); - if (readableBytes == 0) { - in.remove(); - continue; - } - boolean setEpollOut = false; - boolean done = false; - long flushedAmount = 0; - int writeSpinCount = config().getWriteSpinCount(); - for (int i = writeSpinCount - 1; i >= 0; i --) { - int localFlushedAmount = doWriteBytes(buf); - if (localFlushedAmount == 0) { - setEpollOut = true; - break; - } - - flushedAmount += localFlushedAmount; - if (!buf.isReadable()) { - done = true; - break; - } - } - - in.progress(flushedAmount); - - if (done) { - in.remove(); - } else { - incompleteWrite(setEpollOut); + if (!writeBytes(in, buf)) { + // was not able to write everything so break here we will get notified later again once + // the network stack can handle more writes. break; } } else if (msg instanceof DefaultFileRegion) { DefaultFileRegion region = (DefaultFileRegion) msg; - - long expected = region.count() - region.position(); - long localFlushedAmount = doWriteFileRegion(region, expected); - in.progress(localFlushedAmount); - - if (localFlushedAmount < expected) { - setEpollOut(); + if (!writeFileRegion(in, region)) { + // was not able to write everything so break here we will get notified later again once + // the network stack can handle more writes. break; } - - if (region.transfered() >= region.count()) { - in.remove(); - } } else { throw new UnsupportedOperationException("unsupported message type: " + StringUtil.simpleClassName(msg)); }