diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/AbstractSocketReuseFdTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/AbstractSocketReuseFdTest.java new file mode 100644 index 0000000000..6e8ae87eef --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/AbstractSocketReuseFdTest.java @@ -0,0 +1,180 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.Promise; +import org.junit.Test; + +import java.io.IOException; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public abstract class AbstractSocketReuseFdTest extends AbstractSocketTest { + @Override + protected abstract SocketAddress newSocketAddress(); + + @Override + protected abstract List> newFactories(); + + @Test(timeout = 60000) + public void testReuseFd() throws Throwable { + run(); + } + + public void testReuseFd(ServerBootstrap sb, Bootstrap cb) throws Throwable { + sb.childOption(ChannelOption.AUTO_READ, true); + cb.option(ChannelOption.AUTO_READ, true); + + // Use a number which will typically not exceed /proc/sys/net/core/somaxconn (which is 128 on linux by default + // often). + int numChannels = 100; + final AtomicReference globalException = new AtomicReference(); + final AtomicInteger serverRemaining = new AtomicInteger(numChannels); + final AtomicInteger clientRemaining = new AtomicInteger(numChannels); + final Promise serverDonePromise = ImmediateEventExecutor.INSTANCE.newPromise(); + final Promise clientDonePromise = ImmediateEventExecutor.INSTANCE.newPromise(); + + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(Channel sch) { + ReuseFdHandler sh = new ReuseFdHandler( + false, + globalException, + serverRemaining, + serverDonePromise); + sch.pipeline().addLast("handler", sh); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + public void initChannel(Channel sch) { + ReuseFdHandler ch = new ReuseFdHandler( + true, + globalException, + clientRemaining, + clientDonePromise); + sch.pipeline().addLast("handler", ch); + } + }); + + ChannelFutureListener listener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + clientDonePromise.tryFailure(future.cause()); + } + } + }; + + Channel sc = sb.bind().sync().channel(); + for (int i = 0; i < numChannels; i++) { + cb.connect(sc.localAddress()).addListener(listener); + } + + clientDonePromise.sync(); + serverDonePromise.sync(); + sc.close().sync(); + + if (globalException.get() != null && !(globalException.get() instanceof IOException)) { + throw globalException.get(); + } + } + + static class ReuseFdHandler extends ChannelInboundHandlerAdapter { + private static final String EXPECTED_PAYLOAD = "payload"; + + private final Promise donePromise; + private final AtomicInteger remaining; + private final boolean client; + volatile Channel channel; + final AtomicReference globalException; + final AtomicReference exception = new AtomicReference(); + final StringBuilder received = new StringBuilder(); + + ReuseFdHandler( + boolean client, + AtomicReference globalException, + AtomicInteger remaining, + Promise donePromise) { + this.client = client; + this.globalException = globalException; + this.remaining = remaining; + this.donePromise = donePromise; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + channel = ctx.channel(); + if (client) { + ctx.writeAndFlush(Unpooled.copiedBuffer(EXPECTED_PAYLOAD, CharsetUtil.US_ASCII)); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + received.append(buf.toString(CharsetUtil.US_ASCII)); + buf.release(); + + if (received.toString().equals(EXPECTED_PAYLOAD)) { + if (client) { + ctx.close(); + } else { + ctx.writeAndFlush(Unpooled.copiedBuffer(EXPECTED_PAYLOAD, CharsetUtil.US_ASCII)); + } + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (exception.compareAndSet(null, cause)) { + donePromise.tryFailure(new IllegalStateException("exceptionCaught: " + ctx.channel(), cause)); + ctx.close(); + } + globalException.compareAndSet(null, cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + if (remaining.decrementAndGet() == 0) { + if (received.toString().equals(EXPECTED_PAYLOAD)) { + donePromise.setSuccess(null); + } else { + donePromise.tryFailure(new Exception("Unexpected payload:" + received)); + } + } + } + } +} diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java index c59f0ce267..2579b9ec59 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollHandler.java @@ -61,7 +61,7 @@ public class EpollHandler implements IoHandler { private final FileDescriptor epollFd; private final FileDescriptor eventFd; private final FileDescriptor timerFd; - private final IntObjectMap channels = new IntObjectHashMap(4096); + private final IntObjectMap channels = new IntObjectHashMap<>(4096); private final boolean allowGrowing; private final EpollEventArray events; @@ -225,7 +225,11 @@ public class EpollHandler implements IoHandler { private void add(AbstractEpollChannel ch) throws IOException { int fd = ch.socket.intValue(); Native.epollCtlAdd(epollFd.intValue(), fd, ch.flags); - channels.put(fd, ch); + AbstractEpollChannel old = channels.put(fd, ch); + + // We either expect to have no Channel in the map with the same FD or that the FD of the old Channel is already + // closed. + assert old == null || !old.isOpen(); } /** @@ -239,13 +243,19 @@ public class EpollHandler implements IoHandler { * Deregister the given channel from this {@link EpollHandler}. */ private void remove(AbstractEpollChannel ch) throws IOException { - if (ch.isOpen()) { - int fd = ch.socket.intValue(); - if (channels.remove(fd) != null) { - // Remove the epoll. This is only needed if it's still open as otherwise it will be automatically - // removed once the file-descriptor is closed. - Native.epollCtlDel(epollFd.intValue(), ch.fd().intValue()); - } + int fd = ch.socket.intValue(); + + AbstractEpollChannel old = channels.remove(fd); + if (old != null && old != ch) { + // The Channel mapping was already replaced due FD reuse, put back the stored Channel. + channels.put(fd, old); + + // If we found another Channel in the map that is mapped to the same FD the given Channel MUST be closed. + assert !ch.isOpen(); + } else if (ch.isOpen()) { + // Remove the epoll. This is only needed if it's still open as otherwise it will be automatically + // removed once the file-descriptor is closed. + Native.epollCtlDel(epollFd.intValue(), fd); } } @@ -367,11 +377,12 @@ public class EpollHandler implements IoHandler { } catch (IOException ignore) { // ignore on close } + // Using the intermediate collection to prevent ConcurrentModificationException. // In the `close()` method, the channel is deleted from `channels` map. AbstractEpollChannel[] localChannels = channels.values().toArray(new AbstractEpollChannel[0]); - for (AbstractEpollChannel ch : localChannels) { + for (AbstractEpollChannel ch: localChannels) { ch.unsafe().close(ch.unsafe().voidPromise()); } } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketReuseFdTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketReuseFdTest.java new file mode 100644 index 0000000000..487ea64a19 --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainSocketReuseFdTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.AbstractSocketReuseFdTest; + +import java.net.SocketAddress; +import java.util.List; + +public class EpollDomainSocketReuseFdTest extends AbstractSocketReuseFdTest { + @Override + protected SocketAddress newSocketAddress() { + return EpollSocketTestPermutation.newSocketAddress(); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.domainSocket(); + } +} diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java index f53c08e0b8..4efe48f1b2 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueChannel.java @@ -178,10 +178,18 @@ abstract class AbstractKQueueChannel extends AbstractChannel implements UnixChan evSet0(registration, Native.EVFILT_SOCK, Native.EV_ADD, Native.NOTE_RDHUP); } - void deregister0() throws IOException { + void deregister0() { + // As unregisteredFilters() may have not been called because isOpen() returned false we just set both filters + // to false to ensure a consistent state in all cases. + readFilterEnabled = false; + writeFilterEnabled = false; + } + + void unregisterFilters() throws Exception { // Make sure we unregister our filters from kqueue! readFilter(false); writeFilter(false); + if (registration != null) { evSet0(registration, Native.EVFILT_SOCK, Native.EV_DELETE, 0); registration = null; @@ -331,7 +339,7 @@ abstract class AbstractKQueueChannel extends AbstractChannel implements UnixChan } private void evSet(short filter, short flags) { - if (isOpen() && isRegistered()) { + if (isRegistered()) { evSet0(registration, filter, flags); } } @@ -341,7 +349,10 @@ abstract class AbstractKQueueChannel extends AbstractChannel implements UnixChan } private void evSet0(KQueueRegistration registration, short filter, short flags, int fflags) { - registration.evSet(filter, flags, fflags); + // Only try to add to changeList if the FD is still open, if not we already closed it in the meantime. + if (isOpen()) { + registration.evSet(filter, flags, fflags); + } } abstract class AbstractKQueueUnsafe extends AbstractUnsafe { diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueHandler.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueHandler.java index acd8b3b083..aac1a48a8d 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueHandler.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueHandler.java @@ -118,7 +118,10 @@ public final class KQueueHandler implements IoHandler { public void register(Channel channel) { final AbstractKQueueChannel kQueueChannel = cast(channel); final int id = kQueueChannel.fd().intValue(); - channels.put(id, kQueueChannel); + AbstractKQueueChannel old = channels.put(id, kQueueChannel); + // We either expect to have no Channel in the map with the same FD or that the FD of the old Channel is already + // closed. + assert old == null || !old.isOpen(); kQueueChannel.register0(new KQueueRegistration() { @Override @@ -136,7 +139,23 @@ public final class KQueueHandler implements IoHandler { @Override public void deregister(Channel channel) throws Exception { AbstractKQueueChannel kQueueChannel = cast(channel); - channels.remove(kQueueChannel.fd().intValue()); + int fd = kQueueChannel.fd().intValue(); + + AbstractKQueueChannel old = channels.remove(fd); + if (old != null && old != kQueueChannel) { + // The Channel mapping was already replaced due FD reuse, put back the stored Channel. + channels.put(fd, old); + + // If we found another Channel in the map that is mapped to the same FD the given Channel MUST be closed. + assert !kQueueChannel.isOpen(); + } else if (kQueueChannel.isOpen()) { + // Remove the filters. This is only needed if it's still open as otherwise it will be automatically + // removed once the file-descriptor is closed. + // + // See also https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2 + kQueueChannel.unregisterFilters(); + } + kQueueChannel.deregister0(); } @@ -315,6 +334,14 @@ public final class KQueueHandler implements IoHandler { } catch (IOException e) { // ignore on close } + + // Using the intermediate collection to prevent ConcurrentModificationException. + // In the `close()` method, the channel is deleted from `channels` map. + AbstractKQueueChannel[] localChannels = channels.values().toArray(new AbstractKQueueChannel[0]); + + for (AbstractKQueueChannel ch: localChannels) { + ch.unsafe().close(ch.unsafe().voidPromise()); + } } private static void handleLoopException(Throwable t) { diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketReuseFdTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketReuseFdTest.java new file mode 100644 index 0000000000..2e239c21c7 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainSocketReuseFdTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.AbstractSocketReuseFdTest; + +import java.net.SocketAddress; +import java.util.List; + +public class KQueueDomainSocketReuseFdTest extends AbstractSocketReuseFdTest { + @Override + protected SocketAddress newSocketAddress() { + return KQueueSocketTestPermutation.newSocketAddress(); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.domainSocket(); + } +}