CombinedChannelDuplexHandler.removeOutboundHandler() cause connect(...) to not pass the correct parameters (#11414)

Motivation:

Due a bug we did not pass the correct remote and localaddress to the next handler if the outbound portion of the CombinedChannelDuplexHandler was removed

Modifications:

- Call the correct connect(...) method
- Refactor tests to test that the parameters are correctly passed on
- Remvoe some code duplication in the tests

Result:

CombinedChannelDuplexHandler correctly pass parameters on
This commit is contained in:
Norman Maurer 2021-06-24 13:58:17 +02:00
parent 0a3ffc59e3
commit 39d08dbf0c
2 changed files with 234 additions and 192 deletions

View File

@ -261,7 +261,7 @@ public class CombinedChannelDuplexHandler<I extends ChannelHandler, O extends Ch
if (!outboundCtx.removed) { if (!outboundCtx.removed) {
outboundHandler.connect(outboundCtx, remoteAddress, localAddress, promise); outboundHandler.connect(outboundCtx, remoteAddress, localAddress, promise);
} else { } else {
outboundCtx.connect(localAddress, promise); outboundCtx.connect(remoteAddress, localAddress, promise);
} }
} }

View File

@ -27,14 +27,19 @@ import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class CombinedChannelDuplexHandlerTest { public class CombinedChannelDuplexHandlerTest {
private static final Object MSG = new Object(); private static final Object MSG = new Object();
private static final SocketAddress ADDRESS = new InetSocketAddress(0); private static final SocketAddress LOCAL_ADDRESS = new InetSocketAddress(0);
private static final SocketAddress REMOTE_ADDRESS = new InetSocketAddress(0);
private static final Throwable CAUSE = new Throwable();
private static final Object USER_EVENT = new Object();
private enum Event { private enum Event {
REGISTERED, REGISTERED,
@ -138,235 +143,120 @@ public class CombinedChannelDuplexHandlerTest {
@Test @Test
public void testInboundEvents() { public void testInboundEvents() {
final Queue<Event> queue = new ArrayDeque<>(); InboundEventHandler inboundHandler = new InboundEventHandler();
ChannelHandler inboundHandler = new ChannelHandler() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.REGISTERED);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.UNREGISTERED);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.ACTIVE);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.INACTIVE);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
queue.add(Event.CHANNEL_READ);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.CHANNEL_READ_COMPLETE);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
queue.add(Event.USER_EVENT_TRIGGERED);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.CHANNEL_WRITABILITY_CHANGED);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
queue.add(Event.EXCEPTION_CAUGHT);
}
};
CombinedChannelDuplexHandler<ChannelHandler, ChannelHandler> handler = CombinedChannelDuplexHandler<ChannelHandler, ChannelHandler> handler =
new CombinedChannelDuplexHandler<>( new CombinedChannelDuplexHandler<>(
inboundHandler, new ChannelHandler() { }); inboundHandler, new ChannelHandler() { });
EmbeddedChannel channel = new EmbeddedChannel(handler); EmbeddedChannel channel = new EmbeddedChannel();
channel.pipeline().fireChannelWritabilityChanged(); channel.pipeline().addLast(handler);
channel.pipeline().fireUserEventTriggered(MSG); assertEquals(Event.HANDLER_ADDED, inboundHandler.pollEvent());
channel.pipeline().fireChannelRead(MSG);
channel.pipeline().fireChannelReadComplete();
assertEquals(Event.HANDLER_ADDED, queue.poll());
assertEquals(Event.REGISTERED, queue.poll());
assertEquals(Event.ACTIVE, queue.poll());
assertEquals(Event.CHANNEL_WRITABILITY_CHANGED, queue.poll());
assertEquals(Event.USER_EVENT_TRIGGERED, queue.poll());
assertEquals(Event.CHANNEL_READ, queue.poll());
assertEquals(Event.CHANNEL_READ_COMPLETE, queue.poll());
doInboundOperations(channel);
assertInboundOperations(inboundHandler);
handler.removeInboundHandler(); handler.removeInboundHandler();
assertEquals(Event.HANDLER_REMOVED, queue.poll());
assertEquals(Event.HANDLER_REMOVED, inboundHandler.pollEvent());
// These should not be handled by the inboundHandler anymore as it was removed before // These should not be handled by the inboundHandler anymore as it was removed before
channel.pipeline().fireChannelWritabilityChanged(); doInboundOperations(channel);
channel.pipeline().fireUserEventTriggered(MSG);
channel.pipeline().fireChannelRead(MSG);
channel.pipeline().fireChannelReadComplete();
// Should have not received any more events as it was removed before via removeInboundHandler() // Should have not received any more events as it was removed before via removeInboundHandler()
assertTrue(queue.isEmpty()); assertNull(inboundHandler.pollEvent());
try {
channel.checkException();
fail();
} catch (Throwable cause) {
assertSame(CAUSE, cause);
}
assertTrue(channel.finish()); assertTrue(channel.finish());
assertTrue(queue.isEmpty()); assertNull(inboundHandler.pollEvent());
} }
@Test @Test
public void testOutboundEvents() { public void testOutboundEvents() {
final Queue<Event> queue = new ArrayDeque<>(); ChannelInboundHandler inboundHandler = new ChannelInboundHandlerAdapter();
OutboundEventHandler outboundHandler = new OutboundEventHandler();
ChannelHandler inboundHandler = new ChannelHandler() { };
ChannelHandler outboundHandler = new ChannelHandler() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise)
throws Exception {
queue.add(Event.BIND);
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
SocketAddress localAddress, ChannelPromise promise) throws Exception {
queue.add(Event.CONNECT);
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(Event.DISCONNECT);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(Event.CLOSE);
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
queue.add(Event.DEREGISTER);
}
@Override
public void read(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.READ);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
queue.add(Event.WRITE);
}
@Override
public void flush(ChannelHandlerContext ctx) throws Exception {
queue.add(Event.FLUSH);
}
};
CombinedChannelDuplexHandler<ChannelHandler, ChannelHandler> handler = CombinedChannelDuplexHandler<ChannelHandler, ChannelHandler> handler =
new CombinedChannelDuplexHandler<>( new CombinedChannelDuplexHandler<>(
inboundHandler, outboundHandler); inboundHandler, outboundHandler);
EmbeddedChannel channel = new EmbeddedChannel(); EmbeddedChannel channel = new EmbeddedChannel();
channel.pipeline().addFirst(handler); channel.pipeline().addLast(new OutboundEventHandler());
channel.pipeline().addLast(handler);
assertEquals(Event.HANDLER_ADDED, outboundHandler.pollEvent());
doOutboundOperations(channel); doOutboundOperations(channel);
assertEquals(Event.HANDLER_ADDED, queue.poll()); assertOutboundOperations(outboundHandler);
assertEquals(Event.BIND, queue.poll());
assertEquals(Event.CONNECT, queue.poll());
assertEquals(Event.WRITE, queue.poll());
assertEquals(Event.FLUSH, queue.poll());
assertEquals(Event.READ, queue.poll());
assertEquals(Event.CLOSE, queue.poll());
assertEquals(Event.CLOSE, queue.poll());
assertEquals(Event.DEREGISTER, queue.poll());
handler.removeOutboundHandler(); handler.removeOutboundHandler();
assertEquals(Event.HANDLER_REMOVED, queue.poll());
assertEquals(Event.HANDLER_REMOVED, outboundHandler.pollEvent());
// These should not be handled by the inboundHandler anymore as it was removed before // These should not be handled by the inboundHandler anymore as it was removed before
doOutboundOperations(channel); doOutboundOperations(channel);
// Should have not received any more events as it was removed before via removeInboundHandler() // Should have not received any more events as it was removed before via removeInboundHandler()
assertTrue(queue.isEmpty()); assertNull(outboundHandler.pollEvent());
assertTrue(channel.finish()); assertFalse(channel.finish());
assertTrue(queue.isEmpty()); assertNull(outboundHandler.pollEvent());
} }
private static void doOutboundOperations(Channel channel) { private static void doOutboundOperations(Channel channel) {
channel.pipeline().bind(ADDRESS); channel.pipeline().bind(LOCAL_ADDRESS).syncUninterruptibly();
channel.pipeline().connect(ADDRESS); channel.pipeline().connect(REMOTE_ADDRESS, LOCAL_ADDRESS).syncUninterruptibly();
channel.pipeline().write(MSG); channel.pipeline().write(MSG).syncUninterruptibly();
channel.pipeline().flush(); channel.pipeline().flush();
channel.pipeline().read(); channel.pipeline().read();
channel.pipeline().disconnect(); channel.pipeline().disconnect().syncUninterruptibly();
channel.pipeline().close(); channel.pipeline().close().syncUninterruptibly();
channel.pipeline().deregister(); channel.pipeline().deregister().syncUninterruptibly();
}
private static void assertOutboundOperations(OutboundEventHandler outboundHandler) {
assertEquals(Event.BIND, outboundHandler.pollEvent());
assertEquals(Event.CONNECT, outboundHandler.pollEvent());
assertEquals(Event.WRITE, outboundHandler.pollEvent());
assertEquals(Event.FLUSH, outboundHandler.pollEvent());
assertEquals(Event.READ, outboundHandler.pollEvent());
assertEquals(Event.CLOSE, outboundHandler.pollEvent());
assertEquals(Event.CLOSE, outboundHandler.pollEvent());
assertEquals(Event.DEREGISTER, outboundHandler.pollEvent());
}
private static void doInboundOperations(Channel channel) {
channel.pipeline().fireChannelRegistered();
channel.pipeline().fireChannelActive();
channel.pipeline().fireChannelRead(MSG);
channel.pipeline().fireChannelReadComplete();
channel.pipeline().fireExceptionCaught(CAUSE);
channel.pipeline().fireUserEventTriggered(USER_EVENT);
channel.pipeline().fireChannelWritabilityChanged();
channel.pipeline().fireChannelInactive();
channel.pipeline().fireChannelUnregistered();
}
private static void assertInboundOperations(InboundEventHandler handler) {
assertEquals(Event.REGISTERED, handler.pollEvent());
assertEquals(Event.ACTIVE, handler.pollEvent());
assertEquals(Event.CHANNEL_READ, handler.pollEvent());
assertEquals(Event.CHANNEL_READ_COMPLETE, handler.pollEvent());
assertEquals(Event.EXCEPTION_CAUGHT, handler.pollEvent());
assertEquals(Event.USER_EVENT_TRIGGERED, handler.pollEvent());
assertEquals(Event.CHANNEL_WRITABILITY_CHANGED, handler.pollEvent());
assertEquals(Event.INACTIVE, handler.pollEvent());
assertEquals(Event.UNREGISTERED, handler.pollEvent());
} }
@Test @Test
@Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
public void testPromisesPassed() { public void testPromisesPassed() {
ChannelHandler outboundHandler = new ChannelHandler() { OutboundEventHandler outboundHandler = new OutboundEventHandler();
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress,
ChannelPromise promise) throws Exception {
promise.setSuccess();
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
SocketAddress localAddress, ChannelPromise promise) throws Exception {
promise.setSuccess();
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
promise.setSuccess();
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
promise.setSuccess();
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
promise.setSuccess();
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
promise.setSuccess();
}
};
EmbeddedChannel ch = new EmbeddedChannel(outboundHandler, EmbeddedChannel ch = new EmbeddedChannel(outboundHandler,
new CombinedChannelDuplexHandler<ChannelHandler, ChannelHandler>( new CombinedChannelDuplexHandler<ChannelHandler, ChannelHandler>(
new ChannelHandler() { new ChannelHandler() {
@ -374,11 +264,11 @@ public class CombinedChannelDuplexHandlerTest {
ChannelPipeline pipeline = ch.pipeline(); ChannelPipeline pipeline = ch.pipeline();
ChannelPromise promise = ch.newPromise(); ChannelPromise promise = ch.newPromise();
pipeline.connect(new InetSocketAddress(0), null, promise); pipeline.bind(LOCAL_ADDRESS, promise);
promise.syncUninterruptibly(); promise.syncUninterruptibly();
promise = ch.newPromise(); promise = ch.newPromise();
pipeline.bind(new InetSocketAddress(0), promise); pipeline.connect(REMOTE_ADDRESS, LOCAL_ADDRESS, promise);
promise.syncUninterruptibly(); promise.syncUninterruptibly();
promise = ch.newPromise(); promise = ch.newPromise();
@ -390,7 +280,7 @@ public class CombinedChannelDuplexHandlerTest {
promise.syncUninterruptibly(); promise.syncUninterruptibly();
promise = ch.newPromise(); promise = ch.newPromise();
pipeline.write("test", promise); pipeline.write(MSG, promise);
promise.syncUninterruptibly(); promise.syncUninterruptibly();
promise = ch.newPromise(); promise = ch.newPromise();
@ -409,4 +299,156 @@ public class CombinedChannelDuplexHandlerTest {
} }
}); });
} }
private static final class InboundEventHandler implements ChannelHandler {
private final Queue<Object> queue = new ArrayDeque<Object>();
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) {
queue.add(Event.REGISTERED);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) {
queue.add(Event.UNREGISTERED);
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
queue.add(Event.ACTIVE);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
queue.add(Event.INACTIVE);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
queue.add(Event.CHANNEL_READ);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
queue.add(Event.CHANNEL_READ_COMPLETE);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
queue.add(Event.USER_EVENT_TRIGGERED);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) {
queue.add(Event.CHANNEL_WRITABILITY_CHANGED);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
queue.add(Event.EXCEPTION_CAUGHT);
}
Event pollEvent() {
Object o = queue.poll();
if (o instanceof AssertionError) {
throw (AssertionError) o;
}
return (Event) o;
}
}
private static final class OutboundEventHandler implements ChannelHandler {
private final Queue<Object> queue = new ArrayDeque<Object>();
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
try {
assertSame(LOCAL_ADDRESS, localAddress);
queue.add(Event.BIND);
promise.setSuccess();
} catch (AssertionError e) {
promise.setFailure(e);
}
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
SocketAddress localAddress, ChannelPromise promise) {
try {
assertSame(REMOTE_ADDRESS, remoteAddress);
assertSame(LOCAL_ADDRESS, localAddress);
queue.add(Event.CONNECT);
promise.setSuccess();
} catch (AssertionError e) {
promise.setFailure(e);
}
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.DISCONNECT);
promise.setSuccess();
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.CLOSE);
promise.setSuccess();
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.DEREGISTER);
promise.setSuccess();
}
@Override
public void read(ChannelHandlerContext ctx) {
queue.add(Event.READ);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
try {
assertSame(MSG, msg);
queue.add(Event.WRITE);
promise.setSuccess();
} catch (AssertionError e) {
promise.setFailure(e);
}
}
@Override
public void flush(ChannelHandlerContext ctx) {
queue.add(Event.FLUSH);
}
Event pollEvent() {
Object o = queue.poll();
if (o instanceof AssertionError) {
throw (AssertionError) o;
}
return (Event) o;
}
}
} }