From 237a4da1b76ef7f303eca8f6ec8539efc9fe4d26 Mon Sep 17 00:00:00 2001 From: Scott Mitchell Date: Tue, 1 Aug 2017 09:18:08 -0700 Subject: [PATCH] Shutting down the outbound side of the channel should not accept future writes Motivation: Implementations of DuplexChannel delegate the shutdownOutput to the underlying transport, but do not take any action on the ChannelOutboundBuffer. In the event of a write failure due to the underlying transport failing and application may attempt to shutdown the output and allow the read side the transport to finish and detect the close. However this may result in an issue where writes are failed, this generates a writability change, we continue to write more data, and this may lead to another writability change, and this loop may continue. Shutting down the output should fail all pending writes and not allow any future writes to avoid this scenario. Modifications: - Implementations of DuplexChannel should null out the ChannelOutboundBuffer and fail all pending writes Result: More controlled sequencing for shutting down the output side of a channel. --- .../SocketShutdownOutputBySelfTest.java | 98 ++++++++++++++++++- .../epoll/AbstractEpollStreamChannel.java | 2 + .../kqueue/AbstractKQueueStreamChannel.java | 2 + .../io/netty/channel/AbstractChannel.java | 16 +++ .../netty/channel/ChannelOutboundBuffer.java | 10 +- .../socket/ChannelOutputShutdownEvent.java | 33 +++++++ .../ChannelOutputShutdownException.java | 32 ++++++ .../channel/socket/nio/NioSocketChannel.java | 1 + .../channel/socket/oio/OioSocketChannel.java | 9 +- 9 files changed, 197 insertions(+), 6 deletions(-) create mode 100644 transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java create mode 100644 transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java index 751584fd0e..e7f288b146 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java @@ -18,22 +18,31 @@ package io.netty.testsuite.transport.socket; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.oio.OioSocketChannel; import org.junit.Test; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.nio.channels.ClosedChannelException; +import java.util.concurrent.BlockingDeque; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest { @@ -121,15 +130,85 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest { } } + @Test(timeout = 30000) + public void testWriteAfterShutdownOutputNoWritabilityChange() throws Throwable { + run(); + } + + public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable { + final TestHandler h = new TestHandler(); + ServerSocket ss = new ServerSocket(); + Socket s = null; + SocketChannel ch = null; + try { + ss.bind(newSocketAddress()); + cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4)); + ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel(); + assumeFalse(ch instanceof OioSocketChannel); + assertTrue(ch.isActive()); + assertFalse(ch.isOutputShutdown()); + + s = ss.accept(); + + byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 }; + ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes)); + h.assertWritability(false); + ch.flush(); + writeFuture.sync(); + h.assertWritability(true); + for (int i = 0; i < expectedBytes.length; ++i) { + assertEquals(expectedBytes[i], s.getInputStream().read()); + } + + assertTrue(h.ch.isOpen()); + assertTrue(h.ch.isActive()); + assertFalse(h.ch.isInputShutdown()); + assertFalse(h.ch.isOutputShutdown()); + + // Make the connection half-closed and ensure read() returns -1. + ch.shutdownOutput().sync(); + assertEquals(-1, s.getInputStream().read()); + + assertTrue(h.ch.isOpen()); + assertTrue(h.ch.isActive()); + assertFalse(h.ch.isInputShutdown()); + assertTrue(h.ch.isOutputShutdown()); + + try { + // If half-closed, the local endpoint shouldn't be able to write + ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{ 2 })).sync(); + fail(); + } catch (Throwable cause) { + checkThrowable(cause); + } + assertNull(h.writabilityQueue.poll()); + } finally { + if (s != null) { + s.close(); + } + if (ch != null) { + ch.close(); + } + ss.close(); + } + } + private static void checkThrowable(Throwable cause) throws Throwable { // Depending on OIO / NIO both are ok if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) { throw cause; } } - private static class TestHandler extends SimpleChannelInboundHandler { + + private static final class TestHandler extends SimpleChannelInboundHandler { volatile SocketChannel ch; final BlockingQueue queue = new LinkedBlockingQueue(); + final BlockingDeque writabilityQueue = new LinkedBlockingDeque(); + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + writabilityQueue.add(ctx.channel().isWritable()); + } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { @@ -140,5 +219,22 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest { public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { queue.offer(msg.readByte()); } + + private void drainWritabilityQueue() throws InterruptedException { + while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) { + // Just drain the queue. + } + } + + void assertWritability(boolean isWritable) throws InterruptedException { + try { + Boolean writability = writabilityQueue.takeLast(); + assertEquals(isWritable, writability); + // TODO(scott): why do we get multiple writability changes here ... race condition? + drainWritabilityQueue(); + } catch (Throwable c) { + c.printStackTrace(); + } + } } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index cc4ed88f44..4430736a8c 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -545,6 +545,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im private void shutdownOutput0(final ChannelPromise promise) { try { socket.shutdown(false, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); @@ -563,6 +564,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im private void shutdown0(final ChannelPromise promise) { try { socket.shutdown(true, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java index b3a302cd82..2944e99153 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java @@ -373,6 +373,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel private void shutdownOutput0(final ChannelPromise promise) { try { socket.shutdown(false, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); @@ -391,6 +392,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel private void shutdown0(final ChannelPromise promise) { try { socket.shutdown(true, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index c0dbd41e18..d3a2584c4d 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -16,10 +16,13 @@ package io.netty.channel; import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.channel.socket.ChannelOutputShutdownException; import io.netty.util.DefaultAttributeMap; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -607,6 +610,19 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha close(promise, CLOSE_CLOSED_CHANNEL_EXCEPTION, CLOSE_CLOSED_CHANNEL_EXCEPTION, false); } + @UnstableApi + public void shutdownOutput() { + final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + if (outboundBuffer == null) { + return; + } + this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer. + ChannelOutputShutdownException e = new ChannelOutputShutdownException("Channel output explicitly shutdown"); + outboundBuffer.failFlushed(e, false); + outboundBuffer.close(e, true); + pipeline.fireUserEventTriggered(ChannelOutputShutdownEvent.INSTANCE); + } + private void close(final ChannelPromise promise, final Throwable cause, final ClosedChannelException closeCause, final boolean notify) { if (!promise.setUncancellable()) { diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index 34fbde3556..28b308f4f1 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -622,12 +622,12 @@ public final class ChannelOutboundBuffer { } } - void close(final ClosedChannelException cause) { + void close(final Throwable cause, final boolean allowChannelOpen) { if (inFail) { channel.eventLoop().execute(new Runnable() { @Override public void run() { - close(cause); + close(cause, allowChannelOpen); } }); return; @@ -635,7 +635,7 @@ public final class ChannelOutboundBuffer { inFail = true; - if (channel.isOpen()) { + if (!allowChannelOpen && channel.isOpen()) { throw new IllegalStateException("close() must be invoked after the channel is closed."); } @@ -663,6 +663,10 @@ public final class ChannelOutboundBuffer { clearNioBuffers(); } + void close(ClosedChannelException cause) { + close(cause, false); + } + private static void safeSuccess(ChannelPromise promise) { // Only log if the given promise is not of type VoidChannelPromise as trySuccess(...) is expected to return // false. diff --git a/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java new file mode 100644 index 0000000000..49d882f92e --- /dev/null +++ b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java @@ -0,0 +1,33 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.util.internal.UnstableApi; + +/** + * Special event which will be fired and passed to the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} methods once the output of + * a {@link SocketChannel} was shutdown. + */ +@UnstableApi +public final class ChannelOutputShutdownEvent { + public static final ChannelOutputShutdownEvent INSTANCE = new ChannelOutputShutdownEvent(); + + private ChannelOutputShutdownEvent() { + } +} diff --git a/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java new file mode 100644 index 0000000000..7ad7e4785b --- /dev/null +++ b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java @@ -0,0 +1,32 @@ + + +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; + +/** + * Used to fail pending writes when a channel's output has been shutdown. + */ +@UnstableApi +public final class ChannelOutputShutdownException extends IOException { + public ChannelOutputShutdownException(String msg) { + super(msg); + } +} diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index e29f4e34ab..4677153144 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -259,6 +259,7 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty } else { javaChannel().socket().shutdownOutput(); } + ((AbstractUnsafe) unsafe()).shutdownOutput(); } private void shutdownInput0(final ChannelPromise promise) { diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java index b3090ce0a9..54f8d818d9 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java @@ -173,13 +173,18 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan private void shutdownOutput0(ChannelPromise promise) { try { - socket.shutdownOutput(); + shutdownOutput0(); promise.setSuccess(); } catch (Throwable t) { promise.setFailure(t); } } + private void shutdownOutput0() throws IOException { + socket.shutdownOutput(); + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } + @Override public ChannelFuture shutdownInput(final ChannelPromise promise) { EventLoop loop = eventLoop(); @@ -224,7 +229,7 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan private void shutdown0(ChannelPromise promise) { Throwable cause = null; try { - socket.shutdownOutput(); + shutdownOutput0(); } catch (Throwable t) { cause = t; }