diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java index d67c3baa3d..b5600d4252 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketHalfClosedTest.java @@ -28,15 +28,19 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.ChannelInputShutdownEvent; import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.channel.socket.ChannelOutputShutdownEvent; import io.netty.channel.socket.DuplexChannel; import io.netty.util.UncheckedBooleanSupplier; import org.junit.Test; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; public class SocketHalfClosedTest extends AbstractSocketTest { @@ -50,8 +54,8 @@ public class SocketHalfClosedTest extends AbstractSocketTest { testAllDataReadAfterHalfClosure(false, sb, cb); } - public void testAllDataReadAfterHalfClosure(final boolean autoRead, - ServerBootstrap sb, Bootstrap cb) throws Throwable { + private static void testAllDataReadAfterHalfClosure(final boolean autoRead, + ServerBootstrap sb, Bootstrap cb) throws Throwable { final int totalServerBytesWritten = 1024 * 16; final int numReadsPerReadLoop = 2; final CountDownLatch serverInitializedLatch = new CountDownLatch(1); @@ -150,6 +154,200 @@ public class SocketHalfClosedTest extends AbstractSocketTest { } } + @Test + public void testAutoCloseFalseDoesShutdownOutput() throws Throwable { + run(); + } + + public void testAutoCloseFalseDoesShutdownOutput(ServerBootstrap sb, Bootstrap cb) throws Throwable { + testAutoCloseFalseDoesShutdownOutput(false, false, sb, cb); + testAutoCloseFalseDoesShutdownOutput(false, true, sb, cb); + testAutoCloseFalseDoesShutdownOutput(true, false, sb, cb); + testAutoCloseFalseDoesShutdownOutput(true, true, sb, cb); + } + + private static void testAutoCloseFalseDoesShutdownOutput(boolean allowHalfClosed, + final boolean clientIsLeader, + ServerBootstrap sb, + Bootstrap cb) throws InterruptedException { + final int expectedBytes = 100; + final CountDownLatch serverReadExpectedLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(1); + final AtomicReference causeRef = new AtomicReference(); + Channel serverChannel = null; + Channel clientChannel = null; + try { + cb.option(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed) + .option(ChannelOption.AUTO_CLOSE, false) + .option(ChannelOption.SO_LINGER, 0); + sb.childOption(ChannelOption.ALLOW_HALF_CLOSURE, allowHalfClosed) + .childOption(ChannelOption.AUTO_CLOSE, false) + .childOption(ChannelOption.SO_LINGER, 0); + + final SimpleChannelInboundHandler leaderHandler = new AutoCloseFalseLeader(expectedBytes, + serverReadExpectedLatch, doneLatch, causeRef); + final SimpleChannelInboundHandler followerHandler = new AutoCloseFalseFollower(expectedBytes, + serverReadExpectedLatch, doneLatch, causeRef); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(clientIsLeader ? followerHandler :leaderHandler); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(clientIsLeader ? leaderHandler : followerHandler); + } + }); + + serverChannel = sb.bind().sync().channel(); + clientChannel = cb.connect(serverChannel.localAddress()).sync().channel(); + + doneLatch.await(); + assertNull(causeRef.get()); + } finally { + if (clientChannel != null) { + clientChannel.close().sync(); + } + if (serverChannel != null) { + serverChannel.close().sync(); + } + } + } + + private static final class AutoCloseFalseFollower extends SimpleChannelInboundHandler { + private final int expectedBytes; + private final CountDownLatch followerCloseLatch; + private final CountDownLatch doneLatch; + private final AtomicReference causeRef; + private int bytesRead; + + AutoCloseFalseFollower(int expectedBytes, CountDownLatch followerCloseLatch, CountDownLatch doneLatch, + AtomicReference causeRef) { + this.expectedBytes = expectedBytes; + this.followerCloseLatch = followerCloseLatch; + this.doneLatch = doneLatch; + this.causeRef = causeRef; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + checkPrematureClose(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.close(); + checkPrematureClose(); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + bytesRead += msg.readableBytes(); + if (bytesRead >= expectedBytes) { + // We write a reply and immediately close our end of the socket. + ByteBuf buf = ctx.alloc().buffer(expectedBytes); + buf.writerIndex(buf.writerIndex() + expectedBytes); + ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + future.channel().close().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + followerCloseLatch.countDown(); + } + }); + } + }); + } + } + + private void checkPrematureClose() { + if (bytesRead < expectedBytes) { + causeRef.set(new IllegalStateException("follower premature close")); + doneLatch.countDown(); + } + } + } + + private static final class AutoCloseFalseLeader extends SimpleChannelInboundHandler { + private final int expectedBytes; + private final CountDownLatch followerCloseLatch; + private final CountDownLatch doneLatch; + private final AtomicReference causeRef; + private int bytesRead; + private boolean seenOutputShutdown; + + AutoCloseFalseLeader(int expectedBytes, CountDownLatch followerCloseLatch, CountDownLatch doneLatch, + AtomicReference causeRef) { + this.expectedBytes = expectedBytes; + this.followerCloseLatch = followerCloseLatch; + this.doneLatch = doneLatch; + this.causeRef = causeRef; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ByteBuf buf = ctx.alloc().buffer(expectedBytes); + buf.writerIndex(buf.writerIndex() + expectedBytes); + ctx.writeAndFlush(buf.retainedDuplicate()); + + // We wait here to ensure that we write before we have a chance to process the outbound + // shutdown event. + followerCloseLatch.await(); + + // This write should fail, but we should still be allowed to read the peer's data + ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.cause() == null) { + causeRef.set(new IllegalStateException("second write should have failed!")); + doneLatch.countDown(); + } + } + }); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + bytesRead += msg.readableBytes(); + if (bytesRead >= expectedBytes) { + if (!seenOutputShutdown) { + causeRef.set(new IllegalStateException( + ChannelOutputShutdownEvent.class.getSimpleName() + " event was not seen")); + } + doneLatch.countDown(); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof ChannelOutputShutdownEvent) { + seenOutputShutdown = true; + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + checkPrematureClose(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.close(); + checkPrematureClose(); + } + + private void checkPrematureClose() { + if (bytesRead < expectedBytes || !seenOutputShutdown) { + causeRef.set(new IllegalStateException("leader premature close")); + doneLatch.countDown(); + } + } + } + @Test public void testAllDataReadClosure() throws Throwable { run(); @@ -162,8 +360,8 @@ public class SocketHalfClosedTest extends AbstractSocketTest { testAllDataReadClosure(false, true, sb, cb); } - public void testAllDataReadClosure(final boolean autoRead, final boolean allowHalfClosed, - ServerBootstrap sb, Bootstrap cb) throws Throwable { + private static void testAllDataReadClosure(final boolean autoRead, final boolean allowHalfClosed, + ServerBootstrap sb, Bootstrap cb) throws Throwable { final int totalServerBytesWritten = 1024 * 16; final int numReadsPerReadLoop = 2; final CountDownLatch serverInitializedLatch = new CountDownLatch(1); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java index ebda9ebfe7..52ad42e3ce 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollServerChannel.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; import io.netty.channel.ServerChannel; +import io.netty.util.internal.UnstableApi; import java.net.InetSocketAddress; import java.net.SocketAddress; @@ -72,6 +73,16 @@ public abstract class AbstractEpollServerChannel extends AbstractEpollChannel im throw new UnsupportedOperationException(); } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + try { + super.doShutdownOutput(cause); + } finally { + close(); + } + } + abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception; final class EpollServerSocketUnsafe extends AbstractEpollUnsafe { diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index 4430736a8c..03f5559458 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -37,6 +37,7 @@ import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -542,10 +543,27 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + try { + // The native socket implementation may throw a NotYetConnected exception when we attempt to shut it down. + // However NIO doesn't propagate an exception in the same situation (write failure), and we just want to + // update the socket state to flag that it has been shutdown. So don't use a voidPromise but instead create + // a new promise and ignore the results. + shutdownOutput0(newPromise()); + } finally { + super.doShutdownOutput(cause); + } + } + private void shutdownOutput0(final ChannelPromise promise) { try { - socket.shutdown(false, true); - ((AbstractUnsafe) unsafe()).shutdownOutput(); + try { + socket.shutdown(false, true); + } finally { + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); @@ -563,8 +581,11 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im private void shutdown0(final ChannelPromise promise) { try { - socket.shutdown(true, true); - ((AbstractUnsafe) unsafe()).shutdownOutput(); + try { + socket.shutdown(true, true); + } finally { + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java index eca49d5fa1..e3ee726535 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java @@ -31,6 +31,7 @@ import io.netty.channel.unix.DatagramSocketAddress; import io.netty.channel.unix.IovArray; import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; import java.io.IOException; import java.net.InetAddress; @@ -39,8 +40,6 @@ import java.net.NetworkInterface; import java.net.SocketAddress; import java.net.SocketException; import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; import static io.netty.channel.epoll.LinuxSocket.newSocketDgram; @@ -424,6 +423,16 @@ public final class EpollDatagramChannel extends AbstractEpollChannel implements return false; } + @UnstableApi + @Override + protected void doShutdownOutput(Throwable cause) throws Exception { + // UDP sockets are not connected. A write failure may just be temporary or disconnect was called. + ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer(); + if (channelOutboundBuffer != null) { + channelOutboundBuffer.remove(cause); + } + } + @Override protected void doClose() throws Exception { super.doClose(); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java index 47dad9c394..0760527761 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueServerChannel.java @@ -70,6 +70,16 @@ public abstract class AbstractKQueueServerChannel extends AbstractKQueueChannel throw new UnsupportedOperationException(); } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + try { + super.doShutdownOutput(cause); + } finally { + close(); + } + } + abstract Channel newChildChannel(int fd, byte[] remote, int offset, int len) throws Exception; @Override diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java index 2944e99153..9e9a59eab7 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java @@ -370,10 +370,27 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + try { + // The native socket implementation may throw a NotYetConnected exception when we attempt to shut it down. + // However NIO doesn't propagate an exception in the same situation (write failure), and we just want to + // update the socket state to flag that it has been shutdown. So don't use a voidPromise but instead create + // a new promise and ignore the results. + shutdownOutput0(newPromise()); + } finally { + super.doShutdownOutput(cause); + } + } + private void shutdownOutput0(final ChannelPromise promise) { try { - socket.shutdown(false, true); - ((AbstractUnsafe) unsafe()).shutdownOutput(); + try { + socket.shutdown(false, true); + } finally { + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); @@ -391,8 +408,11 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel private void shutdown0(final ChannelPromise promise) { try { - socket.shutdown(true, true); - ((AbstractUnsafe) unsafe()).shutdownOutput(); + try { + socket.shutdown(true, true); + } finally { + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index 887951a28f..59e730114c 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -397,6 +397,16 @@ public final class KQueueDatagramChannel extends AbstractKQueueChannel implement return false; } + @UnstableApi + @Override + protected void doShutdownOutput(Throwable cause) throws Exception { + // UDP sockets are not connected. A write failure may just be temporary or {@link #doDisconnect()} was called. + ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer(); + if (channelOutboundBuffer != null) { + channelOutboundBuffer.remove(cause); + } + } + @Override protected void doClose() throws Exception { super.doClose(); diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index b9fbaac329..bf248075b2 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -611,14 +611,27 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha close(promise, CLOSE_CLOSED_CHANNEL_EXCEPTION, CLOSE_CLOSED_CHANNEL_EXCEPTION, false); } + /** + * Shutdown the output portion of the corresponding {@link Channel}. + * For example this will clean up the {@link ChannelOutboundBuffer} and not allow any more writes. + */ @UnstableApi - public void shutdownOutput() { + public final void shutdownOutput() { + shutdownOutput(null); + } + + /** + * Shutdown the output portion of the corresponding {@link Channel}. + * For example this will clean up the {@link ChannelOutboundBuffer} and not allow any more writes. + * @param cause The cause which may provide rational for the shutdown. + */ + final void shutdownOutput(Throwable cause) { final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; if (outboundBuffer == null) { return; } this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer. - ChannelOutputShutdownException e = new ChannelOutputShutdownException("Channel output explicitly shutdown"); + ChannelOutputShutdownException e = new ChannelOutputShutdownException("Channel output shutdown", cause); outboundBuffer.failFlushed(e, false); outboundBuffer.close(e, true); pipeline.fireUserEventTriggered(ChannelOutputShutdownEvent.INSTANCE); @@ -885,7 +898,11 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha */ close(voidPromise(), t, FLUSH0_CLOSED_CHANNEL_EXCEPTION, false); } else { - outboundBuffer.failFlushed(t, true); + try { + doShutdownOutput(t); + } catch (Throwable t2) { + close(voidPromise(), t2, FLUSH0_CLOSED_CHANNEL_EXCEPTION, false); + } } } finally { inFlush0 = false; @@ -1019,6 +1036,16 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha */ protected abstract void doClose() throws Exception; + /** + * Called when conditions justify shutting down the output portion of the channel. This may happen if a write + * operation throws an exception. + * @param cause The cause for the shutdown. + */ + @UnstableApi + protected void doShutdownOutput(Throwable cause) throws Exception { + ((AbstractUnsafe) unsafe).shutdownOutput(cause); + } + /** * Deregister the {@link Channel} from its {@link EventLoop}. * diff --git a/transport/src/main/java/io/netty/channel/AbstractServerChannel.java b/transport/src/main/java/io/netty/channel/AbstractServerChannel.java index 126ce5fe62..ad21e92638 100644 --- a/transport/src/main/java/io/netty/channel/AbstractServerChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractServerChannel.java @@ -15,6 +15,8 @@ */ package io.netty.channel; +import io.netty.util.internal.UnstableApi; + import java.net.SocketAddress; /** @@ -58,6 +60,16 @@ public abstract class AbstractServerChannel extends AbstractChannel implements S throw new UnsupportedOperationException(); } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + try { + super.doShutdownOutput(cause); + } finally { + close(); + } + } + @Override protected AbstractUnsafe newUnsafe() { return new DefaultServerUnsafe(); diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index 06494f8250..9708773bec 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -40,6 +40,7 @@ import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.RecyclableArrayList; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -695,6 +696,13 @@ public class EmbeddedChannel extends AbstractChannel { state = State.CLOSED; } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + super.doShutdownOutput(cause); + close(); + } + @Override protected void doBeginRead() throws Exception { // NOOP diff --git a/transport/src/main/java/io/netty/channel/local/LocalChannel.java b/transport/src/main/java/io/netty/channel/local/LocalChannel.java index 39cf1a7588..be1684c616 100644 --- a/transport/src/main/java/io/netty/channel/local/LocalChannel.java +++ b/transport/src/main/java/io/netty/channel/local/LocalChannel.java @@ -32,6 +32,7 @@ import io.netty.util.concurrent.SingleThreadEventExecutor; import io.netty.util.internal.InternalThreadLocalMap; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -291,6 +292,13 @@ public class LocalChannel extends AbstractChannel { } } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + super.doShutdownOutput(cause); + close(); + } + private void tryClose(boolean isActive) { if (isActive) { unsafe().close(unsafe().voidPromise()); diff --git a/transport/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java b/transport/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java index 0b0e5ebd0b..c03bd947eb 100644 --- a/transport/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java +++ b/transport/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java @@ -21,6 +21,7 @@ import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPipeline; import io.netty.channel.RecvByteBufAllocator; import io.netty.channel.ServerChannel; +import io.netty.util.internal.UnstableApi; import java.io.IOException; import java.net.PortUnreachableException; @@ -179,6 +180,13 @@ public abstract class AbstractNioMessageChannel extends AbstractNioChannel { !(this instanceof ServerChannel); } + @UnstableApi + @Override + protected void doShutdownOutput(Throwable cause) throws Exception { + super.doShutdownOutput(cause); + close(); + } + /** * Read messages into the given array and return the amount which was read. */ diff --git a/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java b/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java index 0543b83c13..4208335a55 100644 --- a/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java +++ b/transport/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java @@ -19,6 +19,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelPipeline; import io.netty.channel.RecvByteBufAllocator; +import io.netty.util.internal.UnstableApi; import java.io.IOException; import java.util.ArrayList; @@ -103,6 +104,13 @@ public abstract class AbstractOioMessageChannel extends AbstractOioChannel { } } + @UnstableApi + @Override + protected void doShutdownOutput(Throwable cause) throws Exception { + super.doShutdownOutput(cause); + close(); + } + /** * Read messages into the given array and return the amount which was read. */ diff --git a/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java index 7ad7e4785b..500a5aab48 100644 --- a/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java +++ b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java @@ -29,4 +29,8 @@ public final class ChannelOutputShutdownException extends IOException { public ChannelOutputShutdownException(String msg) { super(msg); } + + public ChannelOutputShutdownException(String msg, Throwable cause) { + super(msg, cause); + } } diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java index 3e3b1245ec..3b446b81cd 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java @@ -33,6 +33,7 @@ import io.netty.channel.socket.InternetProtocolFamily; import io.netty.util.internal.SocketUtils; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; import java.io.IOException; import java.net.InetAddress; @@ -234,6 +235,16 @@ public final class NioDatagramChannel javaChannel().close(); } + @UnstableApi + @Override + protected void doShutdownOutput(Throwable cause) throws Exception { + // UDP sockets are not connected. A write failure may just be temporary or doDisconnect() was called. + ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer(); + if (channelOutboundBuffer != null) { + channelOutboundBuffer.remove(cause); + } + } + @Override protected int doReadMessages(List buf) throws Exception { DatagramChannel ch = javaChannel(); diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index 4677153144..ed3938b5fd 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelException; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelOutboundBuffer; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; @@ -31,6 +32,7 @@ import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.SocketChannelConfig; import io.netty.util.concurrent.GlobalEventExecutor; import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -146,6 +148,22 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty return (InetSocketAddress) super.remoteAddress(); } + @UnstableApi + @Override + protected final void doShutdownOutput(final Throwable cause) throws Exception { + ChannelFuture future = shutdownOutput(); + if (future.isDone()) { + super.doShutdownOutput(cause); + } else { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + NioSocketChannel.super.doShutdownOutput(cause); + } + }); + } + } + @Override public ChannelFuture shutdownOutput() { return shutdownOutput(newPromise()); @@ -254,12 +272,15 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty } private void shutdownOutput0() throws Exception { - if (PlatformDependent.javaVersion() >= 7) { - javaChannel().shutdownOutput(); - } else { - javaChannel().socket().shutdownOutput(); + try { + if (PlatformDependent.javaVersion() >= 7) { + javaChannel().shutdownOutput(); + } else { + javaChannel().socket().shutdownOutput(); + } + } finally { + ((AbstractUnsafe) unsafe()).shutdownOutput(); } - ((AbstractUnsafe) unsafe()).shutdownOutput(); } private void shutdownInput0(final ChannelPromise promise) { diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java index 5534ae41f3..5763f75572 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java @@ -32,6 +32,7 @@ import io.netty.channel.socket.DatagramPacket; import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -198,6 +199,16 @@ public class OioDatagramChannel extends AbstractOioMessageChannel socket.disconnect(); } + @UnstableApi + @Override + protected final void doShutdownOutput(Throwable cause) throws Exception { + // UDP sockets are not connected. A write failure may just be temporary or {@link #doDisconnect()} was called. + ChannelOutboundBuffer channelOutboundBuffer = unsafe().outboundBuffer(); + if (channelOutboundBuffer != null) { + channelOutboundBuffer.remove(cause); + } + } + @Override protected void doClose() throws Exception { socket.close(); diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java index 54f8d818d9..8423230726 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java @@ -26,6 +26,7 @@ import io.netty.channel.oio.OioByteStreamChannel; import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.socket.SocketChannel; import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -128,6 +129,13 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive(); } + @UnstableApi + @Override + protected final void doShutdownOutput(final Throwable cause) throws Exception { + shutdownOutput0(voidPromise()); + super.doShutdownOutput(cause); + } + @Override public ChannelFuture shutdownOutput() { return shutdownOutput(newPromise()); @@ -181,8 +189,11 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan } private void shutdownOutput0() throws IOException { - socket.shutdownOutput(); - ((AbstractUnsafe) unsafe()).shutdownOutput(); + try { + socket.shutdownOutput(); + } finally { + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } } @Override