Directly receive remote address when call accept(...)
Motivation: There is a small race in the native transport where an accept(...) may success but a later try to obtain the remote address from the fd may fail is the fd is already closed. Modifications: Let accept(...) directly set the remote address. Result: No more race possible.
This commit is contained in:
parent
556a3d5980
commit
c21067aab6
@ -173,7 +173,24 @@ jobject createInetSocketAddress(JNIEnv* env, struct sockaddr_storage addr) {
|
|||||||
return socketAddr;
|
return socketAddr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
|
|
||||||
|
static inline int addressLength(struct sockaddr_storage addr) {
|
||||||
|
if (addr.ss_family == AF_INET) {
|
||||||
|
return 8;
|
||||||
|
} else {
|
||||||
|
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
|
||||||
|
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
|
||||||
|
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
|
||||||
|
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
|
||||||
|
// IPv4-mapped-on-IPv6
|
||||||
|
return 8;
|
||||||
|
} else {
|
||||||
|
return 24;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void initInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr, jbyteArray bArray, int offset, int len) {
|
||||||
int port;
|
int port;
|
||||||
if (addr.ss_family == AF_INET) {
|
if (addr.ss_family == AF_INET) {
|
||||||
struct sockaddr_in* s = (struct sockaddr_in*) &addr;
|
struct sockaddr_in* s = (struct sockaddr_in*) &addr;
|
||||||
@ -190,17 +207,12 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
|
|||||||
a[6] = port >> 8;
|
a[6] = port >> 8;
|
||||||
a[7] = port;
|
a[7] = port;
|
||||||
|
|
||||||
jbyteArray bArray = (*env)->NewByteArray(env, 8);
|
(*env)->SetByteArrayRegion(env, bArray, offset, 8, (jbyte*) &a);
|
||||||
(*env)->SetByteArrayRegion(env, bArray, 0, 8, (jbyte*) &a);
|
|
||||||
|
|
||||||
return bArray;
|
|
||||||
} else {
|
} else {
|
||||||
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
|
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
|
||||||
port = ntohs(s->sin6_port);
|
port = ntohs(s->sin6_port);
|
||||||
|
|
||||||
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
|
if (len == 8) {
|
||||||
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
|
|
||||||
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
|
|
||||||
// IPv4-mapped-on-IPv6
|
// IPv4-mapped-on-IPv6
|
||||||
// Encode port into the array and write it into the jbyteArray
|
// Encode port into the array and write it into the jbyteArray
|
||||||
unsigned char a[4];
|
unsigned char a[4];
|
||||||
@ -209,12 +221,9 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
|
|||||||
a[2] = port >> 8;
|
a[2] = port >> 8;
|
||||||
a[3] = port;
|
a[3] = port;
|
||||||
|
|
||||||
jbyteArray bArray = (*env)->NewByteArray(env, 8);
|
|
||||||
// we only need the last 4 bytes for mapped address
|
// we only need the last 4 bytes for mapped address
|
||||||
(*env)->SetByteArrayRegion(env, bArray, 0, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
|
(*env)->SetByteArrayRegion(env, bArray, offset, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
|
||||||
(*env)->SetByteArrayRegion(env, bArray, 4, 4, (jbyte*) &a);
|
(*env)->SetByteArrayRegion(env, bArray, offset + 4, 4, (jbyte*) &a);
|
||||||
|
|
||||||
return bArray;
|
|
||||||
} else {
|
} else {
|
||||||
// Encode scopeid and port into the array
|
// Encode scopeid and port into the array
|
||||||
unsigned char a[8];
|
unsigned char a[8];
|
||||||
@ -227,14 +236,20 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
|
|||||||
a[6] = port >> 8;
|
a[6] = port >> 8;
|
||||||
a[7] = port;
|
a[7] = port;
|
||||||
|
|
||||||
jbyteArray bArray = (*env)->NewByteArray(env, 24);
|
(*env)->SetByteArrayRegion(env, bArray, offset, 16, (jbyte*) &(s->sin6_addr.s6_addr));
|
||||||
(*env)->SetByteArrayRegion(env, bArray, 0, 16, (jbyte*) &(s->sin6_addr.s6_addr));
|
(*env)->SetByteArrayRegion(env, bArray, offset + 16, 8, (jbyte*) &a);
|
||||||
(*env)->SetByteArrayRegion(env, bArray, 16, 8, (jbyte*) &a);
|
|
||||||
return bArray;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
|
||||||
|
int len = addressLength(addr);
|
||||||
|
jbyteArray bArray = (*env)->NewByteArray(env, len);
|
||||||
|
|
||||||
|
initInetSocketAddressArray(env, addr, bArray, 0, len);
|
||||||
|
return bArray;
|
||||||
|
}
|
||||||
|
|
||||||
static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) {
|
static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) {
|
||||||
char ipstr[INET6_ADDRSTRLEN];
|
char ipstr[INET6_ADDRSTRLEN];
|
||||||
int port;
|
int port;
|
||||||
@ -537,6 +552,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
|
|||||||
throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count");
|
throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count");
|
||||||
return JNI_ERR;
|
return JNI_ERR;
|
||||||
}
|
}
|
||||||
|
|
||||||
return JNI_VERSION_1_6;
|
return JNI_VERSION_1_6;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1018,21 +1034,30 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv*
|
|||||||
return -optval;
|
return -optval;
|
||||||
}
|
}
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd) {
|
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress) {
|
||||||
jint socketFd;
|
jint socketFd;
|
||||||
int err;
|
int err;
|
||||||
|
struct sockaddr_storage addr;
|
||||||
|
socklen_t address_len = sizeof(addr);
|
||||||
|
|
||||||
do {
|
do {
|
||||||
if (accept4) {
|
if (accept4) {
|
||||||
socketFd = accept4(fd, NULL, 0, SOCK_NONBLOCK | SOCK_CLOEXEC);
|
socketFd = accept4(fd, (struct sockaddr*) &addr, &address_len, SOCK_NONBLOCK | SOCK_CLOEXEC);
|
||||||
} else {
|
} else {
|
||||||
socketFd = accept(fd, NULL, 0);
|
socketFd = accept(fd, (struct sockaddr*) &addr, &address_len);
|
||||||
}
|
}
|
||||||
} while (socketFd == -1 && ((err = errno) == EINTR));
|
} while (socketFd == -1 && ((err = errno) == EINTR));
|
||||||
|
|
||||||
if (socketFd == -1) {
|
if (socketFd == -1) {
|
||||||
return -err;
|
return -err;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int len = addressLength(addr);
|
||||||
|
|
||||||
|
// Fill in remote address details
|
||||||
|
(*env)->SetByteArrayRegion(env, acceptedAddress, 0, 4, (jbyte*) &len);
|
||||||
|
initInetSocketAddressArray(env, addr, acceptedAddress, 1, len);
|
||||||
|
|
||||||
if (accept4) {
|
if (accept4) {
|
||||||
return socketFd;
|
return socketFd;
|
||||||
} else {
|
} else {
|
||||||
|
@ -69,7 +69,7 @@ jint Java_io_netty_channel_epoll_Native_listen0(JNIEnv* env, jclass clazz, jint
|
|||||||
jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
|
jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
|
||||||
jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address);
|
jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address);
|
||||||
jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd);
|
jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd);
|
||||||
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd);
|
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress);
|
||||||
jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
|
jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
|
||||||
jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd);
|
jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd);
|
||||||
jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd);
|
jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd);
|
||||||
|
@ -63,9 +63,13 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
|
|||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract Channel newChildChannel(int fd) throws Exception;
|
abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception;
|
||||||
|
|
||||||
final class EpollServerSocketUnsafe extends AbstractEpollUnsafe {
|
final class EpollServerSocketUnsafe extends AbstractEpollUnsafe {
|
||||||
|
// Will hold the remote address after accept(...) was sucesssful.
|
||||||
|
// We need 24 bytes for the address as maximum + 1 byte for storing the length.
|
||||||
|
// So use 26 bytes as it's a power of two.
|
||||||
|
private final byte[] acceptedAddress = new byte[26];
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) {
|
public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) {
|
||||||
@ -94,7 +98,7 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
|
|||||||
? Integer.MAX_VALUE : config.getMaxMessagesPerRead();
|
? Integer.MAX_VALUE : config.getMaxMessagesPerRead();
|
||||||
int messages = 0;
|
int messages = 0;
|
||||||
do {
|
do {
|
||||||
int socketFd = Native.accept(fd().intValue());
|
int socketFd = Native.accept(fd().intValue(), acceptedAddress);
|
||||||
if (socketFd == -1) {
|
if (socketFd == -1) {
|
||||||
// this means everything was handled for now
|
// this means everything was handled for now
|
||||||
break;
|
break;
|
||||||
@ -102,7 +106,8 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
|
|||||||
readPending = false;
|
readPending = false;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
pipeline.fireChannelRead(newChildChannel(socketFd));
|
int len = acceptedAddress[0];
|
||||||
|
pipeline.fireChannelRead(newChildChannel(socketFd, acceptedAddress, 1, len));
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
// keep on reading as we use epoll ET and need to consume everything from the socket
|
// keep on reading as we use epoll ET and need to consume everything from the socket
|
||||||
pipeline.fireChannelReadComplete();
|
pipeline.fireChannelReadComplete();
|
||||||
|
@ -46,7 +46,7 @@ public final class EpollServerDomainSocketChannel extends AbstractEpollServerCha
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Channel newChildChannel(int fd) throws Exception {
|
protected Channel newChildChannel(int fd, byte[] addr, int offset, int len) throws Exception {
|
||||||
return new EpollDomainSocketChannel(this, fd);
|
return new EpollDomainSocketChannel(this, fd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ public final class EpollServerSocketChannel extends AbstractEpollServerChannel i
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Channel newChildChannel(int fd) throws Exception {
|
protected Channel newChildChannel(int fd, byte[] address, int offset, int len) throws Exception {
|
||||||
return new EpollSocketChannel(this, fd);
|
return new EpollSocketChannel(this, fd, Native.address(address, offset, len));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,12 +40,12 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
|
|||||||
private volatile InetSocketAddress local;
|
private volatile InetSocketAddress local;
|
||||||
private volatile InetSocketAddress remote;
|
private volatile InetSocketAddress remote;
|
||||||
|
|
||||||
EpollSocketChannel(Channel parent, int fd) {
|
EpollSocketChannel(Channel parent, int fd, InetSocketAddress remote) {
|
||||||
super(parent, fd);
|
super(parent, fd);
|
||||||
config = new EpollSocketChannelConfig(this);
|
config = new EpollSocketChannelConfig(this);
|
||||||
// Directly cache the remote and local addresses
|
// Directly cache the remote and local addresses
|
||||||
// See https://github.com/netty/netty/issues/2359
|
// See https://github.com/netty/netty/issues/2359
|
||||||
remote = Native.remoteAddress(fd);
|
this.remote = remote;
|
||||||
local = Native.localAddress(fd);
|
local = Native.localAddress(fd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
|
|||||||
if (super.doConnect(remoteAddress, localAddress)) {
|
if (super.doConnect(remoteAddress, localAddress)) {
|
||||||
int fd = fd().intValue();
|
int fd = fd().intValue();
|
||||||
local = Native.localAddress(fd);
|
local = Native.localAddress(fd);
|
||||||
remote = Native.remoteAddress(fd);
|
remote = (InetSocketAddress) remoteAddress;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -458,7 +458,7 @@ final class Native {
|
|||||||
if (addr == null) {
|
if (addr == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return address(addr);
|
return address(addr, 0, addr.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static InetSocketAddress localAddress(int fd) {
|
public static InetSocketAddress localAddress(int fd) {
|
||||||
@ -468,11 +468,10 @@ final class Native {
|
|||||||
if (addr == null) {
|
if (addr == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return address(addr);
|
return address(addr, 0, addr.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
static InetSocketAddress address(byte[] addr) {
|
static InetSocketAddress address(byte[] addr, int offset, int len) {
|
||||||
int len = addr.length;
|
|
||||||
// The last 4 bytes are always the port
|
// The last 4 bytes are always the port
|
||||||
final int port = decodeInt(addr, len - 4);
|
final int port = decodeInt(addr, len - 4);
|
||||||
final InetAddress address;
|
final InetAddress address;
|
||||||
@ -484,7 +483,7 @@ final class Native {
|
|||||||
// - 4 == port
|
// - 4 == port
|
||||||
case 8:
|
case 8:
|
||||||
byte[] ipv4 = new byte[4];
|
byte[] ipv4 = new byte[4];
|
||||||
System.arraycopy(addr, 0, ipv4, 0, 4);
|
System.arraycopy(addr, offset, ipv4, 0, 4);
|
||||||
address = InetAddress.getByAddress(ipv4);
|
address = InetAddress.getByAddress(ipv4);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@ -494,7 +493,7 @@ final class Native {
|
|||||||
// - 4 == port
|
// - 4 == port
|
||||||
case 24:
|
case 24:
|
||||||
byte[] ipv6 = new byte[16];
|
byte[] ipv6 = new byte[16];
|
||||||
System.arraycopy(addr, 0, ipv6, 0, 16);
|
System.arraycopy(addr, offset, ipv6, 0, 16);
|
||||||
int scopeId = decodeInt(addr, len - 8);
|
int scopeId = decodeInt(addr, len - 8);
|
||||||
address = Inet6Address.getByAddress(null, ipv6, scopeId);
|
address = Inet6Address.getByAddress(null, ipv6, scopeId);
|
||||||
break;
|
break;
|
||||||
@ -507,7 +506,7 @@ final class Native {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int decodeInt(byte[] addr, int index) {
|
static int decodeInt(byte[] addr, int index) {
|
||||||
return (addr[index] & 0xff) << 24 |
|
return (addr[index] & 0xff) << 24 |
|
||||||
(addr[index + 1] & 0xff) << 16 |
|
(addr[index + 1] & 0xff) << 16 |
|
||||||
(addr[index + 2] & 0xff) << 8 |
|
(addr[index + 2] & 0xff) << 8 |
|
||||||
@ -517,8 +516,8 @@ final class Native {
|
|||||||
private static native byte[] remoteAddress0(int fd);
|
private static native byte[] remoteAddress0(int fd);
|
||||||
private static native byte[] localAddress0(int fd);
|
private static native byte[] localAddress0(int fd);
|
||||||
|
|
||||||
public static int accept(int fd) throws IOException {
|
public static int accept(int fd, byte[] addr) throws IOException {
|
||||||
int res = accept0(fd);
|
int res = accept0(fd, addr);
|
||||||
if (res >= 0) {
|
if (res >= 0) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -529,7 +528,7 @@ final class Native {
|
|||||||
throw newIOException("accept", res);
|
throw newIOException("accept", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static native int accept0(int fd);
|
private static native int accept0(int fd, byte[] addr);
|
||||||
|
|
||||||
public static int recvFd(int fd) throws IOException {
|
public static int recvFd(int fd) throws IOException {
|
||||||
int res = recvFd0(fd);
|
int res = recvFd0(fd);
|
||||||
|
@ -32,7 +32,7 @@ public class NativeTest {
|
|||||||
ByteBuffer buffer = ByteBuffer.wrap(bytes);
|
ByteBuffer buffer = ByteBuffer.wrap(bytes);
|
||||||
buffer.put(inetAddress.getAddress().getAddress());
|
buffer.put(inetAddress.getAddress().getAddress());
|
||||||
buffer.putInt(inetAddress.getPort());
|
buffer.putInt(inetAddress.getPort());
|
||||||
Assert.assertEquals(inetAddress, Native.address(buffer.array()));
|
Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -44,6 +44,6 @@ public class NativeTest {
|
|||||||
buffer.put(address.getAddress());
|
buffer.put(address.getAddress());
|
||||||
buffer.putInt(address.getScopeId());
|
buffer.putInt(address.getScopeId());
|
||||||
buffer.putInt(inetAddress.getPort());
|
buffer.putInt(inetAddress.getPort());
|
||||||
Assert.assertEquals(inetAddress, Native.address(buffer.array()));
|
Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user