[#107] Add support for closing either input or output part of a channel

- Add shutdownOutput() and isOutputShutdown() to SocketChannel
This commit is contained in:
Trustin Lee 2012-08-29 13:26:29 +09:00
parent 47f26d219d
commit 02f3df55a8
7 changed files with 252 additions and 22 deletions

View File

@ -0,0 +1,66 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.logging.InternalLogger;
import io.netty.logging.InternalLoggerFactory;
import io.netty.testsuite.transport.socket.SocketTestPermutation.Factory;
import io.netty.testsuite.util.TestUtils;
import io.netty.util.NetworkConstants;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.util.List;
import org.junit.Rule;
import org.junit.rules.TestName;
public abstract class AbstractClientSocketTest {
private static final List<Factory<Bootstrap>> COMBO = SocketTestPermutation.clientSocket();
@Rule
public final TestName testName = new TestName();
protected final InternalLogger logger = InternalLoggerFactory.getInstance(getClass());
protected volatile Bootstrap cb;
protected volatile InetSocketAddress addr;
protected void run() throws Throwable {
int i = 0;
for (Factory<Bootstrap> e: COMBO) {
cb = e.newInstance();
addr = new InetSocketAddress(
NetworkConstants.LOCALHOST, TestUtils.getFreePort());
cb.remoteAddress(addr);
logger.info(String.format(
"Running: %s %d of %d", testName.getMethodName(), ++ i, COMBO.size()));
try {
Method m = getClass().getDeclaredMethod(
testName.getMethodName(), Bootstrap.class);
m.invoke(this, cb);
} catch (InvocationTargetException ex) {
throw ex.getCause();
} finally {
cb.shutdown();
}
}
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.testsuite.transport.socket;
import static org.junit.Assert.*;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.logging.ByteLoggingHandler;
import java.net.ServerSocket;
import java.net.Socket;
import org.junit.Test;
public class SocketShutdownOutputTest extends AbstractClientSocketTest {
@Test
public void testShutdownOutput() throws Throwable {
run();
}
public void testShutdownOutput(Bootstrap cb) throws Throwable {
ServerSocket ss = new ServerSocket();
Socket s = null;
try {
ss.bind(addr);
SocketChannel ch = (SocketChannel) cb.handler(new ByteLoggingHandler()).connect().sync().channel();
assertTrue(ch.isActive());
assertFalse(ch.isOutputShutdown());
s = ss.accept();
ch.write(Unpooled.wrappedBuffer(new byte[] { 1 })).sync();
assertEquals(1, s.getInputStream().read());
ch.shutdownOutput();
assertEquals(-1, s.getInputStream().read());
assertTrue(s.isConnected());
} finally {
if (s != null) {
s.close();
}
ss.close();
}
}
}

View File

@ -44,27 +44,7 @@ final class SocketTestPermutation {
List<Factory<ServerBootstrap>> sbfs = serverSocket();
// Make the list of Bootstrap factories.
List<Factory<Bootstrap>> cbfs =
new ArrayList<Factory<Bootstrap>>();
cbfs.add(new Factory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(new NioEventLoopGroup()).channel(new NioSocketChannel());
}
});
cbfs.add(new Factory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
AioEventLoopGroup loop = new AioEventLoopGroup();
return new Bootstrap().group(loop).channel(new AioSocketChannel(loop));
}
});
cbfs.add(new Factory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(new OioEventLoopGroup()).channel(new OioSocketChannel());
}
});
List<Factory<Bootstrap>> cbfs = clientSocket();
// Populate the combinations
for (Factory<ServerBootstrap> sbf: sbfs) {
@ -178,6 +158,30 @@ final class SocketTestPermutation {
return list;
}
static List<Factory<Bootstrap>> clientSocket() {
List<Factory<Bootstrap>> list = new ArrayList<Factory<Bootstrap>>();
list.add(new Factory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(new NioEventLoopGroup()).channel(new NioSocketChannel());
}
});
list.add(new Factory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
AioEventLoopGroup loop = new AioEventLoopGroup();
return new Bootstrap().group(loop).channel(new AioSocketChannel(loop));
}
});
list.add(new Factory<Bootstrap>() {
@Override
public Bootstrap newInstance() {
return new Bootstrap().group(new OioEventLoopGroup()).channel(new OioSocketChannel());
}
});
return list;
}
private SocketTestPermutation() {}
interface Factory<T> {

View File

@ -16,8 +16,10 @@
package io.netty.channel.socket;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import java.net.InetSocketAddress;
import java.net.Socket;
/**
* A TCP/IP socket {@link Channel} which was either accepted by
@ -32,4 +34,14 @@ public interface SocketChannel extends Channel {
InetSocketAddress localAddress();
@Override
InetSocketAddress remoteAddress();
/**
* @see Socket#isOutputShutdown()
*/
boolean isOutputShutdown();
/**
* @see Socket#shutdownOutput()
*/
ChannelFuture shutdownOutput();
}

View File

@ -22,6 +22,7 @@ import io.netty.channel.ChannelFlushFutureNotifier;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.channel.socket.SocketChannel;
import java.io.IOException;
@ -56,8 +57,9 @@ public class AioSocketChannel extends AbstractAioChannel implements SocketChanne
}
private final AioSocketChannelConfig config;
private boolean flushing;
private volatile boolean outputShutdown;
private boolean flushing;
private final AtomicBoolean readSuspended = new AtomicBoolean();
private final AtomicBoolean readInProgress = new AtomicBoolean();
@ -94,6 +96,34 @@ public class AioSocketChannel extends AbstractAioChannel implements SocketChanne
return METADATA;
}
@Override
public boolean isOutputShutdown() {
return outputShutdown;
}
@Override
public ChannelFuture shutdownOutput() {
ChannelFuture future = newFuture();
EventLoop loop = eventLoop();
if (loop.inEventLoop()) {
try {
javaChannel().shutdownOutput();
outputShutdown = true;
future.setSuccess();
} catch (Throwable t) {
future.setFailure(t);
}
} else {
loop.execute(new Runnable() {
@Override
public void run() {
shutdownOutput();
}
});
}
return future;
}
@Override
protected void doConnect(SocketAddress remoteAddress, SocketAddress localAddress, final ChannelFuture future) {
if (localAddress != null) {

View File

@ -19,7 +19,9 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ChannelBufType;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.EventLoop;
import io.netty.channel.socket.DefaultSocketChannelConfig;
import io.netty.channel.socket.SocketChannelConfig;
import io.netty.logging.InternalLogger;
@ -96,6 +98,33 @@ public class NioSocketChannel extends AbstractNioByteChannel implements io.netty
return ch.isOpen() && ch.isConnected();
}
@Override
public boolean isOutputShutdown() {
return javaChannel().socket().isOutputShutdown();
}
@Override
public ChannelFuture shutdownOutput() {
ChannelFuture future = newFuture();
EventLoop loop = eventLoop();
if (loop.inEventLoop()) {
try {
javaChannel().socket().shutdownOutput();
future.setSuccess();
} catch (Throwable t) {
future.setFailure(t);
}
} else {
loop.execute(new Runnable() {
@Override
public void run() {
shutdownOutput();
}
});
}
return future;
}
@Override
protected SocketAddress localAddress0() {
return javaChannel().socket().getLocalSocketAddress();

View File

@ -19,7 +19,9 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ChannelBufType;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.EventLoop;
import io.netty.channel.socket.DefaultSocketChannelConfig;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.SocketChannelConfig;
@ -101,6 +103,33 @@ public class OioSocketChannel extends AbstractOioByteChannel
return !socket.isClosed() && socket.isConnected();
}
@Override
public boolean isOutputShutdown() {
return socket.isOutputShutdown();
}
@Override
public ChannelFuture shutdownOutput() {
ChannelFuture future = newFuture();
EventLoop loop = eventLoop();
if (loop.inEventLoop()) {
try {
socket.shutdownOutput();
future.setSuccess();
} catch (Throwable t) {
future.setFailure(t);
}
} else {
loop.execute(new Runnable() {
@Override
public void run() {
shutdownOutput();
}
});
}
return future;
}
@Override
protected SocketAddress localAddress0() {
return socket.getLocalSocketAddress();