Shutting down the outbound side of the channel should not accept future writes
Motivation: Implementations of DuplexChannel delegate the shutdownOutput to the underlying transport, but do not take any action on the ChannelOutboundBuffer. In the event of a write failure due to the underlying transport failing and application may attempt to shutdown the output and allow the read side the transport to finish and detect the close. However this may result in an issue where writes are failed, this generates a writability change, we continue to write more data, and this may lead to another writability change, and this loop may continue. Shutting down the output should fail all pending writes and not allow any future writes to avoid this scenario. Modifications: - Implementations of DuplexChannel should null out the ChannelOutboundBuffer and fail all pending writes Result: More controlled sequencing for shutting down the output side of a channel.
This commit is contained in:
parent
d56f403c69
commit
237a4da1b7
@ -18,22 +18,31 @@ package io.netty.testsuite.transport.socket;
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOption;
|
||||
import io.netty.channel.SimpleChannelInboundHandler;
|
||||
import io.netty.channel.WriteBufferWaterMark;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.channel.socket.oio.OioSocketChannel;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
import java.net.SocketException;
|
||||
import java.nio.channels.ClosedChannelException;
|
||||
import java.util.concurrent.BlockingDeque;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.LinkedBlockingDeque;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertNull;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.junit.Assume.assumeFalse;
|
||||
|
||||
public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
|
||||
|
||||
@ -121,15 +130,85 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 30000)
|
||||
public void testWriteAfterShutdownOutputNoWritabilityChange() throws Throwable {
|
||||
run();
|
||||
}
|
||||
|
||||
public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable {
|
||||
final TestHandler h = new TestHandler();
|
||||
ServerSocket ss = new ServerSocket();
|
||||
Socket s = null;
|
||||
SocketChannel ch = null;
|
||||
try {
|
||||
ss.bind(newSocketAddress());
|
||||
cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4));
|
||||
ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
|
||||
assumeFalse(ch instanceof OioSocketChannel);
|
||||
assertTrue(ch.isActive());
|
||||
assertFalse(ch.isOutputShutdown());
|
||||
|
||||
s = ss.accept();
|
||||
|
||||
byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 };
|
||||
ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes));
|
||||
h.assertWritability(false);
|
||||
ch.flush();
|
||||
writeFuture.sync();
|
||||
h.assertWritability(true);
|
||||
for (int i = 0; i < expectedBytes.length; ++i) {
|
||||
assertEquals(expectedBytes[i], s.getInputStream().read());
|
||||
}
|
||||
|
||||
assertTrue(h.ch.isOpen());
|
||||
assertTrue(h.ch.isActive());
|
||||
assertFalse(h.ch.isInputShutdown());
|
||||
assertFalse(h.ch.isOutputShutdown());
|
||||
|
||||
// Make the connection half-closed and ensure read() returns -1.
|
||||
ch.shutdownOutput().sync();
|
||||
assertEquals(-1, s.getInputStream().read());
|
||||
|
||||
assertTrue(h.ch.isOpen());
|
||||
assertTrue(h.ch.isActive());
|
||||
assertFalse(h.ch.isInputShutdown());
|
||||
assertTrue(h.ch.isOutputShutdown());
|
||||
|
||||
try {
|
||||
// If half-closed, the local endpoint shouldn't be able to write
|
||||
ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{ 2 })).sync();
|
||||
fail();
|
||||
} catch (Throwable cause) {
|
||||
checkThrowable(cause);
|
||||
}
|
||||
assertNull(h.writabilityQueue.poll());
|
||||
} finally {
|
||||
if (s != null) {
|
||||
s.close();
|
||||
}
|
||||
if (ch != null) {
|
||||
ch.close();
|
||||
}
|
||||
ss.close();
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkThrowable(Throwable cause) throws Throwable {
|
||||
// Depending on OIO / NIO both are ok
|
||||
if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) {
|
||||
throw cause;
|
||||
}
|
||||
}
|
||||
private static class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
|
||||
|
||||
private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
|
||||
volatile SocketChannel ch;
|
||||
final BlockingQueue<Byte> queue = new LinkedBlockingQueue<Byte>();
|
||||
final BlockingDeque<Boolean> writabilityQueue = new LinkedBlockingDeque<Boolean>();
|
||||
|
||||
@Override
|
||||
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
|
||||
writabilityQueue.add(ctx.channel().isWritable());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx) throws Exception {
|
||||
@ -140,5 +219,22 @@ public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
|
||||
public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
|
||||
queue.offer(msg.readByte());
|
||||
}
|
||||
|
||||
private void drainWritabilityQueue() throws InterruptedException {
|
||||
while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
|
||||
// Just drain the queue.
|
||||
}
|
||||
}
|
||||
|
||||
void assertWritability(boolean isWritable) throws InterruptedException {
|
||||
try {
|
||||
Boolean writability = writabilityQueue.takeLast();
|
||||
assertEquals(isWritable, writability);
|
||||
// TODO(scott): why do we get multiple writability changes here ... race condition?
|
||||
drainWritabilityQueue();
|
||||
} catch (Throwable c) {
|
||||
c.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -545,6 +545,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im
|
||||
private void shutdownOutput0(final ChannelPromise promise) {
|
||||
try {
|
||||
socket.shutdown(false, true);
|
||||
((AbstractUnsafe) unsafe()).shutdownOutput();
|
||||
promise.setSuccess();
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
@ -563,6 +564,7 @@ public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel im
|
||||
private void shutdown0(final ChannelPromise promise) {
|
||||
try {
|
||||
socket.shutdown(true, true);
|
||||
((AbstractUnsafe) unsafe()).shutdownOutput();
|
||||
promise.setSuccess();
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
|
@ -373,6 +373,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel
|
||||
private void shutdownOutput0(final ChannelPromise promise) {
|
||||
try {
|
||||
socket.shutdown(false, true);
|
||||
((AbstractUnsafe) unsafe()).shutdownOutput();
|
||||
promise.setSuccess();
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
@ -391,6 +392,7 @@ public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel
|
||||
private void shutdown0(final ChannelPromise promise) {
|
||||
try {
|
||||
socket.shutdown(true, true);
|
||||
((AbstractUnsafe) unsafe()).shutdownOutput();
|
||||
promise.setSuccess();
|
||||
} catch (Throwable cause) {
|
||||
promise.setFailure(cause);
|
||||
|
@ -16,10 +16,13 @@
|
||||
package io.netty.channel;
|
||||
|
||||
import io.netty.buffer.ByteBufAllocator;
|
||||
import io.netty.channel.socket.ChannelOutputShutdownEvent;
|
||||
import io.netty.channel.socket.ChannelOutputShutdownException;
|
||||
import io.netty.util.DefaultAttributeMap;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.internal.PlatformDependent;
|
||||
import io.netty.util.internal.ThrowableUtil;
|
||||
import io.netty.util.internal.UnstableApi;
|
||||
import io.netty.util.internal.logging.InternalLogger;
|
||||
import io.netty.util.internal.logging.InternalLoggerFactory;
|
||||
|
||||
@ -607,6 +610,19 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
|
||||
close(promise, CLOSE_CLOSED_CHANNEL_EXCEPTION, CLOSE_CLOSED_CHANNEL_EXCEPTION, false);
|
||||
}
|
||||
|
||||
@UnstableApi
|
||||
public void shutdownOutput() {
|
||||
final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer;
|
||||
if (outboundBuffer == null) {
|
||||
return;
|
||||
}
|
||||
this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer.
|
||||
ChannelOutputShutdownException e = new ChannelOutputShutdownException("Channel output explicitly shutdown");
|
||||
outboundBuffer.failFlushed(e, false);
|
||||
outboundBuffer.close(e, true);
|
||||
pipeline.fireUserEventTriggered(ChannelOutputShutdownEvent.INSTANCE);
|
||||
}
|
||||
|
||||
private void close(final ChannelPromise promise, final Throwable cause,
|
||||
final ClosedChannelException closeCause, final boolean notify) {
|
||||
if (!promise.setUncancellable()) {
|
||||
|
@ -622,12 +622,12 @@ public final class ChannelOutboundBuffer {
|
||||
}
|
||||
}
|
||||
|
||||
void close(final ClosedChannelException cause) {
|
||||
void close(final Throwable cause, final boolean allowChannelOpen) {
|
||||
if (inFail) {
|
||||
channel.eventLoop().execute(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
close(cause);
|
||||
close(cause, allowChannelOpen);
|
||||
}
|
||||
});
|
||||
return;
|
||||
@ -635,7 +635,7 @@ public final class ChannelOutboundBuffer {
|
||||
|
||||
inFail = true;
|
||||
|
||||
if (channel.isOpen()) {
|
||||
if (!allowChannelOpen && channel.isOpen()) {
|
||||
throw new IllegalStateException("close() must be invoked after the channel is closed.");
|
||||
}
|
||||
|
||||
@ -663,6 +663,10 @@ public final class ChannelOutboundBuffer {
|
||||
clearNioBuffers();
|
||||
}
|
||||
|
||||
void close(ClosedChannelException cause) {
|
||||
close(cause, false);
|
||||
}
|
||||
|
||||
private static void safeSuccess(ChannelPromise promise) {
|
||||
// Only log if the given promise is not of type VoidChannelPromise as trySuccess(...) is expected to return
|
||||
// false.
|
||||
|
@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright 2017 The Netty Project
|
||||
*
|
||||
* The Netty Project licenses this file to you under the Apache License,
|
||||
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at:
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package io.netty.channel.socket;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandler;
|
||||
import io.netty.util.internal.UnstableApi;
|
||||
|
||||
/**
|
||||
* Special event which will be fired and passed to the
|
||||
* {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} methods once the output of
|
||||
* a {@link SocketChannel} was shutdown.
|
||||
*/
|
||||
@UnstableApi
|
||||
public final class ChannelOutputShutdownEvent {
|
||||
public static final ChannelOutputShutdownEvent INSTANCE = new ChannelOutputShutdownEvent();
|
||||
|
||||
private ChannelOutputShutdownEvent() {
|
||||
}
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
|
||||
|
||||
/*
|
||||
* Copyright 2017 The Netty Project
|
||||
*
|
||||
* The Netty Project licenses this file to you under the Apache License,
|
||||
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at:
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package io.netty.channel.socket;
|
||||
|
||||
import io.netty.util.internal.UnstableApi;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Used to fail pending writes when a channel's output has been shutdown.
|
||||
*/
|
||||
@UnstableApi
|
||||
public final class ChannelOutputShutdownException extends IOException {
|
||||
public ChannelOutputShutdownException(String msg) {
|
||||
super(msg);
|
||||
}
|
||||
}
|
@ -259,6 +259,7 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
|
||||
} else {
|
||||
javaChannel().socket().shutdownOutput();
|
||||
}
|
||||
((AbstractUnsafe) unsafe()).shutdownOutput();
|
||||
}
|
||||
|
||||
private void shutdownInput0(final ChannelPromise promise) {
|
||||
|
@ -173,13 +173,18 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan
|
||||
|
||||
private void shutdownOutput0(ChannelPromise promise) {
|
||||
try {
|
||||
socket.shutdownOutput();
|
||||
shutdownOutput0();
|
||||
promise.setSuccess();
|
||||
} catch (Throwable t) {
|
||||
promise.setFailure(t);
|
||||
}
|
||||
}
|
||||
|
||||
private void shutdownOutput0() throws IOException {
|
||||
socket.shutdownOutput();
|
||||
((AbstractUnsafe) unsafe()).shutdownOutput();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChannelFuture shutdownInput(final ChannelPromise promise) {
|
||||
EventLoop loop = eventLoop();
|
||||
@ -224,7 +229,7 @@ public class OioSocketChannel extends OioByteStreamChannel implements SocketChan
|
||||
private void shutdown0(ChannelPromise promise) {
|
||||
Throwable cause = null;
|
||||
try {
|
||||
socket.shutdownOutput();
|
||||
shutdownOutput0();
|
||||
} catch (Throwable t) {
|
||||
cause = t;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user