Stop sending DNS queries if promise is cancelled (#10171)

Motivation:
A user might want to cancel DNS resolution when it takes too long.
Currently, there's no way to cancel the internal DNS queries especially when there's a lot of search domains.

Modification:
- Stop sending a DNS query if the original `Promise`, which was passed calling `resolve()`, is canceled.

Result:
- You can now stop sending DNS queries by cancelling the `Promise`.
This commit is contained in:
minux 2020-04-07 19:25:02 +09:00 committed by Norman Maurer
parent 3a2bf416bd
commit 35e95e9a6e
5 changed files with 85 additions and 29 deletions

View File

@ -32,23 +32,26 @@ final class DnsAddressResolveContext extends DnsResolveContext<InetAddress> {
private final AuthoritativeDnsServerCache authoritativeDnsServerCache; private final AuthoritativeDnsServerCache authoritativeDnsServerCache;
private final boolean completeEarlyIfPossible; private final boolean completeEarlyIfPossible;
DnsAddressResolveContext(DnsNameResolver parent, String hostname, DnsRecord[] additionals, DnsAddressResolveContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs, DnsCache resolveCache, DnsServerAddressStream nameServerAddrs, DnsCache resolveCache,
AuthoritativeDnsServerCache authoritativeDnsServerCache, AuthoritativeDnsServerCache authoritativeDnsServerCache,
boolean completeEarlyIfPossible) { boolean completeEarlyIfPossible) {
super(parent, hostname, DnsRecord.CLASS_IN, parent.resolveRecordTypes(), additionals, nameServerAddrs); super(parent, originalPromise, hostname, DnsRecord.CLASS_IN,
parent.resolveRecordTypes(), additionals, nameServerAddrs);
this.resolveCache = resolveCache; this.resolveCache = resolveCache;
this.authoritativeDnsServerCache = authoritativeDnsServerCache; this.authoritativeDnsServerCache = authoritativeDnsServerCache;
this.completeEarlyIfPossible = completeEarlyIfPossible; this.completeEarlyIfPossible = completeEarlyIfPossible;
} }
@Override @Override
DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, String hostname, DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) { DnsServerAddressStream nameServerAddrs) {
return new DnsAddressResolveContext(parent, hostname, additionals, nameServerAddrs, resolveCache, return new DnsAddressResolveContext(parent, originalPromise, hostname, additionals, nameServerAddrs,
authoritativeDnsServerCache, completeEarlyIfPossible); resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible);
} }
@Override @Override

View File

@ -711,7 +711,7 @@ public class DnsNameResolver extends InetNameResolver {
* @return the list of the address as the result of the resolution * @return the list of the address as the result of the resolution
*/ */
public final Future<List<InetAddress>> resolveAll(String inetHost, Iterable<DnsRecord> additionals, public final Future<List<InetAddress>> resolveAll(String inetHost, Iterable<DnsRecord> additionals,
Promise<List<InetAddress>> promise) { Promise<List<InetAddress>> promise) {
requireNonNull(promise, "promise"); requireNonNull(promise, "promise");
DnsRecord[] additionalsArray = toArray(additionals, true); DnsRecord[] additionalsArray = toArray(additionals, true);
try { try {
@ -813,7 +813,7 @@ public class DnsNameResolver extends InetNameResolver {
// It was not A/AAAA question or there was no entry in /etc/hosts. // It was not A/AAAA question or there was no entry in /etc/hosts.
final DnsServerAddressStream nameServerAddrs = final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname); dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsRecordResolveContext(this, question, additionals, nameServerAddrs).resolve(promise); new DnsRecordResolveContext(this, promise, question, additionals, nameServerAddrs).resolve(promise);
return promise; return promise;
} }
@ -937,7 +937,7 @@ public class DnsNameResolver extends InetNameResolver {
final Promise<InetAddress> promise, final Promise<InetAddress> promise,
DnsCache resolveCache, boolean completeEarlyIfPossible) { DnsCache resolveCache, boolean completeEarlyIfPossible) {
final Promise<List<InetAddress>> allPromise = executor().newPromise(); final Promise<List<InetAddress>> allPromise = executor().newPromise();
doResolveAllUncached(hostname, additionals, allPromise, resolveCache, true); doResolveAllUncached(hostname, additionals, promise, allPromise, resolveCache, true);
allPromise.addListener((FutureListener<List<InetAddress>>) future -> { allPromise.addListener((FutureListener<List<InetAddress>>) future -> {
if (future.isSuccess()) { if (future.isSuccess()) {
trySuccess(promise, future.getNow().get(0)); trySuccess(promise, future.getNow().get(0));
@ -981,7 +981,8 @@ public class DnsNameResolver extends InetNameResolver {
} }
if (!doResolveAllCached(hostname, additionals, promise, resolveCache, resolvedInternetProtocolFamilies)) { if (!doResolveAllCached(hostname, additionals, promise, resolveCache, resolvedInternetProtocolFamilies)) {
doResolveAllUncached(hostname, additionals, promise, resolveCache, completeOncePreferredResolved); doResolveAllUncached(hostname, additionals, promise, promise,
resolveCache, completeOncePreferredResolved);
} }
} }
@ -1023,6 +1024,7 @@ public class DnsNameResolver extends InetNameResolver {
private void doResolveAllUncached(final String hostname, private void doResolveAllUncached(final String hostname,
final DnsRecord[] additionals, final DnsRecord[] additionals,
final Promise<?> originalPromise,
final Promise<List<InetAddress>> promise, final Promise<List<InetAddress>> promise,
final DnsCache resolveCache, final DnsCache resolveCache,
final boolean completeEarlyIfPossible) { final boolean completeEarlyIfPossible) {
@ -1030,15 +1032,18 @@ public class DnsNameResolver extends InetNameResolver {
// to submit multiple Runnable at the end if we are not already on the EventLoop. // to submit multiple Runnable at the end if we are not already on the EventLoop.
EventExecutor executor = executor(); EventExecutor executor = executor();
if (executor.inEventLoop()) { if (executor.inEventLoop()) {
doResolveAllUncached0(hostname, additionals, promise, resolveCache, completeEarlyIfPossible); doResolveAllUncached0(hostname, additionals, originalPromise,
promise, resolveCache, completeEarlyIfPossible);
} else { } else {
executor.execute(() -> executor.execute(() ->
doResolveAllUncached0(hostname, additionals, promise, resolveCache, completeEarlyIfPossible)); doResolveAllUncached0(hostname, additionals, originalPromise,
promise, resolveCache, completeEarlyIfPossible));
} }
} }
private void doResolveAllUncached0(String hostname, private void doResolveAllUncached0(String hostname,
DnsRecord[] additionals, DnsRecord[] additionals,
Promise<?> originalPromise,
Promise<List<InetAddress>> promise, Promise<List<InetAddress>> promise,
DnsCache resolveCache, DnsCache resolveCache,
boolean completeEarlyIfPossible) { boolean completeEarlyIfPossible) {
@ -1047,8 +1052,9 @@ public class DnsNameResolver extends InetNameResolver {
final DnsServerAddressStream nameServerAddrs = final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname); dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsAddressResolveContext(this, hostname, additionals, nameServerAddrs, resolveCache, new DnsAddressResolveContext(this, originalPromise, hostname, additionals, nameServerAddrs,
authoritativeDnsServerCache, completeEarlyIfPossible).resolve(promise); resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible)
.resolve(promise);
} }
private static String hostname(String inetHost) { private static String hostname(String inetHost) {

View File

@ -23,29 +23,32 @@ import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Promise;
final class DnsRecordResolveContext extends DnsResolveContext<DnsRecord> { final class DnsRecordResolveContext extends DnsResolveContext<DnsRecord> {
DnsRecordResolveContext(DnsNameResolver parent, DnsQuestion question, DnsRecord[] additionals, DnsRecordResolveContext(DnsNameResolver parent, Promise<?> originalPromise, DnsQuestion question,
DnsServerAddressStream nameServerAddrs) { DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) {
this(parent, question.name(), question.dnsClass(), this(parent, originalPromise, question.name(), question.dnsClass(),
new DnsRecordType[] { question.type() }, new DnsRecordType[] { question.type() },
additionals, nameServerAddrs); additionals, nameServerAddrs);
} }
private DnsRecordResolveContext(DnsNameResolver parent, String hostname, private DnsRecordResolveContext(DnsNameResolver parent, Promise<?> originalPromise, String hostname,
int dnsClass, DnsRecordType[] expectedTypes, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) { DnsServerAddressStream nameServerAddrs) {
super(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs); super(parent, originalPromise, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs);
} }
@Override @Override
DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, String hostname, DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) { DnsServerAddressStream nameServerAddrs) {
return new DnsRecordResolveContext(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs); return new DnsRecordResolveContext(parent, originalPromise, hostname, dnsClass,
expectedTypes, additionals, nameServerAddrs);
} }
@Override @Override

View File

@ -84,6 +84,7 @@ abstract class DnsResolveContext<T> {
"tryToFinishResolve(..)"); "tryToFinishResolve(..)");
final DnsNameResolver parent; final DnsNameResolver parent;
private final Promise<?> originalPromise;
private final DnsServerAddressStream nameServerAddrs; private final DnsServerAddressStream nameServerAddrs;
private final String hostname; private final String hostname;
private final int dnsClass; private final int dnsClass;
@ -100,13 +101,13 @@ abstract class DnsResolveContext<T> {
private boolean triedCNAME; private boolean triedCNAME;
private boolean completeEarly; private boolean completeEarly;
DnsResolveContext(DnsNameResolver parent, DnsResolveContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname, int dnsClass, DnsRecordType[] expectedTypes, String hostname, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) { DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) {
assert expectedTypes.length > 0; assert expectedTypes.length > 0;
this.parent = parent; this.parent = parent;
this.originalPromise = originalPromise;
this.hostname = hostname; this.hostname = hostname;
this.dnsClass = dnsClass; this.dnsClass = dnsClass;
this.expectedTypes = expectedTypes; this.expectedTypes = expectedTypes;
@ -152,7 +153,8 @@ abstract class DnsResolveContext<T> {
/** /**
* Creates a new context with the given parameters. * Creates a new context with the given parameters.
*/ */
abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, String hostname, abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs); DnsServerAddressStream nameServerAddrs);
@ -251,8 +253,8 @@ abstract class DnsResolveContext<T> {
} }
void doSearchDomainQuery(String hostname, Promise<List<T>> nextPromise) { void doSearchDomainQuery(String hostname, Promise<List<T>> nextPromise) {
DnsResolveContext<T> nextContext = newResolverContext(parent, hostname, dnsClass, expectedTypes, DnsResolveContext<T> nextContext = newResolverContext(parent, originalPromise, hostname, dnsClass,
additionals, nameServerAddrs); expectedTypes, additionals, nameServerAddrs);
nextContext.internalResolve(hostname, nextPromise); nextContext.internalResolve(hostname, nextPromise);
} }
@ -341,7 +343,7 @@ abstract class DnsResolveContext<T> {
final Promise<List<T>> promise, final Promise<List<T>> promise,
final Throwable cause) { final Throwable cause) {
if (completeEarly || nameServerAddrStreamIndex >= nameServerAddrStream.size() || if (completeEarly || nameServerAddrStreamIndex >= nameServerAddrStream.size() ||
allowedQueries == 0 || promise.isCancelled()) { allowedQueries == 0 || originalPromise.isCancelled() || promise.isCancelled()) {
tryToFinishResolve(nameServerAddrStream, nameServerAddrStreamIndex, question, queryLifecycleObserver, tryToFinishResolve(nameServerAddrStream, nameServerAddrStreamIndex, question, queryLifecycleObserver,
promise, cause); promise, cause);
return; return;
@ -351,7 +353,7 @@ abstract class DnsResolveContext<T> {
final InetSocketAddress nameServerAddr = nameServerAddrStream.next(); final InetSocketAddress nameServerAddr = nameServerAddrStream.next();
if (nameServerAddr.isUnresolved()) { if (nameServerAddr.isUnresolved()) {
queryUnresolvedNameserver(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question, queryUnresolvedNameServer(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question,
queryLifecycleObserver, promise, cause); queryLifecycleObserver, promise, cause);
return; return;
} }
@ -402,7 +404,7 @@ abstract class DnsResolveContext<T> {
}); });
} }
private void queryUnresolvedNameserver(final InetSocketAddress nameServerAddr, private void queryUnresolvedNameServer(final InetSocketAddress nameServerAddr,
final DnsServerAddressStream nameServerAddrStream, final DnsServerAddressStream nameServerAddrStream,
final int nameServerAddrStreamIndex, final int nameServerAddrStreamIndex,
final DnsQuestion question, final DnsQuestion question,
@ -437,7 +439,7 @@ abstract class DnsResolveContext<T> {
if (!DnsNameResolver.doResolveAllCached(nameServerName, additionals, resolverPromise, resolveCache(), if (!DnsNameResolver.doResolveAllCached(nameServerName, additionals, resolverPromise, resolveCache(),
parent.resolvedInternetProtocolFamiliesUnsafe())) { parent.resolvedInternetProtocolFamiliesUnsafe())) {
final AuthoritativeDnsServerCache authoritativeDnsServerCache = authoritativeDnsServerCache(); final AuthoritativeDnsServerCache authoritativeDnsServerCache = authoritativeDnsServerCache();
new DnsAddressResolveContext(parent, nameServerName, additionals, new DnsAddressResolveContext(parent, originalPromise, nameServerName, additionals,
parent.newNameServerAddressStream(nameServerName), parent.newNameServerAddressStream(nameServerName),
resolveCache(), new AuthoritativeDnsServerCache() { resolveCache(), new AuthoritativeDnsServerCache() {
@Override @Override

View File

@ -46,6 +46,7 @@ import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil; import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.SocketUtils; import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
@ -96,6 +97,7 @@ import java.util.Map.Entry;
import java.util.Queue; import java.util.Queue;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
@ -2862,4 +2864,44 @@ public class DnsNameResolverTest {
} }
} }
} }
@Test
public void testCancelPromise() throws Exception {
final EventLoop eventLoop = group.next();
final Promise<InetAddress> promise = eventLoop.newPromise();
final TestDnsServer dnsServer1 = new TestDnsServer(Collections.<String>emptySet()) {
@Override
protected DnsMessage filterMessage(DnsMessage message) {
promise.cancel(true);
return message;
}
};
dnsServer1.start();
final AtomicBoolean isQuerySentToSecondServer = new AtomicBoolean();
final TestDnsServer dnsServer2 = new TestDnsServer(Collections.<String>emptySet()) {
@Override
protected DnsMessage filterMessage(DnsMessage message) {
isQuerySentToSecondServer.set(true);
return message;
}
};
dnsServer2.start();
DnsServerAddressStreamProvider nameServerProvider =
new SequentialDnsServerAddressStreamProvider(dnsServer1.localAddress(),
dnsServer2.localAddress());
final DnsNameResolver resolver = new DnsNameResolverBuilder(group.next())
.dnsQueryLifecycleObserverFactory(new TestRecursiveCacheDnsQueryLifecycleObserverFactory())
.channelType(NioDatagramChannel.class)
.optResourceEnabled(false)
.nameServerProvider(nameServerProvider)
.build();
try {
resolver.resolve("non-existent.netty.io", promise).sync();
fail();
} catch (Exception e) {
assertThat(e, is(instanceOf(CancellationException.class)));
}
assertThat(isQuerySentToSecondServer.get(), is(false));
}
} }