EPOLL SO_LINGER=0 sends FIN+RST

Motivation:
If SO_LINGER is set to 0 the EPOLL transport will send a FIN followed by a RST. This is not consistent with the behavior of the NIO transport. This variation in behavior can cause protocol violations in streaming protocols (e.g. HTTP) where a FIN may be interpreted as a valid end to a data stream, but RST may be treated as the data is corrupted and should be discarded.

https://github.com/netty/netty/issues/4170 Claims the behavior of NIO always issues a shutdown when close occurs. I could not find any evidence of this in Netty's NIO transport nor in the JDK's SocketChannel.close() implementation.

Modifications:
- AbstractEpollChannel should be consistent with the NIO transport and not force a shutdown on every close
- FileDescriptor to keep state in a consistent manner with the JDK and not allow a shutdown after a close
- Unit tests for NIO and EPOLL to ensure consistent behavior

Result:
EPOLL is capable of sending just a RST to terminate a connection.
This commit is contained in:
Scott Mitchell 2016-03-09 18:16:29 -08:00
parent 01835fdf18
commit 8dbf5d02e5
5 changed files with 300 additions and 70 deletions

View File

@ -0,0 +1,144 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import org.junit.Test;
import java.io.IOException;
import java.util.Locale;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
public class SocketRstTest extends AbstractSocketTest {
protected void assertRstOnCloseException(IOException cause, Channel clientChannel) {
if (Locale.getDefault() == Locale.US || Locale.getDefault() == Locale.UK) {
assertTrue("actual message: " + cause.getMessage(), cause.getMessage().contains("reset"));
}
}
@Test(timeout = 3000)
public void testSoLingerZeroCausesOnlyRstOnClose() throws Throwable {
run();
}
public void testSoLingerZeroCausesOnlyRstOnClose(ServerBootstrap sb, Bootstrap cb) throws Throwable {
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
final AtomicReference<Throwable> throwableRef = new AtomicReference<Throwable>();
final CountDownLatch latch = new CountDownLatch(1);
final CountDownLatch latch2 = new CountDownLatch(1);
// SO_LINGER=0 means that we must send ONLY a RST when closing (not a FIN + RST).
sb.childOption(ChannelOption.SO_LINGER, 0);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
serverChannelRef.compareAndSet(null, ch);
latch.countDown();
}
});
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
throwableRef.compareAndSet(null, cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
latch2.countDown();
}
});
}
});
Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect(sc.localAddress()).sync().channel();
// Wait for the server to get setup.
latch.await();
// The server has SO_LINGER=0 and so it must send a RST when close is called.
serverChannelRef.get().close();
// Wait for the client to get channelInactive.
latch2.await();
// Verify the client received a RST.
Throwable cause = throwableRef.get();
assertTrue("actual [type, message]: [" + cause.getClass() + ", " + cause.getMessage() + "]",
cause instanceof IOException);
assertRstOnCloseException((IOException) cause, cc);
}
@Test(timeout = 3000)
public void testNoRstIfSoLingerOnClose() throws Throwable {
run();
}
public void testNoRstIfSoLingerOnClose(ServerBootstrap sb, Bootstrap cb) throws Throwable {
final AtomicReference<Channel> serverChannelRef = new AtomicReference<Channel>();
final AtomicReference<Throwable> throwableRef = new AtomicReference<Throwable>();
final CountDownLatch latch = new CountDownLatch(1);
final CountDownLatch latch2 = new CountDownLatch(1);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
serverChannelRef.compareAndSet(null, ch);
latch.countDown();
}
});
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
throwableRef.compareAndSet(null, cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
latch2.countDown();
}
});
}
});
Channel sc = sb.bind().sync().channel();
cb.connect(sc.localAddress()).syncUninterruptibly();
// Wait for the server to get setup.
latch.await();
// The server has SO_LINGER=0 and so it must send a RST when close is called.
serverChannelRef.get().close();
// Wait for the client to get channelInactive.
latch2.await();
// Verify the client did not received a RST.
assertNull(throwableRef.get());
}
}

View File

@ -97,21 +97,11 @@ abstract class AbstractEpollChannel extends AbstractChannel implements UnixChann
@Override
protected void doClose() throws Exception {
this.active = false;
Socket fd = fileDescriptor;
active = false;
try {
// deregister from epoll now and shutdown the socket.
doDeregister();
if (!fd.isShutdown()) {
try {
fd().shutdown();
} catch (IOException ignored) {
// The FD will be closed, so if shutdown fails there is nothing we can do.
}
}
} finally {
// Ensure the file descriptor is closed in all cases.
fd.close();
fileDescriptor.close();
}
}

View File

@ -34,18 +34,26 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull;
* {@link FileDescriptor} for it.
*/
public class FileDescriptor {
private static final AtomicIntegerFieldUpdater<FileDescriptor> openUpdater;
private static final AtomicIntegerFieldUpdater<FileDescriptor> stateUpdater;
static {
AtomicIntegerFieldUpdater<FileDescriptor> updater
= PlatformDependent.newAtomicIntegerFieldUpdater(FileDescriptor.class, "open");
= PlatformDependent.newAtomicIntegerFieldUpdater(FileDescriptor.class, "state");
if (updater == null) {
updater = AtomicIntegerFieldUpdater.newUpdater(FileDescriptor.class, "open");
updater = AtomicIntegerFieldUpdater.newUpdater(FileDescriptor.class, "state");
}
openUpdater = updater;
stateUpdater = updater;
}
private final int fd;
private volatile int open = 1;
private static final int STATE_CLOSED_MASK = 1;
private static final int STATE_INPUT_SHUTDOWN_MASK = 1 << 1;
private static final int STATE_OUTPUT_SHUTDOWN_MASK = 1 << 2;
private static final int STATE_ALL_MASK = STATE_CLOSED_MASK |
STATE_INPUT_SHUTDOWN_MASK |
STATE_OUTPUT_SHUTDOWN_MASK;
/**
* Bit map = [Output Shutdown | Input Shutdown | Closed]
*/
volatile int state;
final int fd;
public FileDescriptor(int fd) {
if (fd < 0) {
@ -65,11 +73,19 @@ public class FileDescriptor {
* Close the file descriptor.
*/
public void close() throws IOException {
if (openUpdater.compareAndSet(this, 1, 0)) {
int res = close(fd);
if (res < 0) {
throw newIOException("close", res);
for (;;) {
int state = this.state;
if (isClosed(state)) {
return;
}
// Once a close operation happens, the channel is considered shutdown.
if (casState(state, state | STATE_ALL_MASK)) {
break;
}
}
int res = close(fd);
if (res < 0) {
throw newIOException("close", res);
}
}
@ -77,7 +93,7 @@ public class FileDescriptor {
* Returns {@code true} if the file descriptor is open.
*/
public boolean isOpen() {
return open == 1;
return !isClosed(state);
}
public final int write(ByteBuffer buf, int pos, int limit) throws IOException {
@ -188,6 +204,36 @@ public class FileDescriptor {
return new FileDescriptor[]{new FileDescriptor((int) (res >>> 32)), new FileDescriptor((int) res)};
}
final boolean casState(int expected, int update) {
return stateUpdater.compareAndSet(this, expected, update);
}
static boolean isClosed(int state) {
return (state & STATE_CLOSED_MASK) != 0;
}
static boolean isInputShutdown(int state) {
return (state & STATE_INPUT_SHUTDOWN_MASK) != 0;
}
static boolean isOutputShutdown(int state) {
return (state & STATE_OUTPUT_SHUTDOWN_MASK) != 0;
}
static boolean shouldAttemptShutdown(int state, boolean read, boolean write) {
return !isClosed(state) && (read && !isInputShutdown(state) || write && !isOutputShutdown(state));
}
static int calculateShutdownState(int state, boolean read, boolean write) {
if (read) {
state |= STATE_INPUT_SHUTDOWN_MASK;
}
if (write) {
state |= STATE_OUTPUT_SHUTDOWN_MASK;
}
return state;
}
private static native int open(String path);
private static native int close(int fd);

View File

@ -42,40 +42,41 @@ import static io.netty.channel.unix.NativeInetAddress.ipv4MappedIpv6Address;
* <strong>Internal usage only!</strong>
*/
public final class Socket extends FileDescriptor {
private volatile boolean inputShutdown;
private volatile boolean outputShutdown;
public Socket(int fd) {
super(fd);
}
public void shutdown() throws IOException {
shutdown(!inputShutdown, !outputShutdown);
shutdown(true, true);
}
public void shutdown(boolean read, boolean write) throws IOException {
inputShutdown = read || inputShutdown;
outputShutdown = write || outputShutdown;
shutdown0(read, write);
}
private void shutdown0(boolean read, boolean write) throws IOException {
int res = shutdown(intValue(), read, write);
for (;;) {
int state = this.state;
if (!shouldAttemptShutdown(state, read, write)) {
return;
}
if (casState(state, calculateShutdownState(state, read, write))) {
break;
}
}
int res = shutdown(fd, read, write);
if (res < 0) {
ioResult("shutdown", res, CONNECTION_NOT_CONNECTED_SHUTDOWN_EXCEPTION);
}
}
public boolean isShutdown() {
return isInputShutdown() && isOutputShutdown();
int state = this.state;
return isInputShutdown(state) && isOutputShutdown(state);
}
public boolean isInputShutdown() {
return inputShutdown;
return isInputShutdown(state);
}
public boolean isOutputShutdown() {
return outputShutdown;
return isOutputShutdown(state);
}
public int sendTo(ByteBuffer buf, int pos, int limit, InetAddress addr, int port) throws IOException {
@ -91,7 +92,7 @@ public final class Socket extends FileDescriptor {
scopeId = 0;
address = ipv4MappedIpv6Address(addr.getAddress());
}
int res = sendTo(intValue(), buf, pos, limit, address, scopeId, port);
int res = sendTo(fd, buf, pos, limit, address, scopeId, port);
if (res >= 0) {
return res;
}
@ -112,7 +113,7 @@ public final class Socket extends FileDescriptor {
scopeId = 0;
address = ipv4MappedIpv6Address(addr.getAddress());
}
int res = sendToAddress(intValue(), memoryAddress, pos, limit, address, scopeId, port);
int res = sendToAddress(fd, memoryAddress, pos, limit, address, scopeId, port);
if (res >= 0) {
return res;
}
@ -132,7 +133,7 @@ public final class Socket extends FileDescriptor {
scopeId = 0;
address = ipv4MappedIpv6Address(addr.getAddress());
}
int res = sendToAddresses(intValue(), memoryAddress, length, address, scopeId, port);
int res = sendToAddresses(fd, memoryAddress, length, address, scopeId, port);
if (res >= 0) {
return res;
}
@ -140,11 +141,11 @@ public final class Socket extends FileDescriptor {
}
public DatagramSocketAddress recvFrom(ByteBuffer buf, int pos, int limit) throws IOException {
return recvFrom(intValue(), buf, pos, limit);
return recvFrom(fd, buf, pos, limit);
}
public DatagramSocketAddress recvFromAddress(long memoryAddress, int pos, int limit) throws IOException {
return recvFromAddress(intValue(), memoryAddress, pos, limit);
return recvFromAddress(fd, memoryAddress, pos, limit);
}
public boolean connect(SocketAddress socketAddress) throws IOException {
@ -152,10 +153,10 @@ public final class Socket extends FileDescriptor {
if (socketAddress instanceof InetSocketAddress) {
InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress;
NativeInetAddress address = NativeInetAddress.newInstance(inetSocketAddress.getAddress());
res = connect(intValue(), address.address, address.scopeId, inetSocketAddress.getPort());
res = connect(fd, address.address, address.scopeId, inetSocketAddress.getPort());
} else if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress unixDomainSocketAddress = (DomainSocketAddress) socketAddress;
res = connectDomainSocket(intValue(), unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8));
res = connectDomainSocket(fd, unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8));
} else {
throw new Error("Unexpected SocketAddress implementation " + socketAddress);
}
@ -170,7 +171,7 @@ public final class Socket extends FileDescriptor {
}
public boolean finishConnect() throws IOException {
int res = finishConnect(intValue());
int res = finishConnect(fd);
if (res < 0) {
if (res == ERRNO_EINPROGRESS_NEGATIVE) {
// connect still in progress
@ -185,13 +186,13 @@ public final class Socket extends FileDescriptor {
if (socketAddress instanceof InetSocketAddress) {
InetSocketAddress addr = (InetSocketAddress) socketAddress;
NativeInetAddress address = NativeInetAddress.newInstance(addr.getAddress());
int res = bind(intValue(), address.address, address.scopeId, addr.getPort());
int res = bind(fd, address.address, address.scopeId, addr.getPort());
if (res < 0) {
throw newIOException("bind", res);
}
} else if (socketAddress instanceof DomainSocketAddress) {
DomainSocketAddress addr = (DomainSocketAddress) socketAddress;
int res = bindDomainSocket(intValue(), addr.path().getBytes(CharsetUtil.UTF_8));
int res = bindDomainSocket(fd, addr.path().getBytes(CharsetUtil.UTF_8));
if (res < 0) {
throw newIOException("bind", res);
}
@ -201,14 +202,14 @@ public final class Socket extends FileDescriptor {
}
public void listen(int backlog) throws IOException {
int res = listen(intValue(), backlog);
int res = listen(fd, backlog);
if (res < 0) {
throw newIOException("listen", res);
}
}
public int accept(byte[] addr) throws IOException {
int res = accept(intValue(), addr);
int res = accept(fd, addr);
if (res >= 0) {
return res;
}
@ -220,7 +221,7 @@ public final class Socket extends FileDescriptor {
}
public InetSocketAddress remoteAddress() {
byte[] addr = remoteAddress(intValue());
byte[] addr = remoteAddress(fd);
// addr may be null if getpeername failed.
// See https://github.com/netty/netty/issues/3328
if (addr == null) {
@ -230,7 +231,7 @@ public final class Socket extends FileDescriptor {
}
public InetSocketAddress localAddress() {
byte[] addr = localAddress(intValue());
byte[] addr = localAddress(fd);
// addr may be null if getpeername failed.
// See https://github.com/netty/netty/issues/3328
if (addr == null) {
@ -240,77 +241,77 @@ public final class Socket extends FileDescriptor {
}
public int getReceiveBufferSize() throws IOException {
return getReceiveBufferSize(intValue());
return getReceiveBufferSize(fd);
}
public int getSendBufferSize() throws IOException {
return getSendBufferSize(intValue());
return getSendBufferSize(fd);
}
public boolean isKeepAlive() throws IOException {
return isKeepAlive(intValue()) != 0;
return isKeepAlive(fd) != 0;
}
public boolean isTcpNoDelay() throws IOException {
return isTcpNoDelay(intValue()) != 0;
return isTcpNoDelay(fd) != 0;
}
public boolean isTcpCork() throws IOException {
return isTcpCork(intValue()) != 0;
return isTcpCork(fd) != 0;
}
public int getSoLinger() throws IOException {
return getSoLinger(intValue());
return getSoLinger(fd);
}
public int getTcpDeferAccept() throws IOException {
return getTcpDeferAccept(intValue());
return getTcpDeferAccept(fd);
}
public boolean isTcpQuickAck() throws IOException {
return isTcpQuickAck(intValue()) != 0;
return isTcpQuickAck(fd) != 0;
}
public int getSoError() {
return getSoError(intValue());
return getSoError(fd);
}
public void setKeepAlive(boolean keepAlive) throws IOException {
setKeepAlive(intValue(), keepAlive ? 1 : 0);
setKeepAlive(fd, keepAlive ? 1 : 0);
}
public void setReceiveBufferSize(int receiveBufferSize) throws IOException {
setReceiveBufferSize(intValue(), receiveBufferSize);
setReceiveBufferSize(fd, receiveBufferSize);
}
public void setSendBufferSize(int sendBufferSize) throws IOException {
setSendBufferSize(intValue(), sendBufferSize);
setSendBufferSize(fd, sendBufferSize);
}
public void setTcpNoDelay(boolean tcpNoDelay) throws IOException {
setTcpNoDelay(intValue(), tcpNoDelay ? 1 : 0);
setTcpNoDelay(fd, tcpNoDelay ? 1 : 0);
}
public void setTcpCork(boolean tcpCork) throws IOException {
setTcpCork(intValue(), tcpCork ? 1 : 0);
setTcpCork(fd, tcpCork ? 1 : 0);
}
public void setSoLinger(int soLinger) throws IOException {
setSoLinger(intValue(), soLinger);
setSoLinger(fd, soLinger);
}
public void setTcpDeferAccept(int deferAccept) throws IOException {
setTcpDeferAccept(intValue(), deferAccept);
setTcpDeferAccept(fd, deferAccept);
}
public void setTcpQuickAck(boolean quickAck) throws IOException {
setTcpQuickAck(intValue(), quickAck ? 1 : 0);
setTcpQuickAck(fd, quickAck ? 1 : 0);
}
@Override
public String toString() {
return "Socket{" +
"fd=" + intValue() +
"fd=" + fd +
'}';
}

View File

@ -0,0 +1,49 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.epoll;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.unix.Errors;
import io.netty.channel.unix.Errors.NativeIoException;
import io.netty.testsuite.transport.TestsuitePermutation;
import io.netty.testsuite.transport.socket.SocketRstTest;
import java.io.IOException;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class EpollSocketRstTest extends SocketRstTest {
@Override
protected List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> newFactories() {
return EpollSocketTestPermutation.INSTANCE.socket();
}
@Override
protected void assertRstOnCloseException(IOException cause, Channel clientChannel) {
if (!AbstractEpollChannel.class.isInstance(clientChannel)) {
super.assertRstOnCloseException(cause, clientChannel);
return;
}
assertTrue("actual [type, message]: [" + cause.getClass() + ", " + cause.getMessage() + "]",
cause instanceof NativeIoException);
assertEquals(Errors.ERRNO_ECONNRESET_NEGATIVE, ((NativeIoException) cause).expectedErr());
}
}