diff --git a/transport-native-io_uring/src/main/c/netty_io_uring_native.c b/transport-native-io_uring/src/main/c/netty_io_uring_native.c index be98778a2f..1244224c5e 100644 --- a/transport-native-io_uring/src/main/c/netty_io_uring_native.c +++ b/transport-native-io_uring/src/main/c/netty_io_uring_native.c @@ -211,6 +211,40 @@ static void netty_io_uring_eventFdRead(JNIEnv* env, jclass clazz, jint fd) { } } +static jint netty_io_uring_initAddress(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jbyteArray address, jint scopeId, jint port, jlong remoteMemoryAddress) { + struct sockaddr_storage addr; + socklen_t addrSize; + if (netty_unix_socket_initSockaddr(env, ipv6, address, scopeId, port, &addr, &addrSize) == -1) { + // A runtime exception was thrown + return -1; + } + + memcpy((void *) remoteMemoryAddress, &addr, sizeof(struct sockaddr_storage)); + + return addrSize; +} + +//static jint netty_unix_socket_connect(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jbyteArray address, jint scopeId, jint port) { +// struct sockaddr_storage addr; +// socklen_t addrSize; +// if (netty_unix_socket_initSockaddr(env, ipv6, address, scopeId, port, &addr, &addrSize) == -1) { +// // A runtime exception was thrown +// return -1; +// } +// +// int res; +// int err; +// do { +// res = connect(fd, (struct sockaddr*) &addr, addrSize); +// } while (res == -1 && ((err = errno) == EINTR)); +// +// if (res < 0) { +// return -err; +// } +// return 0; +//} + + static void netty_io_uring_eventFdWrite(JNIEnv* env, jclass clazz, jint fd, jlong value) { uint64_t val; @@ -338,7 +372,8 @@ static const JNINativeMethod method_table[] = { {"eventFdWrite", "(IJ)V", (void *) netty_io_uring_eventFdWrite }, {"eventFdRead", "(I)V", (void *) netty_io_uring_eventFdRead }, {"ioUringRegisterEventFd", "(II)I", (void *) netty_io_uring_register_event_fd}, - {"ioUringUnregisterEventFd", "(I)I", (void *) netty_io_uring_unregister_event_fd} + {"ioUringUnregisterEventFd", "(I)I", (void *) netty_io_uring_unregister_event_fd}, + {"initAddress", "(IZ[BIIJ)I", (void *) netty_io_uring_initAddress } }; static const jint method_table_size = sizeof(method_table) / sizeof(method_table[0]); diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringChannel.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringChannel.java index 848bae2685..7177e978e3 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringChannel.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringChannel.java @@ -22,32 +22,47 @@ import io.netty.buffer.Unpooled; import io.netty.channel.AbstractChannel; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelMetadata; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import io.netty.channel.ConnectTimeoutException; import io.netty.channel.DefaultChannelConfig; import io.netty.channel.EventLoop; import io.netty.channel.RecvByteBufAllocator; import io.netty.channel.socket.ChannelInputShutdownEvent; import io.netty.channel.socket.ChannelInputShutdownReadComplete; import io.netty.channel.socket.SocketChannelConfig; +import io.netty.channel.unix.Buffer; import io.netty.channel.unix.FileDescriptor; +import io.netty.channel.unix.NativeInetAddress; import io.netty.channel.unix.Socket; import io.netty.channel.unix.UnixChannel; import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.io.FileNotFoundException; import java.io.IOException; +import java.net.ConnectException; import java.net.InetSocketAddress; +import java.net.NoRouteToHostException; import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.AlreadyConnectedException; import java.nio.channels.ClosedChannelException; +import java.nio.channels.ConnectionPendingException; import java.nio.channels.NotYetConnectedException; import java.nio.channels.UnresolvedAddressException; import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import static io.netty.channel.unix.Errors.*; +import static io.netty.channel.unix.UnixChannelUtil.*; import static io.netty.util.internal.ObjectUtil.*; abstract class AbstractIOUringChannel extends AbstractChannel implements UnixChannel { @@ -58,6 +73,8 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha boolean uringInReadyPending; boolean inputClosedSeenErrorOnRead; + static final int SOCK_ADDR_LEN = 128; + //can only submit one write operation at a time private boolean writeable = true; /** @@ -67,6 +84,9 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha private ScheduledFuture connectTimeoutFuture; private SocketAddress requestedRemoteAddress; + private final ByteBuffer remoteAddressMemory; + private final long remoteAddressMemoryAddress; + private volatile SocketAddress local; private volatile SocketAddress remote; @@ -88,9 +108,12 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha } else { logger.trace("Create Server Socket: {}", socket.intValue()); } + + remoteAddressMemory = Buffer.allocateDirectWithNativeOrder(SOCK_ADDR_LEN); + remoteAddressMemoryAddress = Buffer.memoryAddress(remoteAddressMemory); } - protected AbstractIOUringChannel(final Channel parent, LinuxSocket socket, boolean active) { + AbstractIOUringChannel(final Channel parent, LinuxSocket socket, boolean active) { super(parent); this.socket = checkNotNull(socket, "fd"); this.active = active; @@ -105,6 +128,22 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha } else { logger.trace("Create Server Socket: {}", socket.intValue()); } + remoteAddressMemory = Buffer.allocateDirectWithNativeOrder(SOCK_ADDR_LEN); + remoteAddressMemoryAddress = Buffer.memoryAddress(remoteAddressMemory); + } + + AbstractIOUringChannel(Channel parent, LinuxSocket fd, SocketAddress remote) { + super(parent); + this.socket = checkNotNull(fd, "fd"); + this.active = true; + + // Directly cache the remote and local addresses + // See https://github.com/netty/netty/issues/2359 + this.remote = remote; + this.local = fd.localAddress(); + + remoteAddressMemory = Buffer.allocateDirectWithNativeOrder(SOCK_ADDR_LEN); + remoteAddressMemoryAddress = Buffer.memoryAddress(remoteAddressMemory); } public boolean isOpen() { @@ -333,7 +372,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha // Channel/ChannelHandlerContext.read() was called @Override protected void doBeginRead() { - logger.trace("Begin Read"); + System.out.println("Begin Read"); final AbstractUringUnsafe unsafe = (AbstractUringUnsafe) unsafe(); if (!uringInReadyPending) { unsafe.executeUringReadOperator(); @@ -374,7 +413,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha private void addPollOut() { IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop(); IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue(); - submissionQueue.addPollOut(socket.intValue()); + submissionQueue.addPollOutLink(socket.intValue()); submissionQueue.submit(); } @@ -388,14 +427,21 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha } }; - IOUringRecvByteAllocatorHandle newIOUringHandle(RecvByteBufAllocator.ExtendedHandle handle) { - return new IOUringRecvByteAllocatorHandle(handle); + public void fulfillConnectPromise(ChannelPromise promise, Throwable t, SocketAddress remoteAddress) { + Throwable cause = annotateConnectException(t, remoteAddress); + if (promise == null) { + // Closed via cancellation and the promise has been notified already. + return; + } + + // Use tryFailure() instead of setFailure() to avoid the race against cancel(). + promise.tryFailure(cause); + closeIfClosed(); } - @Override - public void connect(final SocketAddress remoteAddress, final SocketAddress localAddress, - final ChannelPromise promise) { - promise.setFailure(new UnsupportedOperationException()); + + IOUringRecvByteAllocatorHandle newIOUringHandle(RecvByteBufAllocator.ExtendedHandle handle) { + return new IOUringRecvByteAllocatorHandle(handle); } @Override @@ -415,6 +461,91 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha } public abstract void uringEventExecution(); + + @Override + public void connect( + final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) { + if (!promise.setUncancellable() || !ensureOpen(promise)) { + return; + } + + try { + if (connectPromise != null) { + throw new ConnectionPendingException(); + } + + doConnect(remoteAddress, localAddress); + InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress; + NativeInetAddress address = NativeInetAddress.newInstance(inetSocketAddress.getAddress()); + socket.initAddress(address.address(), address.scopeId(), inetSocketAddress.getPort(),remoteAddressMemoryAddress); + IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop(); + final IOUringSubmissionQueue ioUringSubmissionQueue = + ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue(); + ioUringSubmissionQueue.addConnect(socket.intValue(), remoteAddressMemoryAddress, SOCK_ADDR_LEN); + ioUringSubmissionQueue.submit(); + + } catch (Throwable t) { + closeIfClosed(); + promise.tryFailure(annotateConnectException(t, remoteAddress)); + return; + } + connectPromise = promise; + requestedRemoteAddress = remoteAddress; + // Schedule connect timeout. + int connectTimeoutMillis = config().getConnectTimeoutMillis(); + if (connectTimeoutMillis > 0) { + connectTimeoutFuture = eventLoop().schedule(new Runnable() { + @Override + public void run() { + ChannelPromise connectPromise = AbstractIOUringChannel.this.connectPromise; + ConnectTimeoutException cause = + new ConnectTimeoutException("connection timed out: " + remoteAddress); + if (connectPromise != null && connectPromise.tryFailure(cause)) { + close(voidPromise()); + } + } + }, connectTimeoutMillis, TimeUnit.MILLISECONDS); + } + + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isCancelled()) { + if (connectTimeoutFuture != null) { + connectTimeoutFuture.cancel(false); + } + connectPromise = null; + close(voidPromise()); + } + } + }); + } + } + + public void fulfillConnectPromise(ChannelPromise promise, boolean wasActive) { + if (promise == null) { + // Closed via cancellation and the promise has been notified already. + return; + } + active = true; + + // Get the state as trySuccess() may trigger an ChannelFutureListener that will close the Channel. + // We still need to ensure we call fireChannelActive() in this case. + boolean active = isActive(); + + // trySuccess() will return false if a user cancelled the connection attempt. + boolean promiseSet = promise.trySuccess(); + + // Regardless if the connection attempt was cancelled, channelActive() event should be triggered, + // because what happened is what happened. + if (!wasActive && active) { + pipeline().fireChannelActive(); + } + + // If a user cancelled the connection attempt, close the channel, which is followed by channelInactive(). + if (!promiseSet) { + close(voidPromise()); + } } @Override @@ -473,6 +604,53 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha return socket; } + /** + * Connect to the remote peer + */ + protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + if (localAddress instanceof InetSocketAddress) { + checkResolvable((InetSocketAddress) localAddress); + } + + InetSocketAddress remoteSocketAddr = remoteAddress instanceof InetSocketAddress + ? (InetSocketAddress) remoteAddress : null; + if (remoteSocketAddr != null) { + checkResolvable(remoteSocketAddr); + } + + if (remote != null) { + // Check if already connected before trying to connect. This is needed as connect(...) will not return -1 + // and set errno to EISCONN if a previous connect(...) attempt was setting errno to EINPROGRESS and finished + // later. + throw new AlreadyConnectedException(); + } + + if (localAddress != null) { + socket.bind(localAddress); + } + } + +// public void setRemote() { +// remote = remoteSocketAddr == null ? +// remoteAddress : computeRemoteAddr(remoteSocketAddr, socket.remoteAddress()); +// } + + private boolean doConnect0(SocketAddress remote) throws Exception { + boolean success = false; + try { + boolean connected = socket.connect(remote); + if (!connected) { + //setFlag(Native.EPOLLOUT); + } + success = true; + return connected; + } finally { + if (!success) { + doClose(); + } + } + } + void shutdownInput(boolean rdHup) { logger.trace("shutdownInput Fd: {}", this.socket.intValue()); if (!socket.isInputShutdown()) { @@ -498,6 +676,26 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha } } + + //Todo we should move it to a error class + // copy unix Errors + static void throwConnectException(String method, int err) + throws IOException { + if (err == ERROR_EALREADY_NEGATIVE) { + throw new ConnectionPendingException(); + } + if (err == ERROR_ENETUNREACH_NEGATIVE) { + throw new NoRouteToHostException(); + } + if (err == ERROR_EISCONN_NEGATIVE) { + throw new AlreadyConnectedException(); + } + if (err == ERRNO_ENOENT_NEGATIVE) { + throw new FileNotFoundException(); + } + throw new ConnectException(method + "(..) failed: "); + } + private static boolean isAllowHalfClosure(ChannelConfig config) { return config instanceof SocketChannelConfig && ((SocketChannelConfig) config).isAllowHalfClosure(); @@ -508,6 +706,31 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha close(voidPromise()); } + void cancelTimeoutFuture() { + if (connectTimeoutFuture != null) { + connectTimeoutFuture.cancel(false); + } + connectPromise = null; + } + + boolean doFinishConnect() throws Exception { + if (socket.finishConnect()) { + if (requestedRemoteAddress instanceof InetSocketAddress) { + remote = computeRemoteAddr((InetSocketAddress) requestedRemoteAddress, socket.remoteAddress()); + } + requestedRemoteAddress = null; + + return true; + } + return false; + } + + void computeRemote() { + if (requestedRemoteAddress instanceof InetSocketAddress) { + remote = computeRemoteAddr((InetSocketAddress) requestedRemoteAddress, socket.remoteAddress()); + } + } + final boolean shouldBreakIoUringInReady(ChannelConfig config) { return socket.isInputShutdown() && (inputClosedSeenErrorOnRead || !isAllowHalfClosure(config)); } @@ -515,4 +738,20 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha public void setWriteable(boolean writeable) { this.writeable = writeable; } + + public long getRemoteAddressMemoryAddress() { + return remoteAddressMemoryAddress; + } + + public ChannelPromise getConnectPromise() { + return connectPromise; + } + + public ScheduledFuture getConnectTimeoutFuture() { + return connectTimeoutFuture; + } + + public SocketAddress getRequestedRemoteAddress() { + return requestedRemoteAddress; + } } diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringServerChannel.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringServerChannel.java index cb084b17c8..e1063e1acb 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringServerChannel.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringServerChannel.java @@ -87,7 +87,7 @@ abstract class AbstractIOUringServerChannel extends AbstractIOUringChannel imple public void uringEventExecution() { final IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop(); IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue(); - submissionQueue.addPollLink(socket.intValue()); + submissionQueue.addPollInLink(socket.intValue()); //Todo get network addresses submissionQueue.addAccept(fd().intValue()); diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringStreamChannel.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringStreamChannel.java index d13c349bd5..6f19edb477 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringStreamChannel.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/AbstractIOUringStreamChannel.java @@ -30,6 +30,7 @@ import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.net.SocketAddress; import java.util.concurrent.Executor; abstract class AbstractIOUringStreamChannel extends AbstractIOUringChannel implements DuplexChannel { @@ -43,6 +44,11 @@ abstract class AbstractIOUringStreamChannel extends AbstractIOUringChannel imple super(parent, socket, active); } + AbstractIOUringStreamChannel(Channel parent, LinuxSocket fd, SocketAddress remote) { + super(parent, fd, remote); + // Add EPOLLRDHUP so we are notified once the remote peer close the connection. + } + @Override public ChannelFuture shutdown() { System.out.println("AbstractStreamChannel shutdown"); diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUring.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUring.java index 27acef68d9..08bd9069e6 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUring.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUring.java @@ -27,10 +27,12 @@ final class IOUring { static final int OP_READ = 22; static final int OP_WRITE = 23; static final int OP_POLL_REMOVE = 7; + static final int OP_CONNECT = 16; - static final int POLLMASK_LINK = 1; - static final int POLLMASK_OUT = 4; + static final int POLLMASK_IN_LINK = 1; + static final int POLLMASK_OUT_LINK = 4; static final int POLLMASK_RDHUP = 8192; + static final int POLLMASK_OUT = 4; static { Throwable cause = null; diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringCompletionQueue.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringCompletionQueue.java index 4877637ae3..23111e0f2f 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringCompletionQueue.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringCompletionQueue.java @@ -55,7 +55,7 @@ final class IOUringCompletionQueue { this.ringFd = ringFd; } - public int process(IOUringCompletionQueueCallback callback) { + public int process(IOUringCompletionQueueCallback callback) throws Exception { int i = 0; for (;;) { long head = toUnsignedLong(PlatformDependent.getIntVolatile(kHeadAddress)); diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringEventLoop.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringEventLoop.java index 0839141e22..b2bd621402 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringEventLoop.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringEventLoop.java @@ -18,6 +18,7 @@ package io.netty.channel.uring; import io.netty.channel.EventLoopGroup; import io.netty.channel.SingleThreadEventLoop; import io.netty.channel.unix.FileDescriptor; +import io.netty.channel.uring.AbstractIOUringChannel.AbstractUringUnsafe; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; import io.netty.util.internal.PlatformDependent; @@ -32,7 +33,7 @@ import java.util.concurrent.atomic.AtomicLong; import static io.netty.channel.unix.Errors.*; final class IOUringEventLoop extends SingleThreadEventLoop implements - IOUringCompletionQueue.IOUringCompletionQueueCallback { + IOUringCompletionQueue.IOUringCompletionQueueCallback { private static final InternalLogger logger = InternalLoggerFactory.getInstance(IOUringEventLoop.class); //Todo set config ring buffer size @@ -75,7 +76,7 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements private static Queue newTaskQueue0(int maxPendingTasks) { // This event loop never calls takeTask() - return maxPendingTasks == Integer.MAX_VALUE ? PlatformDependent.newMpscQueue() + return maxPendingTasks == Integer.MAX_VALUE? PlatformDependent.newMpscQueue() : PlatformDependent.newMpscQueue(maxPendingTasks); } @@ -117,10 +118,10 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements final IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue(); // Lets add the eventfd related events before starting to do any real work. - submissionQueue.addPollLink(eventfd.intValue()); + submissionQueue.addPollInLink(eventfd.intValue()); submissionQueue.submit(); - for (;;) { + for (; ; ) { logger.trace("Run IOUringEventLoop {}", this.toString()); long curDeadlineNanos = nextScheduledTaskDeadlineNanos(); if (curDeadlineNanos == -1L) { @@ -144,6 +145,8 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements logger.trace("ioUringWaitCqe {}", this.toString()); completionQueue.ioUringWaitCqe(); } + } catch (Throwable t) { + //Todo handle exception } finally { if (nextWakeupNanos.get() == AWAKE || nextWakeupNanos.getAndSet(AWAKE) == AWAKE) { pendingWakeup = true; @@ -151,7 +154,11 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements } } - completionQueue.process(this); + try { + completionQueue.process(this); + } catch (Exception e) { + //Todo handle exception + } if (hasTasks()) { runAllTasks(); @@ -174,90 +181,140 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements public boolean handle(int fd, int res, long flags, int op, int pollMask) { IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue(); switch (op) { - case IOUring.OP_ACCEPT: - //Todo error handle the res - if (res == ECANCELED) { - logger.trace("POLL_LINK canceled"); - break; - } - AbstractIOUringServerChannel acceptChannel = (AbstractIOUringServerChannel) channels.get(fd); - if (acceptChannel == null) { - break; - } - logger.trace("EventLoop Accept filedescriptor: {}", res); - acceptChannel.setUringInReadyPending(false); - if (res != -1 && res != ERRNO_EAGAIN_NEGATIVE && - res != ERRNO_EWOULDBLOCK_NEGATIVE) { - logger.trace("server filedescriptor Fd: {}", fd); - if (acceptChannel.acceptComplete(res)) { - // all childChannels should poll POLLRDHUP - submissionQueue.addPollRdHup(res); - submissionQueue.submit(); - } - } + case IOUring.OP_ACCEPT: + //Todo error handle the res + if (res == ECANCELED) { + logger.trace("POLL_LINK canceled"); break; - case IOUring.OP_READ: - AbstractIOUringChannel readChannel = channels.get(fd); - if (readChannel == null) { - break; - } - readChannel.readComplete(res); + } + AbstractIOUringServerChannel acceptChannel = (AbstractIOUringServerChannel) channels.get(fd); + if (acceptChannel == null) { break; - case IOUring.OP_WRITE: - AbstractIOUringChannel writeChannel = channels.get(fd); - if (writeChannel == null) { - break; - } - //localFlushAmount -> res - logger.trace("EventLoop Write Res: {}", res); - logger.trace("EventLoop Fd: {}", fd); - - if (res == SOCKET_ERROR_EPIPE) { - writeChannel.shutdownInput(false); - } else { - writeChannel.writeComplete(res); - } - break; - case IOUring.IO_TIMEOUT: - if (res == ETIME) { - prevDeadlineNanos = NONE; - } - - break; - case IOUring.IO_POLL: - //Todo error handle the res - if (res == ECANCELED) { - logger.trace("POLL_LINK canceled"); - break; - } - if (eventfd.intValue() == fd) { - pendingWakeup = false; - // We need to consume the data as otherwise we would see another event - // in the completionQueue without - // an extra eventfd_write(....) - Native.eventFdRead(eventfd.intValue()); - submissionQueue.addPollLink(eventfd.intValue()); - // Submit so its picked up + } + logger.trace("EventLoop Accept filedescriptor: {}", res); + acceptChannel.setUringInReadyPending(false); + if (res != -1 && res != ERRNO_EAGAIN_NEGATIVE && + res != ERRNO_EWOULDBLOCK_NEGATIVE) { + logger.trace("server filedescriptor Fd: {}", fd); + if (acceptChannel.acceptComplete(res)) { + // all childChannels should poll POLLRDHUP + submissionQueue.addPollRdHup(res); submissionQueue.submit(); - } else { - if (pollMask == IOUring.POLLMASK_RDHUP) { - AbstractIOUringChannel channel = channels.get(fd); - if (channel != null && !channel.isActive()) { - channel.shutdownInput(true); - } - } else { - //Todo error handling error - logger.trace("POLL_LINK Res: {}", res); + } + } + break; + case IOUring.OP_READ: + AbstractIOUringChannel readChannel = channels.get(fd); + if (readChannel == null) { + break; + } + readChannel.readComplete(res); + break; + case IOUring.OP_WRITE: + AbstractIOUringChannel writeChannel = channels.get(fd); + if (writeChannel == null) { + break; + } + //localFlushAmount -> res + logger.trace("EventLoop Write Res: {}", res); + logger.trace("EventLoop Fd: {}", fd); + + if (res == SOCKET_ERROR_EPIPE) { + writeChannel.shutdownInput(false); + } else { + writeChannel.writeComplete(res); + } + break; + case IOUring.IO_TIMEOUT: + if (res == ETIME) { + prevDeadlineNanos = NONE; + } + + break; + case IOUring.IO_POLL: + //Todo error handle the res + if (res == ECANCELED) { + logger.trace("POLL_LINK canceled"); + break; + } + if (eventfd.intValue() == fd) { + pendingWakeup = false; + // We need to consume the data as otherwise we would see another event + // in the completionQueue without + // an extra eventfd_write(....) + Native.eventFdRead(eventfd.intValue()); + submissionQueue.addPollInLink(eventfd.intValue()); + // Submit so its picked up + submissionQueue.submit(); + } else { + if (pollMask == IOUring.POLLMASK_RDHUP) { + AbstractIOUringChannel channel = channels.get(fd); + if (channel != null && !channel.isActive()) { + channel.shutdownInput(true); } + } else if (pollMask == IOUring.POLLMASK_OUT) { + //connect successful + AbstractIOUringChannel ch = channels.get(fd); + boolean wasActive = ch.isActive(); + try { + if (ch.doFinishConnect()) { + ch.fulfillConnectPromise(ch.getConnectPromise(), wasActive); + ch.cancelTimeoutFuture(); + } else { + //submit pollout + submissionQueue.addPollOut(fd); + submissionQueue.submit(); + } + } catch (Throwable t) { + AbstractUringUnsafe unsafe = (AbstractUringUnsafe) ch.unsafe(); + unsafe.fulfillConnectPromise(ch.getConnectPromise(), t, ch.getRequestedRemoteAddress()); + } + } else { + //Todo error handling error + logger.trace("POLL_LINK Res: {}", res); } + } + break; + case IOUring.OP_POLL_REMOVE: + if (res == ENOENT) { + System.out.println(("POLL_REMOVE OPERATION not permitted")); + } else if (res == 0) { + System.out.println(("POLL_REMOVE OPERATION successful")); + } + break; + case IOUring.OP_CONNECT: + AbstractIOUringChannel channel = channels.get(fd); + System.out.println("Connect res: " + res); + if (res == 0) { + channel.fulfillConnectPromise(channel.getConnectPromise(), channel.active); + channel.cancelTimeoutFuture(); + channel.computeRemote(); break; - case IOUring.OP_POLL_REMOVE: - if (res == ENOENT) { - logger.trace("POLL_REMOVE OPERATION not permitted"); - } else if (res == 0) { - logger.trace("POLL_REMOVE OPERATION successful"); + } + if (res == -1 || res == -4) { + submissionQueue.addConnect(fd, channel.getRemoteAddressMemoryAddress(), + AbstractIOUringChannel.SOCK_ADDR_LEN); + submissionQueue.submit(); + break; + } + if (res < 0) { + if (res == ERRNO_EINPROGRESS_NEGATIVE) { + // connect not complete yet need to wait for poll_out event + submissionQueue.addPollOut(fd); + submissionQueue.submit(); + break; } + try { + channel.doClose(); + } catch (Exception e) { + //Todo error handling + } + + //Todo error handling + //AbstractIOUringChannel.throwConnectException("connect", res); break; + } + break; } return true; } diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannel.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannel.java index 34a4629b3d..8a112f43dc 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannel.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannel.java @@ -33,11 +33,15 @@ import io.netty.channel.unix.Socket; import io.netty.channel.uring.AbstractIOUringStreamChannel.IOUringStreamUnsafe; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; public final class IOUringSocketChannel extends AbstractIOUringStreamChannel implements SocketChannel { private final IOUringSocketChannelConfig config; + //private volatile Collection tcpMd5SigAddresses = Collections.emptyList(); public IOUringSocketChannel() { super(null, LinuxSocket.newSocketStream(), false); @@ -49,6 +53,15 @@ public final class IOUringSocketChannel extends AbstractIOUringStreamChannel imp this.config = new IOUringSocketChannelConfig(this); } + IOUringSocketChannel(Channel parent, LinuxSocket fd, InetSocketAddress remoteAddress) { + super(parent, fd, remoteAddress); + this.config = new IOUringSocketChannelConfig(this); + +// if (parent instanceof IOUringSocketChannel) { +// tcpMd5SigAddresses = ((IOUringSocketChannel) parent).tcpMd5SigAddresses(); +// } + } + @Override public ServerSocketChannel parent() { return (ServerSocketChannel) super.parent(); diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannelConfig.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannelConfig.java index 7ad44f3f5c..0a566d8a6b 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannelConfig.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSocketChannelConfig.java @@ -24,6 +24,7 @@ import io.netty.channel.MessageSizeEstimator; import io.netty.channel.RecvByteBufAllocator; import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.socket.SocketChannelConfig; +import io.netty.util.internal.PlatformDependent; import java.io.IOException; import java.net.InetAddress; @@ -39,6 +40,11 @@ public class IOUringSocketChannelConfig extends DefaultChannelConfig implements public IOUringSocketChannelConfig(Channel channel) { super(channel); + if (PlatformDependent.canEnableTcpNoDelayByDefault()) { + setTcpNoDelay(true); + } + calculateMaxBytesPerGatheringWrite(); + } @Override diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSubmissionQueue.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSubmissionQueue.java index 8ac8119a6b..b5fd524cf2 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSubmissionQueue.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/IOUringSubmissionQueue.java @@ -119,7 +119,7 @@ final class IOUringSubmissionQueue { //user_data should be same as POLL_LINK fd if (op == IOUring.OP_POLL_REMOVE) { PlatformDependent.putInt(sqe + SQE_FD_FIELD, -1); - long pollLinkuData = convertToUserData((byte) IOUring.IO_POLL, fd, IOUring.POLLMASK_LINK); + long pollLinkuData = convertToUserData((byte) IOUring.IO_POLL, fd, IOUring.POLLMASK_IN_LINK); PlatformDependent.putLong(sqe + SQE_ADDRESS_FIELD, pollLinkuData); } @@ -127,7 +127,7 @@ final class IOUringSubmissionQueue { PlatformDependent.putLong(sqe + SQE_USER_DATA_FIELD, uData); //pollread or accept operation - if (op == 6 && (pollMask == IOUring.POLLMASK_OUT || pollMask == IOUring.POLLMASK_LINK)) { + if (op == 6 && (pollMask == IOUring.POLLMASK_OUT_LINK || pollMask == IOUring.POLLMASK_IN_LINK)) { PlatformDependent.putByte(sqe + SQE_FLAGS_FIELD, (byte) IOSQE_IO_LINK); } else { PlatformDependent.putByte(sqe + SQE_FLAGS_FIELD, (byte) 0); @@ -152,6 +152,8 @@ final class IOUringSubmissionQueue { PlatformDependent.putInt(sqe + SQE_RW_FLAGS_FIELD, pollMask); } + + logger.trace("UserDataField: {}", PlatformDependent.getLong(sqe + SQE_USER_DATA_FIELD)); logger.trace("BufferAddress: {}", PlatformDependent.getLong(sqe + SQE_ADDRESS_FIELD)); logger.trace("Length: {}", PlatformDependent.getInt(sqe + SQE_LEN_FIELD)); @@ -168,18 +170,22 @@ final class IOUringSubmissionQueue { return true; } - public boolean addPollLink(int fd) { - return addPoll(fd, IOUring.POLLMASK_LINK); + public boolean addPollInLink(int fd) { + return addPoll(fd, IOUring.POLLMASK_IN_LINK); } - public boolean addPollOut(int fd) { - return addPoll(fd, IOUring.POLLMASK_OUT); + public boolean addPollOutLink(int fd) { + return addPoll(fd, IOUring.POLLMASK_OUT_LINK); } public boolean addPollRdHup(int fd) { return addPoll(fd, IOUring.POLLMASK_RDHUP); } + public boolean addPollOut(int fd) { + return addPoll(fd, IOUring.POLLMASK_OUT); + } + private boolean addPoll(int fd, int pollMask) { long sqe = getSqe(); if (sqe == 0) { @@ -217,7 +223,7 @@ final class IOUringSubmissionQueue { return true; } - //fill the user_data which is associated with server poll link + //fill the adddress which is associated with server poll link user_data public boolean addPollRemove(int fd) { long sqe = getSqe(); if (sqe == 0) { @@ -228,6 +234,16 @@ final class IOUringSubmissionQueue { return true; } + public boolean addConnect(int fd, long socketAddress, long socketAddressLength) { + long sqe = getSqe(); + if (sqe == 0) { + return false; + } + setData(sqe, (byte) IOUring.OP_CONNECT, 0, fd, socketAddress, 0, socketAddressLength); + + return true; + } + private int flushSqe() { long kTail = toUnsignedLong(PlatformDependent.getInt(kTailAddress)); long kHead = toUnsignedLong(PlatformDependent.getIntVolatile(kHeadAddress)); diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/LinuxSocket.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/LinuxSocket.java index da79e73fd8..38504e2172 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/LinuxSocket.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/LinuxSocket.java @@ -67,6 +67,10 @@ final class LinuxSocket extends Socket { setInterface(intValue(), ipv6, nativeAddress.address(), nativeAddress.scopeId(), interfaceIndex(netInterface)); } + public int initAddress(byte[] address, int scopeId, int port, long remoteAddressMemoryAddress) { + return Native.initAddress(intValue(), ipv6, address, scopeId, port, remoteAddressMemoryAddress); + } + InetAddress getInterface() throws IOException { NetworkInterface inf = getNetworkInterface(); if (inf != null) { diff --git a/transport-native-io_uring/src/main/java/io/netty/channel/uring/Native.java b/transport-native-io_uring/src/main/java/io/netty/channel/uring/Native.java index e8a79b24a9..caf389bb6d 100644 --- a/transport-native-io_uring/src/main/java/io/netty/channel/uring/Native.java +++ b/transport-native-io_uring/src/main/java/io/netty/channel/uring/Native.java @@ -67,6 +67,8 @@ final class Native { Socket.initialize(); } + + public static RingBuffer createRingBuffer(int ringSize) { //Todo throw Exception if it's null return ioUringSetup(ringSize); @@ -95,6 +97,8 @@ final class Native { public static native void ioUringExit(RingBuffer ringBuffer); + public static native int initAddress(int fd, boolean ipv6, byte[] address, int scopeId, int port, long remoteMemoryAddress); + private static native int eventFd(); // for testing(it is only temporary) diff --git a/transport-native-io_uring/src/test/java/io/netty/channel/uring/IOUringSocketTestPermutation.java b/transport-native-io_uring/src/test/java/io/netty/channel/uring/IOUringSocketTestPermutation.java index 2a7e9a62b8..46c612a4c9 100644 --- a/transport-native-io_uring/src/test/java/io/netty/channel/uring/IOUringSocketTestPermutation.java +++ b/transport-native-io_uring/src/test/java/io/netty/channel/uring/IOUringSocketTestPermutation.java @@ -100,14 +100,14 @@ public class IOUringSocketTestPermutation extends SocketTestPermutation { new BootstrapFactory() { @Override public Bootstrap newInstance() { - return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class); + return new Bootstrap().group(IO_URING_WORKER_GROUP).channel(IOUringSocketChannel.class); //.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 100000); } }, new BootstrapFactory() { @Override public Bootstrap newInstance() { - return new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class); + return new Bootstrap().group(IO_URING_WORKER_GROUP).channel(IOUringSocketChannel.class); // .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 100000); } } diff --git a/transport-native-io_uring/src/test/java/io/netty/channel/uring/NativeTest.java b/transport-native-io_uring/src/test/java/io/netty/channel/uring/NativeTest.java index 5b137b4c82..f4816a6008 100644 --- a/transport-native-io_uring/src/test/java/io/netty/channel/uring/NativeTest.java +++ b/transport-native-io_uring/src/test/java/io/netty/channel/uring/NativeTest.java @@ -31,7 +31,7 @@ import io.netty.buffer.ByteBuf; public class NativeTest { @Test - public void canWriteFile() { + public void canWriteFile() throws Exception { ByteBufAllocator allocator = new UnpooledByteBufAllocator(true); final ByteBuf writeEventByteBuf = allocator.directBuffer(100); final String inputString = "Hello World!"; @@ -85,7 +85,7 @@ public class NativeTest { } @Test - public void timeoutTest() throws InterruptedException { + public void timeoutTest() throws Exception { RingBuffer ringBuffer = Native.createRingBuffer(32); IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue(); @@ -99,13 +99,17 @@ public class NativeTest { @Override public void run() { assertTrue(completionQueue.ioUringWaitCqe()); - completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { - @Override - public boolean handle(int fd, int res, long flags, int op, int mask) { - assertEquals(-62, res); - return true; - } - }); + try { + completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { + @Override + public boolean handle(int fd, int res, long flags, int op, int mask) { + assertEquals(-62, res); + return true; + } + }); + } catch (Exception e) { + e.printStackTrace(); + } } }; thread.start(); @@ -124,7 +128,7 @@ public class NativeTest { //Todo clean @Test - public void eventfdTest() throws IOException { + public void eventfdTest() throws Exception { RingBuffer ringBuffer = Native.createRingBuffer(32); IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue(); final IOUringCompletionQueue completionQueue = ringBuffer.getIoUringCompletionQueue(); @@ -134,7 +138,7 @@ public class NativeTest { assertNotNull(completionQueue); final FileDescriptor eventFd = Native.newEventFd(); - assertTrue(submissionQueue.addPollLink(eventFd.intValue())); + assertTrue(submissionQueue.addPollInLink(eventFd.intValue())); submissionQueue.submit(); new Thread() { @@ -163,7 +167,7 @@ public class NativeTest { //eventfd signal doesnt work when ioUringWaitCqe and eventFdWrite are executed in a thread //created this test to reproduce this "weird" bug @Test(timeout = 8000) - public void eventfdNoSignal() throws InterruptedException { + public void eventfdNoSignal() throws Exception { RingBuffer ringBuffer = Native.createRingBuffer(32); IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue(); @@ -177,18 +181,22 @@ public class NativeTest { @Override public void run() { assertTrue(completionQueue.ioUringWaitCqe()); - assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { - @Override - public boolean handle(int fd, int res, long flags, int op, int mask) { - assertEquals(1, res); - return true; - } - })); + try { + assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { + @Override + public boolean handle(int fd, int res, long flags, int op, int mask) { + assertEquals(1, res); + return true; + } + })); + } catch (Exception e) { + e.printStackTrace(); + } } }; waitingCqe.start(); final FileDescriptor eventFd = Native.newEventFd(); - assertTrue(submissionQueue.addPollLink(eventFd.intValue())); + assertTrue(submissionQueue.addPollInLink(eventFd.intValue())); submissionQueue.submit(); new Thread() { @@ -216,7 +224,7 @@ public class NativeTest { final IOUringCompletionQueue completionQueue = ringBuffer.getIoUringCompletionQueue(); FileDescriptor eventFd = Native.newEventFd(); - submissionQueue.addPollLink(eventFd.intValue()); + submissionQueue.addPollInLink(eventFd.intValue()); submissionQueue.submit(); Thread.sleep(10); @@ -228,22 +236,30 @@ public class NativeTest { @Override public void run() { assertTrue(completionQueue.ioUringWaitCqe()); - assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { - @Override - public boolean handle(int fd, int res, long flags, int op, int mask) { - assertEquals(IOUringEventLoop.ECANCELED, res); - assertEquals(IOUring.IO_POLL, op); - return true; - } - })); - assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { - @Override - public boolean handle(int fd, int res, long flags, int op, int mask) { - assertEquals(0, res); - assertEquals(IOUring.OP_POLL_REMOVE, op); - return true; - } - })); + try { + assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { + @Override + public boolean handle(int fd, int res, long flags, int op, int mask) { + assertEquals(IOUringEventLoop.ECANCELED, res); + assertEquals(IOUring.IO_POLL, op); + return true; + } + })); + } catch (Exception e) { + e.printStackTrace(); + } + try { + assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { + @Override + public boolean handle(int fd, int res, long flags, int op, int mask) { + assertEquals(0, res); + assertEquals(IOUring.OP_POLL_REMOVE, op); + return true; + } + })); + } catch (Exception e) { + e.printStackTrace(); + } } }; waitingCqe.start(); diff --git a/transport-native-io_uring/src/test/java/io/netty/channel/uring/PollRemoveTest.java b/transport-native-io_uring/src/test/java/io/netty/channel/uring/PollRemoveTest.java new file mode 100644 index 0000000000..6aad1ef361 --- /dev/null +++ b/transport-native-io_uring/src/test/java/io/netty/channel/uring/PollRemoveTest.java @@ -0,0 +1,78 @@ +package io.netty.channel.uring; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import org.junit.Test; + +public class PollRemoveTest { + + @Sharable + class EchoUringServerHandler extends ChannelInboundHandlerAdapter { + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.write(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + // Close the connection when an exception is raised. + cause.printStackTrace(); + ctx.close(); + } + + } + + void io_uring_test() throws Exception { + Class clazz; + final EventLoopGroup bossGroup = new IOUringEventLoopGroup(1); + final EventLoopGroup workerGroup = new IOUringEventLoopGroup(1); + clazz = IOUringServerSocketChannel.class; + + ServerBootstrap b = new ServerBootstrap(); + b.group(bossGroup, workerGroup) + .channel(clazz) + .handler(new LoggingHandler(LogLevel.TRACE)) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + + p.addLast(new EchoUringServerHandler()); + } + }); + + Channel sc = b.bind(2020).sync().channel(); + Thread.sleep(15000); + + //close ServerChannel + sc.close().sync(); + + } + + @Test + public void test() throws Exception { + + io_uring_test(); + + System.out.println("io_uring --------------------------------"); + + io_uring_test(); + + } +} +