Close IOUringEventLoop wakeup/eventfd close shutdown race (#10615)

Motivation

There is a race condition when shutting down the event loop where the
eventFd write performed in the wakeup() method may actually hit a
different fd if it's closed and reassigned in the meantime.

This was already encountered and addressed in the epoll case.

Modifications

Similar to what's done for epoll, in IOUringEventLoop:
- Reinstate pendingWakeup flag which tracks when there is a wakeup
pending (CAS of nextWakeupNanos performed by other thread in the
wakeup() method)
- Add logic to the cleanup() method to wait for corresponding READ CQE
before closing the eventFd
- Remove unused fields from IOUringCompletionQueue (cleanup)

Result

No event loop shutdown race
This commit is contained in:
Nick Hill 2020-09-27 23:47:23 -07:00 committed by GitHub
parent db73538737
commit f6c84541be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 26 deletions

View File

@ -33,9 +33,6 @@ final class IOUringCompletionQueue {
//these unsigned integer pointers(shared with the kernel) will be changed by the kernel //these unsigned integer pointers(shared with the kernel) will be changed by the kernel
private final long kHeadAddress; private final long kHeadAddress;
private final long kTailAddress; private final long kTailAddress;
private final long kRingMaskAddress;
private final long kRingEntriesAddress;
private final long kOverflowAddress;
private final long completionQueueArrayAddress; private final long completionQueueArrayAddress;
@ -51,9 +48,6 @@ final class IOUringCompletionQueue {
int ringFd) { int ringFd) {
this.kHeadAddress = kHeadAddress; this.kHeadAddress = kHeadAddress;
this.kTailAddress = kTailAddress; this.kTailAddress = kTailAddress;
this.kRingMaskAddress = kRingMaskAddress;
this.kRingEntriesAddress = kRingEntriesAddress;
this.kOverflowAddress = kOverflowAddress;
this.completionQueueArrayAddress = completionQueueArrayAddress; this.completionQueueArrayAddress = completionQueueArrayAddress;
this.ringSize = ringSize; this.ringSize = ringSize;
this.ringAddress = ringAddress; this.ringAddress = ringAddress;
@ -110,16 +104,10 @@ final class IOUringCompletionQueue {
/** /**
* Block until there is at least one completion ready to be processed. * Block until there is at least one completion ready to be processed.
*/ */
boolean ioUringWaitCqe() { void ioUringWaitCqe() {
//IORING_ENTER_GETEVENTS -> wait until an event is completely processed
int ret = Native.ioUringEnter(ringFd, 0, 1, Native.IORING_ENTER_GETEVENTS); int ret = Native.ioUringEnter(ringFd, 0, 1, Native.IORING_ENTER_GETEVENTS);
if (ret < 0) { if (ret < 0) {
//Todo throw exception! throw new RuntimeException("ioUringEnter syscall returned " + ret);
return false;
} else if (ret == 0) {
return true;
} }
//Todo throw Exception!
return false;
} }
} }

View File

@ -20,6 +20,7 @@ import io.netty.channel.SingleThreadEventLoop;
import io.netty.channel.unix.Errors; import io.netty.channel.unix.Errors;
import io.netty.channel.unix.FileDescriptor; import io.netty.channel.unix.FileDescriptor;
import io.netty.channel.unix.IovArray; import io.netty.channel.unix.IovArray;
import io.netty.channel.uring.IOUringCompletionQueue.IOUringCompletionQueueCallback;
import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap; import io.netty.util.collection.IntObjectMap;
import io.netty.util.concurrent.RejectedExecutionHandler; import io.netty.util.concurrent.RejectedExecutionHandler;
@ -32,8 +33,7 @@ import java.util.Queue;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
final class IOUringEventLoop extends SingleThreadEventLoop implements final class IOUringEventLoop extends SingleThreadEventLoop implements IOUringCompletionQueueCallback {
IOUringCompletionQueue.IOUringCompletionQueueCallback {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(IOUringEventLoop.class); private static final InternalLogger logger = InternalLoggerFactory.getInstance(IOUringEventLoop.class);
private final long eventfdReadBuf = PlatformDependent.allocateMemory(8); private final long eventfdReadBuf = PlatformDependent.allocateMemory(8);
@ -57,6 +57,7 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
private final byte[] inet6AddressArray = new byte[16]; private final byte[] inet6AddressArray = new byte[16];
private long prevDeadlineNanos = NONE; private long prevDeadlineNanos = NONE;
private boolean pendingWakeup;
IOUringEventLoop(IOUringEventLoopGroup parent, Executor executor, int ringSize, boolean ioseqAsync, IOUringEventLoop(IOUringEventLoopGroup parent, Executor executor, int ringSize, boolean ioseqAsync,
RejectedExecutionHandler rejectedExecutionHandler, EventLoopTaskQueueFactory queueFactory) { RejectedExecutionHandler rejectedExecutionHandler, EventLoopTaskQueueFactory queueFactory) {
@ -167,7 +168,9 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
} }
} }
} finally { } finally {
nextWakeupNanos.set(AWAKE); if (nextWakeupNanos.get() == AWAKE || nextWakeupNanos.getAndSet(AWAKE) == AWAKE) {
pendingWakeup = true;
}
} }
} catch (Throwable t) { } catch (Throwable t) {
handleLoopException(t); handleLoopException(t);
@ -220,9 +223,8 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
@Override @Override
public void handle(int fd, int res, int flags, int op, int data) { public void handle(int fd, int res, int flags, int op, int data) {
if (op == Native.IORING_OP_READ && eventfd.intValue() == fd) { if (op == Native.IORING_OP_READ && eventfd.intValue() == fd) {
if (res != Native.ERRNO_ECANCELED_NEGATIVE) { pendingWakeup = false;
addEventFdRead(ringBuffer.ioUringSubmissionQueue()); addEventFdRead(ringBuffer.ioUringSubmissionQueue());
}
} else if (op == Native.IORING_OP_TIMEOUT) { } else if (op == Native.IORING_OP_TIMEOUT) {
if (res == Native.ERRNO_ETIME_NEGATIVE) { if (res == Native.ERRNO_ETIME_NEGATIVE) {
prevDeadlineNanos = NONE; prevDeadlineNanos = NONE;
@ -288,6 +290,25 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
@Override @Override
protected void cleanup() { protected void cleanup() {
if (pendingWakeup) {
// Another thread is in the process of writing to the eventFd. We must wait to
// receive the corresponding CQE before closing it or else the fd int may be
// reassigned by the kernel in the meantime.
IOUringCompletionQueue completionQueue = ringBuffer.ioUringCompletionQueue();
IOUringCompletionQueueCallback callback = new IOUringCompletionQueueCallback() {
@Override
public void handle(int fd, int res, int flags, int op, int data) {
if (op == Native.IORING_OP_READ && eventfd.intValue() == fd) {
pendingWakeup = false;
}
}
};
completionQueue.process(callback);
while (pendingWakeup) {
completionQueue.ioUringWaitCqe();
completionQueue.process(callback);
}
}
try { try {
eventfd.close(); eventfd.close();
} catch (IOException e) { } catch (IOException e) {

View File

@ -58,7 +58,7 @@ public class NativeTest {
writeEventByteBuf.readerIndex(), writeEventByteBuf.writerIndex())); writeEventByteBuf.readerIndex(), writeEventByteBuf.writerIndex()));
submissionQueue.submit(); submissionQueue.submit();
assertTrue(completionQueue.ioUringWaitCqe()); completionQueue.ioUringWaitCqe();
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override @Override
public void handle(int fd, int res, int flags, int op, int mask) { public void handle(int fd, int res, int flags, int op, int mask) {
@ -72,7 +72,7 @@ public class NativeTest {
readEventByteBuf.writerIndex(), readEventByteBuf.capacity())); readEventByteBuf.writerIndex(), readEventByteBuf.capacity()));
submissionQueue.submit(); submissionQueue.submit();
assertTrue(completionQueue.ioUringWaitCqe()); completionQueue.ioUringWaitCqe();
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override @Override
public void handle(int fd, int res, int flags, int op, int mask) { public void handle(int fd, int res, int flags, int op, int mask) {
@ -103,7 +103,7 @@ public class NativeTest {
Thread thread = new Thread() { Thread thread = new Thread() {
@Override @Override
public void run() { public void run() {
assertTrue(completionQueue.ioUringWaitCqe()); completionQueue.ioUringWaitCqe();
try { try {
completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override @Override
@ -152,7 +152,7 @@ public class NativeTest {
} }
}.start(); }.start();
assertTrue(completionQueue.ioUringWaitCqe()); completionQueue.ioUringWaitCqe();
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override @Override
public void handle(int fd, int res, int flags, int op, int mask) { public void handle(int fd, int res, int flags, int op, int mask) {
@ -183,7 +183,7 @@ public class NativeTest {
Thread waitingCqe = new Thread() { Thread waitingCqe = new Thread() {
@Override @Override
public void run() { public void run() {
assertTrue(completionQueue.ioUringWaitCqe()); completionQueue.ioUringWaitCqe();
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override @Override
public void handle(int fd, int res, int flags, int op, int mask) { public void handle(int fd, int res, int flags, int op, int mask) {
@ -251,7 +251,7 @@ public class NativeTest {
@Override @Override
public void run() { public void run() {
try { try {
assertTrue(completionQueue.ioUringWaitCqe()); completionQueue.ioUringWaitCqe();
assertEquals(2, completionQueue.process(verifyCallback)); assertEquals(2, completionQueue.process(verifyCallback));
} catch (AssertionError error) { } catch (AssertionError error) {
errorRef.set(error); errorRef.set(error);