diff --git a/handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java b/handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java new file mode 100644 index 0000000000..e6d53952f2 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java @@ -0,0 +1,76 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.address; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; + +import java.net.NetworkInterface; +import java.net.SocketAddress; + +/** + * {@link ChannelHandler} implementation which allows to dynamically replace the used + * {@code remoteAddress} and / or {@code localAddress} when making a connection attempt. + *
+ * This can be useful to for example bind to a specific {@link NetworkInterface} based on + * the {@code remoteAddress}. + */ +public abstract class DynamicAddressConnectHandler implements ChannelHandler { + + @Override + public final void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + final SocketAddress remote; + final SocketAddress local; + try { + remote = remoteAddress(remoteAddress, localAddress); + local = localAddress(remoteAddress, localAddress); + } catch (Exception e) { + promise.setFailure(e); + return; + } + ctx.connect(remote, local, promise).addListener(future -> { + if (future.isSuccess()) { + // We only remove this handler from the pipeline once the connect was successful as otherwise + // the user may try to connect again. + ctx.pipeline().remove(DynamicAddressConnectHandler.this); + } + }); + } + + /** + * Returns the local {@link SocketAddress} to use for + * {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress)} based on the original {@code remoteAddress} + * and {@code localAddress}. + * By default, this method returns the given {@code localAddress}. + */ + protected SocketAddress localAddress( + @SuppressWarnings("unused") SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + return localAddress; + } + + /** + * Returns the remote {@link SocketAddress} to use for + * {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress)} based on the original {@code remoteAddress} + * and {@code localAddress}. + * By default, this method returns the given {@code remoteAddress}. + */ + protected SocketAddress remoteAddress( + SocketAddress remoteAddress, @SuppressWarnings("unused") SocketAddress localAddress) throws Exception { + return remoteAddress; + } +} diff --git a/handler/src/main/java/io/netty/handler/address/package-info.java b/handler/src/main/java/io/netty/handler/address/package-info.java new file mode 100644 index 0000000000..965faa888c --- /dev/null +++ b/handler/src/main/java/io/netty/handler/address/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Package to dynamically replace local / remote {@link java.net.SocketAddress}. + */ +package io.netty.handler.address; diff --git a/handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java b/handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java new file mode 100644 index 0000000000..ad241a9d91 --- /dev/null +++ b/handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.address; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import java.net.SocketAddress; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class DynamicAddressConnectHandlerTest { + private static final SocketAddress LOCAL = new SocketAddress() { }; + private static final SocketAddress LOCAL_NEW = new SocketAddress() { }; + private static final SocketAddress REMOTE = new SocketAddress() { }; + private static final SocketAddress REMOTE_NEW = new SocketAddress() { }; + @Test + public void testReplaceAddresses() { + + EmbeddedChannel channel = new EmbeddedChannel(new ChannelHandler() { + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + try { + assertSame(REMOTE_NEW, remoteAddress); + assertSame(LOCAL_NEW, localAddress); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }, new DynamicAddressConnectHandler() { + @Override + protected SocketAddress localAddress(SocketAddress remoteAddress, SocketAddress localAddress) { + assertSame(REMOTE, remoteAddress); + assertSame(LOCAL, localAddress); + return LOCAL_NEW; + } + + @Override + protected SocketAddress remoteAddress(SocketAddress remoteAddress, SocketAddress localAddress) { + assertSame(REMOTE, remoteAddress); + assertSame(LOCAL, localAddress); + return REMOTE_NEW; + } + }); + channel.connect(REMOTE, LOCAL).syncUninterruptibly(); + assertNull(channel.pipeline().get(DynamicAddressConnectHandler.class)); + assertFalse(channel.finish()); + } + + @Test + public void testLocalAddressThrows() { + testThrows0(true); + } + + @Test + public void testRemoteAddressThrows() { + testThrows0(false); + } + + private static void testThrows0(final boolean localThrows) { + final IllegalStateException exception = new IllegalStateException(); + + EmbeddedChannel channel = new EmbeddedChannel(new DynamicAddressConnectHandler() { + @Override + protected SocketAddress localAddress( + SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + if (localThrows) { + throw exception; + } + return super.localAddress(remoteAddress, localAddress); + } + + @Override + protected SocketAddress remoteAddress(SocketAddress remoteAddress, SocketAddress localAddress) + throws Exception { + if (!localThrows) { + throw exception; + } + return super.remoteAddress(remoteAddress, localAddress); + } + }); + assertSame(exception, channel.connect(REMOTE, LOCAL).cause()); + assertNotNull(channel.pipeline().get(DynamicAddressConnectHandler.class)); + assertFalse(channel.finish()); + } +}