diff --git a/transport/src/main/java/io/netty/channel/ChannelMetadata.java b/transport/src/main/java/io/netty/channel/ChannelMetadata.java index 384aaa8cc2..c77f530916 100644 --- a/transport/src/main/java/io/netty/channel/ChannelMetadata.java +++ b/transport/src/main/java/io/netty/channel/ChannelMetadata.java @@ -30,7 +30,7 @@ public final class ChannelMetadata { * * @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} - * again, such as UDP/IP. + * again, such as UDP/IP. */ public ChannelMetadata(boolean hasDisconnect) { this(hasDisconnect, 1); @@ -41,7 +41,7 @@ public final class ChannelMetadata { * * @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} - * again, such as UDP/IP. + * again, such as UDP/IP. * @param defaultMaxMessagesPerRead If a {@link MaxMessagesRecvByteBufAllocator} is in use, then this value will be * set for {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead()}. Must be {@code > 0}. */ diff --git a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java index 83f841565e..2d78ccf1cc 100644 --- a/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java +++ b/transport/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -31,6 +31,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelConfig; import io.netty.channel.EventLoop; import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.logging.InternalLogger; @@ -54,10 +55,12 @@ public class EmbeddedChannel extends AbstractChannel { private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class); - private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false); + private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true); private final EmbeddedEventLoop loop = new EmbeddedEventLoop(); - private final ChannelConfig config = new DefaultChannelConfig(this); + private final ChannelMetadata metadata; + private final ChannelConfig config; private Queue inboundMessages; private Queue outboundMessages; @@ -85,10 +88,21 @@ public class EmbeddedChannel extends AbstractChannel { * * @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline} */ - public EmbeddedChannel(final ChannelHandler... handlers) { + public EmbeddedChannel(ChannelHandler... handlers) { this(EmbeddedChannelId.INSTANCE, handlers); } + /** + * Create a new instance with the pipeline initialized with the specified handlers. + * + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@link false} otherwise. + * @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline} + */ + public EmbeddedChannel(boolean hasDisconnect, ChannelHandler... handlers) { + this(EmbeddedChannelId.INSTANCE, hasDisconnect, handlers); + } + /** * Create a new instance with the channel ID set to the given ID and the pipeline * initialized with the specified handlers. @@ -97,11 +111,24 @@ public class EmbeddedChannel extends AbstractChannel { * @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline} */ public EmbeddedChannel(ChannelId channelId, final ChannelHandler... handlers) { + this(channelId, false, handlers); + } + + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param channelId the {@link ChannelId} that will be used to identify this channel + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@link false} otherwise. + * @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline} + */ + public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelHandler... handlers) { super(null, channelId); - if (handlers == null) { - throw new NullPointerException("handlers"); - } + ObjectUtil.checkNotNull(handlers, "handlers"); + metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT; + config = new DefaultChannelConfig(this); ChannelPipeline p = pipeline(); p.addLast(new ChannelInitializer() { @@ -124,7 +151,7 @@ public class EmbeddedChannel extends AbstractChannel { @Override public ChannelMetadata metadata() { - return METADATA; + return metadata; } @Override @@ -268,37 +295,35 @@ public class EmbeddedChannel extends AbstractChannel { return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages); } - private void finishPendingTasks() { + private void finishPendingTasks(boolean cancel) { runPendingTasks(); - // Cancel all scheduled tasks that are left. - loop.cancelScheduledTasks(); + if (cancel) { + // Cancel all scheduled tasks that are left. + loop.cancelScheduledTasks(); + } } @Override public final ChannelFuture close() { - ChannelFuture future = super.close(); - finishPendingTasks(); - return future; + return close(newPromise()); } @Override public final ChannelFuture disconnect() { - ChannelFuture future = super.disconnect(); - finishPendingTasks(); - return future; + return disconnect(newPromise()); } @Override public final ChannelFuture close(ChannelPromise promise) { ChannelFuture future = super.close(promise); - finishPendingTasks(); + finishPendingTasks(true); return future; } @Override public final ChannelFuture disconnect(ChannelPromise promise) { ChannelFuture future = super.disconnect(promise); - finishPendingTasks(); + finishPendingTasks(!metadata.hasDisconnect()); return future; } @@ -403,7 +428,9 @@ public class EmbeddedChannel extends AbstractChannel { @Override protected void doDisconnect() throws Exception { - doClose(); + if (!metadata.hasDisconnect()) { + doClose(); + } } @Override diff --git a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java index 9a919534b4..7c59676278 100644 --- a/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java +++ b/transport/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java @@ -23,17 +23,22 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelId; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.ScheduledFuture; -import org.junit.Assert; import org.junit.Test; +import java.util.ArrayDeque; +import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import static org.junit.Assert.*; + public class EmbeddedChannelTest { @Test @@ -55,12 +60,12 @@ public class EmbeddedChannelTest { } }); ChannelPipeline pipeline = channel.pipeline(); - Assert.assertSame(handler, pipeline.firstContext().handler()); - Assert.assertTrue(channel.writeInbound(3)); - Assert.assertTrue(channel.finish()); - Assert.assertSame(first, channel.readInbound()); - Assert.assertSame(second, channel.readInbound()); - Assert.assertNull(channel.readInbound()); + assertSame(handler, pipeline.firstContext().handler()); + assertTrue(channel.writeInbound(3)); + assertTrue(channel.finish()); + assertSame(first, channel.readInbound()); + assertSame(second, channel.readInbound()); + assertNull(channel.readInbound()); } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -81,11 +86,11 @@ public class EmbeddedChannelTest { } }); long next = ch.runScheduledPendingTasks(); - Assert.assertTrue(next > 0); + assertTrue(next > 0); // Sleep for the nanoseconds but also give extra 50ms as the clock my not be very precise and so fail the test // otherwise. Thread.sleep(TimeUnit.NANOSECONDS.toMillis(next) + 50); - Assert.assertEquals(-1, ch.runScheduledPendingTasks()); + assertEquals(-1, ch.runScheduledPendingTasks()); latch.await(); } @@ -97,7 +102,7 @@ public class EmbeddedChannelTest { public void run() { } }, 1, TimeUnit.DAYS); ch.finish(); - Assert.assertTrue(future.isCancelled()); + assertTrue(future.isCancelled()); } @Test(timeout = 3000) @@ -108,7 +113,7 @@ public class EmbeddedChannelTest { @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { try { - Assert.assertTrue(ctx.executor().inEventLoop()); + assertTrue(ctx.executor().inEventLoop()); } catch (Throwable cause) { error.set(cause); } finally { @@ -117,7 +122,7 @@ public class EmbeddedChannelTest { } }; EmbeddedChannel channel = new EmbeddedChannel(handler); - Assert.assertFalse(channel.finish()); + assertFalse(channel.finish()); latch.await(); Throwable cause = error.get(); if (cause != null) { @@ -128,20 +133,20 @@ public class EmbeddedChannelTest { @Test public void testConstructWithOutHandler() { EmbeddedChannel channel = new EmbeddedChannel(); - Assert.assertTrue(channel.writeInbound(1)); - Assert.assertTrue(channel.writeOutbound(2)); - Assert.assertTrue(channel.finish()); - Assert.assertSame(1, channel.readInbound()); - Assert.assertNull(channel.readInbound()); - Assert.assertSame(2, channel.readOutbound()); - Assert.assertNull(channel.readOutbound()); + assertTrue(channel.writeInbound(1)); + assertTrue(channel.writeOutbound(2)); + assertTrue(channel.finish()); + assertSame(1, channel.readInbound()); + assertNull(channel.readInbound()); + assertSame(2, channel.readOutbound()); + assertNull(channel.readOutbound()); } @Test public void testConstructWithChannelId() { ChannelId channelId = new CustomChannelId(1); EmbeddedChannel channel = new EmbeddedChannel(channelId); - Assert.assertSame(channelId, channel.id()); + assertSame(channelId, channel.id()); } // See https://github.com/netty/netty/issues/4316. @@ -205,4 +210,49 @@ public class EmbeddedChannelTest { private interface Action { ChannelFuture doRun(Channel channel); } + + @Test + public void testHasDisconnect() { + EventOutboundHandler handler = new EventOutboundHandler(); + EmbeddedChannel channel = new EmbeddedChannel(true, handler); + assertTrue(channel.disconnect().isSuccess()); + assertTrue(channel.close().isSuccess()); + assertEquals(EventOutboundHandler.DISCONNECT, handler.pollEvent()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertNull(handler.pollEvent()); + } + + @Test + public void testHasNoDisconnect() { + EventOutboundHandler handler = new EventOutboundHandler(); + EmbeddedChannel channel = new EmbeddedChannel(false, handler); + assertTrue(channel.disconnect().isSuccess()); + assertTrue(channel.close().isSuccess()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertNull(handler.pollEvent()); + } + + private static final class EventOutboundHandler extends ChannelOutboundHandlerAdapter { + static final Integer DISCONNECT = 0; + static final Integer CLOSE = 1; + + private final Queue queue = new ArrayDeque(); + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + queue.add(DISCONNECT); + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + queue.add(CLOSE); + promise.setSuccess(); + } + + Integer pollEvent() { + return queue.poll(); + } + } }