diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketTestPermutation.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketTestPermutation.java index bba3aef345..8c41f37a54 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketTestPermutation.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketTestPermutation.java @@ -70,6 +70,7 @@ public class SocketTestPermutation { new OioEventLoopGroup(Integer.MAX_VALUE, new DefaultThreadFactory("testsuite-oio-worker", true)); protected , B extends AbstractBootstrap> + List> combo(List> sbfs, List> cbfs) { List> list = new ArrayList>(); diff --git a/transport-native-epoll/src/main/c/netty_epoll_native.c b/transport-native-epoll/src/main/c/netty_epoll_native.c index fc0c883b37..000d7bf94f 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_native.c +++ b/transport-native-epoll/src/main/c/netty_epoll_native.c @@ -330,7 +330,7 @@ static jint netty_epoll_native_sendmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo msg[i].msg_hdr.msg_namelen = addrSize; msg[i].msg_hdr.msg_iov = (struct iovec*) (intptr_t) (*env)->GetLongField(env, packet, packetMemoryAddressFieldId); - msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId);; + msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId); } ssize_t res; @@ -346,6 +346,61 @@ static jint netty_epoll_native_sendmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo return (jint) res; } +static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jobjectArray packets, jint offset, jint len) { + struct mmsghdr msg[len]; + memset(msg, 0, sizeof(msg)); + struct sockaddr_storage addr[len]; + int addrSize = sizeof(addr); + memset(addr, 0, addrSize); + + int i; + + for (i = 0; i < len; i++) { + jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset); + msg[i].msg_hdr.msg_iov = (struct iovec*) (intptr_t) (*env)->GetLongField(env, packet, packetMemoryAddressFieldId); + msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId); + + msg[i].msg_hdr.msg_name = addr + i; + msg[i].msg_hdr.msg_namelen = (socklen_t) addrSize; + } + + ssize_t res; + int err; + do { + res = recvmmsg(fd, msg, len, 0, NULL); + // keep on reading if it was interrupted + } while (res == -1 && ((err = errno) == EINTR)); + + if (res < 0) { + return -err; + } + + for (i = 0; i < res; i++) { + jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset); + jbyteArray address = (jbyteArray) (*env)->GetObjectField(env, packet, packetAddrFieldId); + + (*env)->SetIntField(env, packet, packetCountFieldId, msg[i].msg_len); + + struct sockaddr_storage* addr = (struct sockaddr_storage*) msg[i].msg_hdr.msg_name; + + if (addr->ss_family == AF_INET) { + struct sockaddr_in* ipaddr = (struct sockaddr_in*) addr; + + (*env)->SetByteArrayRegion(env, address, 0, 4, (jbyte*) &ipaddr->sin_addr.s_addr); + (*env)->SetIntField(env, packet, packetScopeIdFieldId, 0); + (*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ipaddr->sin_port)); + } else { + struct sockaddr_in6* ip6addr = (struct sockaddr_in6*) addr; + + (*env)->SetByteArrayRegion(env, address, 0, 16, (jbyte*) &ip6addr->sin6_addr.s6_addr); + (*env)->SetIntField(env, packet, packetScopeIdFieldId, ip6addr->sin6_scope_id); + (*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ip6addr->sin6_port)); + } + } + + return (jint) res; +} + static jstring netty_epoll_native_kernelVersion(JNIEnv* env, jclass clazz) { struct utsname name; @@ -470,18 +525,26 @@ static const JNINativeMethod fixed_method_table[] = { static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof(fixed_method_table[0]); static jint dynamicMethodsTableSize() { - return fixed_method_table_size + 1; // 1 is for the dynamic method signatures. + return fixed_method_table_size + 2; // 2 is for the dynamic method signatures. } static JNINativeMethod* createDynamicMethodsTable(const char* packagePrefix) { JNINativeMethod* dynamicMethods = malloc(sizeof(JNINativeMethod) * dynamicMethodsTableSize()); memcpy(dynamicMethods, fixed_method_table, sizeof(fixed_method_table)); + char* dynamicTypeName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/epoll/NativeDatagramPacketArray$NativeDatagramPacket;II)I"); JNINativeMethod* dynamicMethod = &dynamicMethods[fixed_method_table_size]; dynamicMethod->name = "sendmmsg0"; dynamicMethod->signature = netty_unix_util_prepend("(IZ[L", dynamicTypeName); dynamicMethod->fnPtr = (void *) netty_epoll_native_sendmmsg0; free(dynamicTypeName); + + dynamicTypeName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/epoll/NativeDatagramPacketArray$NativeDatagramPacket;II)I"); + dynamicMethod = &dynamicMethods[fixed_method_table_size + 1]; + dynamicMethod->name = "recvmmsg0"; + dynamicMethod->signature = netty_unix_util_prepend("(IZ[L", dynamicTypeName); + dynamicMethod->fnPtr = (void *) netty_epoll_native_recvmmsg0; + free(dynamicTypeName); return dynamicMethods; } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java index 1f5127c5dc..765eea25d1 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollChannelOption.java @@ -45,6 +45,8 @@ public final class EpollChannelOption extends UnixChannelOption { public static final ChannelOption> TCP_MD5SIG = valueOf("TCP_MD5SIG"); + public static final ChannelOption MAX_DATAGRAM_PAYLOAD_SIZE = valueOf("MAX_DATAGRAM_PAYLOAD_SIZE"); + @SuppressWarnings({ "unused", "deprecation" }) private EpollChannelOption() { } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java index 8679c9e404..b6a5e0ff47 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java @@ -17,6 +17,7 @@ package io.netty.channel.epoll; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; import io.netty.channel.AddressedEnvelope; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelMetadata; @@ -33,6 +34,8 @@ import io.netty.channel.unix.Errors; import io.netty.channel.unix.IovArray; import io.netty.channel.unix.Socket; import io.netty.channel.unix.UnixChannelUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.StringUtil; import java.io.IOException; @@ -307,7 +310,7 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements try { // Check if sendmmsg(...) is supported which is only the case for GLIBC 2.14+ if (Native.IS_SUPPORTING_SENDMMSG && in.size() > 1) { - NativeDatagramPacketArray array = ((EpollEventLoop) eventLoop()).cleanDatagramPacketArray(); + NativeDatagramPacketArray array = cleanDatagramPacketArray(); in.forEachFlushedMessage(array); int cnt = array.count(); @@ -386,7 +389,7 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements } } else if (data.nioBufferCount() > 1) { IovArray array = ((EpollEventLoop) eventLoop()).cleanIovArray(); - array.add(data); + array.add(data, data.readerIndex(), data.readableBytes()); int cnt = array.count(); assert cnt != 0; @@ -486,73 +489,31 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements Throwable exception = null; try { - ByteBuf byteBuf = null; try { boolean connected = isConnected(); do { - byteBuf = allocHandle.allocate(allocator); - allocHandle.attemptedBytesRead(byteBuf.writableBytes()); - - final DatagramPacket packet; + ByteBuf byteBuf = allocHandle.allocate(allocator); + final boolean read; if (connected) { - try { - allocHandle.lastBytesRead(doReadBytes(byteBuf)); - } catch (Errors.NativeIoException e) { - // We need to correctly translate connect errors to match NIO behaviour. - if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) { - PortUnreachableException error = new PortUnreachableException(e.getMessage()); - error.initCause(e); - throw error; - } - throw e; - } - if (allocHandle.lastBytesRead() <= 0) { - // nothing was read, release the buffer. - byteBuf.release(); - byteBuf = null; - break; - } - packet = new DatagramPacket( - byteBuf, (InetSocketAddress) localAddress(), (InetSocketAddress) remoteAddress()); + read = connectedRead(allocHandle, byteBuf); } else { - final DatagramSocketAddress remoteAddress; - if (byteBuf.hasMemoryAddress()) { - // has a memory address so use optimized call - remoteAddress = socket.recvFromAddress(byteBuf.memoryAddress(), byteBuf.writerIndex(), - byteBuf.capacity()); + int datagramSize = config().getMaxDatagramPayloadSize(); + int numDatagram = datagramSize == 0 ? 1 : byteBuf.writableBytes() / datagramSize; + + if (numDatagram <= 1) { + read = read(allocHandle, byteBuf, datagramSize); } else { - ByteBuffer nioData = byteBuf.internalNioBuffer( - byteBuf.writerIndex(), byteBuf.writableBytes()); - remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit()); + // Try to use scattering reads via recvmmsg(...) syscall. + read = scatteringRead(allocHandle, byteBuf, datagramSize, numDatagram); } - - if (remoteAddress == null) { - allocHandle.lastBytesRead(-1); - byteBuf.release(); - byteBuf = null; - break; - } - InetSocketAddress localAddress = remoteAddress.localAddress(); - if (localAddress == null) { - localAddress = (InetSocketAddress) localAddress(); - } - allocHandle.lastBytesRead(remoteAddress.receivedAmount()); - byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead()); - - packet = new DatagramPacket(byteBuf, localAddress, remoteAddress); } - - allocHandle.incMessagesRead(1); - - readPending = false; - pipeline.fireChannelRead(packet); - - byteBuf = null; + if (read) { + readPending = false; + } else { + break; + } } while (allocHandle.continueReading()); } catch (Throwable t) { - if (byteBuf != null) { - byteBuf.release(); - } exception = t; } @@ -567,4 +528,146 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements } } } + + private boolean connectedRead(EpollRecvByteAllocatorHandle allocHandle, ByteBuf byteBuf) + throws Exception { + try { + allocHandle.attemptedBytesRead(byteBuf.writableBytes()); + + try { + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + } catch (Errors.NativeIoException e) { + // We need to correctly translate connect errors to match NIO behaviour. + if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) { + PortUnreachableException error = new PortUnreachableException(e.getMessage()); + error.initCause(e); + throw error; + } + throw e; + } + if (allocHandle.lastBytesRead() <= 0) { + // nothing was read, release the buffer. + return false; + } + DatagramPacket packet = new DatagramPacket(byteBuf, localAddress(), remoteAddress()); + allocHandle.incMessagesRead(1); + + pipeline().fireChannelRead(packet); + byteBuf = null; + return true; + } finally { + if (byteBuf != null) { + byteBuf.release(); + } + } + } + + private boolean scatteringRead(EpollRecvByteAllocatorHandle allocHandle, + ByteBuf byteBuf, int datagramSize, int numDatagram) throws IOException { + RecyclableArrayList bufferPackets = null; + try { + int offset = byteBuf.writerIndex(); + NativeDatagramPacketArray array = cleanDatagramPacketArray(); + + for (int i = 0; i < numDatagram; i++, offset += datagramSize) { + if (!array.addWritable(byteBuf, offset, datagramSize)) { + break; + } + } + + allocHandle.attemptedBytesRead(offset - byteBuf.writerIndex()); + + NativeDatagramPacketArray.NativeDatagramPacket[] packets = array.packets(); + int received = socket.recvmmsg(packets, 0, array.count()); + if (received == 0) { + allocHandle.lastBytesRead(-1); + return false; + } + int bytesReceived = received * datagramSize; + byteBuf.writerIndex(bytesReceived); + InetSocketAddress local = localAddress(); + if (received == 1) { + // Single packet fast-path + DatagramPacket packet = packets[0].newDatagramPacket(byteBuf, local); + allocHandle.lastBytesRead(datagramSize); + allocHandle.incMessagesRead(1); + pipeline().fireChannelRead(packet); + byteBuf = null; + return true; + } + + // Its important that we process all received data out of the NativeDatagramPacketArray + // before we call fireChannelRead(...). This is because the user may call flush() + // in a channelRead(...) method and so may re-use the NativeDatagramPacketArray again. + bufferPackets = RecyclableArrayList.newInstance(); + for (int i = 0; i < received; i++) { + DatagramPacket packet = packets[i].newDatagramPacket(byteBuf.readRetainedSlice(datagramSize), local); + bufferPackets.add(packet); + } + + allocHandle.lastBytesRead(bytesReceived); + allocHandle.incMessagesRead(received); + + for (int i = 0; i < received; i++) { + pipeline().fireChannelRead(bufferPackets.set(i, Unpooled.EMPTY_BUFFER)); + } + bufferPackets.recycle(); + bufferPackets = null; + return true; + } finally { + if (byteBuf != null) { + byteBuf.release(); + } + if (bufferPackets != null) { + for (int i = 0; i < bufferPackets.size(); i++) { + ReferenceCountUtil.release(bufferPackets.get(i)); + } + bufferPackets.recycle(); + } + } + } + + private boolean read(EpollRecvByteAllocatorHandle allocHandle, ByteBuf byteBuf, int maxDatagramPacketSize) + throws IOException { + try { + int writable = maxDatagramPacketSize != 0 ? Math.min(byteBuf.writableBytes(), maxDatagramPacketSize) + : byteBuf.writableBytes(); + allocHandle.attemptedBytesRead(writable); + int writerIndex = byteBuf.writerIndex(); + final DatagramSocketAddress remoteAddress; + if (byteBuf.hasMemoryAddress()) { + // has a memory address so use optimized call + remoteAddress = socket.recvFromAddress( + byteBuf.memoryAddress(), writerIndex, writerIndex + writable); + } else { + ByteBuffer nioData = byteBuf.internalNioBuffer(writerIndex, writable); + remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit()); + } + + if (remoteAddress == null) { + allocHandle.lastBytesRead(-1); + return false; + } + InetSocketAddress localAddress = remoteAddress.localAddress(); + if (localAddress == null) { + localAddress = localAddress(); + } + allocHandle.lastBytesRead(maxDatagramPacketSize <= 0 ? + remoteAddress.receivedAmount() : writable); + byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead()); + allocHandle.incMessagesRead(1); + + pipeline().fireChannelRead(new DatagramPacket(byteBuf, localAddress, remoteAddress)); + byteBuf = null; + return true; + } finally { + if (byteBuf != null) { + byteBuf.release(); + } + } + } + + private NativeDatagramPacketArray cleanDatagramPacketArray() { + return ((EpollEventLoop) eventLoop()).cleanDatagramPacketArray(); + } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java index 4f646329f8..e97d2c5eda 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannelConfig.java @@ -15,6 +15,7 @@ */ package io.netty.channel.epoll; +import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.channel.ChannelException; import io.netty.channel.ChannelOption; @@ -23,6 +24,7 @@ import io.netty.channel.MessageSizeEstimator; import io.netty.channel.RecvByteBufAllocator; import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.socket.DatagramChannelConfig; +import io.netty.util.internal.ObjectUtil; import java.io.IOException; import java.net.InetAddress; @@ -32,6 +34,7 @@ import java.util.Map; public final class EpollDatagramChannelConfig extends EpollChannelConfig implements DatagramChannelConfig { private static final RecvByteBufAllocator DEFAULT_RCVBUF_ALLOCATOR = new FixedRecvByteBufAllocator(2048); private boolean activeOnOpen; + private volatile int maxDatagramSize; EpollDatagramChannelConfig(EpollDatagramChannel channel) { super(channel); @@ -48,7 +51,7 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme ChannelOption.IP_MULTICAST_ADDR, ChannelOption.IP_MULTICAST_IF, ChannelOption.IP_MULTICAST_TTL, ChannelOption.IP_TOS, ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION, EpollChannelOption.SO_REUSEPORT, EpollChannelOption.IP_FREEBIND, EpollChannelOption.IP_TRANSPARENT, - EpollChannelOption.IP_RECVORIGDSTADDR); + EpollChannelOption.IP_RECVORIGDSTADDR, EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE); } @SuppressWarnings({ "unchecked", "deprecation" }) @@ -96,6 +99,9 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme if (option == EpollChannelOption.IP_RECVORIGDSTADDR) { return (T) Boolean.valueOf(isIpRecvOrigDestAddr()); } + if (option == EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE) { + return (T) Integer.valueOf(getMaxDatagramPayloadSize()); + } return super.getOption(option); } @@ -132,6 +138,8 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme setIpTransparent((Boolean) value); } else if (option == EpollChannelOption.IP_RECVORIGDSTADDR) { setIpRecvOrigDestAddr((Boolean) value); + } else if (option == EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE) { + setMaxDatagramPayloadSize((Integer) value); } else { return super.setOption(option, value); } @@ -499,4 +507,23 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme } } + /** + * Set the maximum {@link io.netty.channel.socket.DatagramPacket} size. This will be used to determine if + * {@code recvmmsg} should be used when reading from the underlying socket. When {@code recvmmsg} is used + * we may be able to read multiple {@link io.netty.channel.socket.DatagramPacket}s with one syscall and so + * greatly improve the performance. This number will be used to slice {@link ByteBuf}s returned by the used + * {@link RecvByteBufAllocator}. You can use {@code 0} to disable the usage of recvmmsg, any other bigger value + * will enable it. + */ + public EpollDatagramChannelConfig setMaxDatagramPayloadSize(int maxDatagramSize) { + this.maxDatagramSize = ObjectUtil.checkPositiveOrZero(maxDatagramSize, "maxDatagramSize"); + return this; + } + + /** + * Get the maximum {@link io.netty.channel.socket.DatagramPacket} size. + */ + public int getMaxDatagramPayloadSize() { + return maxDatagramSize; + } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java index 99cc6fe6b8..db2d2469c8 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java @@ -54,6 +54,11 @@ final class LinuxSocket extends Socket { return Native.sendmmsg(intValue(), ipv6, msgs, offset, len); } + int recvmmsg(NativeDatagramPacketArray.NativeDatagramPacket[] msgs, + int offset, int len) throws IOException { + return Native.recvmmsg(intValue(), ipv6, msgs, offset, len); + } + void setTimeToLive(int ttl) throws IOException { setTimeToLive(intValue(), ttl); } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java index b8d17f011a..40af0e742a 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java @@ -184,6 +184,18 @@ public final class Native { private static native int sendmmsg0( int fd, boolean ipv6, NativeDatagramPacketArray.NativeDatagramPacket[] msgs, int offset, int len); + static int recvmmsg(int fd, boolean ipv6, NativeDatagramPacketArray.NativeDatagramPacket[] msgs, + int offset, int len) throws IOException { + int res = recvmmsg0(fd, ipv6, msgs, offset, len); + if (res >= 0) { + return res; + } + return ioResult("recvmmsg", res); + } + + private static native int recvmmsg0( + int fd, boolean ipv6, NativeDatagramPacketArray.NativeDatagramPacket[] msgs, int offset, int len); + // epoll_event related public static native int sizeofEpollEvent(); public static native int offsetofEpollData(); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java index 6107af2c55..81f9ece9a4 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java @@ -19,12 +19,15 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.socket.DatagramPacket; import io.netty.channel.unix.IovArray; +import io.netty.channel.unix.Limits; + import java.net.Inet6Address; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.UnknownHostException; import static io.netty.channel.unix.Limits.UIO_MAX_IOV; -import static io.netty.channel.unix.NativeInetAddress.ipv4MappedIpv6Address; +import static io.netty.channel.unix.NativeInetAddress.copyIpv4MappedIpv6Address; /** * Support sendmmsg(...) on linux with GLIBC 2.14+ @@ -45,29 +48,30 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr } } - /** - * Try to add the given {@link DatagramPacket}. Returns {@code true} on success, - * {@code false} otherwise. - */ - boolean add(DatagramPacket packet) { + private boolean addReadable(DatagramPacket packet) { + ByteBuf buf = packet.content(); + return add0(buf, buf.readerIndex(), buf.readableBytes(), packet.recipient()); + } + + boolean addWritable(ByteBuf buf, int index, int len) { + return add0(buf, index, len, null); + } + + private boolean add0(ByteBuf buf, int index, int len, InetSocketAddress recipient) { if (count == packets.length) { - // We already filled up to UIO_MAX_IOV messages. This is the max allowed per sendmmsg(...) call, we will - // try again later. + // We already filled up to UIO_MAX_IOV messages. This is the max allowed per + // recvmmsg(...) / sendmmsg(...) call, we will try again later. return false; } - ByteBuf content = packet.content(); - int len = content.readableBytes(); if (len == 0) { return true; } - NativeDatagramPacket p = packets[count]; - InetSocketAddress recipient = packet.recipient(); - int offset = iovArray.count(); - if (!iovArray.add(content)) { + if (offset == Limits.IOV_MAX || !iovArray.add(buf, index, len)) { // Not enough space to hold the whole content, we will try again later. return false; } + NativeDatagramPacket p = packets[count]; p.init(iovArray.memoryAddress(offset), iovArray.count() - offset, recipient); count++; @@ -76,7 +80,7 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr @Override public boolean processMessage(Object msg) { - return msg instanceof DatagramPacket && add((DatagramPacket) msg); + return msg instanceof DatagramPacket && addReadable((DatagramPacket) msg); } /** @@ -112,7 +116,8 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr private long memoryAddress; private int count; - private byte[] addr; + private final byte[] addr = new byte[16]; + private int addrLen; private int scopeId; private int port; @@ -120,15 +125,33 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr this.memoryAddress = memoryAddress; this.count = count; - InetAddress address = recipient.getAddress(); - if (address instanceof Inet6Address) { - addr = address.getAddress(); - scopeId = ((Inet6Address) address).getScopeId(); + if (recipient == null) { + this.scopeId = 0; + this.port = 0; + this.addrLen = 0; } else { - addr = ipv4MappedIpv6Address(address.getAddress()); - scopeId = 0; + InetAddress address = recipient.getAddress(); + if (address instanceof Inet6Address) { + System.arraycopy(address.getAddress(), 0, addr, 0, addr.length); + scopeId = ((Inet6Address) address).getScopeId(); + } else { + copyIpv4MappedIpv6Address(address.getAddress(), addr); + scopeId = 0; + } + addrLen = addr.length; + port = recipient.getPort(); } - port = recipient.getPort(); + } + + DatagramPacket newDatagramPacket(ByteBuf buffer, InetSocketAddress localAddress) throws UnknownHostException { + final InetAddress address; + if (scopeId != 0) { + address = Inet6Address.getByAddress(null, addr, scopeId); + } else { + address = InetAddress.getByAddress(addr); + } + return new DatagramPacket(buffer.writerIndex(count), + localAddress, new InetSocketAddress(address, port)); } } } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java new file mode 100644 index 0000000000..46d1c840d4 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.AdaptiveRecvByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.DatagramPacket; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.AbstractDatagramTest; +import io.netty.util.internal.PlatformDependent; +import org.junit.Test; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class EpollDatagramScatteringReadTest extends AbstractDatagramTest { + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.epollOnlyDatagram(internetProtocolFamily()); + } + + @Test + public void testScatteringReadPartial() throws Throwable { + run(); + } + + public void testScatteringReadPartial(Bootstrap sb, Bootstrap cb) throws Throwable { + testScatteringRead(sb, cb, true); + } + + @Test + public void testScatteringRead() throws Throwable { + run(); + } + + public void testScatteringRead(Bootstrap sb, Bootstrap cb) throws Throwable { + testScatteringRead(sb, cb, false); + } + + private void testScatteringRead(Bootstrap sb, Bootstrap cb, boolean partial) throws Throwable { + int packetSize = 512; + int numPackets = 4; + sb.option(ChannelOption.RCVBUF_ALLOCATOR, new AdaptiveRecvByteBufAllocator( + packetSize, packetSize * (partial ? numPackets / 2 : numPackets), 64 * 1024)); + sb.option(EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE, packetSize); + + Channel sc = null; + Channel cc = null; + + try { + cb.handler(new SimpleChannelInboundHandler() { + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exception { + // Nothing will be sent. + } + }); + final AtomicReference errorRef = new AtomicReference(); + final byte[] bytes = new byte[packetSize]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + final CountDownLatch latch = new CountDownLatch(numPackets); + sb.handler(new SimpleChannelInboundHandler() { + private int counter; + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + assertTrue(counter > 1); + counter = 0; + ctx.read(); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + assertEquals(bytes.length, msg.content().readableBytes()); + byte[] receivedBytes = new byte[bytes.length]; + msg.content().readBytes(receivedBytes); + assertArrayEquals(bytes, receivedBytes); + + counter++; + latch.countDown(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + + sb.option(ChannelOption.AUTO_READ, false); + sc = sb.bind(newSocketAddress()).sync().channel(); + cc = cb.bind(newSocketAddress()).sync().channel(); + + InetSocketAddress addr = (InetSocketAddress) sc.localAddress(); + + List futures = new ArrayList(numPackets); + for (int i = 0; i < numPackets; i++) { + futures.add(cc.write(new DatagramPacket(cc.alloc().directBuffer().writeBytes(bytes), addr))); + } + + cc.flush(); + + for (ChannelFuture f: futures) { + f.sync(); + } + + // Enable autoread now which also triggers a read, this should cause scattering reads (recvmmsg) to happen. + sc.config().setAutoRead(true); + + if (!latch.await(10, TimeUnit.SECONDS)) { + Throwable error = errorRef.get(); + if (error != null) { + throw error; + } + fail("Timeout while waiting for packets"); + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + } + } + +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java index 1989c55c91..e5c5ed463a 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java @@ -159,6 +159,31 @@ class EpollSocketTestPermutation extends SocketTestPermutation { return combo(bfs, bfs); } + List> epollOnlyDatagram( + final InternetProtocolFamily family) { + return combo(Collections.singletonList(datagramBootstrapFactory(family)), + Collections.singletonList(datagramBootstrapFactory(family))); + } + + private BootstrapFactory datagramBootstrapFactory(final InternetProtocolFamily family) { + return new BootstrapFactory() { + @Override + public Bootstrap newInstance() { + return new Bootstrap().group(EPOLL_WORKER_GROUP).channelFactory(new ChannelFactory() { + @Override + public Channel newChannel() { + return new EpollDatagramChannel(family); + } + + @Override + public String toString() { + return InternetProtocolFamily.class.getSimpleName() + ".class"; + } + }); + } + }; + } + public List> domainSocket() { List> list = diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index 919db07240..c23f61a41c 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -323,7 +323,7 @@ public final class KQueueDatagramChannel extends AbstractKQueueChannel implement } } else if (data.nioBufferCount() > 1) { IovArray array = ((KQueueEventLoop) eventLoop()).cleanArray(); - array.add(data); + array.add(data, data.readerIndex(), data.readableBytes()); int cnt = array.count(); assert cnt != 0; diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java index e8b4067a86..77ec3a7a49 100644 --- a/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java @@ -16,7 +16,6 @@ package io.netty.channel.unix; import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; import io.netty.channel.ChannelOutboundBuffer.MessageProcessor; import io.netty.util.internal.PlatformDependent; @@ -78,33 +77,33 @@ public final class IovArray implements MessageProcessor { } /** - * Add a {@link ByteBuf} to this {@link IovArray}. - * @param buf The {@link ByteBuf} to add. - * @return {@code true} if the entire {@link ByteBuf} has been added to this {@link IovArray}. Note in the event - * that {@link ByteBuf} is a {@link CompositeByteBuf} {@code false} may be returned even if some of the components - * have been added. + * @deprecated Use {@link #add(ByteBuf, int, int)} */ + @Deprecated public boolean add(ByteBuf buf) { + return add(buf, buf.readerIndex(), buf.readableBytes()); + } + + public boolean add(ByteBuf buf, int offset, int len) { if (count == IOV_MAX) { // No more room! return false; } else if (buf.nioBufferCount() == 1) { - final int len = buf.readableBytes(); if (len == 0) { return true; } if (buf.hasMemoryAddress()) { - return add(buf.memoryAddress(), buf.readerIndex(), len); + return add(buf.memoryAddress() + offset, len); } else { - ByteBuffer nioBuffer = buf.internalNioBuffer(buf.readerIndex(), len); - return add(Buffer.memoryAddress(nioBuffer), nioBuffer.position(), len); + ByteBuffer nioBuffer = buf.internalNioBuffer(offset, len); + return add(Buffer.memoryAddress(nioBuffer) + nioBuffer.position(), len); } } else { - ByteBuffer[] buffers = buf.nioBuffers(); + ByteBuffer[] buffers = buf.nioBuffers(offset, len); for (ByteBuffer nioBuffer : buffers) { - final int len = nioBuffer.remaining(); - if (len != 0 && - (!add(Buffer.memoryAddress(nioBuffer), nioBuffer.position(), len) || count == IOV_MAX)) { + final int remaining = nioBuffer.remaining(); + if (remaining != 0 && + (!add(Buffer.memoryAddress(nioBuffer) + nioBuffer.position(), remaining) || count == IOV_MAX)) { return false; } } @@ -112,7 +111,7 @@ public final class IovArray implements MessageProcessor { } } - private boolean add(long addr, int offset, int len) { + private boolean add(long addr, int len) { assert addr != 0; // If there is at least 1 entry then we enforce the maximum bytes. We want to accept at least one entry so we @@ -135,19 +134,19 @@ public final class IovArray implements MessageProcessor { if (ADDRESS_SIZE == 8) { // 64bit if (PlatformDependent.hasUnsafe()) { - PlatformDependent.putLong(baseOffset + memoryAddress, addr + offset); + PlatformDependent.putLong(baseOffset + memoryAddress, addr); PlatformDependent.putLong(lengthOffset + memoryAddress, len); } else { - memory.putLong(baseOffset, addr + offset); + memory.putLong(baseOffset, addr); memory.putLong(lengthOffset, len); } } else { assert ADDRESS_SIZE == 4; if (PlatformDependent.hasUnsafe()) { - PlatformDependent.putInt(baseOffset + memoryAddress, (int) addr + offset); + PlatformDependent.putInt(baseOffset + memoryAddress, (int) addr); PlatformDependent.putInt(lengthOffset + memoryAddress, len); } else { - memory.putInt(baseOffset, (int) addr + offset); + memory.putInt(baseOffset, (int) addr); memory.putInt(lengthOffset, len); } } @@ -169,22 +168,22 @@ public final class IovArray implements MessageProcessor { } /** - * Set the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}. + * Set the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf, int, int)} *

* This will not impact the existing state of the {@link IovArray}, and only applies to subsequent calls to * {@link #add(ByteBuf)}. *

* In order to ensure some progress is made at least one {@link ByteBuf} will be accepted even if it's size exceeds * this value. - * @param maxBytes the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}. + * @param maxBytes the maximum amount of bytes that can be added to this {@link IovArray}. */ public void maxBytes(long maxBytes) { this.maxBytes = min(SSIZE_MAX, checkPositive(maxBytes, "maxBytes")); } /** - * Get the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}. - * @return the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}. + * Get the maximum amount of bytes that can be added to this {@link IovArray}. + * @return the maximum amount of bytes that can be added to this {@link IovArray}. */ public long maxBytes() { return maxBytes; @@ -206,7 +205,11 @@ public final class IovArray implements MessageProcessor { @Override public boolean processMessage(Object msg) throws Exception { - return msg instanceof ByteBuf && add((ByteBuf) msg); + if (msg instanceof ByteBuf) { + ByteBuf buffer = (ByteBuf) msg; + return add(buffer, buffer.readerIndex(), buffer.readableBytes()); + } + return false; } private static int idx(int index) { diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/NativeInetAddress.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/NativeInetAddress.java index 599f823122..3e76ab9837 100644 --- a/transport-native-unix-common/src/main/java/io/netty/channel/unix/NativeInetAddress.java +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/NativeInetAddress.java @@ -58,11 +58,15 @@ public final class NativeInetAddress { public static byte[] ipv4MappedIpv6Address(byte[] ipv4) { byte[] address = new byte[16]; - System.arraycopy(IPV4_MAPPED_IPV6_PREFIX, 0, address, 0, IPV4_MAPPED_IPV6_PREFIX.length); - System.arraycopy(ipv4, 0, address, 12, ipv4.length); + copyIpv4MappedIpv6Address(ipv4, address); return address; } + public static void copyIpv4MappedIpv6Address(byte[] ipv4, byte[] ipv6) { + System.arraycopy(IPV4_MAPPED_IPV6_PREFIX, 0, ipv6, 0, IPV4_MAPPED_IPV6_PREFIX.length); + System.arraycopy(ipv4, 0, ipv6, 12, ipv4.length); + } + public static InetSocketAddress address(byte[] addr, int offset, int len) { // The last 4 bytes are always the port final int port = decodeInt(addr, offset + len - 4);