Make sure writing to a closed channel does not trigger an UnsupportedOperationException

- Fixes #1442
This commit is contained in:
Trustin Lee 2013-06-14 11:15:46 +09:00
parent 25c51279cf
commit fe40d4b67f
2 changed files with 67 additions and 5 deletions

View File

@ -29,6 +29,7 @@ import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NotYetConnectedException;
import java.util.Random;
import java.util.concurrent.ConcurrentMap;
@ -671,7 +672,20 @@ public abstract class AbstractChannel extends DefaultAttributeMap implements Cha
}
inFlushNow = true;
final ChannelOutboundBuffer outboundBuffer = AbstractChannel.this.outboundBuffer;
// Mark all pending write requests as failure if the channel is inactive.
if (!isActive()) {
if (isOpen()) {
outboundBuffer.fail(new NotYetConnectedException());
} else {
outboundBuffer.fail(new ClosedChannelException());
}
inFlushNow = false;
return;
}
try {
for (;;) {
ChannelPromise promise = outboundBuffer.currentPromise;

View File

@ -17,6 +17,7 @@ package io.netty.channel.local;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.AbstractChannel;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@ -27,20 +28,20 @@ import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.Test;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
public class LocalChannelRegistryTest {
public class LocalChannelTest {
private static final InternalLogger logger =
InternalLoggerFactory.getInstance(LocalChannelRegistryTest.class);
private static final InternalLogger logger = InternalLoggerFactory.getInstance(LocalChannelTest.class);
private static final String LOCAL_ADDR_ID = "test.id";
@Test
public void testLocalAddressReuse() throws Exception {
for (int i = 0; i < 2; i ++) {
EventLoopGroup clientGroup = new LocalEventLoopGroup();
EventLoopGroup serverGroup = new LocalEventLoopGroup();
@ -94,10 +95,57 @@ public class LocalChannelRegistryTest {
}
}
@Test
public void testWriteFailsFastOnClosedChannel() throws Exception {
EventLoopGroup clientGroup = new LocalEventLoopGroup();
EventLoopGroup serverGroup = new LocalEventLoopGroup();
LocalAddress addr = new LocalAddress(LOCAL_ADDR_ID);
Bootstrap cb = new Bootstrap();
ServerBootstrap sb = new ServerBootstrap();
cb.group(clientGroup)
.channel(LocalChannel.class)
.handler(new TestHandler());
sb.group(serverGroup)
.channel(LocalServerChannel.class)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
public void initChannel(LocalChannel ch) throws Exception {
ch.pipeline().addLast(new TestHandler());
}
});
// Start server
sb.bind(addr).sync();
// Connect to the server
final Channel cc = cb.connect(addr).sync().channel();
// Close the channel and write something.
cc.close().sync();
try {
cc.write(new Object()).sync();
fail("must raise a ClosedChannelException");
} catch (Exception e) {
assertThat(e, is(instanceOf(ClosedChannelException.class)));
// Ensure that the actual write attempt on a closed channel was never made by asserting that
// the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations.
assertThat(e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() + "$AbstractUnsafe"));
e.printStackTrace();
}
serverGroup.shutdownGracefully();
clientGroup.shutdownGracefully();
serverGroup.terminationFuture().sync();
clientGroup.terminationFuture().sync();
}
static class TestHandler extends ChannelInboundHandlerAdapter {
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageList<Object> msgs) throws Exception {
for (int i = 0; i < msgs.size(); i ++) {
final int size = msgs.size();
for (int i = 0; i < size; i ++) {
logger.info(String.format("Received mesage: %s", msgs.get(i)));
}
}