Add connect

Motivation:

if connect returns EINPROGRESS we send POLL OUT and check
via socket.finishConnect if the connection is successful

Modifications:

-added new io_uring connect op
-use a directbuffer for the socket address

Result:

you are able to connect to a peer
This commit is contained in:
Josef Grieb 2020-08-27 13:45:44 +02:00
parent 9fcd2926f1
commit 8096b2c15f
15 changed files with 617 additions and 141 deletions

View File

@ -211,6 +211,40 @@ static void netty_io_uring_eventFdRead(JNIEnv* env, jclass clazz, jint fd) {
}
}
static jint netty_io_uring_initAddress(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jbyteArray address, jint scopeId, jint port, jlong remoteMemoryAddress) {
struct sockaddr_storage addr;
socklen_t addrSize;
if (netty_unix_socket_initSockaddr(env, ipv6, address, scopeId, port, &addr, &addrSize) == -1) {
// A runtime exception was thrown
return -1;
}
memcpy((void *) remoteMemoryAddress, &addr, sizeof(struct sockaddr_storage));
return addrSize;
}
//static jint netty_unix_socket_connect(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jbyteArray address, jint scopeId, jint port) {
// struct sockaddr_storage addr;
// socklen_t addrSize;
// if (netty_unix_socket_initSockaddr(env, ipv6, address, scopeId, port, &addr, &addrSize) == -1) {
// // A runtime exception was thrown
// return -1;
// }
//
// int res;
// int err;
// do {
// res = connect(fd, (struct sockaddr*) &addr, addrSize);
// } while (res == -1 && ((err = errno) == EINTR));
//
// if (res < 0) {
// return -err;
// }
// return 0;
//}
static void netty_io_uring_eventFdWrite(JNIEnv* env, jclass clazz, jint fd, jlong value) {
uint64_t val;
@ -338,7 +372,8 @@ static const JNINativeMethod method_table[] = {
{"eventFdWrite", "(IJ)V", (void *) netty_io_uring_eventFdWrite },
{"eventFdRead", "(I)V", (void *) netty_io_uring_eventFdRead },
{"ioUringRegisterEventFd", "(II)I", (void *) netty_io_uring_register_event_fd},
{"ioUringUnregisterEventFd", "(I)I", (void *) netty_io_uring_unregister_event_fd}
{"ioUringUnregisterEventFd", "(I)I", (void *) netty_io_uring_unregister_event_fd},
{"initAddress", "(IZ[BIIJ)I", (void *) netty_io_uring_initAddress }
};
static const jint method_table_size =
sizeof(method_table) / sizeof(method_table[0]);

View File

@ -22,32 +22,47 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.AbstractChannel;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.ConnectTimeoutException;
import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.channel.socket.ChannelInputShutdownReadComplete;
import io.netty.channel.socket.SocketChannelConfig;
import io.netty.channel.unix.Buffer;
import io.netty.channel.unix.FileDescriptor;
import io.netty.channel.unix.NativeInetAddress;
import io.netty.channel.unix.Socket;
import io.netty.channel.unix.UnixChannel;
import io.netty.channel.unix.UnixChannelUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.NoRouteToHostException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AlreadyConnectedException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ConnectionPendingException;
import java.nio.channels.NotYetConnectedException;
import java.nio.channels.UnresolvedAddressException;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import static io.netty.channel.unix.Errors.*;
import static io.netty.channel.unix.UnixChannelUtil.*;
import static io.netty.util.internal.ObjectUtil.*;
abstract class AbstractIOUringChannel extends AbstractChannel implements UnixChannel {
@ -58,6 +73,8 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
boolean uringInReadyPending;
boolean inputClosedSeenErrorOnRead;
static final int SOCK_ADDR_LEN = 128;
//can only submit one write operation at a time
private boolean writeable = true;
/**
@ -67,6 +84,9 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
private ScheduledFuture<?> connectTimeoutFuture;
private SocketAddress requestedRemoteAddress;
private final ByteBuffer remoteAddressMemory;
private final long remoteAddressMemoryAddress;
private volatile SocketAddress local;
private volatile SocketAddress remote;
@ -88,9 +108,12 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
} else {
logger.trace("Create Server Socket: {}", socket.intValue());
}
remoteAddressMemory = Buffer.allocateDirectWithNativeOrder(SOCK_ADDR_LEN);
remoteAddressMemoryAddress = Buffer.memoryAddress(remoteAddressMemory);
}
protected AbstractIOUringChannel(final Channel parent, LinuxSocket socket, boolean active) {
AbstractIOUringChannel(final Channel parent, LinuxSocket socket, boolean active) {
super(parent);
this.socket = checkNotNull(socket, "fd");
this.active = active;
@ -105,6 +128,22 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
} else {
logger.trace("Create Server Socket: {}", socket.intValue());
}
remoteAddressMemory = Buffer.allocateDirectWithNativeOrder(SOCK_ADDR_LEN);
remoteAddressMemoryAddress = Buffer.memoryAddress(remoteAddressMemory);
}
AbstractIOUringChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
super(parent);
this.socket = checkNotNull(fd, "fd");
this.active = true;
// Directly cache the remote and local addresses
// See https://github.com/netty/netty/issues/2359
this.remote = remote;
this.local = fd.localAddress();
remoteAddressMemory = Buffer.allocateDirectWithNativeOrder(SOCK_ADDR_LEN);
remoteAddressMemoryAddress = Buffer.memoryAddress(remoteAddressMemory);
}
public boolean isOpen() {
@ -333,7 +372,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
// Channel/ChannelHandlerContext.read() was called
@Override
protected void doBeginRead() {
logger.trace("Begin Read");
System.out.println("Begin Read");
final AbstractUringUnsafe unsafe = (AbstractUringUnsafe) unsafe();
if (!uringInReadyPending) {
unsafe.executeUringReadOperator();
@ -374,7 +413,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
private void addPollOut() {
IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop();
IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue();
submissionQueue.addPollOut(socket.intValue());
submissionQueue.addPollOutLink(socket.intValue());
submissionQueue.submit();
}
@ -388,14 +427,21 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
}
};
IOUringRecvByteAllocatorHandle newIOUringHandle(RecvByteBufAllocator.ExtendedHandle handle) {
return new IOUringRecvByteAllocatorHandle(handle);
public void fulfillConnectPromise(ChannelPromise promise, Throwable t, SocketAddress remoteAddress) {
Throwable cause = annotateConnectException(t, remoteAddress);
if (promise == null) {
// Closed via cancellation and the promise has been notified already.
return;
}
// Use tryFailure() instead of setFailure() to avoid the race against cancel().
promise.tryFailure(cause);
closeIfClosed();
}
@Override
public void connect(final SocketAddress remoteAddress, final SocketAddress localAddress,
final ChannelPromise promise) {
promise.setFailure(new UnsupportedOperationException());
IOUringRecvByteAllocatorHandle newIOUringHandle(RecvByteBufAllocator.ExtendedHandle handle) {
return new IOUringRecvByteAllocatorHandle(handle);
}
@Override
@ -415,6 +461,91 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
}
public abstract void uringEventExecution();
@Override
public void connect(
final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) {
if (!promise.setUncancellable() || !ensureOpen(promise)) {
return;
}
try {
if (connectPromise != null) {
throw new ConnectionPendingException();
}
doConnect(remoteAddress, localAddress);
InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress;
NativeInetAddress address = NativeInetAddress.newInstance(inetSocketAddress.getAddress());
socket.initAddress(address.address(), address.scopeId(), inetSocketAddress.getPort(),remoteAddressMemoryAddress);
IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop();
final IOUringSubmissionQueue ioUringSubmissionQueue =
ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue();
ioUringSubmissionQueue.addConnect(socket.intValue(), remoteAddressMemoryAddress, SOCK_ADDR_LEN);
ioUringSubmissionQueue.submit();
} catch (Throwable t) {
closeIfClosed();
promise.tryFailure(annotateConnectException(t, remoteAddress));
return;
}
connectPromise = promise;
requestedRemoteAddress = remoteAddress;
// Schedule connect timeout.
int connectTimeoutMillis = config().getConnectTimeoutMillis();
if (connectTimeoutMillis > 0) {
connectTimeoutFuture = eventLoop().schedule(new Runnable() {
@Override
public void run() {
ChannelPromise connectPromise = AbstractIOUringChannel.this.connectPromise;
ConnectTimeoutException cause =
new ConnectTimeoutException("connection timed out: " + remoteAddress);
if (connectPromise != null && connectPromise.tryFailure(cause)) {
close(voidPromise());
}
}
}, connectTimeoutMillis, TimeUnit.MILLISECONDS);
}
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isCancelled()) {
if (connectTimeoutFuture != null) {
connectTimeoutFuture.cancel(false);
}
connectPromise = null;
close(voidPromise());
}
}
});
}
}
public void fulfillConnectPromise(ChannelPromise promise, boolean wasActive) {
if (promise == null) {
// Closed via cancellation and the promise has been notified already.
return;
}
active = true;
// Get the state as trySuccess() may trigger an ChannelFutureListener that will close the Channel.
// We still need to ensure we call fireChannelActive() in this case.
boolean active = isActive();
// trySuccess() will return false if a user cancelled the connection attempt.
boolean promiseSet = promise.trySuccess();
// Regardless if the connection attempt was cancelled, channelActive() event should be triggered,
// because what happened is what happened.
if (!wasActive && active) {
pipeline().fireChannelActive();
}
// If a user cancelled the connection attempt, close the channel, which is followed by channelInactive().
if (!promiseSet) {
close(voidPromise());
}
}
@Override
@ -473,6 +604,53 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
return socket;
}
/**
* Connect to the remote peer
*/
protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
if (localAddress instanceof InetSocketAddress) {
checkResolvable((InetSocketAddress) localAddress);
}
InetSocketAddress remoteSocketAddr = remoteAddress instanceof InetSocketAddress
? (InetSocketAddress) remoteAddress : null;
if (remoteSocketAddr != null) {
checkResolvable(remoteSocketAddr);
}
if (remote != null) {
// Check if already connected before trying to connect. This is needed as connect(...) will not return -1
// and set errno to EISCONN if a previous connect(...) attempt was setting errno to EINPROGRESS and finished
// later.
throw new AlreadyConnectedException();
}
if (localAddress != null) {
socket.bind(localAddress);
}
}
// public void setRemote() {
// remote = remoteSocketAddr == null ?
// remoteAddress : computeRemoteAddr(remoteSocketAddr, socket.remoteAddress());
// }
private boolean doConnect0(SocketAddress remote) throws Exception {
boolean success = false;
try {
boolean connected = socket.connect(remote);
if (!connected) {
//setFlag(Native.EPOLLOUT);
}
success = true;
return connected;
} finally {
if (!success) {
doClose();
}
}
}
void shutdownInput(boolean rdHup) {
logger.trace("shutdownInput Fd: {}", this.socket.intValue());
if (!socket.isInputShutdown()) {
@ -498,6 +676,26 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
}
}
//Todo we should move it to a error class
// copy unix Errors
static void throwConnectException(String method, int err)
throws IOException {
if (err == ERROR_EALREADY_NEGATIVE) {
throw new ConnectionPendingException();
}
if (err == ERROR_ENETUNREACH_NEGATIVE) {
throw new NoRouteToHostException();
}
if (err == ERROR_EISCONN_NEGATIVE) {
throw new AlreadyConnectedException();
}
if (err == ERRNO_ENOENT_NEGATIVE) {
throw new FileNotFoundException();
}
throw new ConnectException(method + "(..) failed: ");
}
private static boolean isAllowHalfClosure(ChannelConfig config) {
return config instanceof SocketChannelConfig &&
((SocketChannelConfig) config).isAllowHalfClosure();
@ -508,6 +706,31 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
close(voidPromise());
}
void cancelTimeoutFuture() {
if (connectTimeoutFuture != null) {
connectTimeoutFuture.cancel(false);
}
connectPromise = null;
}
boolean doFinishConnect() throws Exception {
if (socket.finishConnect()) {
if (requestedRemoteAddress instanceof InetSocketAddress) {
remote = computeRemoteAddr((InetSocketAddress) requestedRemoteAddress, socket.remoteAddress());
}
requestedRemoteAddress = null;
return true;
}
return false;
}
void computeRemote() {
if (requestedRemoteAddress instanceof InetSocketAddress) {
remote = computeRemoteAddr((InetSocketAddress) requestedRemoteAddress, socket.remoteAddress());
}
}
final boolean shouldBreakIoUringInReady(ChannelConfig config) {
return socket.isInputShutdown() && (inputClosedSeenErrorOnRead || !isAllowHalfClosure(config));
}
@ -515,4 +738,20 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
public void setWriteable(boolean writeable) {
this.writeable = writeable;
}
public long getRemoteAddressMemoryAddress() {
return remoteAddressMemoryAddress;
}
public ChannelPromise getConnectPromise() {
return connectPromise;
}
public ScheduledFuture<?> getConnectTimeoutFuture() {
return connectTimeoutFuture;
}
public SocketAddress getRequestedRemoteAddress() {
return requestedRemoteAddress;
}
}

View File

@ -87,7 +87,7 @@ abstract class AbstractIOUringServerChannel extends AbstractIOUringChannel imple
public void uringEventExecution() {
final IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop();
IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue();
submissionQueue.addPollLink(socket.intValue());
submissionQueue.addPollInLink(socket.intValue());
//Todo get network addresses
submissionQueue.addAccept(fd().intValue());

View File

@ -30,6 +30,7 @@ import io.netty.util.internal.UnstableApi;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.SocketAddress;
import java.util.concurrent.Executor;
abstract class AbstractIOUringStreamChannel extends AbstractIOUringChannel implements DuplexChannel {
@ -43,6 +44,11 @@ abstract class AbstractIOUringStreamChannel extends AbstractIOUringChannel imple
super(parent, socket, active);
}
AbstractIOUringStreamChannel(Channel parent, LinuxSocket fd, SocketAddress remote) {
super(parent, fd, remote);
// Add EPOLLRDHUP so we are notified once the remote peer close the connection.
}
@Override
public ChannelFuture shutdown() {
System.out.println("AbstractStreamChannel shutdown");

View File

@ -27,10 +27,12 @@ final class IOUring {
static final int OP_READ = 22;
static final int OP_WRITE = 23;
static final int OP_POLL_REMOVE = 7;
static final int OP_CONNECT = 16;
static final int POLLMASK_LINK = 1;
static final int POLLMASK_OUT = 4;
static final int POLLMASK_IN_LINK = 1;
static final int POLLMASK_OUT_LINK = 4;
static final int POLLMASK_RDHUP = 8192;
static final int POLLMASK_OUT = 4;
static {
Throwable cause = null;

View File

@ -55,7 +55,7 @@ final class IOUringCompletionQueue {
this.ringFd = ringFd;
}
public int process(IOUringCompletionQueueCallback callback) {
public int process(IOUringCompletionQueueCallback callback) throws Exception {
int i = 0;
for (;;) {
long head = toUnsignedLong(PlatformDependent.getIntVolatile(kHeadAddress));

View File

@ -18,6 +18,7 @@ package io.netty.channel.uring;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SingleThreadEventLoop;
import io.netty.channel.unix.FileDescriptor;
import io.netty.channel.uring.AbstractIOUringChannel.AbstractUringUnsafe;
import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap;
import io.netty.util.internal.PlatformDependent;
@ -32,7 +33,7 @@ import java.util.concurrent.atomic.AtomicLong;
import static io.netty.channel.unix.Errors.*;
final class IOUringEventLoop extends SingleThreadEventLoop implements
IOUringCompletionQueue.IOUringCompletionQueueCallback {
IOUringCompletionQueue.IOUringCompletionQueueCallback {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(IOUringEventLoop.class);
//Todo set config ring buffer size
@ -75,7 +76,7 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
private static Queue<Runnable> newTaskQueue0(int maxPendingTasks) {
// This event loop never calls takeTask()
return maxPendingTasks == Integer.MAX_VALUE ? PlatformDependent.<Runnable>newMpscQueue()
return maxPendingTasks == Integer.MAX_VALUE? PlatformDependent.<Runnable>newMpscQueue()
: PlatformDependent.<Runnable>newMpscQueue(maxPendingTasks);
}
@ -117,10 +118,10 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
final IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
// Lets add the eventfd related events before starting to do any real work.
submissionQueue.addPollLink(eventfd.intValue());
submissionQueue.addPollInLink(eventfd.intValue());
submissionQueue.submit();
for (;;) {
for (; ; ) {
logger.trace("Run IOUringEventLoop {}", this.toString());
long curDeadlineNanos = nextScheduledTaskDeadlineNanos();
if (curDeadlineNanos == -1L) {
@ -144,6 +145,8 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
logger.trace("ioUringWaitCqe {}", this.toString());
completionQueue.ioUringWaitCqe();
}
} catch (Throwable t) {
//Todo handle exception
} finally {
if (nextWakeupNanos.get() == AWAKE || nextWakeupNanos.getAndSet(AWAKE) == AWAKE) {
pendingWakeup = true;
@ -151,7 +154,11 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
}
}
completionQueue.process(this);
try {
completionQueue.process(this);
} catch (Exception e) {
//Todo handle exception
}
if (hasTasks()) {
runAllTasks();
@ -174,90 +181,140 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
public boolean handle(int fd, int res, long flags, int op, int pollMask) {
IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
switch (op) {
case IOUring.OP_ACCEPT:
//Todo error handle the res
if (res == ECANCELED) {
logger.trace("POLL_LINK canceled");
break;
}
AbstractIOUringServerChannel acceptChannel = (AbstractIOUringServerChannel) channels.get(fd);
if (acceptChannel == null) {
break;
}
logger.trace("EventLoop Accept filedescriptor: {}", res);
acceptChannel.setUringInReadyPending(false);
if (res != -1 && res != ERRNO_EAGAIN_NEGATIVE &&
res != ERRNO_EWOULDBLOCK_NEGATIVE) {
logger.trace("server filedescriptor Fd: {}", fd);
if (acceptChannel.acceptComplete(res)) {
// all childChannels should poll POLLRDHUP
submissionQueue.addPollRdHup(res);
submissionQueue.submit();
}
}
case IOUring.OP_ACCEPT:
//Todo error handle the res
if (res == ECANCELED) {
logger.trace("POLL_LINK canceled");
break;
case IOUring.OP_READ:
AbstractIOUringChannel readChannel = channels.get(fd);
if (readChannel == null) {
break;
}
readChannel.readComplete(res);
}
AbstractIOUringServerChannel acceptChannel = (AbstractIOUringServerChannel) channels.get(fd);
if (acceptChannel == null) {
break;
case IOUring.OP_WRITE:
AbstractIOUringChannel writeChannel = channels.get(fd);
if (writeChannel == null) {
break;
}
//localFlushAmount -> res
logger.trace("EventLoop Write Res: {}", res);
logger.trace("EventLoop Fd: {}", fd);
if (res == SOCKET_ERROR_EPIPE) {
writeChannel.shutdownInput(false);
} else {
writeChannel.writeComplete(res);
}
break;
case IOUring.IO_TIMEOUT:
if (res == ETIME) {
prevDeadlineNanos = NONE;
}
break;
case IOUring.IO_POLL:
//Todo error handle the res
if (res == ECANCELED) {
logger.trace("POLL_LINK canceled");
break;
}
if (eventfd.intValue() == fd) {
pendingWakeup = false;
// We need to consume the data as otherwise we would see another event
// in the completionQueue without
// an extra eventfd_write(....)
Native.eventFdRead(eventfd.intValue());
submissionQueue.addPollLink(eventfd.intValue());
// Submit so its picked up
}
logger.trace("EventLoop Accept filedescriptor: {}", res);
acceptChannel.setUringInReadyPending(false);
if (res != -1 && res != ERRNO_EAGAIN_NEGATIVE &&
res != ERRNO_EWOULDBLOCK_NEGATIVE) {
logger.trace("server filedescriptor Fd: {}", fd);
if (acceptChannel.acceptComplete(res)) {
// all childChannels should poll POLLRDHUP
submissionQueue.addPollRdHup(res);
submissionQueue.submit();
} else {
if (pollMask == IOUring.POLLMASK_RDHUP) {
AbstractIOUringChannel channel = channels.get(fd);
if (channel != null && !channel.isActive()) {
channel.shutdownInput(true);
}
} else {
//Todo error handling error
logger.trace("POLL_LINK Res: {}", res);
}
}
break;
case IOUring.OP_READ:
AbstractIOUringChannel readChannel = channels.get(fd);
if (readChannel == null) {
break;
}
readChannel.readComplete(res);
break;
case IOUring.OP_WRITE:
AbstractIOUringChannel writeChannel = channels.get(fd);
if (writeChannel == null) {
break;
}
//localFlushAmount -> res
logger.trace("EventLoop Write Res: {}", res);
logger.trace("EventLoop Fd: {}", fd);
if (res == SOCKET_ERROR_EPIPE) {
writeChannel.shutdownInput(false);
} else {
writeChannel.writeComplete(res);
}
break;
case IOUring.IO_TIMEOUT:
if (res == ETIME) {
prevDeadlineNanos = NONE;
}
break;
case IOUring.IO_POLL:
//Todo error handle the res
if (res == ECANCELED) {
logger.trace("POLL_LINK canceled");
break;
}
if (eventfd.intValue() == fd) {
pendingWakeup = false;
// We need to consume the data as otherwise we would see another event
// in the completionQueue without
// an extra eventfd_write(....)
Native.eventFdRead(eventfd.intValue());
submissionQueue.addPollInLink(eventfd.intValue());
// Submit so its picked up
submissionQueue.submit();
} else {
if (pollMask == IOUring.POLLMASK_RDHUP) {
AbstractIOUringChannel channel = channels.get(fd);
if (channel != null && !channel.isActive()) {
channel.shutdownInput(true);
}
} else if (pollMask == IOUring.POLLMASK_OUT) {
//connect successful
AbstractIOUringChannel ch = channels.get(fd);
boolean wasActive = ch.isActive();
try {
if (ch.doFinishConnect()) {
ch.fulfillConnectPromise(ch.getConnectPromise(), wasActive);
ch.cancelTimeoutFuture();
} else {
//submit pollout
submissionQueue.addPollOut(fd);
submissionQueue.submit();
}
} catch (Throwable t) {
AbstractUringUnsafe unsafe = (AbstractUringUnsafe) ch.unsafe();
unsafe.fulfillConnectPromise(ch.getConnectPromise(), t, ch.getRequestedRemoteAddress());
}
} else {
//Todo error handling error
logger.trace("POLL_LINK Res: {}", res);
}
}
break;
case IOUring.OP_POLL_REMOVE:
if (res == ENOENT) {
System.out.println(("POLL_REMOVE OPERATION not permitted"));
} else if (res == 0) {
System.out.println(("POLL_REMOVE OPERATION successful"));
}
break;
case IOUring.OP_CONNECT:
AbstractIOUringChannel channel = channels.get(fd);
System.out.println("Connect res: " + res);
if (res == 0) {
channel.fulfillConnectPromise(channel.getConnectPromise(), channel.active);
channel.cancelTimeoutFuture();
channel.computeRemote();
break;
case IOUring.OP_POLL_REMOVE:
if (res == ENOENT) {
logger.trace("POLL_REMOVE OPERATION not permitted");
} else if (res == 0) {
logger.trace("POLL_REMOVE OPERATION successful");
}
if (res == -1 || res == -4) {
submissionQueue.addConnect(fd, channel.getRemoteAddressMemoryAddress(),
AbstractIOUringChannel.SOCK_ADDR_LEN);
submissionQueue.submit();
break;
}
if (res < 0) {
if (res == ERRNO_EINPROGRESS_NEGATIVE) {
// connect not complete yet need to wait for poll_out event
submissionQueue.addPollOut(fd);
submissionQueue.submit();
break;
}
try {
channel.doClose();
} catch (Exception e) {
//Todo error handling
}
//Todo error handling
//AbstractIOUringChannel.throwConnectException("connect", res);
break;
}
break;
}
return true;
}

View File

@ -33,11 +33,15 @@ import io.netty.channel.unix.Socket;
import io.netty.channel.uring.AbstractIOUringStreamChannel.IOUringStreamUnsafe;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Collections;
public final class IOUringSocketChannel extends AbstractIOUringStreamChannel implements SocketChannel {
private final IOUringSocketChannelConfig config;
//private volatile Collection<InetAddress> tcpMd5SigAddresses = Collections.emptyList();
public IOUringSocketChannel() {
super(null, LinuxSocket.newSocketStream(), false);
@ -49,6 +53,15 @@ public final class IOUringSocketChannel extends AbstractIOUringStreamChannel imp
this.config = new IOUringSocketChannelConfig(this);
}
IOUringSocketChannel(Channel parent, LinuxSocket fd, InetSocketAddress remoteAddress) {
super(parent, fd, remoteAddress);
this.config = new IOUringSocketChannelConfig(this);
// if (parent instanceof IOUringSocketChannel) {
// tcpMd5SigAddresses = ((IOUringSocketChannel) parent).tcpMd5SigAddresses();
// }
}
@Override
public ServerSocketChannel parent() {
return (ServerSocketChannel) super.parent();

View File

@ -24,6 +24,7 @@ import io.netty.channel.MessageSizeEstimator;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.socket.SocketChannelConfig;
import io.netty.util.internal.PlatformDependent;
import java.io.IOException;
import java.net.InetAddress;
@ -39,6 +40,11 @@ public class IOUringSocketChannelConfig extends DefaultChannelConfig implements
public IOUringSocketChannelConfig(Channel channel) {
super(channel);
if (PlatformDependent.canEnableTcpNoDelayByDefault()) {
setTcpNoDelay(true);
}
calculateMaxBytesPerGatheringWrite();
}
@Override

View File

@ -119,7 +119,7 @@ final class IOUringSubmissionQueue {
//user_data should be same as POLL_LINK fd
if (op == IOUring.OP_POLL_REMOVE) {
PlatformDependent.putInt(sqe + SQE_FD_FIELD, -1);
long pollLinkuData = convertToUserData((byte) IOUring.IO_POLL, fd, IOUring.POLLMASK_LINK);
long pollLinkuData = convertToUserData((byte) IOUring.IO_POLL, fd, IOUring.POLLMASK_IN_LINK);
PlatformDependent.putLong(sqe + SQE_ADDRESS_FIELD, pollLinkuData);
}
@ -127,7 +127,7 @@ final class IOUringSubmissionQueue {
PlatformDependent.putLong(sqe + SQE_USER_DATA_FIELD, uData);
//poll<link>read or accept operation
if (op == 6 && (pollMask == IOUring.POLLMASK_OUT || pollMask == IOUring.POLLMASK_LINK)) {
if (op == 6 && (pollMask == IOUring.POLLMASK_OUT_LINK || pollMask == IOUring.POLLMASK_IN_LINK)) {
PlatformDependent.putByte(sqe + SQE_FLAGS_FIELD, (byte) IOSQE_IO_LINK);
} else {
PlatformDependent.putByte(sqe + SQE_FLAGS_FIELD, (byte) 0);
@ -152,6 +152,8 @@ final class IOUringSubmissionQueue {
PlatformDependent.putInt(sqe + SQE_RW_FLAGS_FIELD, pollMask);
}
logger.trace("UserDataField: {}", PlatformDependent.getLong(sqe + SQE_USER_DATA_FIELD));
logger.trace("BufferAddress: {}", PlatformDependent.getLong(sqe + SQE_ADDRESS_FIELD));
logger.trace("Length: {}", PlatformDependent.getInt(sqe + SQE_LEN_FIELD));
@ -168,18 +170,22 @@ final class IOUringSubmissionQueue {
return true;
}
public boolean addPollLink(int fd) {
return addPoll(fd, IOUring.POLLMASK_LINK);
public boolean addPollInLink(int fd) {
return addPoll(fd, IOUring.POLLMASK_IN_LINK);
}
public boolean addPollOut(int fd) {
return addPoll(fd, IOUring.POLLMASK_OUT);
public boolean addPollOutLink(int fd) {
return addPoll(fd, IOUring.POLLMASK_OUT_LINK);
}
public boolean addPollRdHup(int fd) {
return addPoll(fd, IOUring.POLLMASK_RDHUP);
}
public boolean addPollOut(int fd) {
return addPoll(fd, IOUring.POLLMASK_OUT);
}
private boolean addPoll(int fd, int pollMask) {
long sqe = getSqe();
if (sqe == 0) {
@ -217,7 +223,7 @@ final class IOUringSubmissionQueue {
return true;
}
//fill the user_data which is associated with server poll link
//fill the adddress which is associated with server poll link user_data
public boolean addPollRemove(int fd) {
long sqe = getSqe();
if (sqe == 0) {
@ -228,6 +234,16 @@ final class IOUringSubmissionQueue {
return true;
}
public boolean addConnect(int fd, long socketAddress, long socketAddressLength) {
long sqe = getSqe();
if (sqe == 0) {
return false;
}
setData(sqe, (byte) IOUring.OP_CONNECT, 0, fd, socketAddress, 0, socketAddressLength);
return true;
}
private int flushSqe() {
long kTail = toUnsignedLong(PlatformDependent.getInt(kTailAddress));
long kHead = toUnsignedLong(PlatformDependent.getIntVolatile(kHeadAddress));

View File

@ -67,6 +67,10 @@ final class LinuxSocket extends Socket {
setInterface(intValue(), ipv6, nativeAddress.address(), nativeAddress.scopeId(), interfaceIndex(netInterface));
}
public int initAddress(byte[] address, int scopeId, int port, long remoteAddressMemoryAddress) {
return Native.initAddress(intValue(), ipv6, address, scopeId, port, remoteAddressMemoryAddress);
}
InetAddress getInterface() throws IOException {
NetworkInterface inf = getNetworkInterface();
if (inf != null) {

View File

@ -67,6 +67,8 @@ final class Native {
Socket.initialize();
}
public static RingBuffer createRingBuffer(int ringSize) {
//Todo throw Exception if it's null
return ioUringSetup(ringSize);
@ -95,6 +97,8 @@ final class Native {
public static native void ioUringExit(RingBuffer ringBuffer);
public static native int initAddress(int fd, boolean ipv6, byte[] address, int scopeId, int port, long remoteMemoryAddress);
private static native int eventFd();
// for testing(it is only temporary)

View File

@ -100,14 +100,14 @@ public class IOUringSocketTestPermutation extends SocketTestPermutation {
new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class);
return new Bootstrap().group(IO_URING_WORKER_GROUP).channel(IOUringSocketChannel.class);
//.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 100000);
}
},
new BootstrapFactory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class);
return new Bootstrap().group(IO_URING_WORKER_GROUP).channel(IOUringSocketChannel.class);
// .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 100000);
}
}

View File

@ -31,7 +31,7 @@ import io.netty.buffer.ByteBuf;
public class NativeTest {
@Test
public void canWriteFile() {
public void canWriteFile() throws Exception {
ByteBufAllocator allocator = new UnpooledByteBufAllocator(true);
final ByteBuf writeEventByteBuf = allocator.directBuffer(100);
final String inputString = "Hello World!";
@ -85,7 +85,7 @@ public class NativeTest {
}
@Test
public void timeoutTest() throws InterruptedException {
public void timeoutTest() throws Exception {
RingBuffer ringBuffer = Native.createRingBuffer(32);
IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
@ -99,13 +99,17 @@ public class NativeTest {
@Override
public void run() {
assertTrue(completionQueue.ioUringWaitCqe());
completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(-62, res);
return true;
}
});
try {
completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(-62, res);
return true;
}
});
} catch (Exception e) {
e.printStackTrace();
}
}
};
thread.start();
@ -124,7 +128,7 @@ public class NativeTest {
//Todo clean
@Test
public void eventfdTest() throws IOException {
public void eventfdTest() throws Exception {
RingBuffer ringBuffer = Native.createRingBuffer(32);
IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
final IOUringCompletionQueue completionQueue = ringBuffer.getIoUringCompletionQueue();
@ -134,7 +138,7 @@ public class NativeTest {
assertNotNull(completionQueue);
final FileDescriptor eventFd = Native.newEventFd();
assertTrue(submissionQueue.addPollLink(eventFd.intValue()));
assertTrue(submissionQueue.addPollInLink(eventFd.intValue()));
submissionQueue.submit();
new Thread() {
@ -163,7 +167,7 @@ public class NativeTest {
//eventfd signal doesnt work when ioUringWaitCqe and eventFdWrite are executed in a thread
//created this test to reproduce this "weird" bug
@Test(timeout = 8000)
public void eventfdNoSignal() throws InterruptedException {
public void eventfdNoSignal() throws Exception {
RingBuffer ringBuffer = Native.createRingBuffer(32);
IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
@ -177,18 +181,22 @@ public class NativeTest {
@Override
public void run() {
assertTrue(completionQueue.ioUringWaitCqe());
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(1, res);
return true;
}
}));
try {
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(1, res);
return true;
}
}));
} catch (Exception e) {
e.printStackTrace();
}
}
};
waitingCqe.start();
final FileDescriptor eventFd = Native.newEventFd();
assertTrue(submissionQueue.addPollLink(eventFd.intValue()));
assertTrue(submissionQueue.addPollInLink(eventFd.intValue()));
submissionQueue.submit();
new Thread() {
@ -216,7 +224,7 @@ public class NativeTest {
final IOUringCompletionQueue completionQueue = ringBuffer.getIoUringCompletionQueue();
FileDescriptor eventFd = Native.newEventFd();
submissionQueue.addPollLink(eventFd.intValue());
submissionQueue.addPollInLink(eventFd.intValue());
submissionQueue.submit();
Thread.sleep(10);
@ -228,22 +236,30 @@ public class NativeTest {
@Override
public void run() {
assertTrue(completionQueue.ioUringWaitCqe());
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(IOUringEventLoop.ECANCELED, res);
assertEquals(IOUring.IO_POLL, op);
return true;
}
}));
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(0, res);
assertEquals(IOUring.OP_POLL_REMOVE, op);
return true;
}
}));
try {
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(IOUringEventLoop.ECANCELED, res);
assertEquals(IOUring.IO_POLL, op);
return true;
}
}));
} catch (Exception e) {
e.printStackTrace();
}
try {
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(0, res);
assertEquals(IOUring.OP_POLL_REMOVE, op);
return true;
}
}));
} catch (Exception e) {
e.printStackTrace();
}
}
};
waitingCqe.start();

View File

@ -0,0 +1,78 @@
package io.netty.channel.uring;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import org.junit.Test;
public class PollRemoveTest {
@Sharable
class EchoUringServerHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
ctx.write(msg);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
// Close the connection when an exception is raised.
cause.printStackTrace();
ctx.close();
}
}
void io_uring_test() throws Exception {
Class clazz;
final EventLoopGroup bossGroup = new IOUringEventLoopGroup(1);
final EventLoopGroup workerGroup = new IOUringEventLoopGroup(1);
clazz = IOUringServerSocketChannel.class;
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(clazz)
.handler(new LoggingHandler(LogLevel.TRACE))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new EchoUringServerHandler());
}
});
Channel sc = b.bind(2020).sync().channel();
Thread.sleep(15000);
//close ServerChannel
sc.close().sync();
}
@Test
public void test() throws Exception {
io_uring_test();
System.out.println("io_uring --------------------------------");
io_uring_test();
}
}