Add support for abstract domain sockets

Motivation:

Because of java custom UTF encoding, it was previously impossible to use
nul-bytes in domain socket names, which is required for abstract domain
sockets.

Modifications:

- Pass the encoded string byte array to the native code
- Modify native code accordingly to work with nul-bytes in the the
array.
- Move the string encoding to UTF-8 in java code.

Result:

Unix domain socket addresses will work properly if they contain nul-
bytes. Address encoding for these addresses changes from UTF-8-like to
real UTF-8.
This commit is contained in:
Jonas Konrad 2015-08-15 23:19:03 +02:00 committed by Norman Maurer
parent 74eab518d7
commit 452818ca97
3 changed files with 57 additions and 15 deletions

View File

@ -1482,21 +1482,29 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_socketDomain(JNIEnv* e
return fd; return fd;
} }
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_bindDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring socketPath) { // macro to calculate the length of a sockaddr_un struct for a given path length.
// see sys/un.h#SUN_LEN, this is modified to allow nul bytes
#define _UNIX_ADDR_LENGTH(path_len) (((struct sockaddr_un *) 0)->sun_path) + path_len
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_bindDomainSocket(JNIEnv* env, jclass clazz, jint fd, jbyteArray socketPath) {
struct sockaddr_un addr; struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr)); memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX; addr.sun_family = AF_UNIX;
const char* socket_path = (*env)->GetStringUTFChars(env, socketPath, 0); const jbyte* socket_path = (*env)->GetByteArrayElements(env, socketPath, 0);
memcpy(addr.sun_path, socket_path, strlen(socket_path)); jint socket_path_len = (*env)->GetArrayLength(env, socketPath);
if (socket_path_len > sizeof(addr.sun_path)) {
socket_path_len = sizeof(addr.sun_path);
}
memcpy(addr.sun_path, socket_path, socket_path_len);
if (unlink(socket_path) == -1 && errno != ENOENT) { if (unlink(socket_path) == -1 && errno != ENOENT) {
return -errno; return -errno;
} }
int res = bind(fd, (struct sockaddr*) &addr, sizeof(addr)); int res = bind(fd, (struct sockaddr*) &addr, _UNIX_ADDR_LENGTH(socket_path_len));
(*env)->ReleaseStringUTFChars(env, socketPath, socket_path); (*env)->ReleaseByteArrayElements(env, socketPath, socket_path, 0);
if (res == -1) { if (res == -1) {
return -errno; return -errno;
@ -1504,22 +1512,27 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_bindDomainSocket(JNIEn
return res; return res;
} }
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring socketPath) { JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jbyteArray socketPath) {
struct sockaddr_un addr; struct sockaddr_un addr;
jint socket_path_len;
memset(&addr, 0, sizeof(addr)); memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX; addr.sun_family = AF_UNIX;
const char* socket_path = (*env)->GetStringUTFChars(env, socketPath, 0); const jbyte* socket_path = (*env)->GetByteArrayElements(env, socketPath, 0);
strncpy(addr.sun_path, socket_path, sizeof(addr.sun_path) - 1); socket_path_len = (*env)->GetArrayLength(env, socketPath);
if (socket_path_len > sizeof(addr.sun_path)) {
socket_path_len = sizeof(addr.sun_path);
}
memcpy(addr.sun_path, socket_path, socket_path_len);
int res; int res;
int err; int err;
do { do {
res = connect(fd, (struct sockaddr*) &addr, sizeof(addr)); res = connect(fd, (struct sockaddr*) &addr, _UNIX_ADDR_LENGTH(socket_path_len));
} while (res == -1 && ((err = errno) == EINTR)); } while (res == -1 && ((err = errno) == EINTR));
(*env)->ReleaseStringUTFChars(env, socketPath, socket_path); (*env)->ReleaseByteArrayElements(env, socketPath, socket_path, 0);
if (res < 0) { if (res < 0) {
return -err; return -err;
@ -1688,4 +1701,4 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_splice0(JNIEnv* env, j
JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_ssizeMax(JNIEnv* env, jclass clazz) { JNIEXPORT jlong JNICALL Java_io_netty_channel_epoll_Native_ssizeMax(JNIEnv* env, jclass clazz) {
return SSIZE_MAX; return SSIZE_MAX;
} }

View File

@ -19,6 +19,7 @@ package io.netty.channel.epoll;
import io.netty.channel.ChannelException; import io.netty.channel.ChannelException;
import io.netty.channel.DefaultFileRegion; import io.netty.channel.DefaultFileRegion;
import io.netty.channel.unix.DomainSocketAddress; import io.netty.channel.unix.DomainSocketAddress;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.NativeLibraryLoader; import io.netty.util.internal.NativeLibraryLoader;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
@ -435,7 +436,7 @@ public final class Native {
} }
} else if (socketAddress instanceof DomainSocketAddress) { } else if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress addr = (DomainSocketAddress) socketAddress; DomainSocketAddress addr = (DomainSocketAddress) socketAddress;
int res = bindDomainSocket(fd, addr.path()); int res = bindDomainSocket(fd, addr.path().getBytes(CharsetUtil.UTF_8));
if (res < 0) { if (res < 0) {
throw newIOException("bind", res); throw newIOException("bind", res);
} }
@ -445,7 +446,7 @@ public final class Native {
} }
private static native int bind(int fd, byte[] address, int scopeId, int port); private static native int bind(int fd, byte[] address, int scopeId, int port);
private static native int bindDomainSocket(int fd, String path); private static native int bindDomainSocket(int fd, byte[] path);
public static void listen(int fd, int backlog) throws IOException { public static void listen(int fd, int backlog) throws IOException {
int res = listen0(fd, backlog); int res = listen0(fd, backlog);
@ -464,7 +465,7 @@ public final class Native {
res = connect(fd, address.address, address.scopeId, inetSocketAddress.getPort()); res = connect(fd, address.address, address.scopeId, inetSocketAddress.getPort());
} else if (socketAddress instanceof DomainSocketAddress) { } else if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress unixDomainSocketAddress = (DomainSocketAddress) socketAddress; DomainSocketAddress unixDomainSocketAddress = (DomainSocketAddress) socketAddress;
res = connectDomainSocket(fd, unixDomainSocketAddress.path()); res = connectDomainSocket(fd, unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8));
} else { } else {
throw new Error("Unexpected SocketAddress implementation " + socketAddress); throw new Error("Unexpected SocketAddress implementation " + socketAddress);
} }
@ -479,7 +480,7 @@ public final class Native {
} }
private static native int connect(int fd, byte[] address, int scopeId, int port); private static native int connect(int fd, byte[] address, int scopeId, int port);
private static native int connectDomainSocket(int fd, String path); private static native int connectDomainSocket(int fd, byte[] path);
public static boolean finishConnect(int fd) throws IOException { public static boolean finishConnect(int fd) throws IOException {
int res = finishConnect0(fd); int res = finishConnect0(fd);

View File

@ -0,0 +1,28 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.epoll;
import io.netty.channel.unix.DomainSocketAddress;
import java.net.SocketAddress;
import java.util.UUID;
public class EpollAbstractDomainSocketEchoTest extends EpollDomainSocketEchoTest {
@Override
protected SocketAddress newSocketAddress() {
// these don't actually show up in the file system so creating a temp file isn't reliable
return new DomainSocketAddress("\0/tmp/" + UUID.randomUUID());
}
}