Directly receive remote address when call accept(...)

Motivation:

There is a small race in the native transport where an accept(...) may success but a later try to obtain the remote address from the fd may fail is the fd is already closed.

Modifications:

Let accept(...) directly set the remote address.

Result:

No more race possible.
This commit is contained in:
Norman Maurer 2015-02-19 13:39:00 +01:00
parent 93343ff00c
commit 88dae15bc2
8 changed files with 71 additions and 42 deletions

View File

@ -173,7 +173,24 @@ jobject createInetSocketAddress(JNIEnv* env, struct sockaddr_storage addr) {
return socketAddr;
}
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
static inline int addressLength(struct sockaddr_storage addr) {
if (addr.ss_family == AF_INET) {
return 8;
} else {
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
// IPv4-mapped-on-IPv6
return 8;
} else {
return 24;
}
}
}
static inline void initInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr, jbyteArray bArray, int offset, int len) {
int port;
if (addr.ss_family == AF_INET) {
struct sockaddr_in* s = (struct sockaddr_in*) &addr;
@ -190,17 +207,12 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
a[6] = port >> 8;
a[7] = port;
jbyteArray bArray = (*env)->NewByteArray(env, 8);
(*env)->SetByteArrayRegion(env, bArray, 0, 8, (jbyte*) &a);
return bArray;
(*env)->SetByteArrayRegion(env, bArray, offset, 8, (jbyte*) &a);
} else {
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
port = ntohs(s->sin6_port);
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
if (len == 8) {
// IPv4-mapped-on-IPv6
// Encode port into the array and write it into the jbyteArray
unsigned char a[4];
@ -209,12 +221,9 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
a[2] = port >> 8;
a[3] = port;
jbyteArray bArray = (*env)->NewByteArray(env, 8);
// we only need the last 4 bytes for mapped address
(*env)->SetByteArrayRegion(env, bArray, 0, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
(*env)->SetByteArrayRegion(env, bArray, 4, 4, (jbyte*) &a);
return bArray;
(*env)->SetByteArrayRegion(env, bArray, offset, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
(*env)->SetByteArrayRegion(env, bArray, offset + 4, 4, (jbyte*) &a);
} else {
// Encode scopeid and port into the array
unsigned char a[8];
@ -227,14 +236,20 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
a[6] = port >> 8;
a[7] = port;
jbyteArray bArray = (*env)->NewByteArray(env, 24);
(*env)->SetByteArrayRegion(env, bArray, 0, 16, (jbyte*) &(s->sin6_addr.s6_addr));
(*env)->SetByteArrayRegion(env, bArray, 16, 8, (jbyte*) &a);
return bArray;
(*env)->SetByteArrayRegion(env, bArray, offset, 16, (jbyte*) &(s->sin6_addr.s6_addr));
(*env)->SetByteArrayRegion(env, bArray, offset + 16, 8, (jbyte*) &a);
}
}
}
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
int len = addressLength(addr);
jbyteArray bArray = (*env)->NewByteArray(env, len);
initInetSocketAddressArray(env, addr, bArray, 0, len);
return bArray;
}
static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) {
char ipstr[INET6_ADDRSTRLEN];
int port;
@ -537,6 +552,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count");
return JNI_ERR;
}
return JNI_VERSION_1_6;
}
}
@ -1018,21 +1034,30 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv*
return -optval;
}
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd) {
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress) {
jint socketFd;
int err;
struct sockaddr_storage addr;
socklen_t address_len = sizeof(addr);
do {
if (accept4) {
socketFd = accept4(fd, NULL, 0, SOCK_NONBLOCK | SOCK_CLOEXEC);
socketFd = accept4(fd, (struct sockaddr*) &addr, &address_len, SOCK_NONBLOCK | SOCK_CLOEXEC);
} else {
socketFd = accept(fd, NULL, 0);
socketFd = accept(fd, (struct sockaddr*) &addr, &address_len);
}
} while (socketFd == -1 && ((err = errno) == EINTR));
if (socketFd == -1) {
return -err;
}
int len = addressLength(addr);
// Fill in remote address details
(*env)->SetByteArrayRegion(env, acceptedAddress, 0, 4, (jbyte*) &len);
initInetSocketAddressArray(env, addr, acceptedAddress, 1, len);
if (accept4) {
return socketFd;
} else {

View File

@ -69,7 +69,7 @@ jint Java_io_netty_channel_epoll_Native_listen0(JNIEnv* env, jclass clazz, jint
jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address);
jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd);
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd);
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress);
jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd);
jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd);

View File

@ -63,9 +63,13 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
throw new UnsupportedOperationException();
}
protected abstract Channel newChildChannel(int fd) throws Exception;
abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception;
final class EpollServerSocketUnsafe extends AbstractEpollUnsafe {
// Will hold the remote address after accept(...) was sucesssful.
// We need 24 bytes for the address as maximum + 1 byte for storing the length.
// So use 26 bytes as it's a power of two.
private final byte[] acceptedAddress = new byte[26];
@Override
public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) {
@ -94,7 +98,7 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
? Integer.MAX_VALUE : config.getMaxMessagesPerRead();
int messages = 0;
do {
int socketFd = Native.accept(fd().intValue());
int socketFd = Native.accept(fd().intValue(), acceptedAddress);
if (socketFd == -1) {
// this means everything was handled for now
break;
@ -102,7 +106,8 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
readPending = false;
try {
pipeline.fireChannelRead(newChildChannel(socketFd));
int len = acceptedAddress[0];
pipeline.fireChannelRead(newChildChannel(socketFd, acceptedAddress, 1, len));
} catch (Throwable t) {
// keep on reading as we use epoll ET and need to consume everything from the socket
pipeline.fireChannelReadComplete();

View File

@ -46,7 +46,7 @@ public final class EpollServerDomainSocketChannel extends AbstractEpollServerCha
}
@Override
protected Channel newChildChannel(int fd) throws Exception {
protected Channel newChildChannel(int fd, byte[] addr, int offset, int len) throws Exception {
return new EpollDomainSocketChannel(this, fd);
}

View File

@ -86,7 +86,7 @@ public final class EpollServerSocketChannel extends AbstractEpollServerChannel i
}
@Override
protected Channel newChildChannel(int fd) throws Exception {
return new EpollSocketChannel(this, fd);
protected Channel newChildChannel(int fd, byte[] address, int offset, int len) throws Exception {
return new EpollSocketChannel(this, fd, Native.address(address, offset, len));
}
}

View File

@ -40,12 +40,12 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
private volatile InetSocketAddress local;
private volatile InetSocketAddress remote;
EpollSocketChannel(Channel parent, int fd) {
EpollSocketChannel(Channel parent, int fd, InetSocketAddress remote) {
super(parent, fd);
config = new EpollSocketChannelConfig(this);
// Directly cache the remote and local addresses
// See https://github.com/netty/netty/issues/2359
remote = Native.remoteAddress(fd);
this.remote = remote;
local = Native.localAddress(fd);
}
@ -184,7 +184,7 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
if (super.doConnect(remoteAddress, localAddress)) {
int fd = fd().intValue();
local = Native.localAddress(fd);
remote = Native.remoteAddress(fd);
remote = (InetSocketAddress) remoteAddress;
return true;
}
return false;

View File

@ -458,7 +458,7 @@ final class Native {
if (addr == null) {
return null;
}
return address(addr);
return address(addr, 0, addr.length);
}
public static InetSocketAddress localAddress(int fd) {
@ -468,11 +468,10 @@ final class Native {
if (addr == null) {
return null;
}
return address(addr);
return address(addr, 0, addr.length);
}
static InetSocketAddress address(byte[] addr) {
int len = addr.length;
static InetSocketAddress address(byte[] addr, int offset, int len) {
// The last 4 bytes are always the port
final int port = decodeInt(addr, len - 4);
final InetAddress address;
@ -484,7 +483,7 @@ final class Native {
// - 4 == port
case 8:
byte[] ipv4 = new byte[4];
System.arraycopy(addr, 0, ipv4, 0, 4);
System.arraycopy(addr, offset, ipv4, 0, 4);
address = InetAddress.getByAddress(ipv4);
break;
@ -494,7 +493,7 @@ final class Native {
// - 4 == port
case 24:
byte[] ipv6 = new byte[16];
System.arraycopy(addr, 0, ipv6, 0, 16);
System.arraycopy(addr, offset, ipv6, 0, 16);
int scopeId = decodeInt(addr, len - 8);
address = Inet6Address.getByAddress(null, ipv6, scopeId);
break;
@ -507,7 +506,7 @@ final class Native {
}
}
private static int decodeInt(byte[] addr, int index) {
static int decodeInt(byte[] addr, int index) {
return (addr[index] & 0xff) << 24 |
(addr[index + 1] & 0xff) << 16 |
(addr[index + 2] & 0xff) << 8 |
@ -517,8 +516,8 @@ final class Native {
private static native byte[] remoteAddress0(int fd);
private static native byte[] localAddress0(int fd);
public static int accept(int fd) throws IOException {
int res = accept0(fd);
public static int accept(int fd, byte[] addr) throws IOException {
int res = accept0(fd, addr);
if (res >= 0) {
return res;
}
@ -529,7 +528,7 @@ final class Native {
throw newIOException("accept", res);
}
private static native int accept0(int fd);
private static native int accept0(int fd, byte[] addr);
public static int recvFd(int fd) throws IOException {
int res = recvFd0(fd);

View File

@ -32,7 +32,7 @@ public class NativeTest {
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.put(inetAddress.getAddress().getAddress());
buffer.putInt(inetAddress.getPort());
Assert.assertEquals(inetAddress, Native.address(buffer.array()));
Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
}
@Test
@ -44,6 +44,6 @@ public class NativeTest {
buffer.put(address.getAddress());
buffer.putInt(address.getScopeId());
buffer.putInt(inetAddress.getPort());
Assert.assertEquals(inetAddress, Native.address(buffer.array()));
Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
}
}