Allow to change if EmbeddedChannel should handle close() and disconnect() different.

Motivation:

At the moment EmbeddedChannel always handle close() and disconnect() the same way which also means that ChannelOutboundHandler.disconnect(...) will never called. We should allow to specify if these are handle different or not to make the use of EmbeddedChannel more flexible.

Modifications:

Add 2 other constructors which allow to specify if disconnect / close are handled the same way or differently.

Result:

More flexible usage of EmbeddedChannel possible.
This commit is contained in:
Norman Maurer 2016-01-12 19:09:40 +01:00
parent 45674baf3e
commit 4d854cc149
3 changed files with 118 additions and 41 deletions

View File

@ -31,6 +31,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.RecyclableArrayList;
import io.netty.util.internal.logging.InternalLogger;
@ -54,10 +55,12 @@ public class EmbeddedChannel extends AbstractChannel {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
private static final ChannelMetadata METADATA = new ChannelMetadata(false);
private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
private final ChannelConfig config = new DefaultChannelConfig(this);
private final ChannelMetadata metadata;
private final ChannelConfig config;
private Queue<Object> inboundMessages;
private Queue<Object> outboundMessages;
@ -85,10 +88,21 @@ public class EmbeddedChannel extends AbstractChannel {
*
* @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/
public EmbeddedChannel(final ChannelHandler... handlers) {
public EmbeddedChannel(ChannelHandler... handlers) {
this(EmbeddedChannelId.INSTANCE, handlers);
}
/**
* Create a new instance with the pipeline initialized with the specified handlers.
*
* @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()}
* to {@link #close()}, {@link false} otherwise.
* @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/
public EmbeddedChannel(boolean hasDisconnect, ChannelHandler... handlers) {
this(EmbeddedChannelId.INSTANCE, hasDisconnect, handlers);
}
/**
* Create a new instance with the channel ID set to the given ID and the pipeline
* initialized with the specified handlers.
@ -97,11 +111,24 @@ public class EmbeddedChannel extends AbstractChannel {
* @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/
public EmbeddedChannel(ChannelId channelId, final ChannelHandler... handlers) {
this(channelId, false, handlers);
}
/**
* Create a new instance with the channel ID set to the given ID and the pipeline
* initialized with the specified handlers.
*
* @param channelId the {@link ChannelId} that will be used to identify this channel
* @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()}
* to {@link #close()}, {@link false} otherwise.
* @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/
public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelHandler... handlers) {
super(null, channelId);
if (handlers == null) {
throw new NullPointerException("handlers");
}
ObjectUtil.checkNotNull(handlers, "handlers");
metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
config = new DefaultChannelConfig(this);
ChannelPipeline p = pipeline();
p.addLast(new ChannelInitializer<Channel>() {
@ -124,7 +151,7 @@ public class EmbeddedChannel extends AbstractChannel {
@Override
public ChannelMetadata metadata() {
return METADATA;
return metadata;
}
@Override
@ -268,37 +295,35 @@ public class EmbeddedChannel extends AbstractChannel {
return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
}
private void finishPendingTasks() {
private void finishPendingTasks(boolean cancel) {
runPendingTasks();
if (cancel) {
// Cancel all scheduled tasks that are left.
loop.cancelScheduledTasks();
}
}
@Override
public final ChannelFuture close() {
ChannelFuture future = super.close();
finishPendingTasks();
return future;
return close(newPromise());
}
@Override
public final ChannelFuture disconnect() {
ChannelFuture future = super.disconnect();
finishPendingTasks();
return future;
return disconnect(newPromise());
}
@Override
public final ChannelFuture close(ChannelPromise promise) {
ChannelFuture future = super.close(promise);
finishPendingTasks();
finishPendingTasks(true);
return future;
}
@Override
public final ChannelFuture disconnect(ChannelPromise promise) {
ChannelFuture future = super.disconnect(promise);
finishPendingTasks();
finishPendingTasks(!metadata.hasDisconnect());
return future;
}
@ -403,8 +428,10 @@ public class EmbeddedChannel extends AbstractChannel {
@Override
protected void doDisconnect() throws Exception {
if (!metadata.hasDisconnect()) {
doClose();
}
}
@Override
protected void doClose() throws Exception {

View File

@ -23,17 +23,22 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
public class EmbeddedChannelTest {
@Test
@ -55,12 +60,12 @@ public class EmbeddedChannelTest {
}
});
ChannelPipeline pipeline = channel.pipeline();
Assert.assertSame(handler, pipeline.firstContext().handler());
Assert.assertTrue(channel.writeInbound(3));
Assert.assertTrue(channel.finish());
Assert.assertSame(first, channel.readInbound());
Assert.assertSame(second, channel.readInbound());
Assert.assertNull(channel.readInbound());
assertSame(handler, pipeline.firstContext().handler());
assertTrue(channel.writeInbound(3));
assertTrue(channel.finish());
assertSame(first, channel.readInbound());
assertSame(second, channel.readInbound());
assertNull(channel.readInbound());
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@ -81,11 +86,11 @@ public class EmbeddedChannelTest {
}
});
long next = ch.runScheduledPendingTasks();
Assert.assertTrue(next > 0);
assertTrue(next > 0);
// Sleep for the nanoseconds but also give extra 50ms as the clock my not be very precise and so fail the test
// otherwise.
Thread.sleep(TimeUnit.NANOSECONDS.toMillis(next) + 50);
Assert.assertEquals(-1, ch.runScheduledPendingTasks());
assertEquals(-1, ch.runScheduledPendingTasks());
latch.await();
}
@ -97,7 +102,7 @@ public class EmbeddedChannelTest {
public void run() { }
}, 1, TimeUnit.DAYS);
ch.finish();
Assert.assertTrue(future.isCancelled());
assertTrue(future.isCancelled());
}
@Test(timeout = 3000)
@ -108,7 +113,7 @@ public class EmbeddedChannelTest {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
try {
Assert.assertTrue(ctx.executor().inEventLoop());
assertTrue(ctx.executor().inEventLoop());
} catch (Throwable cause) {
error.set(cause);
} finally {
@ -117,7 +122,7 @@ public class EmbeddedChannelTest {
}
};
EmbeddedChannel channel = new EmbeddedChannel(handler);
Assert.assertFalse(channel.finish());
assertFalse(channel.finish());
latch.await();
Throwable cause = error.get();
if (cause != null) {
@ -128,20 +133,20 @@ public class EmbeddedChannelTest {
@Test
public void testConstructWithOutHandler() {
EmbeddedChannel channel = new EmbeddedChannel();
Assert.assertTrue(channel.writeInbound(1));
Assert.assertTrue(channel.writeOutbound(2));
Assert.assertTrue(channel.finish());
Assert.assertSame(1, channel.readInbound());
Assert.assertNull(channel.readInbound());
Assert.assertSame(2, channel.readOutbound());
Assert.assertNull(channel.readOutbound());
assertTrue(channel.writeInbound(1));
assertTrue(channel.writeOutbound(2));
assertTrue(channel.finish());
assertSame(1, channel.readInbound());
assertNull(channel.readInbound());
assertSame(2, channel.readOutbound());
assertNull(channel.readOutbound());
}
@Test
public void testConstructWithChannelId() {
ChannelId channelId = new CustomChannelId(1);
EmbeddedChannel channel = new EmbeddedChannel(channelId);
Assert.assertSame(channelId, channel.id());
assertSame(channelId, channel.id());
}
// See https://github.com/netty/netty/issues/4316.
@ -205,4 +210,49 @@ public class EmbeddedChannelTest {
private interface Action {
ChannelFuture doRun(Channel channel);
}
@Test
public void testHasDisconnect() {
EventOutboundHandler handler = new EventOutboundHandler();
EmbeddedChannel channel = new EmbeddedChannel(true, handler);
assertTrue(channel.disconnect().isSuccess());
assertTrue(channel.close().isSuccess());
assertEquals(EventOutboundHandler.DISCONNECT, handler.pollEvent());
assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent());
assertNull(handler.pollEvent());
}
@Test
public void testHasNoDisconnect() {
EventOutboundHandler handler = new EventOutboundHandler();
EmbeddedChannel channel = new EmbeddedChannel(false, handler);
assertTrue(channel.disconnect().isSuccess());
assertTrue(channel.close().isSuccess());
assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent());
assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent());
assertNull(handler.pollEvent());
}
private static final class EventOutboundHandler extends ChannelOutboundHandlerAdapter {
static final Integer DISCONNECT = 0;
static final Integer CLOSE = 1;
private final Queue<Integer> queue = new ArrayDeque<Integer>();
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(DISCONNECT);
promise.setSuccess();
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(CLOSE);
promise.setSuccess();
}
Integer pollEvent() {
return queue.poll();
}
}
}