Directly receive remote address when call accept(...)
Motivation: There is a small race in the native transport where an accept(...) may success but a later try to obtain the remote address from the fd may fail is the fd is already closed. Modifications: Let accept(...) directly set the remote address. Result: No more race possible.
This commit is contained in:
parent
93343ff00c
commit
88dae15bc2
@ -173,7 +173,24 @@ jobject createInetSocketAddress(JNIEnv* env, struct sockaddr_storage addr) {
|
||||
return socketAddr;
|
||||
}
|
||||
|
||||
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
|
||||
|
||||
static inline int addressLength(struct sockaddr_storage addr) {
|
||||
if (addr.ss_family == AF_INET) {
|
||||
return 8;
|
||||
} else {
|
||||
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
|
||||
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
|
||||
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
|
||||
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
|
||||
// IPv4-mapped-on-IPv6
|
||||
return 8;
|
||||
} else {
|
||||
return 24;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void initInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr, jbyteArray bArray, int offset, int len) {
|
||||
int port;
|
||||
if (addr.ss_family == AF_INET) {
|
||||
struct sockaddr_in* s = (struct sockaddr_in*) &addr;
|
||||
@ -190,17 +207,12 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
|
||||
a[6] = port >> 8;
|
||||
a[7] = port;
|
||||
|
||||
jbyteArray bArray = (*env)->NewByteArray(env, 8);
|
||||
(*env)->SetByteArrayRegion(env, bArray, 0, 8, (jbyte*) &a);
|
||||
|
||||
return bArray;
|
||||
(*env)->SetByteArrayRegion(env, bArray, offset, 8, (jbyte*) &a);
|
||||
} else {
|
||||
struct sockaddr_in6* s = (struct sockaddr_in6*) &addr;
|
||||
port = ntohs(s->sin6_port);
|
||||
|
||||
if (s->sin6_addr.s6_addr[0] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[4] == 0x00
|
||||
&& s->sin6_addr.s6_addr[5] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[9] == 0x00
|
||||
&& s->sin6_addr.s6_addr[10] == 0xff && s->sin6_addr.s6_addr[11] == 0xff) {
|
||||
if (len == 8) {
|
||||
// IPv4-mapped-on-IPv6
|
||||
// Encode port into the array and write it into the jbyteArray
|
||||
unsigned char a[4];
|
||||
@ -209,12 +221,9 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
|
||||
a[2] = port >> 8;
|
||||
a[3] = port;
|
||||
|
||||
jbyteArray bArray = (*env)->NewByteArray(env, 8);
|
||||
// we only need the last 4 bytes for mapped address
|
||||
(*env)->SetByteArrayRegion(env, bArray, 0, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
|
||||
(*env)->SetByteArrayRegion(env, bArray, 4, 4, (jbyte*) &a);
|
||||
|
||||
return bArray;
|
||||
(*env)->SetByteArrayRegion(env, bArray, offset, 4, (jbyte*) &(s->sin6_addr.s6_addr[12]));
|
||||
(*env)->SetByteArrayRegion(env, bArray, offset + 4, 4, (jbyte*) &a);
|
||||
} else {
|
||||
// Encode scopeid and port into the array
|
||||
unsigned char a[8];
|
||||
@ -227,14 +236,20 @@ static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_stor
|
||||
a[6] = port >> 8;
|
||||
a[7] = port;
|
||||
|
||||
jbyteArray bArray = (*env)->NewByteArray(env, 24);
|
||||
(*env)->SetByteArrayRegion(env, bArray, 0, 16, (jbyte*) &(s->sin6_addr.s6_addr));
|
||||
(*env)->SetByteArrayRegion(env, bArray, 16, 8, (jbyte*) &a);
|
||||
return bArray;
|
||||
(*env)->SetByteArrayRegion(env, bArray, offset, 16, (jbyte*) &(s->sin6_addr.s6_addr));
|
||||
(*env)->SetByteArrayRegion(env, bArray, offset + 16, 8, (jbyte*) &a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static jbyteArray createInetSocketAddressArray(JNIEnv* env, struct sockaddr_storage addr) {
|
||||
int len = addressLength(addr);
|
||||
jbyteArray bArray = (*env)->NewByteArray(env, len);
|
||||
|
||||
initInetSocketAddressArray(env, addr, bArray, 0, len);
|
||||
return bArray;
|
||||
}
|
||||
|
||||
static jobject createDatagramSocketAddress(JNIEnv* env, struct sockaddr_storage addr, int len) {
|
||||
char ipstr[INET6_ADDRSTRLEN];
|
||||
int port;
|
||||
@ -537,6 +552,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
|
||||
throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.count");
|
||||
return JNI_ERR;
|
||||
}
|
||||
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
}
|
||||
@ -1018,21 +1034,30 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv*
|
||||
return -optval;
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd) {
|
||||
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress) {
|
||||
jint socketFd;
|
||||
int err;
|
||||
struct sockaddr_storage addr;
|
||||
socklen_t address_len = sizeof(addr);
|
||||
|
||||
do {
|
||||
if (accept4) {
|
||||
socketFd = accept4(fd, NULL, 0, SOCK_NONBLOCK | SOCK_CLOEXEC);
|
||||
socketFd = accept4(fd, (struct sockaddr*) &addr, &address_len, SOCK_NONBLOCK | SOCK_CLOEXEC);
|
||||
} else {
|
||||
socketFd = accept(fd, NULL, 0);
|
||||
socketFd = accept(fd, (struct sockaddr*) &addr, &address_len);
|
||||
}
|
||||
} while (socketFd == -1 && ((err = errno) == EINTR));
|
||||
|
||||
if (socketFd == -1) {
|
||||
return -err;
|
||||
}
|
||||
|
||||
int len = addressLength(addr);
|
||||
|
||||
// Fill in remote address details
|
||||
(*env)->SetByteArrayRegion(env, acceptedAddress, 0, 4, (jbyte*) &len);
|
||||
initInetSocketAddressArray(env, addr, acceptedAddress, 1, len);
|
||||
|
||||
if (accept4) {
|
||||
return socketFd;
|
||||
} else {
|
||||
|
@ -69,7 +69,7 @@ jint Java_io_netty_channel_epoll_Native_listen0(JNIEnv* env, jclass clazz, jint
|
||||
jint Java_io_netty_channel_epoll_Native_connect(JNIEnv* env, jclass clazz, jint fd, jbyteArray address, jint scopeId, jint port);
|
||||
jint Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring address);
|
||||
jint Java_io_netty_channel_epoll_Native_finishConnect0(JNIEnv* env, jclass clazz, jint fd);
|
||||
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd);
|
||||
jint Java_io_netty_channel_epoll_Native_accept0(JNIEnv* env, jclass clazz, jint fd, jbyteArray acceptedAddress);
|
||||
jlong Java_io_netty_channel_epoll_Native_sendfile0(JNIEnv* env, jclass clazz, jint fd, jobject fileRegion, jlong base_off, jlong off, jlong len);
|
||||
jbyteArray Java_io_netty_channel_epoll_Native_remoteAddress0(JNIEnv* env, jclass clazz, jint fd);
|
||||
jbyteArray Java_io_netty_channel_epoll_Native_localAddress0(JNIEnv* env, jclass clazz, jint fd);
|
||||
|
@ -63,9 +63,13 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
protected abstract Channel newChildChannel(int fd) throws Exception;
|
||||
abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception;
|
||||
|
||||
final class EpollServerSocketUnsafe extends AbstractEpollUnsafe {
|
||||
// Will hold the remote address after accept(...) was sucesssful.
|
||||
// We need 24 bytes for the address as maximum + 1 byte for storing the length.
|
||||
// So use 26 bytes as it's a power of two.
|
||||
private final byte[] acceptedAddress = new byte[26];
|
||||
|
||||
@Override
|
||||
public void connect(SocketAddress socketAddress, SocketAddress socketAddress2, ChannelPromise channelPromise) {
|
||||
@ -94,7 +98,7 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
|
||||
? Integer.MAX_VALUE : config.getMaxMessagesPerRead();
|
||||
int messages = 0;
|
||||
do {
|
||||
int socketFd = Native.accept(fd().intValue());
|
||||
int socketFd = Native.accept(fd().intValue(), acceptedAddress);
|
||||
if (socketFd == -1) {
|
||||
// this means everything was handled for now
|
||||
break;
|
||||
@ -102,7 +106,8 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im
|
||||
readPending = false;
|
||||
|
||||
try {
|
||||
pipeline.fireChannelRead(newChildChannel(socketFd));
|
||||
int len = acceptedAddress[0];
|
||||
pipeline.fireChannelRead(newChildChannel(socketFd, acceptedAddress, 1, len));
|
||||
} catch (Throwable t) {
|
||||
// keep on reading as we use epoll ET and need to consume everything from the socket
|
||||
pipeline.fireChannelReadComplete();
|
||||
|
@ -46,7 +46,7 @@ public final class EpollServerDomainSocketChannel extends AbstractEpollServerCha
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Channel newChildChannel(int fd) throws Exception {
|
||||
protected Channel newChildChannel(int fd, byte[] addr, int offset, int len) throws Exception {
|
||||
return new EpollDomainSocketChannel(this, fd);
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ public final class EpollServerSocketChannel extends AbstractEpollServerChannel i
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Channel newChildChannel(int fd) throws Exception {
|
||||
return new EpollSocketChannel(this, fd);
|
||||
protected Channel newChildChannel(int fd, byte[] address, int offset, int len) throws Exception {
|
||||
return new EpollSocketChannel(this, fd, Native.address(address, offset, len));
|
||||
}
|
||||
}
|
||||
|
@ -40,12 +40,12 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
|
||||
private volatile InetSocketAddress local;
|
||||
private volatile InetSocketAddress remote;
|
||||
|
||||
EpollSocketChannel(Channel parent, int fd) {
|
||||
EpollSocketChannel(Channel parent, int fd, InetSocketAddress remote) {
|
||||
super(parent, fd);
|
||||
config = new EpollSocketChannelConfig(this);
|
||||
// Directly cache the remote and local addresses
|
||||
// See https://github.com/netty/netty/issues/2359
|
||||
remote = Native.remoteAddress(fd);
|
||||
this.remote = remote;
|
||||
local = Native.localAddress(fd);
|
||||
}
|
||||
|
||||
@ -184,7 +184,7 @@ public final class EpollSocketChannel extends AbstractEpollStreamChannel impleme
|
||||
if (super.doConnect(remoteAddress, localAddress)) {
|
||||
int fd = fd().intValue();
|
||||
local = Native.localAddress(fd);
|
||||
remote = Native.remoteAddress(fd);
|
||||
remote = (InetSocketAddress) remoteAddress;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -458,7 +458,7 @@ final class Native {
|
||||
if (addr == null) {
|
||||
return null;
|
||||
}
|
||||
return address(addr);
|
||||
return address(addr, 0, addr.length);
|
||||
}
|
||||
|
||||
public static InetSocketAddress localAddress(int fd) {
|
||||
@ -468,11 +468,10 @@ final class Native {
|
||||
if (addr == null) {
|
||||
return null;
|
||||
}
|
||||
return address(addr);
|
||||
return address(addr, 0, addr.length);
|
||||
}
|
||||
|
||||
static InetSocketAddress address(byte[] addr) {
|
||||
int len = addr.length;
|
||||
static InetSocketAddress address(byte[] addr, int offset, int len) {
|
||||
// The last 4 bytes are always the port
|
||||
final int port = decodeInt(addr, len - 4);
|
||||
final InetAddress address;
|
||||
@ -484,7 +483,7 @@ final class Native {
|
||||
// - 4 == port
|
||||
case 8:
|
||||
byte[] ipv4 = new byte[4];
|
||||
System.arraycopy(addr, 0, ipv4, 0, 4);
|
||||
System.arraycopy(addr, offset, ipv4, 0, 4);
|
||||
address = InetAddress.getByAddress(ipv4);
|
||||
break;
|
||||
|
||||
@ -494,7 +493,7 @@ final class Native {
|
||||
// - 4 == port
|
||||
case 24:
|
||||
byte[] ipv6 = new byte[16];
|
||||
System.arraycopy(addr, 0, ipv6, 0, 16);
|
||||
System.arraycopy(addr, offset, ipv6, 0, 16);
|
||||
int scopeId = decodeInt(addr, len - 8);
|
||||
address = Inet6Address.getByAddress(null, ipv6, scopeId);
|
||||
break;
|
||||
@ -507,7 +506,7 @@ final class Native {
|
||||
}
|
||||
}
|
||||
|
||||
private static int decodeInt(byte[] addr, int index) {
|
||||
static int decodeInt(byte[] addr, int index) {
|
||||
return (addr[index] & 0xff) << 24 |
|
||||
(addr[index + 1] & 0xff) << 16 |
|
||||
(addr[index + 2] & 0xff) << 8 |
|
||||
@ -517,8 +516,8 @@ final class Native {
|
||||
private static native byte[] remoteAddress0(int fd);
|
||||
private static native byte[] localAddress0(int fd);
|
||||
|
||||
public static int accept(int fd) throws IOException {
|
||||
int res = accept0(fd);
|
||||
public static int accept(int fd, byte[] addr) throws IOException {
|
||||
int res = accept0(fd, addr);
|
||||
if (res >= 0) {
|
||||
return res;
|
||||
}
|
||||
@ -529,7 +528,7 @@ final class Native {
|
||||
throw newIOException("accept", res);
|
||||
}
|
||||
|
||||
private static native int accept0(int fd);
|
||||
private static native int accept0(int fd, byte[] addr);
|
||||
|
||||
public static int recvFd(int fd) throws IOException {
|
||||
int res = recvFd0(fd);
|
||||
|
@ -32,7 +32,7 @@ public class NativeTest {
|
||||
ByteBuffer buffer = ByteBuffer.wrap(bytes);
|
||||
buffer.put(inetAddress.getAddress().getAddress());
|
||||
buffer.putInt(inetAddress.getPort());
|
||||
Assert.assertEquals(inetAddress, Native.address(buffer.array()));
|
||||
Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -44,6 +44,6 @@ public class NativeTest {
|
||||
buffer.put(address.getAddress());
|
||||
buffer.putInt(address.getScopeId());
|
||||
buffer.putInt(inetAddress.getPort());
|
||||
Assert.assertEquals(inetAddress, Native.address(buffer.array()));
|
||||
Assert.assertEquals(inetAddress, Native.address(buffer.array(), 0, bytes.length));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user