diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java index 944e4d7096..da6a97d461 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java @@ -33,23 +33,26 @@ final class DnsAddressResolveContext extends DnsResolveContext { private final AuthoritativeDnsServerCache authoritativeDnsServerCache; private final boolean completeEarlyIfPossible; - DnsAddressResolveContext(DnsNameResolver parent, String hostname, DnsRecord[] additionals, + DnsAddressResolveContext(DnsNameResolver parent, Promise originalPromise, + String hostname, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs, DnsCache resolveCache, AuthoritativeDnsServerCache authoritativeDnsServerCache, boolean completeEarlyIfPossible) { - super(parent, hostname, DnsRecord.CLASS_IN, parent.resolveRecordTypes(), additionals, nameServerAddrs); + super(parent, originalPromise, hostname, DnsRecord.CLASS_IN, + parent.resolveRecordTypes(), additionals, nameServerAddrs); this.resolveCache = resolveCache; this.authoritativeDnsServerCache = authoritativeDnsServerCache; this.completeEarlyIfPossible = completeEarlyIfPossible; } @Override - DnsResolveContext newResolverContext(DnsNameResolver parent, String hostname, + DnsResolveContext newResolverContext(DnsNameResolver parent, Promise originalPromise, + String hostname, int dnsClass, DnsRecordType[] expectedTypes, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { - return new DnsAddressResolveContext(parent, hostname, additionals, nameServerAddrs, resolveCache, - authoritativeDnsServerCache, completeEarlyIfPossible); + return new DnsAddressResolveContext(parent, originalPromise, hostname, additionals, nameServerAddrs, + resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible); } @Override diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 1852e16934..67f85dcce6 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -714,7 +714,7 @@ public class DnsNameResolver extends InetNameResolver { * @return the list of the address as the result of the resolution */ public final Future> resolveAll(String inetHost, Iterable additionals, - Promise> promise) { + Promise> promise) { checkNotNull(promise, "promise"); DnsRecord[] additionalsArray = toArray(additionals, true); try { @@ -816,7 +816,7 @@ public class DnsNameResolver extends InetNameResolver { // It was not A/AAAA question or there was no entry in /etc/hosts. final DnsServerAddressStream nameServerAddrs = dnsServerAddressStreamProvider.nameServerAddressStream(hostname); - new DnsRecordResolveContext(this, question, additionals, nameServerAddrs).resolve(promise); + new DnsRecordResolveContext(this, promise, question, additionals, nameServerAddrs).resolve(promise); return promise; } @@ -940,7 +940,7 @@ public class DnsNameResolver extends InetNameResolver { final Promise promise, DnsCache resolveCache, boolean completeEarlyIfPossible) { final Promise> allPromise = executor().newPromise(); - doResolveAllUncached(hostname, additionals, allPromise, resolveCache, true); + doResolveAllUncached(hostname, additionals, promise, allPromise, resolveCache, true); allPromise.addListener(new FutureListener>() { @Override public void operationComplete(Future> future) { @@ -987,7 +987,8 @@ public class DnsNameResolver extends InetNameResolver { } if (!doResolveAllCached(hostname, additionals, promise, resolveCache, resolvedInternetProtocolFamilies)) { - doResolveAllUncached(hostname, additionals, promise, resolveCache, completeOncePreferredResolved); + doResolveAllUncached(hostname, additionals, promise, promise, + resolveCache, completeOncePreferredResolved); } } @@ -1029,6 +1030,7 @@ public class DnsNameResolver extends InetNameResolver { private void doResolveAllUncached(final String hostname, final DnsRecord[] additionals, + final Promise originalPromise, final Promise> promise, final DnsCache resolveCache, final boolean completeEarlyIfPossible) { @@ -1036,12 +1038,14 @@ public class DnsNameResolver extends InetNameResolver { // to submit multiple Runnable at the end if we are not already on the EventLoop. EventExecutor executor = executor(); if (executor.inEventLoop()) { - doResolveAllUncached0(hostname, additionals, promise, resolveCache, completeEarlyIfPossible); + doResolveAllUncached0(hostname, additionals, originalPromise, + promise, resolveCache, completeEarlyIfPossible); } else { executor.execute(new Runnable() { @Override public void run() { - doResolveAllUncached0(hostname, additionals, promise, resolveCache, completeEarlyIfPossible); + doResolveAllUncached0(hostname, additionals, originalPromise, + promise, resolveCache, completeEarlyIfPossible); } }); } @@ -1049,6 +1053,7 @@ public class DnsNameResolver extends InetNameResolver { private void doResolveAllUncached0(String hostname, DnsRecord[] additionals, + Promise originalPromise, Promise> promise, DnsCache resolveCache, boolean completeEarlyIfPossible) { @@ -1057,8 +1062,9 @@ public class DnsNameResolver extends InetNameResolver { final DnsServerAddressStream nameServerAddrs = dnsServerAddressStreamProvider.nameServerAddressStream(hostname); - new DnsAddressResolveContext(this, hostname, additionals, nameServerAddrs, resolveCache, - authoritativeDnsServerCache, completeEarlyIfPossible).resolve(promise); + new DnsAddressResolveContext(this, originalPromise, hostname, additionals, nameServerAddrs, + resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible) + .resolve(promise); } private static String hostname(String inetHost) { diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java index 93814e2102..ed88ae7492 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java @@ -23,29 +23,32 @@ import io.netty.handler.codec.dns.DnsQuestion; import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecordType; import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; final class DnsRecordResolveContext extends DnsResolveContext { - DnsRecordResolveContext(DnsNameResolver parent, DnsQuestion question, DnsRecord[] additionals, - DnsServerAddressStream nameServerAddrs) { - this(parent, question.name(), question.dnsClass(), + DnsRecordResolveContext(DnsNameResolver parent, Promise originalPromise, DnsQuestion question, + DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { + this(parent, originalPromise, question.name(), question.dnsClass(), new DnsRecordType[] { question.type() }, additionals, nameServerAddrs); } - private DnsRecordResolveContext(DnsNameResolver parent, String hostname, + private DnsRecordResolveContext(DnsNameResolver parent, Promise originalPromise, String hostname, int dnsClass, DnsRecordType[] expectedTypes, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { - super(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs); + super(parent, originalPromise, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs); } @Override - DnsResolveContext newResolverContext(DnsNameResolver parent, String hostname, + DnsResolveContext newResolverContext(DnsNameResolver parent, Promise originalPromise, + String hostname, int dnsClass, DnsRecordType[] expectedTypes, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { - return new DnsRecordResolveContext(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs); + return new DnsRecordResolveContext(parent, originalPromise, hostname, dnsClass, + expectedTypes, additionals, nameServerAddrs); } @Override diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java index 38fbafcaa2..51f93d4006 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java @@ -85,6 +85,7 @@ abstract class DnsResolveContext { "tryToFinishResolve(..)"); final DnsNameResolver parent; + private final Promise originalPromise; private final DnsServerAddressStream nameServerAddrs; private final String hostname; private final int dnsClass; @@ -101,13 +102,13 @@ abstract class DnsResolveContext { private boolean triedCNAME; private boolean completeEarly; - DnsResolveContext(DnsNameResolver parent, + DnsResolveContext(DnsNameResolver parent, Promise originalPromise, String hostname, int dnsClass, DnsRecordType[] expectedTypes, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { - assert expectedTypes.length > 0; this.parent = parent; + this.originalPromise = originalPromise; this.hostname = hostname; this.dnsClass = dnsClass; this.expectedTypes = expectedTypes; @@ -163,7 +164,8 @@ abstract class DnsResolveContext { /** * Creates a new context with the given parameters. */ - abstract DnsResolveContext newResolverContext(DnsNameResolver parent, String hostname, + abstract DnsResolveContext newResolverContext(DnsNameResolver parent, Promise originalPromise, + String hostname, int dnsClass, DnsRecordType[] expectedTypes, DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs); @@ -262,8 +264,8 @@ abstract class DnsResolveContext { } void doSearchDomainQuery(String hostname, Promise> nextPromise) { - DnsResolveContext nextContext = newResolverContext(parent, hostname, dnsClass, expectedTypes, - additionals, nameServerAddrs); + DnsResolveContext nextContext = newResolverContext(parent, originalPromise, hostname, dnsClass, + expectedTypes, additionals, nameServerAddrs); nextContext.internalResolve(hostname, nextPromise); } @@ -352,7 +354,7 @@ abstract class DnsResolveContext { final Promise> promise, final Throwable cause) { if (completeEarly || nameServerAddrStreamIndex >= nameServerAddrStream.size() || - allowedQueries == 0 || promise.isCancelled()) { + allowedQueries == 0 || originalPromise.isCancelled() || promise.isCancelled()) { tryToFinishResolve(nameServerAddrStream, nameServerAddrStreamIndex, question, queryLifecycleObserver, promise, cause); return; @@ -362,7 +364,7 @@ abstract class DnsResolveContext { final InetSocketAddress nameServerAddr = nameServerAddrStream.next(); if (nameServerAddr.isUnresolved()) { - queryUnresolvedNameserver(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question, + queryUnresolvedNameServer(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question, queryLifecycleObserver, promise, cause); return; } @@ -416,7 +418,7 @@ abstract class DnsResolveContext { }); } - private void queryUnresolvedNameserver(final InetSocketAddress nameServerAddr, + private void queryUnresolvedNameServer(final InetSocketAddress nameServerAddr, final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, final DnsQuestion question, @@ -455,7 +457,7 @@ abstract class DnsResolveContext { if (!DnsNameResolver.doResolveAllCached(nameServerName, additionals, resolverPromise, resolveCache(), parent.resolvedInternetProtocolFamiliesUnsafe())) { final AuthoritativeDnsServerCache authoritativeDnsServerCache = authoritativeDnsServerCache(); - new DnsAddressResolveContext(parent, nameServerName, additionals, + new DnsAddressResolveContext(parent, originalPromise, nameServerName, additionals, parent.newNameServerAddressStream(nameServerName), resolveCache(), new AuthoritativeDnsServerCache() { @Override diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index 3e35e288a5..0e0319ca16 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -45,6 +45,7 @@ import io.netty.util.CharsetUtil; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SocketUtils; import io.netty.util.internal.StringUtil; @@ -95,6 +96,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Queue; import java.util.Set; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CopyOnWriteArrayList; @@ -2906,4 +2908,44 @@ public class DnsNameResolverTest { } } } + + @Test + public void testCancelPromise() throws Exception { + final EventLoop eventLoop = group.next(); + final Promise promise = eventLoop.newPromise(); + final TestDnsServer dnsServer1 = new TestDnsServer(Collections.emptySet()) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + promise.cancel(true); + return message; + } + }; + dnsServer1.start(); + final AtomicBoolean isQuerySentToSecondServer = new AtomicBoolean(); + final TestDnsServer dnsServer2 = new TestDnsServer(Collections.emptySet()) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + isQuerySentToSecondServer.set(true); + return message; + } + }; + dnsServer2.start(); + DnsServerAddressStreamProvider nameServerProvider = + new SequentialDnsServerAddressStreamProvider(dnsServer1.localAddress(), + dnsServer2.localAddress()); + final DnsNameResolver resolver = new DnsNameResolverBuilder(group.next()) + .dnsQueryLifecycleObserverFactory(new TestRecursiveCacheDnsQueryLifecycleObserverFactory()) + .channelType(NioDatagramChannel.class) + .optResourceEnabled(false) + .nameServerProvider(nameServerProvider) + .build(); + + try { + resolver.resolve("non-existent.netty.io", promise).sync(); + fail(); + } catch (Exception e) { + assertThat(e, is(instanceOf(CancellationException.class))); + } + assertThat(isQuerySentToSecondServer.get(), is(false)); + } }