Directly receive remote address when call accept(...)

Motivation:

There is a small race in the native transport where an accept(...) may success but a later try to obtain the remote address from the fd may fail is the fd is already closed.

Modifications:

Let accept(...) directly set the remote address.

Result:

No more race possible.
This commit is contained in:
Norman Maurer 2015-02-19 13:39:00 +01:00
parent 556a3d5980
commit c21067aab6
8 changed files with 71 additions and 42 deletions

View File

@ -173,7 +173,24 @@ jobject createInetSocketAddress(JNIEnv* env, struct sockaddr_storage addr) {
return socketAddr; return socketAddr;
} }
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
static inline int addressLength(struct sockaddr_storage addr) {
if (addr.ss_family == AF_INET) {
return 8;
} else {
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
// IPv4-mapped-on-IPv6
return 8;
} else {
return 24;
}
}
}
static inline void initInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr, jbyteArray bArray, int offset, int len) {
int port; int port;
if (addr.ss_family == AF_INET) { if (addr.ss_family == AF_INET) {
struct sockaddr_in* s = (struct sockaddr_in*) &addr; struct sockaddr_in* s = (struct sockaddr_in*) &addr;
@ -190,17 +207,12 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
a[6] = port >> 8; a[6] = port >> 8;
a[7] = port; a[7] = port;
jbyteArray bArray = (*env)->NewByteArray(env, 8); (*env)->SetByteArrayRegion(env, bArray, offset, 8, (jbyte*) &a);
(*env)->SetByteArrayRegion(env, bArray, 0, 8, (jbyte*) &a);
return bArray;
} else { } else {
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr; struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
port = ntohs(s->sin6_port); port = ntohs(s->sin6_port);
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00 if (len == 8) {
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
// IPv4-mapped-on-IPv6 // IPv4-mapped-on-IPv6
// Encode port into the array and write it into the jbyteArray // Encode port into the array and write it into the jbyteArray
unsigned char a[4]; unsigned char a[4];
@ -209,12 +221,9 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
a[2] = port >> 8; a[2] = port >> 8;
a[3] = port; a[3] = port;
jbyteArray bArray = (*env)->NewByteArray(env, 8);
// we only need the last 4 bytes for mapped address // we only need the last 4 bytes for mapped address
(*env)->SetByteArrayRegion(env, bArray, 0, 4, (jbyte*) &(s->sin6_addr.s6_addr[12])); (*env)->SetByteArrayRegion(env, bArray, offset, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
(*env)->SetByteArrayRegion(env, bArray, 4, 4, (jbyte*) &a); (*env)->SetByteArrayRegion(env, bArray, offset + 4, 4, (jbyte*) &a);
return bArray;
} else { } else {
// Encode scopeid and port into the array // Encode scopeid and port into the array
unsigned char a[8]; unsigned char a[8];
@ -227,12 +236,18 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
a[6] = port >> 8; a[6] = port >> 8;
a[7] = port; a[7] = port;
jbyteArray bArray = (*env)->NewByteArray(env, 24); (*env)->SetByteArrayRegion(env, bArray, offset, 16, (jbyte*) &(s->sin6_addr.s6_addr));
(*env)->SetByteArrayRegion(env, bArray, 0, 16, (jbyte*) &(s->sin6_addr.s6_addr)); (*env)->SetByteArrayRegion(env, bArray, offset + 16, 8, (jbyte*) &a);
(*env)->SetByteArrayRegion(env, bArray, 16, 8, (jbyte*) &a); }
}
}
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
int len = addressLength(addr);
jbyteArray bArray = (*env)->NewByteArray(env, len);
initInetSocketAddressArray(env, addr, bArray, 0, len);
return bArray; return bArray;
}
}
} }
static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) { static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) {
@ -537,6 +552,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count"); throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count");
return JNI_ERR; return JNI_ERR;
} }
return JNI_VERSION_1_6; return JNI_VERSION_1_6;
} }
} }
@ -1018,21 +1034,30 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv*
return -optval; return -optval;
} }
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd) { JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress) {
jint socketFd; jint socketFd;
int err; int err;
struct sockaddr_storage addr;
socklen_t address_len = sizeof(addr);
do { do {
if (accept4) { if (accept4) {
socketFd = accept4(fd, NULL, 0, SOCK_NONBLOCK | SOCK_CLOEXEC); socketFd = accept4(fd, (struct sockaddr*) &addr, &address_len, SOCK_NONBLOCK | SOCK_CLOEXEC);
} else { } else {
socketFd = accept(fd, NULL, 0); socketFd = accept(fd, (struct sockaddr*) &addr, &address_len);
} }
} while (socketFd == -1 && ((err = errno) == EINTR)); } while (socketFd == -1 && ((err = errno) == EINTR));
if (socketFd == -1) { if (socketFd == -1) {
return -err; return -err;
} }
int len = addressLength(addr);
// Fill in remote address details
(*env)->SetByteArrayRegion(env, acceptedAddress, 0, 4, (jbyte*) &len);
initInetSocketAddressArray(env, addr, acceptedAddress, 1, len);
if (accept4) { if (accept4) {
return socketFd; return socketFd;
} else { } else {

View File

@ -69,7 +69,7 @@ jint Java_io_netty_channel_epoll_Native_listen0(JNIEnv* env, jclass clazz, jint
jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port); jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address); jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address);
jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd); jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd);
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd); jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress);
jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len); jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd); jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd);
jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd); jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd);

View File

@ -63,9 +63,13 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
protected abstract Channel newChildChannel(int fd) throws Exception; abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception;
final class EpollServerSocketUnsafe extends AbstractEpollUnsafe { final class EpollServerSocketUnsafe extends AbstractEpollUnsafe {
// Will hold the remote address after accept(...) was sucesssful.
// We need 24 bytes for the address as maximum + 1 byte for storing the length.
// So use 26 bytes as it's a power of two.
private final byte[] acceptedAddress = new byte[26];
@Override @Override
public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) { public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) {
@ -94,7 +98,7 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
? Integer.MAX_VALUE : config.getMaxMessagesPerRead(); ? Integer.MAX_VALUE : config.getMaxMessagesPerRead();
int messages = 0; int messages = 0;
do { do {
int socketFd = Native.accept(fd().intValue()); int socketFd = Native.accept(fd().intValue(), acceptedAddress);
if (socketFd == -1) { if (socketFd == -1) {
// this means everything was handled for now // this means everything was handled for now
break; break;
@ -102,7 +106,8 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
readPending = false; readPending = false;
try { try {
pipeline.fireChannelRead(newChildChannel(socketFd)); int len = acceptedAddress[0];
pipeline.fireChannelRead(newChildChannel(socketFd, acceptedAddress, 1, len));
} catch (Throwable t) { } catch (Throwable t) {
// keep on reading as we use epoll ET and need to consume everything from the socket // keep on reading as we use epoll ET and need to consume everything from the socket
pipeline.fireChannelReadComplete(); pipeline.fireChannelReadComplete();

View File

@ -46,7 +46,7 @@ public final class EpollServerDomainSocketChannel extends AbstractEpollServerCha
} }
@Override @Override
protected Channel newChildChannel(int fd) throws Exception { protected Channel newChildChannel(int fd, byte[] addr, int offset, int len) throws Exception {
return new EpollDomainSocketChannel(this, fd); return new EpollDomainSocketChannel(this, fd);
} }

View File

@ -86,7 +86,7 @@ public final class EpollServerSocketChannel extends AbstractEpollServerChannel i
} }
@Override @Override
protected Channel newChildChannel(int fd) throws Exception { protected Channel newChildChannel(int fd, byte[] address, int offset, int len) throws Exception {
return new EpollSocketChannel(this, fd); return new EpollSocketChannel(this, fd, Native.address(address, offset, len));
} }
} }

View File

@ -40,12 +40,12 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
private volatile InetSocketAddress local; private volatile InetSocketAddress local;
private volatile InetSocketAddress remote; private volatile InetSocketAddress remote;
EpollSocketChannel(Channel parent, int fd) { EpollSocketChannel(Channel parent, int fd, InetSocketAddress remote) {
super(parent, fd); super(parent, fd);
config = new EpollSocketChannelConfig(this); config = new EpollSocketChannelConfig(this);
// Directly cache the remote and local addresses // Directly cache the remote and local addresses
// See https://github.com/netty/netty/issues/2359 // See https://github.com/netty/netty/issues/2359
remote = Native.remoteAddress(fd); this.remote = remote;
local = Native.localAddress(fd); local = Native.localAddress(fd);
} }
@ -184,7 +184,7 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
if (super.doConnect(remoteAddress, localAddress)) { if (super.doConnect(remoteAddress, localAddress)) {
int fd = fd().intValue(); int fd = fd().intValue();
local = Native.localAddress(fd); local = Native.localAddress(fd);
remote = Native.remoteAddress(fd); remote = (InetSocketAddress) remoteAddress;
return true; return true;
} }
return false; return false;

View File

@ -458,7 +458,7 @@ final class Native {
if (addr == null) { if (addr == null) {
return null; return null;
} }
return address(addr); return address(addr, 0, addr.length);
} }
public static InetSocketAddress localAddress(int fd) { public static InetSocketAddress localAddress(int fd) {
@ -468,11 +468,10 @@ final class Native {
if (addr == null) { if (addr == null) {
return null; return null;
} }
return address(addr); return address(addr, 0, addr.length);
} }
static InetSocketAddress address(byte[] addr) { static InetSocketAddress address(byte[] addr, int offset, int len) {
int len = addr.length;
// The last 4 bytes are always the port // The last 4 bytes are always the port
final int port = decodeInt(addr, len - 4); final int port = decodeInt(addr, len - 4);
final InetAddress address; final InetAddress address;
@ -484,7 +483,7 @@ final class Native {
// - 4 == port // - 4 == port
case 8: case 8:
byte[] ipv4 = new byte[4]; byte[] ipv4 = new byte[4];
System.arraycopy(addr, 0, ipv4, 0, 4); System.arraycopy(addr, offset, ipv4, 0, 4);
address = InetAddress.getByAddress(ipv4); address = InetAddress.getByAddress(ipv4);
break; break;
@ -494,7 +493,7 @@ final class Native {
// - 4 == port // - 4 == port
case 24: case 24:
byte[] ipv6 = new byte[16]; byte[] ipv6 = new byte[16];
System.arraycopy(addr, 0, ipv6, 0, 16); System.arraycopy(addr, offset, ipv6, 0, 16);
int scopeId = decodeInt(addr, len - 8); int scopeId = decodeInt(addr, len - 8);
address = Inet6Address.getByAddress(null, ipv6, scopeId); address = Inet6Address.getByAddress(null, ipv6, scopeId);
break; break;
@ -507,7 +506,7 @@ final class Native {
} }
} }
private static int decodeInt(byte[] addr, int index) { static int decodeInt(byte[] addr, int index) {
return (addr[index] & 0xff) << 24 | return (addr[index] & 0xff) << 24 |
(addr[index + 1] & 0xff) << 16 | (addr[index + 1] & 0xff) << 16 |
(addr[index + 2] & 0xff) << 8 | (addr[index + 2] & 0xff) << 8 |
@ -517,8 +516,8 @@ final class Native {
private static native byte[] remoteAddress0(int fd); private static native byte[] remoteAddress0(int fd);
private static native byte[] localAddress0(int fd); private static native byte[] localAddress0(int fd);
public static int accept(int fd) throws IOException { public static int accept(int fd, byte[] addr) throws IOException {
int res = accept0(fd); int res = accept0(fd, addr);
if (res >= 0) { if (res >= 0) {
return res; return res;
} }
@ -529,7 +528,7 @@ final class Native {
throw newIOException("accept", res); throw newIOException("accept", res);
} }
private static native int accept0(int fd); private static native int accept0(int fd, byte[] addr);
public static int recvFd(int fd) throws IOException { public static int recvFd(int fd) throws IOException {
int res = recvFd0(fd); int res = recvFd0(fd);

View File

@ -32,7 +32,7 @@ public class NativeTest {
ByteBuffer buffer = ByteBuffer.wrap(bytes); ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.put(inetAddress.getAddress().getAddress()); buffer.put(inetAddress.getAddress().getAddress());
buffer.putInt(inetAddress.getPort()); buffer.putInt(inetAddress.getPort());
Assert.assertEquals(inetAddress, Native.address(buffer.array())); Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
} }
@Test @Test
@ -44,6 +44,6 @@ public class NativeTest {
buffer.put(address.getAddress()); buffer.put(address.getAddress());
buffer.putInt(address.getScopeId()); buffer.putInt(address.getScopeId());
buffer.putInt(inetAddress.getPort()); buffer.putInt(inetAddress.getPort());
Assert.assertEquals(inetAddress, Native.address(buffer.array())); Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
} }
} }