Stop sending DNS queries if promise is cancelled (#10171)

Motivation:
A user might want to cancel DNS resolution when it takes too long.
Currently, there's no way to cancel the internal DNS queries especially when there's a lot of search domains.

Modification:
- Stop sending a DNS query if the original `Promise`, which was passed calling `resolve()`, is canceled.

Result:
- You can now stop sending DNS queries by cancelling the `Promise`.
This commit is contained in:
minux 2020-04-07 19:25:02 +09:00 committed by Norman Maurer
parent 3a2bf416bd
commit 35e95e9a6e
5 changed files with 85 additions and 29 deletions

View File

@ -32,23 +32,26 @@ final class DnsAddressResolveContext extends DnsResolveContext<InetAddress> {
private final AuthoritativeDnsServerCache authoritativeDnsServerCache;
private final boolean completeEarlyIfPossible;
DnsAddressResolveContext(DnsNameResolver parent, String hostname, DnsRecord[] additionals,
DnsAddressResolveContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs, DnsCache resolveCache,
AuthoritativeDnsServerCache authoritativeDnsServerCache,
boolean completeEarlyIfPossible) {
super(parent, hostname, DnsRecord.CLASS_IN, parent.resolveRecordTypes(), additionals, nameServerAddrs);
super(parent, originalPromise, hostname, DnsRecord.CLASS_IN,
parent.resolveRecordTypes(), additionals, nameServerAddrs);
this.resolveCache = resolveCache;
this.authoritativeDnsServerCache = authoritativeDnsServerCache;
this.completeEarlyIfPossible = completeEarlyIfPossible;
}
@Override
DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, String hostname,
DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
return new DnsAddressResolveContext(parent, hostname, additionals, nameServerAddrs, resolveCache,
authoritativeDnsServerCache, completeEarlyIfPossible);
return new DnsAddressResolveContext(parent, originalPromise, hostname, additionals, nameServerAddrs,
resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible);
}
@Override

View File

@ -813,7 +813,7 @@ public class DnsNameResolver extends InetNameResolver {
// It was not A/AAAA question or there was no entry in /etc/hosts.
final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsRecordResolveContext(this, question, additionals, nameServerAddrs).resolve(promise);
new DnsRecordResolveContext(this, promise, question, additionals, nameServerAddrs).resolve(promise);
return promise;
}
@ -937,7 +937,7 @@ public class DnsNameResolver extends InetNameResolver {
final Promise<InetAddress> promise,
DnsCache resolveCache, boolean completeEarlyIfPossible) {
final Promise<List<InetAddress>> allPromise = executor().newPromise();
doResolveAllUncached(hostname, additionals, allPromise, resolveCache, true);
doResolveAllUncached(hostname, additionals, promise, allPromise, resolveCache, true);
allPromise.addListener((FutureListener<List<InetAddress>>) future -> {
if (future.isSuccess()) {
trySuccess(promise, future.getNow().get(0));
@ -981,7 +981,8 @@ public class DnsNameResolver extends InetNameResolver {
}
if (!doResolveAllCached(hostname, additionals, promise, resolveCache, resolvedInternetProtocolFamilies)) {
doResolveAllUncached(hostname, additionals, promise, resolveCache, completeOncePreferredResolved);
doResolveAllUncached(hostname, additionals, promise, promise,
resolveCache, completeOncePreferredResolved);
}
}
@ -1023,6 +1024,7 @@ public class DnsNameResolver extends InetNameResolver {
private void doResolveAllUncached(final String hostname,
final DnsRecord[] additionals,
final Promise<?> originalPromise,
final Promise<List<InetAddress>> promise,
final DnsCache resolveCache,
final boolean completeEarlyIfPossible) {
@ -1030,15 +1032,18 @@ public class DnsNameResolver extends InetNameResolver {
// to submit multiple Runnable at the end if we are not already on the EventLoop.
EventExecutor executor = executor();
if (executor.inEventLoop()) {
doResolveAllUncached0(hostname, additionals, promise, resolveCache, completeEarlyIfPossible);
doResolveAllUncached0(hostname, additionals, originalPromise,
promise, resolveCache, completeEarlyIfPossible);
} else {
executor.execute(() ->
doResolveAllUncached0(hostname, additionals, promise, resolveCache, completeEarlyIfPossible));
doResolveAllUncached0(hostname, additionals, originalPromise,
promise, resolveCache, completeEarlyIfPossible));
}
}
private void doResolveAllUncached0(String hostname,
DnsRecord[] additionals,
Promise<?> originalPromise,
Promise<List<InetAddress>> promise,
DnsCache resolveCache,
boolean completeEarlyIfPossible) {
@ -1047,8 +1052,9 @@ public class DnsNameResolver extends InetNameResolver {
final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsAddressResolveContext(this, hostname, additionals, nameServerAddrs, resolveCache,
authoritativeDnsServerCache, completeEarlyIfPossible).resolve(promise);
new DnsAddressResolveContext(this, originalPromise, hostname, additionals, nameServerAddrs,
resolveCache, authoritativeDnsServerCache, completeEarlyIfPossible)
.resolve(promise);
}
private static String hostname(String inetHost) {

View File

@ -23,29 +23,32 @@ import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Promise;
final class DnsRecordResolveContext extends DnsResolveContext<DnsRecord> {
DnsRecordResolveContext(DnsNameResolver parent, DnsQuestion question, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
this(parent, question.name(), question.dnsClass(),
DnsRecordResolveContext(DnsNameResolver parent, Promise<?> originalPromise, DnsQuestion question,
DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) {
this(parent, originalPromise, question.name(), question.dnsClass(),
new DnsRecordType[] { question.type() },
additionals, nameServerAddrs);
}
private DnsRecordResolveContext(DnsNameResolver parent, String hostname,
private DnsRecordResolveContext(DnsNameResolver parent, Promise<?> originalPromise, String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
super(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs);
super(parent, originalPromise, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs);
}
@Override
DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, String hostname,
DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
return new DnsRecordResolveContext(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs);
return new DnsRecordResolveContext(parent, originalPromise, hostname, dnsClass,
expectedTypes, additionals, nameServerAddrs);
}
@Override

View File

@ -84,6 +84,7 @@ abstract class DnsResolveContext<T> {
"tryToFinishResolve(..)");
final DnsNameResolver parent;
private final Promise<?> originalPromise;
private final DnsServerAddressStream nameServerAddrs;
private final String hostname;
private final int dnsClass;
@ -100,13 +101,13 @@ abstract class DnsResolveContext<T> {
private boolean triedCNAME;
private boolean completeEarly;
DnsResolveContext(DnsNameResolver parent,
DnsResolveContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) {
assert expectedTypes.length > 0;
this.parent = parent;
this.originalPromise = originalPromise;
this.hostname = hostname;
this.dnsClass = dnsClass;
this.expectedTypes = expectedTypes;
@ -152,7 +153,8 @@ abstract class DnsResolveContext<T> {
/**
* Creates a new context with the given parameters.
*/
abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, String hostname,
abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, Promise<?> originalPromise,
String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs);
@ -251,8 +253,8 @@ abstract class DnsResolveContext<T> {
}
void doSearchDomainQuery(String hostname, Promise<List<T>> nextPromise) {
DnsResolveContext<T> nextContext = newResolverContext(parent, hostname, dnsClass, expectedTypes,
additionals, nameServerAddrs);
DnsResolveContext<T> nextContext = newResolverContext(parent, originalPromise, hostname, dnsClass,
expectedTypes, additionals, nameServerAddrs);
nextContext.internalResolve(hostname, nextPromise);
}
@ -341,7 +343,7 @@ abstract class DnsResolveContext<T> {
final Promise<List<T>> promise,
final Throwable cause) {
if (completeEarly || nameServerAddrStreamIndex >= nameServerAddrStream.size() ||
allowedQueries == 0 || promise.isCancelled()) {
allowedQueries == 0 || originalPromise.isCancelled() || promise.isCancelled()) {
tryToFinishResolve(nameServerAddrStream, nameServerAddrStreamIndex, question, queryLifecycleObserver,
promise, cause);
return;
@ -351,7 +353,7 @@ abstract class DnsResolveContext<T> {
final InetSocketAddress nameServerAddr = nameServerAddrStream.next();
if (nameServerAddr.isUnresolved()) {
queryUnresolvedNameserver(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question,
queryUnresolvedNameServer(nameServerAddr, nameServerAddrStream, nameServerAddrStreamIndex, question,
queryLifecycleObserver, promise, cause);
return;
}
@ -402,7 +404,7 @@ abstract class DnsResolveContext<T> {
});
}
private void queryUnresolvedNameserver(final InetSocketAddress nameServerAddr,
private void queryUnresolvedNameServer(final InetSocketAddress nameServerAddr,
final DnsServerAddressStream nameServerAddrStream,
final int nameServerAddrStreamIndex,
final DnsQuestion question,
@ -437,7 +439,7 @@ abstract class DnsResolveContext<T> {
if (!DnsNameResolver.doResolveAllCached(nameServerName, additionals, resolverPromise, resolveCache(),
parent.resolvedInternetProtocolFamiliesUnsafe())) {
final AuthoritativeDnsServerCache authoritativeDnsServerCache = authoritativeDnsServerCache();
new DnsAddressResolveContext(parent, nameServerName, additionals,
new DnsAddressResolveContext(parent, originalPromise, nameServerName, additionals,
parent.newNameServerAddressStream(nameServerName),
resolveCache(), new AuthoritativeDnsServerCache() {
@Override

View File

@ -46,6 +46,7 @@ import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
@ -96,6 +97,7 @@ import java.util.Map.Entry;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
@ -2862,4 +2864,44 @@ public class DnsNameResolverTest {
}
}
}
@Test
public void testCancelPromise() throws Exception {
final EventLoop eventLoop = group.next();
final Promise<InetAddress> promise = eventLoop.newPromise();
final TestDnsServer dnsServer1 = new TestDnsServer(Collections.<String>emptySet()) {
@Override
protected DnsMessage filterMessage(DnsMessage message) {
promise.cancel(true);
return message;
}
};
dnsServer1.start();
final AtomicBoolean isQuerySentToSecondServer = new AtomicBoolean();
final TestDnsServer dnsServer2 = new TestDnsServer(Collections.<String>emptySet()) {
@Override
protected DnsMessage filterMessage(DnsMessage message) {
isQuerySentToSecondServer.set(true);
return message;
}
};
dnsServer2.start();
DnsServerAddressStreamProvider nameServerProvider =
new SequentialDnsServerAddressStreamProvider(dnsServer1.localAddress(),
dnsServer2.localAddress());
final DnsNameResolver resolver = new DnsNameResolverBuilder(group.next())
.dnsQueryLifecycleObserverFactory(new TestRecursiveCacheDnsQueryLifecycleObserverFactory())
.channelType(NioDatagramChannel.class)
.optResourceEnabled(false)
.nameServerProvider(nameServerProvider)
.build();
try {
resolver.resolve("non-existent.netty.io", promise).sync();
fail();
} catch (Exception e) {
assertThat(e, is(instanceOf(CancellationException.class)));
}
assertThat(isQuerySentToSecondServer.get(), is(false));
}
}