Add support for abstract domain sockets

Motivation:

Because of java custom UTF encoding, it was previously impossible to use
nul-bytes in domain socket names, which is required for abstract domain
sockets.

Modifications:

- Pass the encoded string byte array to the native code
- Modify native code accordingly to work with nul-bytes in the the
array.
- Move the string encoding to UTF-8 in java code.

Result:

Unix domain socket addresses will work properly if they contain nul-
bytes. Address encoding for these addresses changes from UTF-8-like to
real UTF-8.
This commit is contained in:
Jonas Konrad 2015-08-15 23:19:03 +02:00 committed by Norman Maurer
parent 74eab518d7
commit 452818ca97
3 changed files with 57 additions and 15 deletions

View File

@ -1482,21 +1482,29 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_socketDomain(JNIEnv* e
return fd;
}
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_bindDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring socketPath) {
// macro to calculate the length of a sockaddr_un struct for a given path length.
// see sys/un.h#SUN_LEN, this is modified to allow nul bytes
#define _UNIX_ADDR_LENGTH(path_len) (((struct sockaddr_un *) 0)->sun_path) + path_len
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_bindDomainSocket(JNIEnv* env, jclass clazz, jint fd, jbyteArray socketPath) {
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
const char* socket_path = (*env)->GetStringUTFChars(env, socketPath, 0);
memcpy(addr.sun_path, socket_path, strlen(socket_path));
const jbyte* socket_path = (*env)->GetByteArrayElements(env, socketPath, 0);
jint socket_path_len = (*env)->GetArrayLength(env, socketPath);
if (socket_path_len > sizeof(addr.sun_path)) {
socket_path_len = sizeof(addr.sun_path);
}
memcpy(addr.sun_path, socket_path, socket_path_len);
if (unlink(socket_path) == -1 && errno != ENOENT) {
return -errno;
}
int res = bind(fd, (struct sockaddr*) &addr, sizeof(addr));
(*env)->ReleaseStringUTFChars(env, socketPath, socket_path);
int res = bind(fd, (struct sockaddr*) &addr, _UNIX_ADDR_LENGTH(socket_path_len));
(*env)->ReleaseByteArrayElements(env, socketPath, socket_path, 0);
if (res == -1) {
return -errno;
@ -1504,22 +1512,27 @@ JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_bindDomainSocket(JNIEn
return res;
}
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jstring socketPath) {
JNIEXPORT jint JNICALL Java_io_netty_channel_epoll_Native_connectDomainSocket(JNIEnv* env, jclass clazz, jint fd, jbyteArray socketPath) {
struct sockaddr_un addr;
jint socket_path_len;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
const char* socket_path = (*env)->GetStringUTFChars(env, socketPath, 0);
strncpy(addr.sun_path, socket_path, sizeof(addr.sun_path) - 1);
const jbyte* socket_path = (*env)->GetByteArrayElements(env, socketPath, 0);
socket_path_len = (*env)->GetArrayLength(env, socketPath);
if (socket_path_len > sizeof(addr.sun_path)) {
socket_path_len = sizeof(addr.sun_path);
}
memcpy(addr.sun_path, socket_path, socket_path_len);
int res;
int err;
do {
res = connect(fd, (struct sockaddr*) &addr, sizeof(addr));
res = connect(fd, (struct sockaddr*) &addr, _UNIX_ADDR_LENGTH(socket_path_len));
} while (res == -1 && ((err = errno) == EINTR));
(*env)->ReleaseStringUTFChars(env, socketPath, socket_path);
(*env)->ReleaseByteArrayElements(env, socketPath, socket_path, 0);
if (res < 0) {
return -err;

View File

@ -19,6 +19,7 @@ package io.netty.channel.epoll;
import io.netty.channel.ChannelException;
import io.netty.channel.DefaultFileRegion;
import io.netty.channel.unix.DomainSocketAddress;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.NativeLibraryLoader;
import io.netty.util.internal.PlatformDependent;
@ -435,7 +436,7 @@ public final class Native {
}
} else if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress addr = (DomainSocketAddress) socketAddress;
int res = bindDomainSocket(fd, addr.path());
int res = bindDomainSocket(fd, addr.path().getBytes(CharsetUtil.UTF_8));
if (res < 0) {
throw newIOException("bind", res);
}
@ -445,7 +446,7 @@ public final class Native {
}
private static native int bind(int fd, byte[] address, int scopeId, int port);
private static native int bindDomainSocket(int fd, String path);
private static native int bindDomainSocket(int fd, byte[] path);
public static void listen(int fd, int backlog) throws IOException {
int res = listen0(fd, backlog);
@ -464,7 +465,7 @@ public final class Native {
res = connect(fd, address.address, address.scopeId, inetSocketAddress.getPort());
} else if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress unixDomainSocketAddress = (DomainSocketAddress) socketAddress;
res = connectDomainSocket(fd, unixDomainSocketAddress.path());
res = connectDomainSocket(fd, unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8));
} else {
throw new Error("Unexpected SocketAddress implementation " + socketAddress);
}
@ -479,7 +480,7 @@ public final class Native {
}
private static native int connect(int fd, byte[] address, int scopeId, int port);
private static native int connectDomainSocket(int fd, String path);
private static native int connectDomainSocket(int fd, byte[] path);
public static boolean finishConnect(int fd) throws IOException {
int res = finishConnect0(fd);

View File

@ -0,0 +1,28 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.epoll;
import io.netty.channel.unix.DomainSocketAddress;
import java.net.SocketAddress;
import java.util.UUID;
public class EpollAbstractDomainSocketEchoTest extends EpollDomainSocketEchoTest {
@Override
protected SocketAddress newSocketAddress() {
// these don't actually show up in the file system so creating a temp file isn't reliable
return new DomainSocketAddress("\0/tmp/" + UUID.randomUUID());
}
}