CombinedChannelDuplexHandler.removeOutboundHandler() cause connect(...) to not pass the correct parameters (#11414)

Motivation:

Due a bug we did not pass the correct remote and localaddress to the next handler if the outbound portion of the CombinedChannelDuplexHandler was removed

Modifications:

- Call the correct connect(...) method
- Refactor tests to test that the parameters are correctly passed on
- Remvoe some code duplication in the tests

Result:

CombinedChannelDuplexHandler correctly pass parameters on
This commit is contained in:
Norman Maurer 2021-06-24 13:58:17 +02:00 committed by GitHub
parent a71ec15fc4
commit 6c618e30af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 233 additions and 190 deletions

View File

@ -295,7 +295,7 @@ public class CombinedChannelDuplexHandler<I extends ChannelInboundHandler, O ext
if (!outboundCtx.removed) {
outboundHandler.connect(outboundCtx, remoteAddress, localAddress, promise);
} else {
outboundCtx.connect(localAddress, promise);
outboundCtx.connect(remoteAddress, localAddress, promise);
}
}

View File

@ -28,14 +28,19 @@ import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class CombinedChannelDuplexHandlerTest {
private static final Object MSG = new Object();
private static final SocketAddress ADDRESS = new InetSocketAddress(0);
private static final SocketAddress LOCAL_ADDRESS = new InetSocketAddress(0);
private static final SocketAddress REMOTE_ADDRESS = new InetSocketAddress(0);
private static final Throwable CAUSE = new Throwable();
private static final Object USER_EVENT = new Object();
private enum Event {
REGISTERED,
@ -162,245 +167,131 @@ public class CombinedChannelDuplexHandlerTest {
@Test
public void testInboundEvents() {
final Queue<Event> queue = new ArrayDeque<Event>();
ChannelInboundHandler inboundHandler = new ChannelInboundHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) {
queue.add(Event.REGISTERED);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) {
queue.add(Event.UNREGISTERED);
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
queue.add(Event.ACTIVE);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
queue.add(Event.INACTIVE);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
queue.add(Event.CHANNEL_READ);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
queue.add(Event.CHANNEL_READ_COMPLETE);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
queue.add(Event.USER_EVENT_TRIGGERED);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) {
queue.add(Event.CHANNEL_WRITABILITY_CHANGED);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
queue.add(Event.EXCEPTION_CAUGHT);
}
};
InboundEventHandler inboundHandler = new InboundEventHandler();
CombinedChannelDuplexHandler<ChannelInboundHandler, ChannelOutboundHandler> handler =
new CombinedChannelDuplexHandler<ChannelInboundHandler, ChannelOutboundHandler>(
inboundHandler, new ChannelOutboundHandlerAdapter());
EmbeddedChannel channel = new EmbeddedChannel(handler);
channel.pipeline().fireChannelWritabilityChanged();
channel.pipeline().fireUserEventTriggered(MSG);
channel.pipeline().fireChannelRead(MSG);
channel.pipeline().fireChannelReadComplete();
assertEquals(Event.HANDLER_ADDED, queue.poll());
assertEquals(Event.REGISTERED, queue.poll());
assertEquals(Event.ACTIVE, queue.poll());
assertEquals(Event.CHANNEL_WRITABILITY_CHANGED, queue.poll());
assertEquals(Event.USER_EVENT_TRIGGERED, queue.poll());
assertEquals(Event.CHANNEL_READ, queue.poll());
assertEquals(Event.CHANNEL_READ_COMPLETE, queue.poll());
EmbeddedChannel channel = new EmbeddedChannel();
channel.pipeline().addLast(handler);
assertEquals(Event.HANDLER_ADDED, inboundHandler.pollEvent());
doInboundOperations(channel);
assertInboundOperations(inboundHandler);
handler.removeInboundHandler();
assertEquals(Event.HANDLER_REMOVED, queue.poll());
assertEquals(Event.HANDLER_REMOVED, inboundHandler.pollEvent());
// These should not be handled by the inboundHandler anymore as it was removed before
channel.pipeline().fireChannelWritabilityChanged();
channel.pipeline().fireUserEventTriggered(MSG);
channel.pipeline().fireChannelRead(MSG);
channel.pipeline().fireChannelReadComplete();
doInboundOperations(channel);
// Should have not received any more events as it was removed before via removeInboundHandler()
assertTrue(queue.isEmpty());
assertNull(inboundHandler.pollEvent());
try {
channel.checkException();
fail();
} catch (Throwable cause) {
assertSame(CAUSE, cause);
}
assertTrue(channel.finish());
assertTrue(queue.isEmpty());
assertNull(inboundHandler.pollEvent());
}
@Test
public void testOutboundEvents() {
final Queue<Event> queue = new ArrayDeque<Event>();
ChannelInboundHandler inboundHandler = new ChannelInboundHandlerAdapter();
ChannelOutboundHandler outboundHandler = new ChannelOutboundHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
queue.add(Event.BIND);
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
SocketAddress localAddress, ChannelPromise promise) {
queue.add(Event.CONNECT);
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.DISCONNECT);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.CLOSE);
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.DEREGISTER);
}
@Override
public void read(ChannelHandlerContext ctx) {
queue.add(Event.READ);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
queue.add(Event.WRITE);
}
@Override
public void flush(ChannelHandlerContext ctx) {
queue.add(Event.FLUSH);
}
};
OutboundEventHandler outboundHandler = new OutboundEventHandler();
CombinedChannelDuplexHandler<ChannelInboundHandler, ChannelOutboundHandler> handler =
new CombinedChannelDuplexHandler<ChannelInboundHandler, ChannelOutboundHandler>(
inboundHandler, outboundHandler);
EmbeddedChannel channel = new EmbeddedChannel();
channel.pipeline().addFirst(handler);
channel.pipeline().addLast(new OutboundEventHandler());
channel.pipeline().addLast(handler);
assertEquals(Event.HANDLER_ADDED, outboundHandler.pollEvent());
doOutboundOperations(channel);
assertEquals(Event.HANDLER_ADDED, queue.poll());
assertEquals(Event.BIND, queue.poll());
assertEquals(Event.CONNECT, queue.poll());
assertEquals(Event.WRITE, queue.poll());
assertEquals(Event.FLUSH, queue.poll());
assertEquals(Event.READ, queue.poll());
assertEquals(Event.CLOSE, queue.poll());
assertEquals(Event.CLOSE, queue.poll());
assertEquals(Event.DEREGISTER, queue.poll());
assertOutboundOperations(outboundHandler);
handler.removeOutboundHandler();
assertEquals(Event.HANDLER_REMOVED, queue.poll());
assertEquals(Event.HANDLER_REMOVED, outboundHandler.pollEvent());
// These should not be handled by the inboundHandler anymore as it was removed before
doOutboundOperations(channel);
// Should have not received any more events as it was removed before via removeInboundHandler()
assertTrue(queue.isEmpty());
assertTrue(channel.finish());
assertTrue(queue.isEmpty());
assertNull(outboundHandler.pollEvent());
assertFalse(channel.finish());
assertNull(outboundHandler.pollEvent());
}
private static void doOutboundOperations(Channel channel) {
channel.pipeline().bind(ADDRESS);
channel.pipeline().connect(ADDRESS);
channel.pipeline().write(MSG);
channel.pipeline().bind(LOCAL_ADDRESS).syncUninterruptibly();
channel.pipeline().connect(REMOTE_ADDRESS, LOCAL_ADDRESS).syncUninterruptibly();
channel.pipeline().write(MSG).syncUninterruptibly();
channel.pipeline().flush();
channel.pipeline().read();
channel.pipeline().disconnect();
channel.pipeline().close();
channel.pipeline().deregister();
channel.pipeline().disconnect().syncUninterruptibly();
channel.pipeline().close().syncUninterruptibly();
channel.pipeline().deregister().syncUninterruptibly();
}
private static void assertOutboundOperations(OutboundEventHandler outboundHandler) {
assertEquals(Event.BIND, outboundHandler.pollEvent());
assertEquals(Event.CONNECT, outboundHandler.pollEvent());
assertEquals(Event.WRITE, outboundHandler.pollEvent());
assertEquals(Event.FLUSH, outboundHandler.pollEvent());
assertEquals(Event.READ, outboundHandler.pollEvent());
assertEquals(Event.CLOSE, outboundHandler.pollEvent());
assertEquals(Event.CLOSE, outboundHandler.pollEvent());
assertEquals(Event.DEREGISTER, outboundHandler.pollEvent());
}
private static void doInboundOperations(Channel channel) {
channel.pipeline().fireChannelRegistered();
channel.pipeline().fireChannelActive();
channel.pipeline().fireChannelRead(MSG);
channel.pipeline().fireChannelReadComplete();
channel.pipeline().fireExceptionCaught(CAUSE);
channel.pipeline().fireUserEventTriggered(USER_EVENT);
channel.pipeline().fireChannelWritabilityChanged();
channel.pipeline().fireChannelInactive();
channel.pipeline().fireChannelUnregistered();
}
private static void assertInboundOperations(InboundEventHandler handler) {
assertEquals(Event.REGISTERED, handler.pollEvent());
assertEquals(Event.ACTIVE, handler.pollEvent());
assertEquals(Event.CHANNEL_READ, handler.pollEvent());
assertEquals(Event.CHANNEL_READ_COMPLETE, handler.pollEvent());
assertEquals(Event.EXCEPTION_CAUGHT, handler.pollEvent());
assertEquals(Event.USER_EVENT_TRIGGERED, handler.pollEvent());
assertEquals(Event.CHANNEL_WRITABILITY_CHANGED, handler.pollEvent());
assertEquals(Event.INACTIVE, handler.pollEvent());
assertEquals(Event.UNREGISTERED, handler.pollEvent());
}
@Test
@Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
public void testPromisesPassed() {
ChannelOutboundHandler outboundHandler = new ChannelOutboundHandlerAdapter() {
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress,
ChannelPromise promise) {
promise.setSuccess();
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
SocketAddress localAddress, ChannelPromise promise) {
promise.setSuccess();
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
promise.setSuccess();
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
promise.setSuccess();
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
promise.setSuccess();
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
promise.setSuccess();
}
};
OutboundEventHandler outboundHandler = new OutboundEventHandler();
EmbeddedChannel ch = new EmbeddedChannel(outboundHandler,
new CombinedChannelDuplexHandler<ChannelInboundHandler, ChannelOutboundHandler>(
new ChannelInboundHandlerAdapter(), new ChannelOutboundHandlerAdapter()));
ChannelPipeline pipeline = ch.pipeline();
ChannelPromise promise = ch.newPromise();
pipeline.connect(new InetSocketAddress(0), null, promise);
pipeline.bind(LOCAL_ADDRESS, promise);
promise.syncUninterruptibly();
promise = ch.newPromise();
pipeline.bind(new InetSocketAddress(0), promise);
pipeline.connect(REMOTE_ADDRESS, LOCAL_ADDRESS, promise);
promise.syncUninterruptibly();
promise = ch.newPromise();
@ -412,7 +303,7 @@ public class CombinedChannelDuplexHandlerTest {
promise.syncUninterruptibly();
promise = ch.newPromise();
pipeline.write("test", promise);
pipeline.write(MSG, promise);
promise.syncUninterruptibly();
promise = ch.newPromise();
@ -435,4 +326,156 @@ public class CombinedChannelDuplexHandlerTest {
}
});
}
private static final class InboundEventHandler extends ChannelInboundHandlerAdapter {
private final Queue<Object> queue = new ArrayDeque<Object>();
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) {
queue.add(Event.REGISTERED);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) {
queue.add(Event.UNREGISTERED);
}
@Override
public void channelActive(ChannelHandlerContext ctx) {
queue.add(Event.ACTIVE);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
queue.add(Event.INACTIVE);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
queue.add(Event.CHANNEL_READ);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
queue.add(Event.CHANNEL_READ_COMPLETE);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
queue.add(Event.USER_EVENT_TRIGGERED);
}
@Override
public void channelWritabilityChanged(ChannelHandlerContext ctx) {
queue.add(Event.CHANNEL_WRITABILITY_CHANGED);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
queue.add(Event.EXCEPTION_CAUGHT);
}
Event pollEvent() {
Object o = queue.poll();
if (o instanceof AssertionError) {
throw (AssertionError) o;
}
return (Event) o;
}
}
private static final class OutboundEventHandler extends ChannelOutboundHandlerAdapter {
private final Queue<Object> queue = new ArrayDeque<Object>();
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_ADDED);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
queue.add(Event.HANDLER_REMOVED);
}
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
try {
assertSame(LOCAL_ADDRESS, localAddress);
queue.add(Event.BIND);
promise.setSuccess();
} catch (AssertionError e) {
promise.setFailure(e);
}
}
@Override
public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress,
SocketAddress localAddress, ChannelPromise promise) {
try {
assertSame(REMOTE_ADDRESS, remoteAddress);
assertSame(LOCAL_ADDRESS, localAddress);
queue.add(Event.CONNECT);
promise.setSuccess();
} catch (AssertionError e) {
promise.setFailure(e);
}
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.DISCONNECT);
promise.setSuccess();
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.CLOSE);
promise.setSuccess();
}
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
queue.add(Event.DEREGISTER);
promise.setSuccess();
}
@Override
public void read(ChannelHandlerContext ctx) {
queue.add(Event.READ);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
try {
assertSame(MSG, msg);
queue.add(Event.WRITE);
promise.setSuccess();
} catch (AssertionError e) {
promise.setFailure(e);
}
}
@Override
public void flush(ChannelHandlerContext ctx) {
queue.add(Event.FLUSH);
}
Event pollEvent() {
Object o = queue.poll();
if (o instanceof AssertionError) {
throw (AssertionError) o;
}
return (Event) o;
}
}
}