From 88dae15bc20e51f33d22f28ad84b8e62bf7d116d Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Thu, 19 Feb 2015 13:39:00 +0100 Subject: [PATCH] Directly receive remote address when call accept(...) Motivation: There is a small race in the native transport where an accept(...) may success but a later try to obtain the remote address from the fd may fail is the fd is already closed. Modifications: Let accept(...) directly set the remote address. Result: No more race possible. --- .../main/c/io_netty_channel_epoll_Native.c | 65 +++++++++++++------ .../main/c/io_netty_channel_epoll_Native.h | 2 +- .../epoll/AbstractEpollServerChannel.java | 11 +++- .../epoll/EpollServerDomainSocketChannel.java | 2 +- .../epoll/EpollServerSocketChannel.java | 4 +- .../channel/epoll/EpollSocketChannel.java | 6 +- .../java/io/netty/channel/epoll/Native.java | 19 +++--- .../io/netty/channel/epoll/NativeTest.java | 4 +- 8 files changed, 71 insertions(+), 42 deletions(-) diff --git a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c index 24231c1443..ef4d1f0fb3 100644 --- a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c +++ b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.c @@ -173,7 +173,24 @@ jobject createInetSocketAddress(JNIEnv* env, struct sockaddr_storage addr) { return socketAddr; } -static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) { + +static inline int addressLength(struct sockaddr_storage addr) { + if (addr.ss_family == AF_INET) { + return 8; + } else { + struct sockaddr_in6* s = (struct sockaddr_in6*) &addr; + if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00 + && s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00 + && s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) { + // IPv4-mapped-on-IPv6 + return 8; + } else { + return 24; + } + } +} + +static inline void initInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr, jbyteArray bArray, int offset, int len) { int port; if (addr.ss_family == AF_INET) { struct sockaddr_in* s = (struct sockaddr_in*) &addr; @@ -190,17 +207,12 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor a[6] = port >> 8; a[7] = port; - jbyteArray bArray = (*env)->NewByteArray(env, 8); - (*env)->SetByteArrayRegion(env, bArray, 0, 8, (jbyte*) &a); - - return bArray; + (*env)->SetByteArrayRegion(env, bArray, offset, 8, (jbyte*) &a); } else { struct sockaddr_in6* s = (struct sockaddr_in6*) &addr; port = ntohs(s->sin6_port); - if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00 - && s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00 - && s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) { + if (len == 8) { // IPv4-mapped-on-IPv6 // Encode port into the array and write it into the jbyteArray unsigned char a[4]; @@ -209,12 +221,9 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor a[2] = port >> 8; a[3] = port; - jbyteArray bArray = (*env)->NewByteArray(env, 8); // we only need the last 4 bytes for mapped address - (*env)->SetByteArrayRegion(env, bArray, 0, 4, (jbyte*) &(s->sin6_addr.s6_addr[12])); - (*env)->SetByteArrayRegion(env, bArray, 4, 4, (jbyte*) &a); - - return bArray; + (*env)->SetByteArrayRegion(env, bArray, offset, 4, (jbyte*) &(s->sin6_addr.s6_addr[12])); + (*env)->SetByteArrayRegion(env, bArray, offset + 4, 4, (jbyte*) &a); } else { // Encode scopeid and port into the array unsigned char a[8]; @@ -227,14 +236,20 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor a[6] = port >> 8; a[7] = port; - jbyteArray bArray = (*env)->NewByteArray(env, 24); - (*env)->SetByteArrayRegion(env, bArray, 0, 16, (jbyte*) &(s->sin6_addr.s6_addr)); - (*env)->SetByteArrayRegion(env, bArray, 16, 8, (jbyte*) &a); - return bArray; + (*env)->SetByteArrayRegion(env, bArray, offset, 16, (jbyte*) &(s->sin6_addr.s6_addr)); + (*env)->SetByteArrayRegion(env, bArray, offset + 16, 8, (jbyte*) &a); } } } +static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) { + int len = addressLength(addr); + jbyteArray bArray = (*env)->NewByteArray(env, len); + + initInetSocketAddressArray(env, addr, bArray, 0, len); + return bArray; +} + static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) { char ipstr[INET6_ADDRSTRLEN]; int port; @@ -537,6 +552,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count"); return JNI_ERR; } + return JNI_VERSION_1_6; } } @@ -1018,21 +1034,30 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* return -optval; } -JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd) { +JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress) { jint socketFd; int err; + struct sockaddr_storage addr; + socklen_t address_len = sizeof(addr); do { if (accept4) { - socketFd = accept4(fd, NULL, 0, SOCK_NONBLOCK | SOCK_CLOEXEC); + socketFd = accept4(fd, (struct sockaddr*) &addr, &address_len, SOCK_NONBLOCK | SOCK_CLOEXEC); } else { - socketFd = accept(fd, NULL, 0); + socketFd = accept(fd, (struct sockaddr*) &addr, &address_len); } } while (socketFd == -1 && ((err = errno) == EINTR)); if (socketFd == -1) { return -err; } + + int len = addressLength(addr); + + // Fill in remote address details + (*env)->SetByteArrayRegion(env, acceptedAddress, 0, 4, (jbyte*) &len); + initInetSocketAddressArray(env, addr, acceptedAddress, 1, len); + if (accept4) { return socketFd; } else { diff --git a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.h b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.h index a16c69199d..fa41f99e4f 100644 --- a/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.h +++ b/transport-native-epoll/src/main/c/io_netty_channel_epoll_Native.h @@ -69,7 +69,7 @@ jint Java_io_netty_channel_epoll_Native_listen0(JNIEnv* env, jclass clazz, jint jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port); jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address); jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd); -jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd); +jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress); jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len); jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd); jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java index 8339824de1..db66538500 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java @@ -63,9 +63,13 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im throw new UnsupportedOperationException(); } - protected abstract Channel newChildChannel(int fd) throws Exception; + abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception; final class EpollServerSocketUnsafe extends AbstractEpollUnsafe { + // Will hold the remote address after accept(...) was sucesssful. + // We need 24 bytes for the address as maximum + 1 byte for storing the length. + // So use 26 bytes as it's a power of two. + private final byte[] acceptedAddress = new byte[26]; @Override public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) { @@ -94,7 +98,7 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im ? Integer.MAX_VALUE : config.getMaxMessagesPerRead(); int messages = 0; do { - int socketFd = Native.accept(fd().intValue()); + int socketFd = Native.accept(fd().intValue(), acceptedAddress); if (socketFd == -1) { // this means everything was handled for now break; @@ -102,7 +106,8 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im readPending = false; try { - pipeline.fireChannelRead(newChildChannel(socketFd)); + int len = acceptedAddress[0]; + pipeline.fireChannelRead(newChildChannel(socketFd, acceptedAddress, 1, len)); } catch (Throwable t) { // keep on reading as we use epoll ET and need to consume everything from the socket pipeline.fireChannelReadComplete(); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerDomainSocketChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerDomainSocketChannel.java index 1b0ffcfb84..3850a56659 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerDomainSocketChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerDomainSocketChannel.java @@ -46,7 +46,7 @@ public final class EpollServerDomainSocketChannel extends AbstractEpollServerCha } @Override - protected Channel newChildChannel(int fd) throws Exception { + protected Channel newChildChannel(int fd, byte[] addr, int offset, int len) throws Exception { return new EpollDomainSocketChannel(this, fd); } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannel.java index f3afb56bdb..03e74be8b6 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollServerSocketChannel.java @@ -86,7 +86,7 @@ public final class EpollServerSocketChannel extends AbstractEpollServerChannel i } @Override - protected Channel newChildChannel(int fd) throws Exception { - return new EpollSocketChannel(this, fd); + protected Channel newChildChannel(int fd, byte[] address, int offset, int len) throws Exception { + return new EpollSocketChannel(this, fd, Native.address(address, offset, len)); } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java index 68ce231666..9177e4db26 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollSocketChannel.java @@ -40,12 +40,12 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme private volatile InetSocketAddress local; private volatile InetSocketAddress remote; - EpollSocketChannel(Channel parent, int fd) { + EpollSocketChannel(Channel parent, int fd, InetSocketAddress remote) { super(parent, fd); config = new EpollSocketChannelConfig(this); // Directly cache the remote and local addresses // See https://github.com/netty/netty/issues/2359 - remote = Native.remoteAddress(fd); + this.remote = remote; local = Native.localAddress(fd); } @@ -184,7 +184,7 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme if (super.doConnect(remoteAddress, localAddress)) { int fd = fd().intValue(); local = Native.localAddress(fd); - remote = Native.remoteAddress(fd); + remote = (InetSocketAddress) remoteAddress; return true; } return false; diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java index 7c4794af3d..46f3b56838 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/Native.java @@ -458,7 +458,7 @@ final class Native { if (addr == null) { return null; } - return address(addr); + return address(addr, 0, addr.length); } public static InetSocketAddress localAddress(int fd) { @@ -468,11 +468,10 @@ final class Native { if (addr == null) { return null; } - return address(addr); + return address(addr, 0, addr.length); } - static InetSocketAddress address(byte[] addr) { - int len = addr.length; + static InetSocketAddress address(byte[] addr, int offset, int len) { // The last 4 bytes are always the port final int port = decodeInt(addr, len - 4); final InetAddress address; @@ -484,7 +483,7 @@ final class Native { // - 4 == port case 8: byte[] ipv4 = new byte[4]; - System.arraycopy(addr, 0, ipv4, 0, 4); + System.arraycopy(addr, offset, ipv4, 0, 4); address = InetAddress.getByAddress(ipv4); break; @@ -494,7 +493,7 @@ final class Native { // - 4 == port case 24: byte[] ipv6 = new byte[16]; - System.arraycopy(addr, 0, ipv6, 0, 16); + System.arraycopy(addr, offset, ipv6, 0, 16); int scopeId = decodeInt(addr, len - 8); address = Inet6Address.getByAddress(null, ipv6, scopeId); break; @@ -507,7 +506,7 @@ final class Native { } } - private static int decodeInt(byte[] addr, int index) { + static int decodeInt(byte[] addr, int index) { return (addr[index] & 0xff) << 24 | (addr[index + 1] & 0xff) << 16 | (addr[index + 2] & 0xff) << 8 | @@ -517,8 +516,8 @@ final class Native { private static native byte[] remoteAddress0(int fd); private static native byte[] localAddress0(int fd); - public static int accept(int fd) throws IOException { - int res = accept0(fd); + public static int accept(int fd, byte[] addr) throws IOException { + int res = accept0(fd, addr); if (res >= 0) { return res; } @@ -529,7 +528,7 @@ final class Native { throw newIOException("accept", res); } - private static native int accept0(int fd); + private static native int accept0(int fd, byte[] addr); public static int recvFd(int fd) throws IOException { int res = recvFd0(fd); diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/NativeTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/NativeTest.java index a99da2d234..a28d785e00 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/NativeTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/NativeTest.java @@ -32,7 +32,7 @@ public class NativeTest { ByteBuffer buffer = ByteBuffer.wrap(bytes); buffer.put(inetAddress.getAddress().getAddress()); buffer.putInt(inetAddress.getPort()); - Assert.assertEquals(inetAddress, Native.address(buffer.array())); + Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length)); } @Test @@ -44,6 +44,6 @@ public class NativeTest { buffer.put(address.getAddress()); buffer.putInt(address.getScopeId()); buffer.putInt(inetAddress.getPort()); - Assert.assertEquals(inetAddress, Native.address(buffer.array())); + Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length)); } }