Add support for recvmmsg when using epoll transport (#9509)

Motivation:

When using datagram sockets which need to handle a lot of packets it makes sense to use recvmmsg to be able to read multiple datagram packets with one syscall.

Modifications:

- Add support for recvmmsg on linux
- Add new EpollChannelOption.MAX_DATAGRAM_PACKET_SIZE
- Add tests

Result:

Fixes https://github.com/netty/netty/issues/8446.
This commit is contained in:
Norman Maurer 2019-09-03 08:40:17 +02:00
parent b08bbd4c20
commit b3e6e41384
12 changed files with 527 additions and 112 deletions

View File

@ -305,7 +305,7 @@ static jint netty_epoll_native_sendmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo
msg[i].msg_hdr.msg_namelen = addrSize;
msg[i].msg_hdr.msg_iov = (struct iovec*) (intptr_t) (*env)->GetLongField(env, packet, packetMemoryAddressFieldId);
msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId);;
msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId);
}
ssize_t res;
@ -321,6 +321,61 @@ static jint netty_epoll_native_sendmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo
return (jint) res;
}
static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jobjectArray packets, jint offset, jint len) {
struct mmsghdr msg[len];
memset(msg, 0, sizeof(msg));
struct sockaddr_storage addr[len];
int addrSize = sizeof(addr);
memset(addr, 0, addrSize);
int i;
for (i = 0; i < len; i++) {
jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset);
msg[i].msg_hdr.msg_iov = (struct iovec*) (intptr_t) (*env)->GetLongField(env, packet, packetMemoryAddressFieldId);
msg[i].msg_hdr.msg_iovlen = (*env)->GetIntField(env, packet, packetCountFieldId);
msg[i].msg_hdr.msg_name = addr + i;
msg[i].msg_hdr.msg_namelen = (socklen_t) addrSize;
}
ssize_t res;
int err;
do {
res = recvmmsg(fd, msg, len, 0, NULL);
// keep on reading if it was interrupted
} while (res == -1 && ((err = errno) == EINTR));
if (res < 0) {
return -err;
}
for (i = 0; i < res; i++) {
jobject packet = (*env)->GetObjectArrayElement(env, packets, i + offset);
jbyteArray address = (jbyteArray) (*env)->GetObjectField(env, packet, packetAddrFieldId);
(*env)->SetIntField(env, packet, packetCountFieldId, msg[i].msg_len);
struct sockaddr_storage* addr = (struct sockaddr_storage*) msg[i].msg_hdr.msg_name;
if (addr->ss_family == AF_INET) {
struct sockaddr_in* ipaddr = (struct sockaddr_in*) addr;
(*env)->SetByteArrayRegion(env, address, 0, 4, (jbyte*) &ipaddr->sin_addr.s_addr);
(*env)->SetIntField(env, packet, packetScopeIdFieldId, 0);
(*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ipaddr->sin_port));
} else {
struct sockaddr_in6* ip6addr = (struct sockaddr_in6*) addr;
(*env)->SetByteArrayRegion(env, address, 0, 16, (jbyte*) &ip6addr->sin6_addr.s6_addr);
(*env)->SetIntField(env, packet, packetScopeIdFieldId, ip6addr->sin6_scope_id);
(*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ip6addr->sin6_port));
}
}
return (jint) res;
}
static jstring netty_epoll_native_kernelVersion(JNIEnv* env, jclass clazz) {
struct utsname name;
@ -443,18 +498,26 @@ static const JNINativeMethod fixed_method_table[] = {
static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof(fixed_method_table[0]);
static jint dynamicMethodsTableSize() {
return fixed_method_table_size + 1; // 1 is for the dynamic method signatures.
return fixed_method_table_size + 2; // 2 is for the dynamic method signatures.
}
static JNINativeMethod* createDynamicMethodsTable(const char* packagePrefix) {
JNINativeMethod* dynamicMethods = malloc(sizeof(JNINativeMethod) * dynamicMethodsTableSize());
memcpy(dynamicMethods, fixed_method_table, sizeof(fixed_method_table));
char* dynamicTypeName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/epoll/NativeDatagramPacketArray$NativeDatagramPacket;II)I");
JNINativeMethod* dynamicMethod = &dynamicMethods[fixed_method_table_size];
dynamicMethod->name = "sendmmsg0";
dynamicMethod->signature = netty_unix_util_prepend("(IZ[L", dynamicTypeName);
dynamicMethod->fnPtr = (void *) netty_epoll_native_sendmmsg0;
free(dynamicTypeName);
dynamicTypeName = netty_unix_util_prepend(packagePrefix, "io/netty/channel/epoll/NativeDatagramPacketArray$NativeDatagramPacket;II)I");
dynamicMethod = &dynamicMethods[fixed_method_table_size + 1];
dynamicMethod->name = "recvmmsg0";
dynamicMethod->signature = netty_unix_util_prepend("(IZ[L", dynamicTypeName);
dynamicMethod->fnPtr = (void *) netty_epoll_native_recvmmsg0;
free(dynamicTypeName);
return dynamicMethods;
}

View File

@ -45,6 +45,8 @@ public final class EpollChannelOption<T> extends UnixChannelOption<T> {
public static final ChannelOption<Map<InetAddress, byte[]>> TCP_MD5SIG = valueOf("TCP_MD5SIG");
public static final ChannelOption<Integer> MAX_DATAGRAM_PAYLOAD_SIZE = valueOf("MAX_DATAGRAM_PAYLOAD_SIZE");
@SuppressWarnings({ "unused", "deprecation" })
private EpollChannelOption() {
}

View File

@ -17,6 +17,7 @@ package io.netty.channel.epoll;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelMetadata;
@ -34,6 +35,8 @@ import io.netty.channel.unix.Errors;
import io.netty.channel.unix.IovArray;
import io.netty.channel.unix.Socket;
import io.netty.channel.unix.UnixChannelUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.RecyclableArrayList;
import io.netty.util.internal.StringUtil;
import java.io.IOException;
@ -293,7 +296,7 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements
try {
// Check if sendmmsg(...) is supported which is only the case for GLIBC 2.14+
if (Native.IS_SUPPORTING_SENDMMSG && in.size() > 1) {
NativeDatagramPacketArray array = registration().cleanDatagramPacketArray();
NativeDatagramPacketArray array = cleanDatagramPacketArray();
in.forEachFlushedMessage(array);
int cnt = array.count();
@ -372,7 +375,7 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements
}
} else if (data.nioBufferCount() > 1) {
IovArray array = registration().cleanIovArray();
array.add(data);
array.add(data, data.readerIndex(), data.readableBytes());
int cnt = array.count();
assert cnt != 0;
@ -472,73 +475,31 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements
Throwable exception = null;
try {
ByteBuf byteBuf = null;
try {
boolean connected = isConnected();
do {
byteBuf = allocHandle.allocate(allocator);
allocHandle.attemptedBytesRead(byteBuf.writableBytes());
final DatagramPacket packet;
ByteBuf byteBuf = allocHandle.allocate(allocator);
final boolean read;
if (connected) {
try {
allocHandle.lastBytesRead(doReadBytes(byteBuf));
} catch (Errors.NativeIoException e) {
// We need to correctly translate connect errors to match NIO behaviour.
if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) {
PortUnreachableException error = new PortUnreachableException(e.getMessage());
error.initCause(e);
throw error;
}
throw e;
}
if (allocHandle.lastBytesRead() <= 0) {
// nothing was read, release the buffer.
byteBuf.release();
byteBuf = null;
break;
}
packet = new DatagramPacket(
byteBuf, (InetSocketAddress) localAddress(), (InetSocketAddress) remoteAddress());
read = connectedRead(allocHandle, byteBuf);
} else {
final DatagramSocketAddress remoteAddress;
if (byteBuf.hasMemoryAddress()) {
// has a memory address so use optimized call
remoteAddress = socket.recvFromAddress(byteBuf.memoryAddress(), byteBuf.writerIndex(),
byteBuf.capacity());
int datagramSize = config().getMaxDatagramPayloadSize();
int numDatagram = datagramSize == 0 ? 1 : byteBuf.writableBytes() / datagramSize;
if (numDatagram <= 1) {
read = read(allocHandle, byteBuf, datagramSize);
} else {
ByteBuffer nioData = byteBuf.internalNioBuffer(
byteBuf.writerIndex(), byteBuf.writableBytes());
remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit());
// Try to use scattering reads via recvmmsg(...) syscall.
read = scatteringRead(allocHandle, byteBuf, datagramSize, numDatagram);
}
if (remoteAddress == null) {
allocHandle.lastBytesRead(-1);
byteBuf.release();
byteBuf = null;
break;
}
InetSocketAddress localAddress = remoteAddress.localAddress();
if (localAddress == null) {
localAddress = (InetSocketAddress) localAddress();
}
allocHandle.lastBytesRead(remoteAddress.receivedAmount());
byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead());
packet = new DatagramPacket(byteBuf, localAddress, remoteAddress);
}
allocHandle.incMessagesRead(1);
if (read) {
readPending = false;
pipeline.fireChannelRead(packet);
byteBuf = null;
} else {
break;
}
} while (allocHandle.continueReading());
} catch (Throwable t) {
if (byteBuf != null) {
byteBuf.release();
}
exception = t;
}
@ -554,4 +515,146 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements
}
}
}
private boolean connectedRead(EpollRecvByteAllocatorHandle allocHandle, ByteBuf byteBuf)
throws Exception {
try {
allocHandle.attemptedBytesRead(byteBuf.writableBytes());
try {
allocHandle.lastBytesRead(doReadBytes(byteBuf));
} catch (Errors.NativeIoException e) {
// We need to correctly translate connect errors to match NIO behaviour.
if (e.expectedErr() == Errors.ERROR_ECONNREFUSED_NEGATIVE) {
PortUnreachableException error = new PortUnreachableException(e.getMessage());
error.initCause(e);
throw error;
}
throw e;
}
if (allocHandle.lastBytesRead() <= 0) {
// nothing was read, release the buffer.
return false;
}
DatagramPacket packet = new DatagramPacket(byteBuf, localAddress(), remoteAddress());
allocHandle.incMessagesRead(1);
pipeline().fireChannelRead(packet);
byteBuf = null;
return true;
} finally {
if (byteBuf != null) {
byteBuf.release();
}
}
}
private boolean scatteringRead(EpollRecvByteAllocatorHandle allocHandle,
ByteBuf byteBuf, int datagramSize, int numDatagram) throws IOException {
RecyclableArrayList bufferPackets = null;
try {
int offset = byteBuf.writerIndex();
NativeDatagramPacketArray array = cleanDatagramPacketArray();
for (int i = 0; i < numDatagram; i++, offset += datagramSize) {
if (!array.addWritable(byteBuf, offset, datagramSize)) {
break;
}
}
allocHandle.attemptedBytesRead(offset - byteBuf.writerIndex());
NativeDatagramPacketArray.NativeDatagramPacket[] packets = array.packets();
int received = socket.recvmmsg(packets, 0, array.count());
if (received == 0) {
allocHandle.lastBytesRead(-1);
return false;
}
int bytesReceived = received * datagramSize;
byteBuf.writerIndex(bytesReceived);
InetSocketAddress local = localAddress();
if (received == 1) {
// Single packet fast-path
DatagramPacket packet = packets[0].newDatagramPacket(byteBuf, local);
allocHandle.lastBytesRead(datagramSize);
allocHandle.incMessagesRead(1);
pipeline().fireChannelRead(packet);
byteBuf = null;
return true;
}
// Its important that we process all received data out of the NativeDatagramPacketArray
// before we call fireChannelRead(...). This is because the user may call flush()
// in a channelRead(...) method and so may re-use the NativeDatagramPacketArray again.
bufferPackets = RecyclableArrayList.newInstance();
for (int i = 0; i < received; i++) {
DatagramPacket packet = packets[i].newDatagramPacket(byteBuf.readRetainedSlice(datagramSize), local);
bufferPackets.add(packet);
}
allocHandle.lastBytesRead(bytesReceived);
allocHandle.incMessagesRead(received);
for (int i = 0; i < received; i++) {
pipeline().fireChannelRead(bufferPackets.set(i, Unpooled.EMPTY_BUFFER));
}
bufferPackets.recycle();
bufferPackets = null;
return true;
} finally {
if (byteBuf != null) {
byteBuf.release();
}
if (bufferPackets != null) {
for (int i = 0; i < bufferPackets.size(); i++) {
ReferenceCountUtil.release(bufferPackets.get(i));
}
bufferPackets.recycle();
}
}
}
private boolean read(EpollRecvByteAllocatorHandle allocHandle, ByteBuf byteBuf, int maxDatagramPacketSize)
throws IOException {
try {
int writable = maxDatagramPacketSize != 0 ? Math.min(byteBuf.writableBytes(), maxDatagramPacketSize)
: byteBuf.writableBytes();
allocHandle.attemptedBytesRead(writable);
int writerIndex = byteBuf.writerIndex();
final DatagramSocketAddress remoteAddress;
if (byteBuf.hasMemoryAddress()) {
// has a memory address so use optimized call
remoteAddress = socket.recvFromAddress(
byteBuf.memoryAddress(), writerIndex, writerIndex + writable);
} else {
ByteBuffer nioData = byteBuf.internalNioBuffer(writerIndex, writable);
remoteAddress = socket.recvFrom(nioData, nioData.position(), nioData.limit());
}
if (remoteAddress == null) {
allocHandle.lastBytesRead(-1);
return false;
}
InetSocketAddress localAddress = remoteAddress.localAddress();
if (localAddress == null) {
localAddress = localAddress();
}
allocHandle.lastBytesRead(maxDatagramPacketSize <= 0 ?
remoteAddress.receivedAmount() : writable);
byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead());
allocHandle.incMessagesRead(1);
pipeline().fireChannelRead(new DatagramPacket(byteBuf, localAddress, remoteAddress));
byteBuf = null;
return true;
} finally {
if (byteBuf != null) {
byteBuf.release();
}
}
}
private NativeDatagramPacketArray cleanDatagramPacketArray() {
return registration().cleanDatagramPacketArray();
}
}

View File

@ -15,6 +15,7 @@
*/
package io.netty.channel.epoll;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelOption;
@ -23,6 +24,7 @@ import io.netty.channel.MessageSizeEstimator;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.socket.DatagramChannelConfig;
import io.netty.util.internal.ObjectUtil;
import java.io.IOException;
import java.net.InetAddress;
@ -32,6 +34,7 @@ import java.util.Map;
public final class EpollDatagramChannelConfig extends EpollChannelConfig implements DatagramChannelConfig {
private static final RecvByteBufAllocator DEFAULT_RCVBUF_ALLOCATOR = new FixedRecvByteBufAllocator(2048);
private boolean activeOnOpen;
private volatile int maxDatagramSize;
EpollDatagramChannelConfig(EpollDatagramChannel channel) {
super(channel);
@ -48,7 +51,7 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme
ChannelOption.IP_MULTICAST_ADDR, ChannelOption.IP_MULTICAST_IF, ChannelOption.IP_MULTICAST_TTL,
ChannelOption.IP_TOS, ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION,
EpollChannelOption.SO_REUSEPORT, EpollChannelOption.IP_FREEBIND, EpollChannelOption.IP_TRANSPARENT,
EpollChannelOption.IP_RECVORIGDSTADDR);
EpollChannelOption.IP_RECVORIGDSTADDR, EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE);
}
@SuppressWarnings({ "unchecked", "deprecation" })
@ -96,6 +99,9 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme
if (option == EpollChannelOption.IP_RECVORIGDSTADDR) {
return (T) Boolean.valueOf(isIpRecvOrigDestAddr());
}
if (option == EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE) {
return (T) Integer.valueOf(getMaxDatagramPayloadSize());
}
return super.getOption(option);
}
@ -132,6 +138,8 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme
setIpTransparent((Boolean) value);
} else if (option == EpollChannelOption.IP_RECVORIGDSTADDR) {
setIpRecvOrigDestAddr((Boolean) value);
} else if (option == EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE) {
setMaxDatagramPayloadSize((Integer) value);
} else {
return super.setOption(option, value);
}
@ -499,4 +507,23 @@ public final class EpollDatagramChannelConfig extends EpollChannelConfig impleme
}
}
/**
* Set the maximum {@link io.netty.channel.socket.DatagramPacket} size. This will be used to determine if
* {@code recvmmsg} should be used when reading from the underlying socket. When {@code recvmmsg} is used
* we may be able to read multiple {@link io.netty.channel.socket.DatagramPacket}s with one syscall and so
* greatly improve the performance. This number will be used to slice {@link ByteBuf}s returned by the used
* {@link RecvByteBufAllocator}. You can use {@code 0} to disable the usage of recvmmsg, any other bigger value
* will enable it.
*/
public EpollDatagramChannelConfig setMaxDatagramPayloadSize(int maxDatagramSize) {
this.maxDatagramSize = ObjectUtil.checkPositiveOrZero(maxDatagramSize, "maxDatagramSize");
return this;
}
/**
* Get the maximum {@link io.netty.channel.socket.DatagramPacket} size.
*/
public int getMaxDatagramPayloadSize() {
return maxDatagramSize;
}
}

View File

@ -54,6 +54,11 @@ final class LinuxSocket extends Socket {
return Native.sendmmsg(intValue(), ipv6, msgs, offset, len);
}
int recvmmsg(NativeDatagramPacketArray.NativeDatagramPacket[] msgs,
int offset, int len) throws IOException {
return Native.recvmmsg(intValue(), ipv6, msgs, offset, len);
}
void setTimeToLive(int ttl) throws IOException {
setTimeToLive(intValue(), ttl);
}

View File

@ -170,6 +170,18 @@ public final class Native {
private static native int sendmmsg0(
int fd, boolean ipv6, NativeDatagramPacketArray.NativeDatagramPacket[] msgs, int offset, int len);
static int recvmmsg(int fd, boolean ipv6, NativeDatagramPacketArray.NativeDatagramPacket[] msgs,
int offset, int len) throws IOException {
int res = recvmmsg0(fd, ipv6, msgs, offset, len);
if (res >= 0) {
return res;
}
return ioResult("recvmmsg", res);
}
private static native int recvmmsg0(
int fd, boolean ipv6, NativeDatagramPacketArray.NativeDatagramPacket[] msgs, int offset, int len);
// epoll_event related
public static native int sizeofEpollEvent();
public static native int offsetofEpollData();

View File

@ -19,12 +19,15 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.unix.IovArray;
import io.netty.channel.unix.Limits;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import static io.netty.channel.unix.Limits.UIO_MAX_IOV;
import static io.netty.channel.unix.NativeInetAddress.ipv4MappedIpv6Address;
import static io.netty.channel.unix.NativeInetAddress.copyIpv4MappedIpv6Address;
/**
* Support <a href="http://linux.die.net/man/2/sendmmsg">sendmmsg(...)</a> on linux with GLIBC 2.14+
@ -45,29 +48,30 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr
}
}
/**
* Try to add the given {@link DatagramPacket}. Returns {@code true} on success,
* {@code false} otherwise.
*/
boolean add(DatagramPacket packet) {
private boolean addReadable(DatagramPacket packet) {
ByteBuf buf = packet.content();
return add0(buf, buf.readerIndex(), buf.readableBytes(), packet.recipient());
}
boolean addWritable(ByteBuf buf, int index, int len) {
return add0(buf, index, len, null);
}
private boolean add0(ByteBuf buf, int index, int len, InetSocketAddress recipient) {
if (count == packets.length) {
// We already filled up to UIO_MAX_IOV messages. This is the max allowed per sendmmsg(...) call, we will
// try again later.
// We already filled up to UIO_MAX_IOV messages. This is the max allowed per
// recvmmsg(...) / sendmmsg(...) call, we will try again later.
return false;
}
ByteBuf content = packet.content();
int len = content.readableBytes();
if (len == 0) {
return true;
}
NativeDatagramPacket p = packets[count];
InetSocketAddress recipient = packet.recipient();
int offset = iovArray.count();
if (!iovArray.add(content)) {
if (offset == Limits.IOV_MAX || !iovArray.add(buf, index, len)) {
// Not enough space to hold the whole content, we will try again later.
return false;
}
NativeDatagramPacket p = packets[count];
p.init(iovArray.memoryAddress(offset), iovArray.count() - offset, recipient);
count++;
@ -76,7 +80,7 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr
@Override
public boolean processMessage(Object msg) {
return msg instanceof DatagramPacket && add((DatagramPacket) msg);
return msg instanceof DatagramPacket && addReadable((DatagramPacket) msg);
}
/**
@ -112,7 +116,8 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr
private long memoryAddress;
private int count;
private byte[] addr;
private final byte[] addr = new byte[16];
private int addrLen;
private int scopeId;
private int port;
@ -120,15 +125,33 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr
this.memoryAddress = memoryAddress;
this.count = count;
if (recipient == null) {
this.scopeId = 0;
this.port = 0;
this.addrLen = 0;
} else {
InetAddress address = recipient.getAddress();
if (address instanceof Inet6Address) {
addr = address.getAddress();
System.arraycopy(address.getAddress(), 0, addr, 0, addr.length);
scopeId = ((Inet6Address) address).getScopeId();
} else {
addr = ipv4MappedIpv6Address(address.getAddress());
copyIpv4MappedIpv6Address(address.getAddress(), addr);
scopeId = 0;
}
addrLen = addr.length;
port = recipient.getPort();
}
}
DatagramPacket newDatagramPacket(ByteBuf buffer, InetSocketAddress localAddress) throws UnknownHostException {
final InetAddress address;
if (scopeId != 0) {
address = Inet6Address.getByAddress(null, addr, scopeId);
} else {
address = InetAddress.getByAddress(addr);
}
return new DatagramPacket(buffer.writerIndex(count),
localAddress, new InetSocketAddress(address, port));
}
}
}

View File

@ -0,0 +1,153 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.epoll;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.AdaptiveRecvByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.DatagramPacket;
import io.netty.testsuite.transport.TestsuitePermutation;
import io.netty.testsuite.transport.socket.AbstractDatagramTest;
import org.junit.Test;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class EpollDatagramScatteringReadTest extends AbstractDatagramTest {
@Override
protected List<TestsuitePermutation.BootstrapComboFactory<Bootstrap, Bootstrap>> newFactories() {
return EpollSocketTestPermutation.INSTANCE.epollOnlyDatagram(internetProtocolFamily());
}
@Test
public void testScatteringReadPartial() throws Throwable {
run();
}
public void testScatteringReadPartial(Bootstrap sb, Bootstrap cb) throws Throwable {
testScatteringRead(sb, cb, true);
}
@Test
public void testScatteringRead() throws Throwable {
run();
}
public void testScatteringRead(Bootstrap sb, Bootstrap cb) throws Throwable {
testScatteringRead(sb, cb, false);
}
private void testScatteringRead(Bootstrap sb, Bootstrap cb, boolean partial) throws Throwable {
int packetSize = 512;
int numPackets = 4;
sb.option(ChannelOption.RCVBUF_ALLOCATOR, new AdaptiveRecvByteBufAllocator(
packetSize, packetSize * (partial ? numPackets / 2 : numPackets), 64 * 1024));
sb.option(EpollChannelOption.MAX_DATAGRAM_PAYLOAD_SIZE, packetSize);
Channel sc = null;
Channel cc = null;
try {
cb.handler(new SimpleChannelInboundHandler<Object>() {
@Override
public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exception {
// Nothing will be sent.
}
});
final AtomicReference<Throwable> errorRef = new AtomicReference<Throwable>();
final byte[] bytes = new byte[packetSize];
ThreadLocalRandom.current().nextBytes(bytes);
final CountDownLatch latch = new CountDownLatch(numPackets);
sb.handler(new SimpleChannelInboundHandler<DatagramPacket>() {
private int counter;
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
assertTrue(counter > 1);
counter = 0;
ctx.read();
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) {
assertEquals(bytes.length, msg.content().readableBytes());
byte[] receivedBytes = new byte[bytes.length];
msg.content().readBytes(receivedBytes);
assertArrayEquals(bytes, receivedBytes);
counter++;
latch.countDown();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
errorRef.compareAndSet(null, cause);
}
});
sb.option(ChannelOption.AUTO_READ, false);
sc = sb.bind(newSocketAddress()).sync().channel();
cc = cb.bind(newSocketAddress()).sync().channel();
InetSocketAddress addr = (InetSocketAddress) sc.localAddress();
List<ChannelFuture> futures = new ArrayList<ChannelFuture>(numPackets);
for (int i = 0; i < numPackets; i++) {
futures.add(cc.write(new DatagramPacket(cc.alloc().directBuffer().writeBytes(bytes), addr)));
}
cc.flush();
for (ChannelFuture f: futures) {
f.sync();
}
// Enable autoread now which also triggers a read, this should cause scattering reads (recvmmsg) to happen.
sc.config().setAutoRead(true);
if (!latch.await(10, TimeUnit.SECONDS)) {
Throwable error = errorRef.get();
if (error != null) {
throw error;
}
fail("Timeout while waiting for packets");
}
} finally {
if (cc != null) {
cc.close().syncUninterruptibly();
}
if (sc != null) {
sc.close().syncUninterruptibly();
}
}
}
}

View File

@ -130,6 +130,26 @@ class EpollSocketTestPermutation extends SocketTestPermutation {
return combo(bfs, bfs);
}
List<TestsuitePermutation.BootstrapComboFactory<Bootstrap, Bootstrap>> epollOnlyDatagram(
final InternetProtocolFamily family) {
return combo(Collections.singletonList(datagramBootstrapFactory(family)),
Collections.singletonList(datagramBootstrapFactory(family)));
}
private BootstrapFactory<Bootstrap> datagramBootstrapFactory(final InternetProtocolFamily family) {
return () -> new Bootstrap().group(EPOLL_WORKER_GROUP).channelFactory(new ChannelFactory<Channel>() {
@Override
public Channel newChannel(EventLoop eventLoop) {
return new EpollDatagramChannel(eventLoop, family);
}
@Override
public String toString() {
return InternetProtocolFamily.class.getSimpleName() + ".class";
}
});
}
public List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> domainSocket() {
List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> list =

View File

@ -308,7 +308,7 @@ public final class KQueueDatagramChannel extends AbstractKQueueChannel implement
}
} else if (data.nioBufferCount() > 1) {
IovArray array = registration().cleanArray();
array.add(data);
array.add(data, data.readerIndex(), data.readableBytes());
int cnt = array.count();
assert cnt != 0;

View File

@ -16,7 +16,6 @@
package io.netty.channel.unix;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelOutboundBuffer.MessageProcessor;
import io.netty.util.internal.PlatformDependent;
@ -78,33 +77,33 @@ public final class IovArray implements MessageProcessor {
}
/**
* Add a {@link ByteBuf} to this {@link IovArray}.
* @param buf The {@link ByteBuf} to add.
* @return {@code true} if the entire {@link ByteBuf} has been added to this {@link IovArray}. Note in the event
* that {@link ByteBuf} is a {@link CompositeByteBuf} {@code false} may be returned even if some of the components
* have been added.
* @deprecated Use {@link #add(ByteBuf, int, int)}
*/
@Deprecated
public boolean add(ByteBuf buf) {
return add(buf, buf.readerIndex(), buf.readableBytes());
}
public boolean add(ByteBuf buf, int offset, int len) {
if (count == IOV_MAX) {
// No more room!
return false;
} else if (buf.nioBufferCount() == 1) {
final int len = buf.readableBytes();
if (len == 0) {
return true;
}
if (buf.hasMemoryAddress()) {
return add(buf.memoryAddress(), buf.readerIndex(), len);
return add(buf.memoryAddress() + offset, len);
} else {
ByteBuffer nioBuffer = buf.internalNioBuffer(buf.readerIndex(), len);
return add(Buffer.memoryAddress(nioBuffer), nioBuffer.position(), len);
ByteBuffer nioBuffer = buf.internalNioBuffer(offset, len);
return add(Buffer.memoryAddress(nioBuffer) + nioBuffer.position(), len);
}
} else {
ByteBuffer[] buffers = buf.nioBuffers();
ByteBuffer[] buffers = buf.nioBuffers(offset, len);
for (ByteBuffer nioBuffer : buffers) {
final int len = nioBuffer.remaining();
if (len != 0 &&
(!add(Buffer.memoryAddress(nioBuffer), nioBuffer.position(), len) || count == IOV_MAX)) {
final int remaining = nioBuffer.remaining();
if (remaining != 0 &&
(!add(Buffer.memoryAddress(nioBuffer) + nioBuffer.position(), remaining) || count == IOV_MAX)) {
return false;
}
}
@ -112,7 +111,7 @@ public final class IovArray implements MessageProcessor {
}
}
private boolean add(long addr, int offset, int len) {
private boolean add(long addr, int len) {
assert addr != 0;
// If there is at least 1 entry then we enforce the maximum bytes. We want to accept at least one entry so we
@ -135,19 +134,19 @@ public final class IovArray implements MessageProcessor {
if (ADDRESS_SIZE == 8) {
// 64bit
if (PlatformDependent.hasUnsafe()) {
PlatformDependent.putLong(baseOffset + memoryAddress, addr + offset);
PlatformDependent.putLong(baseOffset + memoryAddress, addr);
PlatformDependent.putLong(lengthOffset + memoryAddress, len);
} else {
memory.putLong(baseOffset, addr + offset);
memory.putLong(baseOffset, addr);
memory.putLong(lengthOffset, len);
}
} else {
assert ADDRESS_SIZE == 4;
if (PlatformDependent.hasUnsafe()) {
PlatformDependent.putInt(baseOffset + memoryAddress, (int) addr + offset);
PlatformDependent.putInt(baseOffset + memoryAddress, (int) addr);
PlatformDependent.putInt(lengthOffset + memoryAddress, len);
} else {
memory.putInt(baseOffset, (int) addr + offset);
memory.putInt(baseOffset, (int) addr);
memory.putInt(lengthOffset, len);
}
}
@ -169,22 +168,22 @@ public final class IovArray implements MessageProcessor {
}
/**
* Set the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}.
* Set the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf, int, int)}
* <p>
* This will not impact the existing state of the {@link IovArray}, and only applies to subsequent calls to
* {@link #add(ByteBuf)}.
* <p>
* In order to ensure some progress is made at least one {@link ByteBuf} will be accepted even if it's size exceeds
* this value.
* @param maxBytes the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}.
* @param maxBytes the maximum amount of bytes that can be added to this {@link IovArray}.
*/
public void maxBytes(long maxBytes) {
this.maxBytes = min(SSIZE_MAX, checkPositive(maxBytes, "maxBytes"));
}
/**
* Get the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}.
* @return the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf)}.
* Get the maximum amount of bytes that can be added to this {@link IovArray}.
* @return the maximum amount of bytes that can be added to this {@link IovArray}.
*/
public long maxBytes() {
return maxBytes;
@ -206,7 +205,11 @@ public final class IovArray implements MessageProcessor {
@Override
public boolean processMessage(Object msg) throws Exception {
return msg instanceof ByteBuf && add((ByteBuf) msg);
if (msg instanceof ByteBuf) {
ByteBuf buffer = (ByteBuf) msg;
return add(buffer, buffer.readerIndex(), buffer.readableBytes());
}
return false;
}
private static int idx(int index) {

View File

@ -58,11 +58,15 @@ public final class NativeInetAddress {
public static byte[] ipv4MappedIpv6Address(byte[] ipv4) {
byte[] address = new byte[16];
System.arraycopy(IPV4_MAPPED_IPV6_PREFIX, 0, address, 0, IPV4_MAPPED_IPV6_PREFIX.length);
System.arraycopy(ipv4, 0, address, 12, ipv4.length);
copyIpv4MappedIpv6Address(ipv4, address);
return address;
}
public static void copyIpv4MappedIpv6Address(byte[] ipv4, byte[] ipv6) {
System.arraycopy(IPV4_MAPPED_IPV6_PREFIX, 0, ipv6, 0, IPV4_MAPPED_IPV6_PREFIX.length);
System.arraycopy(ipv4, 0, ipv6, 12, ipv4.length);
}
public static InetSocketAddress address(byte[] addr, int offset, int len) {
// The last 4 bytes are always the port
final int port = decodeInt(addr, offset + len - 4);