From 951bacb0ca7bc41a8d19cf9c942247ec6e823452 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Tue, 12 Jan 2016 19:09:40 +0100 Subject: [PATCH] Allow to change if EmbeddedChannel should handle close() and disconnect() different. Motivation: At the moment EmbeddedChannel always handle close() and disconnect() the same way which also means that ChannelOutboundHandler.disconnect(...) will never called. We should allow to specify if these are handle different or not to make the use of EmbeddedChannel more flexible. Modifications: Add 2 other constructors which allow to specify if disconnect / close are handled the same way or differently. Result: More flexible usage of EmbeddedChannel possible. --- .../io/netty/channel/ChannelMetadata.java | 2 +- .../channel/embedded/EmbeddedChannel.java | 53 +++++++---- .../channel/embedded/EmbeddedChannelTest.java | 88 +++++++++++++++---- 3 files changed, 104 insertions(+), 39 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/ChannelMetadata.java b/transport/src/main/java/io/netty/channel/ChannelMetadata.java index 628ad43fdd..b959188823 100644 --- a/transport/src/main/java/io/netty/channel/ChannelMetadata.java +++ b/transport/src/main/java/io/netty/channel/ChannelMetadata.java @@ -29,7 +29,7 @@ public final class ChannelMetadata { * * @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} - * again, such as UDP/IP. + * again, such as UDP/IP. */ public ChannelMetadata(boolean hasDisconnect) { this.hasDisconnect = hasDisconnect; diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index 21e8bd4495..ea2d7e1185 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -30,6 +30,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelConfig; import io.netty.channel.EventLoop; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.logging.InternalLogger; @@ -47,12 +48,15 @@ public class EmbeddedChannel extends AbstractChannel { private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class); - private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false); + private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true); private final EmbeddedEventLoop loop = new EmbeddedEventLoop(); - private final ChannelConfig config = new DefaultChannelConfig(this); + private final ChannelMetadata metadata; + private final ChannelConfig config; private final SocketAddress localAddress = new EmbeddedSocketAddress(); private final SocketAddress remoteAddress = new EmbeddedSocketAddress(); + private Queue inboundMessages; private Queue outboundMessages; private Throwable lastException; @@ -64,11 +68,22 @@ public class EmbeddedChannel extends AbstractChannel { * @param handlers the @link ChannelHandler}s which will be add in the {@link ChannelPipeline} */ public EmbeddedChannel(final ChannelHandler... handlers) { - super(null); + this(false, handlers); + } - if (handlers == null) { - throw new NullPointerException("handlers"); - } + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@link false} otherwise. + * @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline} + */ + public EmbeddedChannel(boolean hasDisconnect, final ChannelHandler... handlers) { + super(null); + ObjectUtil.checkNotNull(handlers, "handlers"); + metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT; + config = new DefaultChannelConfig(this); ChannelPipeline p = pipeline(); p.addLast(new ChannelInitializer() { @@ -91,7 +106,7 @@ public class EmbeddedChannel extends AbstractChannel { @Override public ChannelMetadata metadata() { - return METADATA; + return metadata; } @Override @@ -233,37 +248,35 @@ public class EmbeddedChannel extends AbstractChannel { return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages); } - private void finishPendingTasks() { + private void finishPendingTasks(boolean cancel) { runPendingTasks(); - // Cancel all scheduled tasks that are left. - loop.cancelScheduledTasks(); + if (cancel) { + // Cancel all scheduled tasks that are left. + loop.cancelScheduledTasks(); + } } @Override public final ChannelFuture close() { - ChannelFuture future = super.close(); - finishPendingTasks(); - return future; + return close(newPromise()); } @Override public final ChannelFuture disconnect() { - ChannelFuture future = super.disconnect(); - finishPendingTasks(); - return future; + return disconnect(newPromise()); } @Override public final ChannelFuture close(ChannelPromise promise) { ChannelFuture future = super.close(promise); - finishPendingTasks(); + finishPendingTasks(true); return future; } @Override public final ChannelFuture disconnect(ChannelPromise promise) { ChannelFuture future = super.disconnect(promise); - finishPendingTasks(); + finishPendingTasks(!metadata.hasDisconnect()); return future; } @@ -368,7 +381,9 @@ public class EmbeddedChannel extends AbstractChannel { @Override protected void doDisconnect() throws Exception { - doClose(); + if (!metadata.hasDisconnect()) { + doClose(); + } } @Override diff --git a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java index 12abe5fae0..a55b705803 100644 --- a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java +++ b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java @@ -22,17 +22,22 @@ import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.ScheduledFuture; -import org.junit.Assert; import org.junit.Test; +import java.util.ArrayDeque; +import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.junit.Assert.*; + public class EmbeddedChannelTest { @Test @@ -54,12 +59,12 @@ public class EmbeddedChannelTest { } }); ChannelPipeline pipeline = channel.pipeline(); - Assert.assertSame(handler, pipeline.firstContext().handler()); - Assert.assertTrue(channel.writeInbound(3)); - Assert.assertTrue(channel.finish()); - Assert.assertSame(first, channel.readInbound()); - Assert.assertSame(second, channel.readInbound()); - Assert.assertNull(channel.readInbound()); + assertSame(handler, pipeline.firstContext().handler()); + assertTrue(channel.writeInbound(3)); + assertTrue(channel.finish()); + assertSame(first, channel.readInbound()); + assertSame(second, channel.readInbound()); + assertNull(channel.readInbound()); } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -80,11 +85,11 @@ public class EmbeddedChannelTest { } }); long next = ch.runScheduledPendingTasks(); - Assert.assertTrue(next > 0); + assertTrue(next > 0); // Sleep for the nanoseconds but also give extra 50ms as the clock my not be very precise and so fail the test // otherwise. Thread.sleep(TimeUnit.NANOSECONDS.toMillis(next) + 50); - Assert.assertEquals(-1, ch.runScheduledPendingTasks()); + assertEquals(-1, ch.runScheduledPendingTasks()); latch.await(); } @@ -96,7 +101,7 @@ public class EmbeddedChannelTest { public void run() { } }, 1, TimeUnit.DAYS); ch.finish(); - Assert.assertTrue(future.isCancelled()); + assertTrue(future.isCancelled()); } @Test(timeout = 3000) @@ -107,7 +112,7 @@ public class EmbeddedChannelTest { @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { try { - Assert.assertTrue(ctx.executor().inEventLoop()); + assertTrue(ctx.executor().inEventLoop()); } catch (Throwable cause) { error.set(cause); } finally { @@ -116,7 +121,7 @@ public class EmbeddedChannelTest { } }; EmbeddedChannel channel = new EmbeddedChannel(handler); - Assert.assertFalse(channel.finish()); + assertFalse(channel.finish()); latch.await(); Throwable cause = error.get(); if (cause != null) { @@ -127,13 +132,13 @@ public class EmbeddedChannelTest { @Test public void testConstructWithOutHandler() { EmbeddedChannel channel = new EmbeddedChannel(); - Assert.assertTrue(channel.writeInbound(1)); - Assert.assertTrue(channel.writeOutbound(2)); - Assert.assertTrue(channel.finish()); - Assert.assertSame(1, channel.readInbound()); - Assert.assertNull(channel.readInbound()); - Assert.assertSame(2, channel.readOutbound()); - Assert.assertNull(channel.readOutbound()); + assertTrue(channel.writeInbound(1)); + assertTrue(channel.writeOutbound(2)); + assertTrue(channel.finish()); + assertSame(1, channel.readInbound()); + assertNull(channel.readInbound()); + assertSame(2, channel.readOutbound()); + assertNull(channel.readOutbound()); } // See https://github.com/netty/netty/issues/4316. @@ -197,4 +202,49 @@ public class EmbeddedChannelTest { private interface Action { ChannelFuture doRun(Channel channel); } + + @Test + public void testHasDisconnect() { + EventOutboundHandler handler = new EventOutboundHandler(); + EmbeddedChannel channel = new EmbeddedChannel(true, handler); + assertTrue(channel.disconnect().isSuccess()); + assertTrue(channel.close().isSuccess()); + assertEquals(EventOutboundHandler.DISCONNECT, handler.pollEvent()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertNull(handler.pollEvent()); + } + + @Test + public void testHasNoDisconnect() { + EventOutboundHandler handler = new EventOutboundHandler(); + EmbeddedChannel channel = new EmbeddedChannel(false, handler); + assertTrue(channel.disconnect().isSuccess()); + assertTrue(channel.close().isSuccess()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertNull(handler.pollEvent()); + } + + private static final class EventOutboundHandler extends ChannelOutboundHandlerAdapter { + static final Integer DISCONNECT = 0; + static final Integer CLOSE = 1; + + private final Queue queue = new ArrayDeque(); + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + queue.add(DISCONNECT); + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + queue.add(CLOSE); + promise.setSuccess(); + } + + Integer pollEvent() { + return queue.poll(); + } + } }