diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java new file mode 100644 index 0000000000..d467f3f22a --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketRstTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import org.junit.Test; + +import java.io.IOException; +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class SocketRstTest extends AbstractSocketTest { + protected void assertRstOnCloseException(IOException cause, Channel clientChannel) { + if (Locale.getDefault() == Locale.US || Locale.getDefault() == Locale.UK) { + assertTrue("actual message: " + cause.getMessage(), cause.getMessage().contains("reset")); + } + } + + @Test(timeout = 3000) + public void testSoLingerZeroCausesOnlyRstOnClose() throws Throwable { + run(); + } + + public void testSoLingerZeroCausesOnlyRstOnClose(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final AtomicReference serverChannelRef = new AtomicReference(); + final AtomicReference throwableRef = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch latch2 = new CountDownLatch(1); + // SO_LINGER=0 means that we must send ONLY a RST when closing (not a FIN + RST). + sb.childOption(ChannelOption.SO_LINGER, 0); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverChannelRef.compareAndSet(null, ch); + latch.countDown(); + } + }); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + throwableRef.compareAndSet(null, cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + latch2.countDown(); + } + }); + } + }); + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect(sc.localAddress()).sync().channel(); + + // Wait for the server to get setup. + latch.await(); + + // The server has SO_LINGER=0 and so it must send a RST when close is called. + serverChannelRef.get().close(); + + // Wait for the client to get channelInactive. + latch2.await(); + + // Verify the client received a RST. + Throwable cause = throwableRef.get(); + assertTrue("actual [type, message]: [" + cause.getClass() + ", " + cause.getMessage() + "]", + cause instanceof IOException); + assertRstOnCloseException((IOException) cause, cc); + } + + @Test(timeout = 3000) + public void testNoRstIfSoLingerOnClose() throws Throwable { + run(); + } + + public void testNoRstIfSoLingerOnClose(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final AtomicReference serverChannelRef = new AtomicReference(); + final AtomicReference throwableRef = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch latch2 = new CountDownLatch(1); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverChannelRef.compareAndSet(null, ch); + latch.countDown(); + } + }); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + throwableRef.compareAndSet(null, cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + latch2.countDown(); + } + }); + } + }); + Channel sc = sb.bind().sync().channel(); + cb.connect(sc.localAddress()).syncUninterruptibly(); + + // Wait for the server to get setup. + latch.await(); + + // The server has SO_LINGER=0 and so it must send a RST when close is called. + serverChannelRef.get().close(); + + // Wait for the client to get channelInactive. + latch2.await(); + + // Verify the client did not received a RST. + assertNull(throwableRef.get()); + } +} diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java index 4fd3ba07b4..023ddd50c0 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollChannel.java @@ -97,21 +97,11 @@ abstract class AbstractEpollChannel extends AbstractChannel implements UnixChann @Override protected void doClose() throws Exception { - this.active = false; - Socket fd = fileDescriptor; + active = false; try { - // deregister from epoll now and shutdown the socket. doDeregister(); - if (!fd.isShutdown()) { - try { - fd().shutdown(); - } catch (IOException ignored) { - // The FD will be closed, so if shutdown fails there is nothing we can do. - } - } } finally { - // Ensure the file descriptor is closed in all cases. - fd.close(); + fileDescriptor.close(); } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/unix/FileDescriptor.java b/transport-native-epoll/src/main/java/io/netty/channel/unix/FileDescriptor.java index 48fc555c3a..d4e6d32a4a 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/unix/FileDescriptor.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/unix/FileDescriptor.java @@ -34,18 +34,26 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull; * {@link FileDescriptor} for it. */ public class FileDescriptor { - private static final AtomicIntegerFieldUpdater openUpdater; + private static final AtomicIntegerFieldUpdater stateUpdater; static { AtomicIntegerFieldUpdater updater - = PlatformDependent.newAtomicIntegerFieldUpdater(FileDescriptor.class, "open"); + = PlatformDependent.newAtomicIntegerFieldUpdater(FileDescriptor.class, "state"); if (updater == null) { - updater = AtomicIntegerFieldUpdater.newUpdater(FileDescriptor.class, "open"); + updater = AtomicIntegerFieldUpdater.newUpdater(FileDescriptor.class, "state"); } - openUpdater = updater; + stateUpdater = updater; } - - private final int fd; - private volatile int open = 1; + private static final int STATE_CLOSED_MASK = 1; + private static final int STATE_INPUT_SHUTDOWN_MASK = 1 << 1; + private static final int STATE_OUTPUT_SHUTDOWN_MASK = 1 << 2; + private static final int STATE_ALL_MASK = STATE_CLOSED_MASK | + STATE_INPUT_SHUTDOWN_MASK | + STATE_OUTPUT_SHUTDOWN_MASK; + /** + * Bit map = [Output Shutdown | Input Shutdown | Closed] + */ + volatile int state; + final int fd; public FileDescriptor(int fd) { if (fd < 0) { @@ -65,11 +73,19 @@ public class FileDescriptor { * Close the file descriptor. */ public void close() throws IOException { - if (openUpdater.compareAndSet(this, 1, 0)) { - int res = close(fd); - if (res < 0) { - throw newIOException("close", res); + for (;;) { + int state = this.state; + if (isClosed(state)) { + return; } + // Once a close operation happens, the channel is considered shutdown. + if (casState(state, state | STATE_ALL_MASK)) { + break; + } + } + int res = close(fd); + if (res < 0) { + throw newIOException("close", res); } } @@ -77,7 +93,7 @@ public class FileDescriptor { * Returns {@code true} if the file descriptor is open. */ public boolean isOpen() { - return open == 1; + return !isClosed(state); } public final int write(ByteBuffer buf, int pos, int limit) throws IOException { @@ -188,6 +204,36 @@ public class FileDescriptor { return new FileDescriptor[]{new FileDescriptor((int) (res >>> 32)), new FileDescriptor((int) res)}; } + final boolean casState(int expected, int update) { + return stateUpdater.compareAndSet(this, expected, update); + } + + static boolean isClosed(int state) { + return (state & STATE_CLOSED_MASK) != 0; + } + + static boolean isInputShutdown(int state) { + return (state & STATE_INPUT_SHUTDOWN_MASK) != 0; + } + + static boolean isOutputShutdown(int state) { + return (state & STATE_OUTPUT_SHUTDOWN_MASK) != 0; + } + + static boolean shouldAttemptShutdown(int state, boolean read, boolean write) { + return !isClosed(state) && (read && !isInputShutdown(state) || write && !isOutputShutdown(state)); + } + + static int calculateShutdownState(int state, boolean read, boolean write) { + if (read) { + state |= STATE_INPUT_SHUTDOWN_MASK; + } + if (write) { + state |= STATE_OUTPUT_SHUTDOWN_MASK; + } + return state; + } + private static native int open(String path); private static native int close(int fd); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/unix/Socket.java b/transport-native-epoll/src/main/java/io/netty/channel/unix/Socket.java index 7f62fdfbec..a1795832d4 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/unix/Socket.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/unix/Socket.java @@ -42,40 +42,41 @@ import static io.netty.channel.unix.NativeInetAddress.ipv4MappedIpv6Address; * Internal usage only! */ public final class Socket extends FileDescriptor { - private volatile boolean inputShutdown; - private volatile boolean outputShutdown; - public Socket(int fd) { super(fd); } public void shutdown() throws IOException { - shutdown(!inputShutdown, !outputShutdown); + shutdown(true, true); } public void shutdown(boolean read, boolean write) throws IOException { - inputShutdown = read || inputShutdown; - outputShutdown = write || outputShutdown; - shutdown0(read, write); - } - - private void shutdown0(boolean read, boolean write) throws IOException { - int res = shutdown(intValue(), read, write); + for (;;) { + int state = this.state; + if (!shouldAttemptShutdown(state, read, write)) { + return; + } + if (casState(state, calculateShutdownState(state, read, write))) { + break; + } + } + int res = shutdown(fd, read, write); if (res < 0) { ioResult("shutdown", res, CONNECTION_NOT_CONNECTED_SHUTDOWN_EXCEPTION); } } public boolean isShutdown() { - return isInputShutdown() && isOutputShutdown(); + int state = this.state; + return isInputShutdown(state) && isOutputShutdown(state); } public boolean isInputShutdown() { - return inputShutdown; + return isInputShutdown(state); } public boolean isOutputShutdown() { - return outputShutdown; + return isOutputShutdown(state); } public int sendTo(ByteBuffer buf, int pos, int limit, InetAddress addr, int port) throws IOException { @@ -91,7 +92,7 @@ public final class Socket extends FileDescriptor { scopeId = 0; address = ipv4MappedIpv6Address(addr.getAddress()); } - int res = sendTo(intValue(), buf, pos, limit, address, scopeId, port); + int res = sendTo(fd, buf, pos, limit, address, scopeId, port); if (res >= 0) { return res; } @@ -112,7 +113,7 @@ public final class Socket extends FileDescriptor { scopeId = 0; address = ipv4MappedIpv6Address(addr.getAddress()); } - int res = sendToAddress(intValue(), memoryAddress, pos, limit, address, scopeId, port); + int res = sendToAddress(fd, memoryAddress, pos, limit, address, scopeId, port); if (res >= 0) { return res; } @@ -132,7 +133,7 @@ public final class Socket extends FileDescriptor { scopeId = 0; address = ipv4MappedIpv6Address(addr.getAddress()); } - int res = sendToAddresses(intValue(), memoryAddress, length, address, scopeId, port); + int res = sendToAddresses(fd, memoryAddress, length, address, scopeId, port); if (res >= 0) { return res; } @@ -140,11 +141,11 @@ public final class Socket extends FileDescriptor { } public DatagramSocketAddress recvFrom(ByteBuffer buf, int pos, int limit) throws IOException { - return recvFrom(intValue(), buf, pos, limit); + return recvFrom(fd, buf, pos, limit); } public DatagramSocketAddress recvFromAddress(long memoryAddress, int pos, int limit) throws IOException { - return recvFromAddress(intValue(), memoryAddress, pos, limit); + return recvFromAddress(fd, memoryAddress, pos, limit); } public boolean connect(SocketAddress socketAddress) throws IOException { @@ -152,10 +153,10 @@ public final class Socket extends FileDescriptor { if (socketAddress instanceof InetSocketAddress) { InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; NativeInetAddress address = NativeInetAddress.newInstance(inetSocketAddress.getAddress()); - res = connect(intValue(), address.address, address.scopeId, inetSocketAddress.getPort()); + res = connect(fd, address.address, address.scopeId, inetSocketAddress.getPort()); } else if (socketAddress instanceof DomainSocketAddress) { DomainSocketAddress unixDomainSocketAddress = (DomainSocketAddress) socketAddress; - res = connectDomainSocket(intValue(), unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8)); + res = connectDomainSocket(fd, unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8)); } else { throw new Error("Unexpected SocketAddress implementation " + socketAddress); } @@ -170,7 +171,7 @@ public final class Socket extends FileDescriptor { } public boolean finishConnect() throws IOException { - int res = finishConnect(intValue()); + int res = finishConnect(fd); if (res < 0) { if (res == ERRNO_EINPROGRESS_NEGATIVE) { // connect still in progress @@ -185,13 +186,13 @@ public final class Socket extends FileDescriptor { if (socketAddress instanceof InetSocketAddress) { InetSocketAddress addr = (InetSocketAddress) socketAddress; NativeInetAddress address = NativeInetAddress.newInstance(addr.getAddress()); - int res = bind(intValue(), address.address, address.scopeId, addr.getPort()); + int res = bind(fd, address.address, address.scopeId, addr.getPort()); if (res < 0) { throw newIOException("bind", res); } } else if (socketAddress instanceof DomainSocketAddress) { DomainSocketAddress addr = (DomainSocketAddress) socketAddress; - int res = bindDomainSocket(intValue(), addr.path().getBytes(CharsetUtil.UTF_8)); + int res = bindDomainSocket(fd, addr.path().getBytes(CharsetUtil.UTF_8)); if (res < 0) { throw newIOException("bind", res); } @@ -201,14 +202,14 @@ public final class Socket extends FileDescriptor { } public void listen(int backlog) throws IOException { - int res = listen(intValue(), backlog); + int res = listen(fd, backlog); if (res < 0) { throw newIOException("listen", res); } } public int accept(byte[] addr) throws IOException { - int res = accept(intValue(), addr); + int res = accept(fd, addr); if (res >= 0) { return res; } @@ -220,7 +221,7 @@ public final class Socket extends FileDescriptor { } public InetSocketAddress remoteAddress() { - byte[] addr = remoteAddress(intValue()); + byte[] addr = remoteAddress(fd); // addr may be null if getpeername failed. // See https://github.com/netty/netty/issues/3328 if (addr == null) { @@ -230,7 +231,7 @@ public final class Socket extends FileDescriptor { } public InetSocketAddress localAddress() { - byte[] addr = localAddress(intValue()); + byte[] addr = localAddress(fd); // addr may be null if getpeername failed. // See https://github.com/netty/netty/issues/3328 if (addr == null) { @@ -240,77 +241,77 @@ public final class Socket extends FileDescriptor { } public int getReceiveBufferSize() throws IOException { - return getReceiveBufferSize(intValue()); + return getReceiveBufferSize(fd); } public int getSendBufferSize() throws IOException { - return getSendBufferSize(intValue()); + return getSendBufferSize(fd); } public boolean isKeepAlive() throws IOException { - return isKeepAlive(intValue()) != 0; + return isKeepAlive(fd) != 0; } public boolean isTcpNoDelay() throws IOException { - return isTcpNoDelay(intValue()) != 0; + return isTcpNoDelay(fd) != 0; } public boolean isTcpCork() throws IOException { - return isTcpCork(intValue()) != 0; + return isTcpCork(fd) != 0; } public int getSoLinger() throws IOException { - return getSoLinger(intValue()); + return getSoLinger(fd); } public int getTcpDeferAccept() throws IOException { - return getTcpDeferAccept(intValue()); + return getTcpDeferAccept(fd); } public boolean isTcpQuickAck() throws IOException { - return isTcpQuickAck(intValue()) != 0; + return isTcpQuickAck(fd) != 0; } public int getSoError() { - return getSoError(intValue()); + return getSoError(fd); } public void setKeepAlive(boolean keepAlive) throws IOException { - setKeepAlive(intValue(), keepAlive ? 1 : 0); + setKeepAlive(fd, keepAlive ? 1 : 0); } public void setReceiveBufferSize(int receiveBufferSize) throws IOException { - setReceiveBufferSize(intValue(), receiveBufferSize); + setReceiveBufferSize(fd, receiveBufferSize); } public void setSendBufferSize(int sendBufferSize) throws IOException { - setSendBufferSize(intValue(), sendBufferSize); + setSendBufferSize(fd, sendBufferSize); } public void setTcpNoDelay(boolean tcpNoDelay) throws IOException { - setTcpNoDelay(intValue(), tcpNoDelay ? 1 : 0); + setTcpNoDelay(fd, tcpNoDelay ? 1 : 0); } public void setTcpCork(boolean tcpCork) throws IOException { - setTcpCork(intValue(), tcpCork ? 1 : 0); + setTcpCork(fd, tcpCork ? 1 : 0); } public void setSoLinger(int soLinger) throws IOException { - setSoLinger(intValue(), soLinger); + setSoLinger(fd, soLinger); } public void setTcpDeferAccept(int deferAccept) throws IOException { - setTcpDeferAccept(intValue(), deferAccept); + setTcpDeferAccept(fd, deferAccept); } public void setTcpQuickAck(boolean quickAck) throws IOException { - setTcpQuickAck(intValue(), quickAck ? 1 : 0); + setTcpQuickAck(fd, quickAck ? 1 : 0); } @Override public String toString() { return "Socket{" + - "fd=" + intValue() + + "fd=" + fd + '}'; } diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketRstTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketRstTest.java new file mode 100644 index 0000000000..3f1ad1f2ce --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketRstTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.unix.Errors; +import io.netty.channel.unix.Errors.NativeIoException; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.SocketRstTest; + +import java.io.IOException; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class EpollSocketRstTest extends SocketRstTest { + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.socket(); + } + + @Override + protected void assertRstOnCloseException(IOException cause, Channel clientChannel) { + if (!AbstractEpollChannel.class.isInstance(clientChannel)) { + super.assertRstOnCloseException(cause, clientChannel); + return; + } + + assertTrue("actual [type, message]: [" + cause.getClass() + ", " + cause.getMessage() + "]", + cause instanceof NativeIoException); + assertEquals(Errors.ERRNO_ECONNRESET_NEGATIVE, ((NativeIoException) cause).expectedErr()); + } +}