Allow to change if EmbeddedChannel should handle close() and disconnect() different.

Motivation:

At the moment EmbeddedChannel always handle close() and disconnect() the same way which also means that ChannelOutboundHandler.disconnect(...) will never called. We should allow to specify if these are handle different or not to make the use of EmbeddedChannel more flexible.

Modifications:

Add 2 other constructors which allow to specify if disconnect / close are handled the same way or differently.

Result:

More flexible usage of EmbeddedChannel possible.
This commit is contained in:
Norman Maurer 2016-01-12 19:09:40 +01:00
parent 8fdf2f5120
commit 951bacb0ca
3 changed files with 104 additions and 39 deletions

View File

@ -29,7 +29,7 @@ public final class ChannelMetadata {
* *
* @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation * @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation
* that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)}
* again, such as UDP/IP. * again, such as UDP/IP.
*/ */
public ChannelMetadata(boolean hasDisconnect) { public ChannelMetadata(boolean hasDisconnect) {
this.hasDisconnect = hasDisconnect; this.hasDisconnect = hasDisconnect;

View File

@ -30,6 +30,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelConfig; import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.EventLoop; import io.netty.channel.EventLoop;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.ObjectUtil;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.RecyclableArrayList; import io.netty.util.internal.RecyclableArrayList;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
@ -47,12 +48,15 @@ public class EmbeddedChannel extends AbstractChannel {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class); private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class);
private static final ChannelMetadata METADATA = new ChannelMetadata(false); private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false);
private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true);
private final EmbeddedEventLoop loop = new EmbeddedEventLoop(); private final EmbeddedEventLoop loop = new EmbeddedEventLoop();
private final ChannelConfig config = new DefaultChannelConfig(this); private final ChannelMetadata metadata;
private final ChannelConfig config;
private final SocketAddress localAddress = new EmbeddedSocketAddress(); private final SocketAddress localAddress = new EmbeddedSocketAddress();
private final SocketAddress remoteAddress = new EmbeddedSocketAddress(); private final SocketAddress remoteAddress = new EmbeddedSocketAddress();
private Queue<Object> inboundMessages; private Queue<Object> inboundMessages;
private Queue<Object> outboundMessages; private Queue<Object> outboundMessages;
private Throwable lastException; private Throwable lastException;
@ -64,11 +68,22 @@ public class EmbeddedChannel extends AbstractChannel {
* @param handlers the @link ChannelHandler}s which will be add in the {@link ChannelPipeline} * @param handlers the @link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/ */
public EmbeddedChannel(final ChannelHandler... handlers) { public EmbeddedChannel(final ChannelHandler... handlers) {
super(null); this(false, handlers);
}
if (handlers == null) { /**
throw new NullPointerException("handlers"); * Create a new instance with the channel ID set to the given ID and the pipeline
} * initialized with the specified handlers.
*
* @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()}
* to {@link #close()}, {@link false} otherwise.
* @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline}
*/
public EmbeddedChannel(boolean hasDisconnect, final ChannelHandler... handlers) {
super(null);
ObjectUtil.checkNotNull(handlers, "handlers");
metadata = hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT;
config = new DefaultChannelConfig(this);
ChannelPipeline p = pipeline(); ChannelPipeline p = pipeline();
p.addLast(new ChannelInitializer<Channel>() { p.addLast(new ChannelInitializer<Channel>() {
@ -91,7 +106,7 @@ public class EmbeddedChannel extends AbstractChannel {
@Override @Override
public ChannelMetadata metadata() { public ChannelMetadata metadata() {
return METADATA; return metadata;
} }
@Override @Override
@ -233,37 +248,35 @@ public class EmbeddedChannel extends AbstractChannel {
return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages); return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages);
} }
private void finishPendingTasks() { private void finishPendingTasks(boolean cancel) {
runPendingTasks(); runPendingTasks();
// Cancel all scheduled tasks that are left. if (cancel) {
loop.cancelScheduledTasks(); // Cancel all scheduled tasks that are left.
loop.cancelScheduledTasks();
}
} }
@Override @Override
public final ChannelFuture close() { public final ChannelFuture close() {
ChannelFuture future = super.close(); return close(newPromise());
finishPendingTasks();
return future;
} }
@Override @Override
public final ChannelFuture disconnect() { public final ChannelFuture disconnect() {
ChannelFuture future = super.disconnect(); return disconnect(newPromise());
finishPendingTasks();
return future;
} }
@Override @Override
public final ChannelFuture close(ChannelPromise promise) { public final ChannelFuture close(ChannelPromise promise) {
ChannelFuture future = super.close(promise); ChannelFuture future = super.close(promise);
finishPendingTasks(); finishPendingTasks(true);
return future; return future;
} }
@Override @Override
public final ChannelFuture disconnect(ChannelPromise promise) { public final ChannelFuture disconnect(ChannelPromise promise) {
ChannelFuture future = super.disconnect(promise); ChannelFuture future = super.disconnect(promise);
finishPendingTasks(); finishPendingTasks(!metadata.hasDisconnect());
return future; return future;
} }
@ -368,7 +381,9 @@ public class EmbeddedChannel extends AbstractChannel {
@Override @Override
protected void doDisconnect() throws Exception { protected void doDisconnect() throws Exception {
doClose(); if (!metadata.hasDisconnect()) {
doClose();
}
} }
@Override @Override

View File

@ -22,17 +22,22 @@ import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.concurrent.ScheduledFuture;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
public class EmbeddedChannelTest { public class EmbeddedChannelTest {
@Test @Test
@ -54,12 +59,12 @@ public class EmbeddedChannelTest {
} }
}); });
ChannelPipeline pipeline = channel.pipeline(); ChannelPipeline pipeline = channel.pipeline();
Assert.assertSame(handler, pipeline.firstContext().handler()); assertSame(handler, pipeline.firstContext().handler());
Assert.assertTrue(channel.writeInbound(3)); assertTrue(channel.writeInbound(3));
Assert.assertTrue(channel.finish()); assertTrue(channel.finish());
Assert.assertSame(first, channel.readInbound()); assertSame(first, channel.readInbound());
Assert.assertSame(second, channel.readInbound()); assertSame(second, channel.readInbound());
Assert.assertNull(channel.readInbound()); assertNull(channel.readInbound());
} }
@SuppressWarnings({ "rawtypes", "unchecked" }) @SuppressWarnings({ "rawtypes", "unchecked" })
@ -80,11 +85,11 @@ public class EmbeddedChannelTest {
} }
}); });
long next = ch.runScheduledPendingTasks(); long next = ch.runScheduledPendingTasks();
Assert.assertTrue(next > 0); assertTrue(next > 0);
// Sleep for the nanoseconds but also give extra 50ms as the clock my not be very precise and so fail the test // Sleep for the nanoseconds but also give extra 50ms as the clock my not be very precise and so fail the test
// otherwise. // otherwise.
Thread.sleep(TimeUnit.NANOSECONDS.toMillis(next) + 50); Thread.sleep(TimeUnit.NANOSECONDS.toMillis(next) + 50);
Assert.assertEquals(-1, ch.runScheduledPendingTasks()); assertEquals(-1, ch.runScheduledPendingTasks());
latch.await(); latch.await();
} }
@ -96,7 +101,7 @@ public class EmbeddedChannelTest {
public void run() { } public void run() { }
}, 1, TimeUnit.DAYS); }, 1, TimeUnit.DAYS);
ch.finish(); ch.finish();
Assert.assertTrue(future.isCancelled()); assertTrue(future.isCancelled());
} }
@Test(timeout = 3000) @Test(timeout = 3000)
@ -107,7 +112,7 @@ public class EmbeddedChannelTest {
@Override @Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception { public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
try { try {
Assert.assertTrue(ctx.executor().inEventLoop()); assertTrue(ctx.executor().inEventLoop());
} catch (Throwable cause) { } catch (Throwable cause) {
error.set(cause); error.set(cause);
} finally { } finally {
@ -116,7 +121,7 @@ public class EmbeddedChannelTest {
} }
}; };
EmbeddedChannel channel = new EmbeddedChannel(handler); EmbeddedChannel channel = new EmbeddedChannel(handler);
Assert.assertFalse(channel.finish()); assertFalse(channel.finish());
latch.await(); latch.await();
Throwable cause = error.get(); Throwable cause = error.get();
if (cause != null) { if (cause != null) {
@ -127,13 +132,13 @@ public class EmbeddedChannelTest {
@Test @Test
public void testConstructWithOutHandler() { public void testConstructWithOutHandler() {
EmbeddedChannel channel = new EmbeddedChannel(); EmbeddedChannel channel = new EmbeddedChannel();
Assert.assertTrue(channel.writeInbound(1)); assertTrue(channel.writeInbound(1));
Assert.assertTrue(channel.writeOutbound(2)); assertTrue(channel.writeOutbound(2));
Assert.assertTrue(channel.finish()); assertTrue(channel.finish());
Assert.assertSame(1, channel.readInbound()); assertSame(1, channel.readInbound());
Assert.assertNull(channel.readInbound()); assertNull(channel.readInbound());
Assert.assertSame(2, channel.readOutbound()); assertSame(2, channel.readOutbound());
Assert.assertNull(channel.readOutbound()); assertNull(channel.readOutbound());
} }
// See https://github.com/netty/netty/issues/4316. // See https://github.com/netty/netty/issues/4316.
@ -197,4 +202,49 @@ public class EmbeddedChannelTest {
private interface Action { private interface Action {
ChannelFuture doRun(Channel channel); ChannelFuture doRun(Channel channel);
} }
@Test
public void testHasDisconnect() {
EventOutboundHandler handler = new EventOutboundHandler();
EmbeddedChannel channel = new EmbeddedChannel(true, handler);
assertTrue(channel.disconnect().isSuccess());
assertTrue(channel.close().isSuccess());
assertEquals(EventOutboundHandler.DISCONNECT, handler.pollEvent());
assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent());
assertNull(handler.pollEvent());
}
@Test
public void testHasNoDisconnect() {
EventOutboundHandler handler = new EventOutboundHandler();
EmbeddedChannel channel = new EmbeddedChannel(false, handler);
assertTrue(channel.disconnect().isSuccess());
assertTrue(channel.close().isSuccess());
assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent());
assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent());
assertNull(handler.pollEvent());
}
private static final class EventOutboundHandler extends ChannelOutboundHandlerAdapter {
static final Integer DISCONNECT = 0;
static final Integer CLOSE = 1;
private final Queue<Integer> queue = new ArrayDeque<Integer>();
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(DISCONNECT);
promise.setSuccess();
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(CLOSE);
promise.setSuccess();
}
Integer pollEvent() {
return queue.poll();
}
}
} }