diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 76a82759b8..39f85e9fc1 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -227,6 +227,7 @@ public class DnsNameResolver extends InetNameResolver { private final DnsRecordType[] resolveRecordTypes; private final boolean decodeIdn; private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory; + private final boolean completeOncePreferredResolved; /** * Creates a new DNS-based name resolver that communicates with the specified list of DNS servers. @@ -328,7 +329,7 @@ public class DnsNameResolver extends InetNameResolver { this(eventLoop, channelFactory, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver, - dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn); + dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false); } DnsNameResolver( @@ -349,7 +350,8 @@ public class DnsNameResolver extends InetNameResolver { DnsServerAddressStreamProvider dnsServerAddressStreamProvider, String[] searchDomains, int ndots, - boolean decodeIdn) { + boolean decodeIdn, + boolean completeOncePreferredResolved) { super(eventLoop); this.queryTimeoutMillis = checkPositive(queryTimeoutMillis, "queryTimeoutMillis"); this.resolvedAddressTypes = resolvedAddressTypes != null ? resolvedAddressTypes : DEFAULT_RESOLVE_ADDRESS_TYPES; @@ -371,6 +373,7 @@ public class DnsNameResolver extends InetNameResolver { this.searchDomains = searchDomains != null ? searchDomains.clone() : DEFAULT_SEARCH_DOMAINS; this.ndots = ndots >= 0 ? ndots : DEFAULT_NDOTS; this.decodeIdn = decodeIdn; + this.completeOncePreferredResolved = completeOncePreferredResolved; switch (this.resolvedAddressTypes) { case IPV4_ONLY: @@ -948,7 +951,7 @@ public class DnsNameResolver extends InetNameResolver { } if (!doResolveAllCached(hostname, additionals, promise, resolveCache, resolvedInternetProtocolFamilies)) { - doResolveAllUncached(hostname, additionals, promise, resolveCache, false); + doResolveAllUncached(hostname, additionals, promise, resolveCache, completeOncePreferredResolved); } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java index 6b42464bf5..fa9f9ed72d 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java @@ -22,6 +22,7 @@ import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.resolver.HostsFileEntriesResolver; import io.netty.resolver.ResolvedAddressTypes; +import io.netty.util.concurrent.Future; import io.netty.util.internal.UnstableApi; import java.util.ArrayList; @@ -47,6 +48,7 @@ public final class DnsNameResolverBuilder { private Integer negativeTtl; private long queryTimeoutMillis = 5000; private ResolvedAddressTypes resolvedAddressTypes = DnsNameResolver.DEFAULT_RESOLVE_ADDRESS_TYPES; + private boolean completeOncePreferredResolved; private boolean recursionDesired = true; private int maxQueriesPerResolve = 16; private boolean traceEnabled; @@ -253,6 +255,18 @@ public final class DnsNameResolverBuilder { return this; } + /** + * If {@code true} {@link DnsNameResolver#resolveAll(String)} will notify the returned {@link Future} as + * soon as all queries for the preferred address-type are complete. + * + * @param completeOncePreferredResolved {@code true} to enable, {@code false} to disable. + * @return {@code this} + */ + public DnsNameResolverBuilder completeOncePreferredResolved(boolean completeOncePreferredResolved) { + this.completeOncePreferredResolved = completeOncePreferredResolved; + return this; + } + /** * Sets if this resolver has to send a DNS query with the RD (recursion desired) flag set. * @@ -445,7 +459,8 @@ public final class DnsNameResolverBuilder { dnsServerAddressStreamProvider, searchDomains, ndots, - decodeIdn); + decodeIdn, + completeOncePreferredResolved); } /** @@ -506,6 +521,7 @@ public final class DnsNameResolverBuilder { copiedBuilder.ndots(ndots); copiedBuilder.decodeIdn(decodeIdn); + copiedBuilder.completeOncePreferredResolved(completeOncePreferredResolved); return copiedBuilder; } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java index 0656852693..93420702f6 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java @@ -36,7 +36,6 @@ import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.Promise; -import io.netty.util.internal.ObjectUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.ThrowableUtil; @@ -644,6 +643,7 @@ abstract class DnsResolveContext { final int answerCount = response.count(DnsSection.ANSWER); boolean found = false; + boolean completeEarly = this.completeEarly; for (int i = 0; i < answerCount; i ++) { final DnsRecord r = response.recordAt(DnsSection.ANSWER, i); final DnsRecordType type = r.type(); @@ -688,27 +688,19 @@ abstract class DnsResolveContext { // Check if we did determine we wanted to complete early before. If this is the case we want to not // include the result if (!completeEarly) { - boolean completeEarly = isCompleteEarly(converted); - if (completeEarly) { - this.completeEarly = true; - } - // We want to ensure we do not have duplicates in finalResult as this may be unexpected. - // - // While using a LinkedHashSet or HashSet may sound like the perfect fit for this we will use an - // ArrayList here as duplicates should be found quite unfrequently in the wild and we dont want to pay - // for the extra memory copy and allocations in this cases later on. - if (finalResult == null) { - if (completeEarly) { - finalResult = Collections.singletonList(converted); - } else { - finalResult = new ArrayList<>(8); - finalResult.add(converted); - } - } else if (!finalResult.contains(converted)) { - finalResult.add(converted); - } else { - shouldRelease = true; - } + completeEarly = isCompleteEarly(converted); + } + + // We want to ensure we do not have duplicates in finalResult as this may be unexpected. + // + // While using a LinkedHashSet or HashSet may sound like the perfect fit for this we will use an + // ArrayList here as duplicates should be found quite unfrequently in the wild and we dont want to pay + // for the extra memory copy and allocations in this cases later on. + if (finalResult == null) { + finalResult = new ArrayList<>(8); + finalResult.add(converted); + } else if (!finalResult.contains(converted)) { + finalResult.add(converted); } else { shouldRelease = true; } @@ -724,6 +716,9 @@ abstract class DnsResolveContext { if (cnames.isEmpty()) { if (found) { + if (completeEarly) { + this.completeEarly = true; + } queryLifecycleObserver.querySucceed(); return; } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index 3d9c83bb9b..c4b94c49f8 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -65,6 +65,7 @@ import org.junit.rules.ExpectedException; import java.io.IOException; import java.net.DatagramSocket; +import java.net.Inet4Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; @@ -2593,7 +2594,7 @@ public class DnsNameResolverTest { } @Test(timeout = 2000) - public void testDropAAAAResolveFastFast() throws IOException { + public void testDropAAAAResolveFast() throws IOException { String host = "somehost.netty.io"; TestDnsServer dnsServer2 = new TestDnsServer(Collections.singleton(host)); dnsServer2.start(true); @@ -2616,4 +2617,50 @@ public class DnsNameResolverTest { } } } + + @Test(timeout = 2000) + public void testDropAAAAResolveAllFast() throws IOException { + final String host = "somehost.netty.io"; + TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() { + @Override + public Set getRecords(QuestionRecord question) throws DnsException { + String name = question.getDomainName(); + if (name.equals(host)) { + Set records = new HashSet(2); + records.add(new TestDnsServer.TestResourceRecord(name, RecordType.A, + Collections.singletonMap(DnsAttribute.IP_ADDRESS.toLowerCase(), + "10.0.0.1"))); + records.add(new TestDnsServer.TestResourceRecord(name, RecordType.A, + Collections.singletonMap(DnsAttribute.IP_ADDRESS.toLowerCase(), + "10.0.0.2"))); + return records; + } + return null; + } + }); + dnsServer2.start(true); + DnsNameResolver resolver = null; + try { + DnsNameResolverBuilder builder = newResolver() + .recursionDesired(false) + .queryTimeoutMillis(10000) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED) + .completeOncePreferredResolved(true) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress())); + + resolver = builder.build(); + List addresses = resolver.resolveAll(host).syncUninterruptibly().getNow(); + assertEquals(2, addresses.size()); + for (InetAddress address: addresses) { + assertThat(address, instanceOf(Inet4Address.class)); + assertEquals(host, address.getHostName()); + } + } finally { + dnsServer2.stop(); + if (resolver != null) { + resolver.close(); + } + } + } }