Fix server socket address already in use error even if close is called

Motivation:

when you submit a poll, io_uring still hold reference to it even if close is called
source io_uring mailing list(https://lore.kernel.org/io-uring/27657840-4E8E-422D-93BB-7F485F21341A@kernel.dk/T/#t)

Modification:

-To delete fd reference in io_uring, we use POLL_REMOVE to avoid a server socket address error
-Added a POLL_REMOVE test

Result:

server can be closed and restarted
This commit is contained in:
Josef Grieb 2020-08-26 19:35:46 +02:00
parent 6ab424f3c6
commit 56ace4228f
7 changed files with 115 additions and 21 deletions

View File

@ -146,7 +146,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
if (byteBuf.hasMemoryAddress()) { if (byteBuf.hasMemoryAddress()) {
readBuffer = byteBuf; readBuffer = byteBuf;
submissionQueue.addRead(socket.intValue(), byteBuf.memoryAddress(), submissionQueue.addRead(socket.intValue(), byteBuf.memoryAddress(),
byteBuf.writerIndex(), byteBuf.capacity()); byteBuf.writerIndex(), byteBuf.capacity());
submissionQueue.submit(); submissionQueue.submit();
} }
} }
@ -164,6 +164,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
} }
} }
} }
void readComplete(int localReadAmount) { void readComplete(int localReadAmount) {
boolean close = false; boolean close = false;
ByteBuf byteBuf = null; ByteBuf byteBuf = null;
@ -190,7 +191,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
close = allocHandle.lastBytesRead() < 0; close = allocHandle.lastBytesRead() < 0;
if (close) { if (close) {
// There is nothing left to read as we received an EOF. // There is nothing left to read as we received an EOF.
shutdownInput(false); shutdownInput(false);
} }
allocHandle.readComplete(); allocHandle.readComplete();
pipeline.fireChannelReadComplete(); pipeline.fireChannelReadComplete();
@ -275,6 +276,10 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
protected void doClose() throws Exception { protected void doClose() throws Exception {
if (parent() == null) { if (parent() == null) {
logger.trace("ServerSocket Close: {}", this.socket.intValue()); logger.trace("ServerSocket Close: {}", this.socket.intValue());
IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop();
IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue();
submissionQueue.addPollRemove(socket.intValue());
submissionQueue.submit();
} }
active = false; active = false;
// Even if we allow half closed sockets we should give up on reading. Otherwise we may allow a read attempt on a // Even if we allow half closed sockets we should give up on reading. Otherwise we may allow a read attempt on a
@ -359,7 +364,7 @@ abstract class AbstractIOUringChannel extends AbstractChannel implements UnixCha
IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop(); IOUringEventLoop ioUringEventLoop = (IOUringEventLoop) eventLoop();
IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue(); IOUringSubmissionQueue submissionQueue = ioUringEventLoop.getRingBuffer().getIoUringSubmissionQueue();
submissionQueue.addWrite(socket.intValue(), buf.memoryAddress(), buf.readerIndex(), submissionQueue.addWrite(socket.intValue(), buf.memoryAddress(), buf.readerIndex(),
buf.writerIndex()); buf.writerIndex());
submissionQueue.submit(); submissionQueue.submit();
writeable = false; writeable = false;
} }

View File

@ -28,10 +28,12 @@ abstract class AbstractIOUringServerChannel extends AbstractIOUringChannel imple
AbstractIOUringServerChannel(int fd) { AbstractIOUringServerChannel(int fd) {
super(null, new LinuxSocket(fd)); super(null, new LinuxSocket(fd));
System.out.println("Server Socket fd: " + fd);
} }
AbstractIOUringServerChannel(LinuxSocket fd) { AbstractIOUringServerChannel(LinuxSocket fd) {
super(null, fd); super(null, fd);
System.out.println("Server Socket fd: " + fd);
} }
@Override @Override

View File

@ -26,6 +26,7 @@ final class IOUring {
static final int OP_ACCEPT = 13; static final int OP_ACCEPT = 13;
static final int OP_READ = 22; static final int OP_READ = 22;
static final int OP_WRITE = 23; static final int OP_WRITE = 23;
static final int OP_POLL_REMOVE = 7;
static final int POLLMASK_LINK = 1; static final int POLLMASK_LINK = 1;
static final int POLLMASK_OUT = 4; static final int POLLMASK_OUT = 4;

View File

@ -74,6 +74,7 @@ final class IOUringCompletionQueue {
int opMask = (int) udata; int opMask = (int) udata;
short op = (short) (opMask >> 16); short op = (short) (opMask >> 16);
short mask = (short) opMask; short mask = (short) opMask;
i++; i++;
if (!callback.handle(fd, res, flags, op, mask)) { if (!callback.handle(fd, res, flags, op, mask)) {
break; break;

View File

@ -37,10 +37,12 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
//Todo set config ring buffer size //Todo set config ring buffer size
private final int ringSize = 32; private final int ringSize = 32;
private final int ENOENT = -2;
//just temporary -> Todo use ErrorsStaticallyReferencedJniMethods like in Epoll //just temporary -> Todo use ErrorsStaticallyReferencedJniMethods like in Epoll
private final int SOCKET_ERROR_EPIPE = -32; private static final int SOCKET_ERROR_EPIPE = -32;
private static long ETIME = -62; private static final long ETIME = -62;
static final long ECANCELED = -125;
// events should be unique to identify which event type that was // events should be unique to identify which event type that was
private long eventIdCounter; private long eventIdCounter;
@ -175,6 +177,11 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue(); IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
switch (op) { switch (op) {
case IOUring.OP_ACCEPT: case IOUring.OP_ACCEPT:
//Todo error handle the res
if (res == ECANCELED) {
logger.trace("POLL_LINK canceled");
break;
}
AbstractIOUringServerChannel acceptChannel = (AbstractIOUringServerChannel) channels.get(fd); AbstractIOUringServerChannel acceptChannel = (AbstractIOUringServerChannel) channels.get(fd);
if (acceptChannel == null) { if (acceptChannel == null) {
break; break;
@ -217,6 +224,11 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
break; break;
case IOUring.IO_POLL: case IOUring.IO_POLL:
//Todo error handle the res
if (res == ECANCELED) {
logger.trace("POLL_LINK canceled");
break;
}
if (eventfd.intValue() == fd) { if (eventfd.intValue() == fd) {
pendingWakeup = false; pendingWakeup = false;
// We need to consume the data as otherwise we would see another event // We need to consume the data as otherwise we would see another event
@ -238,7 +250,12 @@ final class IOUringEventLoop extends SingleThreadEventLoop implements
break; break;
} }
} }
case IOUring.OP_POLL_REMOVE:
if (res == ENOENT) {
logger.trace("POLL_REMOVE OPERATION not permitted");
} else if (res == 0) {
logger.trace("POLL_REMOVE OPERATION successful");
}
break; break;
} }
return true; return true;

View File

@ -116,10 +116,14 @@ final class IOUringSubmissionQueue {
PlatformDependent.putLong(sqe + SQE_ADDRESS_FIELD, bufferAddress); PlatformDependent.putLong(sqe + SQE_ADDRESS_FIELD, bufferAddress);
PlatformDependent.putInt(sqe + SQE_LEN_FIELD, length); PlatformDependent.putInt(sqe + SQE_LEN_FIELD, length);
// Store the fd and event type in the user_data field //user_data should be same as POLL_LINK fd
int opMask = (((short) op) << 16) | (((short) pollMask) & 0xFFFF); if (op == IOUring.OP_POLL_REMOVE) {
long uData = (long) fd << 32 | opMask & 0xFFFFFFFFL; PlatformDependent.putInt(sqe + SQE_FD_FIELD, -1);
long pollLinkuData = convertToUserData((byte) IOUring.IO_POLL, fd, IOUring.POLLMASK_LINK);
PlatformDependent.putLong(sqe + SQE_ADDRESS_FIELD, pollLinkuData);
}
long uData = convertToUserData(op, fd, pollMask);
PlatformDependent.putLong(sqe + SQE_USER_DATA_FIELD, uData); PlatformDependent.putLong(sqe + SQE_USER_DATA_FIELD, uData);
//poll<link>read or accept operation //poll<link>read or accept operation
@ -213,6 +217,17 @@ final class IOUringSubmissionQueue {
return true; return true;
} }
//fill the user_data which is associated with server poll link
public boolean addPollRemove(int fd) {
long sqe = getSqe();
if (sqe == 0) {
return false;
}
setData(sqe, (byte) IOUring.OP_POLL_REMOVE, 0, fd, 0, 0, 0);
return true;
}
private int flushSqe() { private int flushSqe() {
long kTail = toUnsignedLong(PlatformDependent.getInt(kTailAddress)); long kTail = toUnsignedLong(PlatformDependent.getInt(kTailAddress));
long kHead = toUnsignedLong(PlatformDependent.getIntVolatile(kHeadAddress)); long kHead = toUnsignedLong(PlatformDependent.getIntVolatile(kHeadAddress));
@ -270,6 +285,11 @@ final class IOUringSubmissionQueue {
PlatformDependent.putLong(timeoutMemoryAddress + KERNEL_TIMESPEC_TV_NSEC_FIELD, nanoSeconds); PlatformDependent.putLong(timeoutMemoryAddress + KERNEL_TIMESPEC_TV_NSEC_FIELD, nanoSeconds);
} }
private long convertToUserData(byte op, int fd, int pollMask) {
int opMask = (((short) op) << 16) | (((short) pollMask) & 0xFFFF);
return (long) fd << 32 | opMask & 0xFFFFFFFFL;
}
//delete memory //delete memory
public void release() { public void release() {
Buffer.free(timeoutMemory); Buffer.free(timeoutMemory);

View File

@ -16,11 +16,17 @@
package io.netty.channel.uring; package io.netty.channel.uring;
import io.netty.channel.unix.FileDescriptor; import io.netty.channel.unix.FileDescriptor;
import io.netty.channel.unix.Socket;
import org.junit.Test; import org.junit.Test;
import java.io.File;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import org.junit.experimental.theories.suppliers.TestedOn; import org.junit.experimental.theories.suppliers.TestedOn;
@ -46,7 +52,7 @@ public class NativeTest {
assertNotNull(completionQueue); assertNotNull(completionQueue);
assertTrue(submissionQueue.addWrite(fd, writeEventByteBuf.memoryAddress(), assertTrue(submissionQueue.addWrite(fd, writeEventByteBuf.memoryAddress(),
writeEventByteBuf.readerIndex(), writeEventByteBuf.writerIndex())); writeEventByteBuf.readerIndex(), writeEventByteBuf.writerIndex()));
submissionQueue.submit(); submissionQueue.submit();
assertTrue(completionQueue.ioUringWaitCqe()); assertTrue(completionQueue.ioUringWaitCqe());
@ -61,7 +67,7 @@ public class NativeTest {
final ByteBuf readEventByteBuf = allocator.directBuffer(100); final ByteBuf readEventByteBuf = allocator.directBuffer(100);
assertTrue(submissionQueue.addRead(fd, readEventByteBuf.memoryAddress(), assertTrue(submissionQueue.addRead(fd, readEventByteBuf.memoryAddress(),
readEventByteBuf.writerIndex(), readEventByteBuf.capacity())); readEventByteBuf.writerIndex(), readEventByteBuf.capacity()));
submissionQueue.submit(); submissionQueue.submit();
assertTrue(completionQueue.ioUringWaitCqe()); assertTrue(completionQueue.ioUringWaitCqe());
@ -94,14 +100,14 @@ public class NativeTest {
Thread thread = new Thread() { Thread thread = new Thread() {
@Override @Override
public void run() { public void run() {
assertTrue(completionQueue.ioUringWaitCqe()); assertTrue(completionQueue.ioUringWaitCqe());
completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() { completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override @Override
public boolean handle(int fd, int res, long flags, int op, int mask) { public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(-62, res); assertEquals(-62, res);
return true; return true;
} }
}); });
} }
}; };
thread.start(); thread.start();
@ -133,7 +139,7 @@ public class NativeTest {
new Thread() { new Thread() {
@Override @Override
public void run() { public void run() {
Native.eventFdWrite(eventFd.intValue(), 1L); Native.eventFdWrite(eventFd.intValue(), 1L);
} }
}.start(); }.start();
@ -182,7 +188,7 @@ public class NativeTest {
new Thread() { new Thread() {
@Override @Override
public void run() { public void run() {
Native.eventFdWrite(eventFd.intValue(), 1L); Native.eventFdWrite(eventFd.intValue(), 1L);
} }
}.start(); }.start();
@ -196,4 +202,46 @@ public class NativeTest {
RingBuffer ringBuffer = Native.createRingBuffer(); RingBuffer ringBuffer = Native.createRingBuffer();
ringBuffer.close(); ringBuffer.close();
} }
@Test
public void ioUringPollRemoveTest() throws InterruptedException {
RingBuffer ringBuffer = Native.createRingBuffer(32);
IOUringSubmissionQueue submissionQueue = ringBuffer.getIoUringSubmissionQueue();
final IOUringCompletionQueue completionQueue = ringBuffer.getIoUringCompletionQueue();
FileDescriptor eventFd = Native.newEventFd();
submissionQueue.addPollLink(eventFd.intValue());
submissionQueue.submit();
Thread.sleep(10);
submissionQueue.addPollRemove(eventFd.intValue());
submissionQueue.submit();
Thread waitingCqe = new Thread() {
@Override
public void run() {
assertTrue(completionQueue.ioUringWaitCqe());
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(IOUringEventLoop.ECANCELED, res);
assertEquals(IOUring.IO_POLL, op);
return true;
}
}));
assertEquals(1, completionQueue.process(new IOUringCompletionQueue.IOUringCompletionQueueCallback() {
@Override
public boolean handle(int fd, int res, long flags, int op, int mask) {
assertEquals(0, res);
assertEquals(IOUring.OP_POLL_REMOVE, op);
return true;
}
}));
}
};
waitingCqe.start();
waitingCqe.join();
}
} }