diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java index 751584fd0e..e7f288b146 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketShutdownOutputBySelfTest.java @@ -18,22 +18,31 @@ package io.netty.testsuite.transport.socket; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.oio.OioSocketChannel; import org.junit.Test; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.nio.channels.ClosedChannelException; +import java.util.concurrent.BlockingDeque; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.junit.Assume.assumeFalse; public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest { @@ -121,15 +130,85 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest { } } + @Test(timeout = 30000) + public void testWriteAfterShutdownOutputNoWritabilityChange() throws Throwable { + run(); + } + + public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable { + final TestHandler h = new TestHandler(); + ServerSocket ss = new ServerSocket(); + Socket s = null; + SocketChannel ch = null; + try { + ss.bind(newSocketAddress()); + cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4)); + ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel(); + assumeFalse(ch instanceof OioSocketChannel); + assertTrue(ch.isActive()); + assertFalse(ch.isOutputShutdown()); + + s = ss.accept(); + + byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 }; + ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes)); + h.assertWritability(false); + ch.flush(); + writeFuture.sync(); + h.assertWritability(true); + for (int i = 0; i < expectedBytes.length; ++i) { + assertEquals(expectedBytes[i], s.getInputStream().read()); + } + + assertTrue(h.ch.isOpen()); + assertTrue(h.ch.isActive()); + assertFalse(h.ch.isInputShutdown()); + assertFalse(h.ch.isOutputShutdown()); + + // Make the connection half-closed and ensure read() returns -1. + ch.shutdownOutput().sync(); + assertEquals(-1, s.getInputStream().read()); + + assertTrue(h.ch.isOpen()); + assertTrue(h.ch.isActive()); + assertFalse(h.ch.isInputShutdown()); + assertTrue(h.ch.isOutputShutdown()); + + try { + // If half-closed, the local endpoint shouldn't be able to write + ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{ 2 })).sync(); + fail(); + } catch (Throwable cause) { + checkThrowable(cause); + } + assertNull(h.writabilityQueue.poll()); + } finally { + if (s != null) { + s.close(); + } + if (ch != null) { + ch.close(); + } + ss.close(); + } + } + private static void checkThrowable(Throwable cause) throws Throwable { // Depending on OIO / NIO both are ok if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) { throw cause; } } - private static class TestHandler extends SimpleChannelInboundHandler { + + private static final class TestHandler extends SimpleChannelInboundHandler { volatile SocketChannel ch; final BlockingQueue queue = new LinkedBlockingQueue(); + final BlockingDeque writabilityQueue = new LinkedBlockingDeque(); + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + writabilityQueue.add(ctx.channel().isWritable()); + } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { @@ -140,5 +219,22 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest { public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { queue.offer(msg.readByte()); } + + private void drainWritabilityQueue() throws InterruptedException { + while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) { + // Just drain the queue. + } + } + + void assertWritability(boolean isWritable) throws InterruptedException { + try { + Boolean writability = writabilityQueue.takeLast(); + assertEquals(isWritable, writability); + // TODO(scott): why do we get multiple writability changes here ... race condition? + drainWritabilityQueue(); + } catch (Throwable c) { + c.printStackTrace(); + } + } } } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index cc4ed88f44..4430736a8c 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -545,6 +545,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im private void shutdownOutput0(final ChannelPromise promise) { try { socket.shutdown(false, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); @@ -563,6 +564,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im private void shutdown0(final ChannelPromise promise) { try { socket.shutdown(true, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java index b3a302cd82..2944e99153 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java @@ -373,6 +373,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel private void shutdownOutput0(final ChannelPromise promise) { try { socket.shutdown(false, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); @@ -391,6 +392,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel private void shutdown0(final ChannelPromise promise) { try { socket.shutdown(true, true); + ((AbstractUnsafe) unsafe()).shutdownOutput(); promise.setSuccess(); } catch (Throwable cause) { promise.setFailure(cause); diff --git a/transport/src/main/java/io/netty/channel/AbstractChannel.java b/transport/src/main/java/io/netty/channel/AbstractChannel.java index c0dbd41e18..d3a2584c4d 100644 --- a/transport/src/main/java/io/netty/channel/AbstractChannel.java +++ b/transport/src/main/java/io/netty/channel/AbstractChannel.java @@ -16,10 +16,13 @@ package io.netty.channel; import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.channel.socket.ChannelOutputShutdownException; import io.netty.util.DefaultAttributeMap; import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -607,6 +610,19 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha close(promise, CLOSE_CLOSED_CHANNEL_EXCEPTION, CLOSE_CLOSED_CHANNEL_EXCEPTION, false); } + @UnstableApi + public void shutdownOutput() { + final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + if (outboundBuffer == null) { + return; + } + this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer. + ChannelOutputShutdownException e = new ChannelOutputShutdownException("Channel output explicitly shutdown"); + outboundBuffer.failFlushed(e, false); + outboundBuffer.close(e, true); + pipeline.fireUserEventTriggered(ChannelOutputShutdownEvent.INSTANCE); + } + private void close(final ChannelPromise promise, final Throwable cause, final ClosedChannelException closeCause, final boolean notify) { if (!promise.setUncancellable()) { diff --git a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java index 34fbde3556..28b308f4f1 100644 --- a/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java +++ b/transport/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -622,12 +622,12 @@ public final class ChannelOutboundBuffer { } } - void close(final ClosedChannelException cause) { + void close(final Throwable cause, final boolean allowChannelOpen) { if (inFail) { channel.eventLoop().execute(new Runnable() { @Override public void run() { - close(cause); + close(cause, allowChannelOpen); } }); return; @@ -635,7 +635,7 @@ public final class ChannelOutboundBuffer { inFail = true; - if (channel.isOpen()) { + if (!allowChannelOpen && channel.isOpen()) { throw new IllegalStateException("close() must be invoked after the channel is closed."); } @@ -663,6 +663,10 @@ public final class ChannelOutboundBuffer { clearNioBuffers(); } + void close(ClosedChannelException cause) { + close(cause, false); + } + private static void safeSuccess(ChannelPromise promise) { // Only log if the given promise is not of type VoidChannelPromise as trySuccess(...) is expected to return // false. diff --git a/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java new file mode 100644 index 0000000000..49d882f92e --- /dev/null +++ b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java @@ -0,0 +1,33 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.util.internal.UnstableApi; + +/** + * Special event which will be fired and passed to the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} methods once the output of + * a {@link SocketChannel} was shutdown. + */ +@UnstableApi +public final class ChannelOutputShutdownEvent { + public static final ChannelOutputShutdownEvent INSTANCE = new ChannelOutputShutdownEvent(); + + private ChannelOutputShutdownEvent() { + } +} diff --git a/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java new file mode 100644 index 0000000000..7ad7e4785b --- /dev/null +++ b/transport/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java @@ -0,0 +1,32 @@ + + +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; + +/** + * Used to fail pending writes when a channel's output has been shutdown. + */ +@UnstableApi +public final class ChannelOutputShutdownException extends IOException { + public ChannelOutputShutdownException(String msg) { + super(msg); + } +} diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java index e29f4e34ab..4677153144 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -259,6 +259,7 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty } else { javaChannel().socket().shutdownOutput(); } + ((AbstractUnsafe) unsafe()).shutdownOutput(); } private void shutdownInput0(final ChannelPromise promise) { diff --git a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java index b3090ce0a9..54f8d818d9 100644 --- a/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java @@ -173,13 +173,18 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan private void shutdownOutput0(ChannelPromise promise) { try { - socket.shutdownOutput(); + shutdownOutput0(); promise.setSuccess(); } catch (Throwable t) { promise.setFailure(t); } } + private void shutdownOutput0() throws IOException { + socket.shutdownOutput(); + ((AbstractUnsafe) unsafe()).shutdownOutput(); + } + @Override public ChannelFuture shutdownInput(final ChannelPromise promise) { EventLoop loop = eventLoop(); @@ -224,7 +229,7 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan private void shutdown0(ChannelPromise promise) { Throwable cause = null; try { - socket.shutdownOutput(); + shutdownOutput0(); } catch (Throwable t) { cause = t; }