Shutting down the outbound side of the channel should not accept future writes

Motivation:
Implementations of DuplexChannel delegate the shutdownOutput to the underlying transport, but do not take any action on the ChannelOutboundBuffer. In the event of a write failure due to the underlying transport failing and application may attempt to shutdown the output and allow the read side the transport to finish and detect the close. However this may result in an issue where writes are failed, this generates a writability change, we continue to write more data, and this may lead to another writability change, and this loop may continue. Shutting down the output should fail all pending writes and not allow any future writes to avoid this scenario.

Modifications:
- Implementations of DuplexChannel should null out the ChannelOutboundBuffer and fail all pending writes

Result:
More controlled sequencing for shutting down the output side of a channel.
This commit is contained in:
Scott Mitchell 2017-08-01 09:18:08 -07:00
parent d56f403c69
commit 237a4da1b7
9 changed files with 197 additions and 6 deletions

View File

@ -18,22 +18,31 @@ package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.WriteBufferWaterMark;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.oio.OioSocketChannel;
import org.junit.Test;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeFalse;
public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
@ -121,15 +130,85 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
}
}
@Test(timeout = 30000)
public void testWriteAfterShutdownOutputNoWritabilityChange() throws Throwable {
run();
}
public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable {
final TestHandler h = new TestHandler();
ServerSocket ss = new ServerSocket();
Socket s = null;
SocketChannel ch = null;
try {
ss.bind(newSocketAddress());
cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4));
ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
assumeFalse(ch instanceof OioSocketChannel);
assertTrue(ch.isActive());
assertFalse(ch.isOutputShutdown());
s = ss.accept();
byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 };
ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes));
h.assertWritability(false);
ch.flush();
writeFuture.sync();
h.assertWritability(true);
for (int i = 0; i < expectedBytes.length; ++i) {
assertEquals(expectedBytes[i], s.getInputStream().read());
}
assertTrue(h.ch.isOpen());
assertTrue(h.ch.isActive());
assertFalse(h.ch.isInputShutdown());
assertFalse(h.ch.isOutputShutdown());
// Make the connection half-closed and ensure read() returns -1.
ch.shutdownOutput().sync();
assertEquals(-1, s.getInputStream().read());
assertTrue(h.ch.isOpen());
assertTrue(h.ch.isActive());
assertFalse(h.ch.isInputShutdown());
assertTrue(h.ch.isOutputShutdown());
try {
// If half-closed, the local endpoint shouldn't be able to write
ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{ 2 })).sync();
fail();
} catch (Throwable cause) {
checkThrowable(cause);
}
assertNull(h.writabilityQueue.poll());
} finally {
if (s != null) {
s.close();
}
if (ch != null) {
ch.close();
}
ss.close();
}
}
private static void checkThrowable(Throwable cause) throws Throwable {
// Depending on OIO / NIO both are ok
if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) {
throw cause;
}
}
private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
volatile SocketChannel ch;
final BlockingQueue<Byte> queue = new LinkedBlockingQueue<Byte>();
final BlockingDeque<Boolean> writabilityQueue = new LinkedBlockingDeque<Boolean>();
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
writabilityQueue.add(ctx.channel().isWritable());
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
@ -140,5 +219,22 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
queue.offer(msg.readByte());
}
private void drainWritabilityQueue() throws InterruptedException {
while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
// Just drain the queue.
}
}
void assertWritability(boolean isWritable) throws InterruptedException {
try {
Boolean writability = writabilityQueue.takeLast();
assertEquals(isWritable, writability);
// TODO(scott): why do we get multiple writability changes here ... race condition?
drainWritabilityQueue();
} catch (Throwable c) {
c.printStackTrace();
}
}
}
}

View File

@ -545,6 +545,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im
private void shutdownOutput0(final ChannelPromise promise) {
try {
socket.shutdown(false, true);
((AbstractUnsafe) unsafe()).shutdownOutput();
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
@ -563,6 +564,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im
private void shutdown0(final ChannelPromise promise) {
try {
socket.shutdown(true, true);
((AbstractUnsafe) unsafe()).shutdownOutput();
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);

View File

@ -373,6 +373,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel
private void shutdownOutput0(final ChannelPromise promise) {
try {
socket.shutdown(false, true);
((AbstractUnsafe) unsafe()).shutdownOutput();
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);
@ -391,6 +392,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel
private void shutdown0(final ChannelPromise promise) {
try {
socket.shutdown(true, true);
((AbstractUnsafe) unsafe()).shutdownOutput();
promise.setSuccess();
} catch (Throwable cause) {
promise.setFailure(cause);

View File

@ -16,10 +16,13 @@
package io.netty.channel;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.socket.ChannelOutputShutdownEvent;
import io.netty.channel.socket.ChannelOutputShutdownException;
import io.netty.util.DefaultAttributeMap;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.ThrowableUtil;
import io.netty.util.internal.UnstableApi;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -607,6 +610,19 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
close(promise, CLOSE_CLOSED_CHANNEL_EXCEPTION, CLOSE_CLOSED_CHANNEL_EXCEPTION, false);
}
@UnstableApi
public void shutdownOutput() {
final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer;
if (outboundBuffer == null) {
return;
}
this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer.
ChannelOutputShutdownException e = new ChannelOutputShutdownException("Channel output explicitly shutdown");
outboundBuffer.failFlushed(e, false);
outboundBuffer.close(e, true);
pipeline.fireUserEventTriggered(ChannelOutputShutdownEvent.INSTANCE);
}
private void close(final ChannelPromise promise, final Throwable cause,
final ClosedChannelException closeCause, final boolean notify) {
if (!promise.setUncancellable()) {

View File

@ -622,12 +622,12 @@ public final class ChannelOutboundBuffer {
}
}
void close(final ClosedChannelException cause) {
void close(final Throwable cause, final boolean allowChannelOpen) {
if (inFail) {
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
close(cause);
close(cause, allowChannelOpen);
}
});
return;
@ -635,7 +635,7 @@ public final class ChannelOutboundBuffer {
inFail = true;
if (channel.isOpen()) {
if (!allowChannelOpen && channel.isOpen()) {
throw new IllegalStateException("close() must be invoked after the channel is closed.");
}
@ -663,6 +663,10 @@ public final class ChannelOutboundBuffer {
clearNioBuffers();
}
void close(ClosedChannelException cause) {
close(cause, false);
}
private static void safeSuccess(ChannelPromise promise) {
// Only log if the given promise is not of type VoidChannelPromise as trySuccess(...) is expected to return
// false.

View File

@ -0,0 +1,33 @@
/*
* Copyright 2017 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.socket;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.util.internal.UnstableApi;
/**
* Special event which will be fired and passed to the
* {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} methods once the output of
* a {@link SocketChannel} was shutdown.
*/
@UnstableApi
public final class ChannelOutputShutdownEvent {
public static final ChannelOutputShutdownEvent INSTANCE = new ChannelOutputShutdownEvent();
private ChannelOutputShutdownEvent() {
}
}

View File

@ -0,0 +1,32 @@
/*
* Copyright 2017 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.socket;
import io.netty.util.internal.UnstableApi;
import java.io.IOException;
/**
* Used to fail pending writes when a channel's output has been shutdown.
*/
@UnstableApi
public final class ChannelOutputShutdownException extends IOException {
public ChannelOutputShutdownException(String msg) {
super(msg);
}
}

View File

@ -259,6 +259,7 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
} else {
javaChannel().socket().shutdownOutput();
}
((AbstractUnsafe) unsafe()).shutdownOutput();
}
private void shutdownInput0(final ChannelPromise promise) {

View File

@ -173,13 +173,18 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan
private void shutdownOutput0(ChannelPromise promise) {
try {
socket.shutdownOutput();
shutdownOutput0();
promise.setSuccess();
} catch (Throwable t) {
promise.setFailure(t);
}
}
private void shutdownOutput0() throws IOException {
socket.shutdownOutput();
((AbstractUnsafe) unsafe()).shutdownOutput();
}
@Override
public ChannelFuture shutdownInput(final ChannelPromise promise) {
EventLoop loop = eventLoop();
@ -224,7 +229,7 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan
private void shutdown0(ChannelPromise promise) {
Throwable cause = null;
try {
socket.shutdownOutput();
shutdownOutput0();
} catch (Throwable t) {
cause = t;
}