diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java index 37adc20e38..6b009e7d27 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java @@ -30,6 +30,7 @@ import io.netty.channel.socket.DatagramPacket; import org.junit.Test; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.nio.channels.NotYetConnectedException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -157,14 +158,19 @@ public class DatagramUnicastTest extends AbstractDatagramTest { } }); - final CountDownLatch latch = new CountDownLatch(count); - sc = setupServerChannel(sb, bytes, latch, false); + final SocketAddress sender; if (bindClient) { cc = cb.bind(newSocketAddress()).sync().channel(); + sender = cc.localAddress(); } else { cb.option(ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION, true); cc = cb.register().sync().channel(); + sender = null; } + + final CountDownLatch latch = new CountDownLatch(count); + sc = setupServerChannel(sb, bytes, sender, latch, false); + InetSocketAddress addr = (InetSocketAddress) sc.localAddress(); for (int i = 0; i < count; i++) { switch (wrapType) { @@ -231,8 +237,10 @@ public class DatagramUnicastTest extends AbstractDatagramTest { DatagramChannel cc = null; try { final CountDownLatch latch = new CountDownLatch(count); - sc = setupServerChannel(sb, bytes, latch, true); - cc = (DatagramChannel) cb.connect(sc.localAddress()).sync().channel(); + cc = (DatagramChannel) cb.bind(newSocketAddress()).sync().channel(); + sc = setupServerChannel(sb, bytes, cc.localAddress(), latch, true); + + cc.connect(sc.localAddress()).syncUninterruptibly(); for (int i = 0; i < count; i++) { switch (wrapType) { @@ -275,7 +283,8 @@ public class DatagramUnicastTest extends AbstractDatagramTest { } @SuppressWarnings("deprecation") - private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final CountDownLatch latch, final boolean echo) + private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final SocketAddress sender, + final CountDownLatch latch, final boolean echo) throws Throwable { sb.handler(new ChannelInitializer() { @Override @@ -283,6 +292,12 @@ public class DatagramUnicastTest extends AbstractDatagramTest { ch.pipeline().addLast(new SimpleChannelInboundHandler() { @Override public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception { + if (sender == null) { + assertNotNull(msg.sender()); + } else { + assertEquals(sender, msg.sender()); + } + ByteBuf buf = msg.content(); assertEquals(bytes.length, buf.readableBytes()); for (int i = 0; i < bytes.length; i++) { diff --git a/transport-native-epoll/src/main/c/netty_epoll_native.c b/transport-native-epoll/src/main/c/netty_epoll_native.c index 123e2fdb21..69c3e915ee 100644 --- a/transport-native-epoll/src/main/c/netty_epoll_native.c +++ b/transport-native-epoll/src/main/c/netty_epoll_native.c @@ -67,6 +67,7 @@ struct mmsghdr { // Those are initialized in the init(...) method and cached for performance reasons static jfieldID packetAddrFieldId = NULL; +static jfieldID packetAddrLenFieldId = NULL; static jfieldID packetScopeIdFieldId = NULL; static jfieldID packetPortFieldId = NULL; static jfieldID packetMemoryAddressFieldId = NULL; @@ -362,12 +363,20 @@ static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo struct sockaddr_in* ipaddr = (struct sockaddr_in*) addr; (*env)->SetByteArrayRegion(env, address, 0, 4, (jbyte*) &ipaddr->sin_addr.s_addr); + (*env)->SetIntField(env, packet, packetAddrLenFieldId, 4); (*env)->SetIntField(env, packet, packetScopeIdFieldId, 0); (*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ipaddr->sin_port)); } else { + int addrLen = netty_unix_socket_ipAddressLength(addr); struct sockaddr_in6* ip6addr = (struct sockaddr_in6*) addr; - (*env)->SetByteArrayRegion(env, address, 0, 16, (jbyte*) &ip6addr->sin6_addr.s6_addr); + if (addrLen == 4) { + // IPV4 mapped IPV6 address + (*env)->SetByteArrayRegion(env, address, 12, 4, (jbyte*) &ip6addr->sin6_addr.s6_addr); + } else { + (*env)->SetByteArrayRegion(env, address, 0, 16, (jbyte*) &ip6addr->sin6_addr.s6_addr); + } + (*env)->SetIntField(env, packet, packetAddrLenFieldId, addrLen); (*env)->SetIntField(env, packet, packetScopeIdFieldId, ip6addr->sin6_scope_id); (*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ip6addr->sin6_port)); } @@ -605,6 +614,11 @@ static jint netty_epoll_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.addr"); goto error; } + packetAddrLenFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "addrLen", "I"); + if (packetAddrLenFieldId == NULL) { + netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.addrLen"); + goto error; + } packetScopeIdFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "scopeId", "I"); if (packetScopeIdFieldId == NULL) { netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.scopeId"); @@ -649,6 +663,7 @@ error: netty_epoll_linuxsocket_JNI_OnUnLoad(env); } packetAddrFieldId = NULL; + packetAddrLenFieldId = NULL; packetScopeIdFieldId = NULL; packetPortFieldId = NULL; packetMemoryAddressFieldId = NULL; @@ -666,6 +681,7 @@ static void netty_epoll_native_JNI_OnUnLoad(JNIEnv* env) { netty_epoll_linuxsocket_JNI_OnUnLoad(env); packetAddrFieldId = NULL; + packetAddrLenFieldId = NULL; packetScopeIdFieldId = NULL; packetPortFieldId = NULL; packetMemoryAddressFieldId = NULL; diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java index 81f9ece9a4..0e97cd52ac 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/NativeDatagramPacketArray.java @@ -40,6 +40,10 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr // We share one IovArray for all NativeDatagramPackets to reduce memory overhead. This will allow us to write // up to IOV_MAX iovec across all messages in one sendmmsg(...) call. private final IovArray iovArray = new IovArray(); + + // temporary array to copy the ipv4 part of ipv6-mapped-ipv4 addresses and then create a Inet4Address out of it. + private final byte[] ipv4Bytes = new byte[4]; + private int count; NativeDatagramPacketArray() { @@ -110,13 +114,14 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr * Used to pass needed data to JNI. */ @SuppressWarnings("unused") - static final class NativeDatagramPacket { + final class NativeDatagramPacket { // This is the actual struct iovec* private long memoryAddress; private int count; private final byte[] addr = new byte[16]; + private int addrLen; private int scopeId; private int port; @@ -145,10 +150,11 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr DatagramPacket newDatagramPacket(ByteBuf buffer, InetSocketAddress localAddress) throws UnknownHostException { final InetAddress address; - if (scopeId != 0) { - address = Inet6Address.getByAddress(null, addr, scopeId); + if (addrLen == ipv4Bytes.length) { + System.arraycopy(addr, 0, ipv4Bytes, 0, addrLen); + address = InetAddress.getByAddress(ipv4Bytes); } else { - address = InetAddress.getByAddress(addr); + address = Inet6Address.getByAddress(null, addr, scopeId); } return new DatagramPacket(buffer.writerIndex(count), localAddress, new InetSocketAddress(address, port)); diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java index 9fe2a49186..869f6587ed 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramScatteringReadTest.java @@ -28,6 +28,7 @@ import io.netty.testsuite.transport.socket.AbstractDatagramTest; import org.junit.Test; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -82,6 +83,9 @@ public class EpollDatagramScatteringReadTest extends AbstractDatagramTest { // Nothing will be sent. } }); + cc = cb.bind(newSocketAddress()).sync().channel(); + final SocketAddress ccAddress = cc.localAddress(); + final AtomicReference errorRef = new AtomicReference(); final byte[] bytes = new byte[packetSize]; ThreadLocalRandom.current().nextBytes(bytes); @@ -98,6 +102,8 @@ public class EpollDatagramScatteringReadTest extends AbstractDatagramTest { @Override protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + assertEquals(ccAddress, msg.sender()); + assertEquals(bytes.length, msg.content().readableBytes()); byte[] receivedBytes = new byte[bytes.length]; msg.content().readBytes(receivedBytes); @@ -115,7 +121,6 @@ public class EpollDatagramScatteringReadTest extends AbstractDatagramTest { sb.option(ChannelOption.AUTO_READ, false); sc = sb.bind(newSocketAddress()).sync().channel(); - cc = cb.bind(newSocketAddress()).sync().channel(); InetSocketAddress addr = (InetSocketAddress) sc.localAddress(); diff --git a/transport-native-unix-common/src/main/c/netty_unix_socket.c b/transport-native-unix-common/src/main/c/netty_unix_socket.c index 67f22058fa..db4b2914ae 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_socket.c +++ b/transport-native-unix-common/src/main/c/netty_unix_socket.c @@ -43,6 +43,7 @@ static jclass inetSocketAddressClass = NULL; static int socketType = AF_INET; static const unsigned char wildcardAddress[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; static const unsigned char ipv4MappedWildcardAddress[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 }; +static const unsigned char ipv4MappedAddress[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff }; // Optional external methods extern int accept4(int sockFd, struct sockaddr* addr, socklen_t* addrlen, int flags) __attribute__((weak)) __attribute__((weak_import)); @@ -67,15 +68,13 @@ static int nettyNonBlockingSocket(int domain, int type, int protocol) { #endif } -static int ipAddressLength(const struct sockaddr_storage* addr) { +int netty_unix_socket_ipAddressLength(const struct sockaddr_storage* addr) { if (addr->ss_family == AF_INET) { return 4; } struct sockaddr_in6* s = (struct sockaddr_in6*) addr; - if (s->sin6_addr.s6_addr[11] == 0xff && s->sin6_addr.s6_addr[10] == 0xff && - s->sin6_addr.s6_addr[9] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[5] == 0x00 && - s->sin6_addr.s6_addr[4] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[0] == 0x00) { - // IPv4-mapped-on-IPv6 + if (memcmp(s->sin6_addr.s6_addr, ipv4MappedAddress, 12) == 0) { + // IPv4-mapped-on-IPv6 return 4; } return 16; @@ -84,7 +83,7 @@ static int ipAddressLength(const struct sockaddr_storage* addr) { static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, int len, jobject local) { int port; int scopeId; - int ipLength = ipAddressLength(addr); + int ipLength = netty_unix_socket_ipAddressLength(addr); jbyteArray addressBytes = (*env)->NewByteArray(env, ipLength); if (addr->ss_family == AF_INET) { @@ -114,7 +113,7 @@ static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_st } static jsize addressLength(const struct sockaddr_storage* addr) { - int len = ipAddressLength(addr); + int len = netty_unix_socket_ipAddressLength(addr); if (len == 4) { // Only encode port into it return len + 4; diff --git a/transport-native-unix-common/src/main/c/netty_unix_socket.h b/transport-native-unix-common/src/main/c/netty_unix_socket.h index 88a03b2d16..e24de06b54 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_socket.h +++ b/transport-native-unix-common/src/main/c/netty_unix_socket.h @@ -23,6 +23,7 @@ int netty_unix_socket_initSockaddr(JNIEnv* env, jboolean ipv6, jbyteArray address, jint scopeId, jint jport, const struct sockaddr_storage* addr, socklen_t* addrSize); int netty_unix_socket_getOption(JNIEnv* env, jint fd, int level, int optname, void* optval, socklen_t optlen); int netty_unix_socket_setOption(JNIEnv* env, jint fd, int level, int optname, const void* optval, socklen_t len); +int netty_unix_socket_ipAddressLength(const struct sockaddr_storage* addr); // These method is sometimes needed if you want to special handle some errno value before throwing an exception. int netty_unix_socket_getOption0(jint fd, int level, int optname, void* optval, socklen_t optlen);