From 369d64c3e6050b5f006c359523a60fcc15eb6110 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 9 May 2018 13:48:20 +0200 Subject: [PATCH] Always follow cnames even if a matching A or AAAA record was found. (#7919) Motivation: At the moment if you do a resolveAll and at least one A / AAAA record is present we will not follow any CNAMEs that are also present. This is different to how the JDK behaves. Modifications: - Allows follow CNAMEs. - Add unit test. Result: Fixes https://github.com/netty/netty/issues/7915. --- .../dns/DnsAddressResolveContext.java | 13 ----- .../resolver/dns/DnsRecordResolveContext.java | 8 --- .../netty/resolver/dns/DnsResolveContext.java | 54 +++++++------------ .../resolver/dns/DnsNameResolverTest.java | 52 +++++++++++++++++- 4 files changed, 71 insertions(+), 56 deletions(-) diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java index 7803d733f9..63490efa4a 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsAddressResolveContext.java @@ -49,19 +49,6 @@ final class DnsAddressResolveContext extends DnsResolveContext { return decodeAddress(record, hostname, parent.isDecodeIdn()); } - @Override - boolean containsExpectedResult(List finalResult) { - final int size = finalResult.size(); - final Class inetAddressType = parent.preferredAddressType().addressType(); - for (int i = 0; i < size; i++) { - InetAddress address = finalResult.get(i); - if (inetAddressType.isInstance(address)) { - return true; - } - } - return false; - } - @Override List filterResults(List unfiltered) { final Class inetAddressType = parent.preferredAddressType().addressType(); diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java index 056fa078f8..e633eb8820 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsRecordResolveContext.java @@ -15,9 +15,6 @@ */ package io.netty.resolver.dns; -import static io.netty.resolver.dns.DnsAddressDecoder.decodeAddress; - -import java.net.InetAddress; import java.net.UnknownHostException; import java.util.List; @@ -56,11 +53,6 @@ final class DnsRecordResolveContext extends DnsResolveContext { return ReferenceCountUtil.retain(record); } - @Override - boolean containsExpectedResult(List finalResult) { - return true; - } - @Override List filterResults(List unfiltered) { return unfiltered; diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java index 6300a35898..bc63e710e1 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsResolveContext.java @@ -137,12 +137,6 @@ abstract class DnsResolveContext { */ abstract T convertRecord(DnsRecord record, String hostname, DnsRecord[] additionals, EventLoop eventLoop); - /** - * Returns {@code true} if the given list contains any expected records. {@code finalResult} always contains - * at least one element. - */ - abstract boolean containsExpectedResult(List finalResult); - /** * Returns a filtered list of results which should be the final result of DNS resolution. This must take into * account JDK semantics such as {@link NetUtil#isIpV6AddressesPreferred()}. @@ -173,7 +167,7 @@ abstract class DnsResolveContext { doSearchDomainQuery(initialHostname, new FutureListener>() { private int searchDomainIdx = initialSearchDomainIdx; @Override - public void operationComplete(Future> future) throws Exception { + public void operationComplete(Future> future) { Throwable cause = future.cause(); if (cause == null) { promise.trySuccess(future.getNow()); @@ -232,11 +226,11 @@ abstract class DnsResolveContext { final int end = expectedTypes.length - 1; for (int i = 0; i < end; ++i) { - if (!query(hostname, expectedTypes[i], nameServerAddressStream.duplicate(), promise, null)) { + if (!query(hostname, expectedTypes[i], nameServerAddressStream.duplicate(), promise)) { return; } } - query(hostname, expectedTypes[end], nameServerAddressStream, promise, null); + query(hostname, expectedTypes[end], nameServerAddressStream, promise); } /** @@ -391,7 +385,7 @@ abstract class DnsResolveContext { }); } - void onResponse(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, + private void onResponse(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, final DnsQuestion question, AddressedEnvelope envelope, final DnsQueryLifecycleObserver queryLifecycleObserver, Promise> promise) { @@ -406,7 +400,7 @@ abstract class DnsResolveContext { final DnsRecordType type = question.type(); if (type == DnsRecordType.CNAME) { - onResponseCNAME(question, envelope, queryLifecycleObserver, promise); + onResponseCNAME(question, buildAliasMap(envelope.content()), queryLifecycleObserver, promise); return; } @@ -562,24 +556,20 @@ abstract class DnsResolveContext { // Note that we do not break from the loop here, so we decode/cache all A/AAAA records. } - if (found) { - queryLifecycleObserver.querySucceed(); - return; - } - if (cnames.isEmpty()) { + if (found) { + queryLifecycleObserver.querySucceed(); + return; + } queryLifecycleObserver.queryFailed(NO_MATCHING_RECORD_QUERY_FAILED_EXCEPTION); } else { - // We got only CNAME, not one of expectedTypes. - onResponseCNAME(question, cnames, queryLifecycleObserver, promise); + queryLifecycleObserver.querySucceed(); + // We also got a CNAME so we need to ensure we also query it. + onResponseCNAME(question, cnames, + parent.dnsQueryLifecycleObserverFactory().newDnsQueryLifecycleObserver(question), promise); } } - private void onResponseCNAME(DnsQuestion question, AddressedEnvelope envelope, - final DnsQueryLifecycleObserver queryLifecycleObserver, Promise> promise) { - onResponseCNAME(question, buildAliasMap(envelope.content()), queryLifecycleObserver, promise); - } - private void onResponseCNAME( DnsQuestion question, Map cnames, final DnsQueryLifecycleObserver queryLifecycleObserver, @@ -637,23 +627,19 @@ abstract class DnsResolveContext { return cnames != null? cnames : Collections.emptyMap(); } - void tryToFinishResolve(final DnsServerAddressStream nameServerAddrStream, + private void tryToFinishResolve(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex, final DnsQuestion question, final DnsQueryLifecycleObserver queryLifecycleObserver, final Promise> promise, final Throwable cause) { + // There are no queries left to try. if (!queriesInProgress.isEmpty()) { queryLifecycleObserver.queryCancelled(allowedQueries); - // There are still some queries we did not receive responses for. - if (finalResult != null && containsExpectedResult(finalResult)) { - // But it's OK to finish the resolution process if we got something expected. - finishResolve(promise, cause); - } - - // We did not get an expected result yet, so we can't finish the resolution process. + // There are still some queries in process, we will try to notify once the next one finishes until + // all are finished. return; } @@ -682,7 +668,7 @@ abstract class DnsResolveContext { // As the last resort, try to query CNAME, just in case the name server has it. triedCNAME = true; - query(hostname, DnsRecordType.CNAME, getNameServers(hostname), promise, null); + query(hostname, DnsRecordType.CNAME, getNameServers(hostname), promise); return; } } else { @@ -773,12 +759,12 @@ abstract class DnsResolveContext { } private boolean query(String hostname, DnsRecordType type, DnsServerAddressStream dnsServerAddressStream, - Promise> promise, Throwable cause) { + Promise> promise) { final DnsQuestion question = newQuestion(hostname, type); if (question == null) { return false; } - query(dnsServerAddressStream, 0, question, promise, cause); + query(dnsServerAddressStream, 0, question, promise, null); return true; } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index c7a416da8c..dcda33c8d5 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -72,6 +72,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Locale; import java.util.Map; @@ -567,7 +568,7 @@ public class DnsNameResolverTest { } @Test - public void testQueryMx() throws Exception { + public void testQueryMx() { DnsNameResolver resolver = newResolver().build(); try { assertThat(resolver.isRecursionDesired(), is(true)); @@ -1632,4 +1633,53 @@ public class DnsNameResolverTest { assertEquals(channelFactory, copiedBuilder.channelFactory()); assertEquals(newChannelFactory, builder.channelFactory()); } + + @Test + public void testFollowCNAMEEvenIfARecordIsPresent() throws IOException { + TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() { + + @Override + public Set getRecords(QuestionRecord question) { + if (question.getDomainName().equals("cname.netty.io")) { + Map map1 = new HashMap(); + map1.put(DnsAttribute.IP_ADDRESS.toLowerCase(), "10.0.0.99"); + return Collections.singleton( + new TestDnsServer.TestResourceRecord(question.getDomainName(), RecordType.A, map1)); + } else { + Set records = new LinkedHashSet(2); + Map map = new HashMap(); + map.put(DnsAttribute.DOMAIN_NAME.toLowerCase(), "cname.netty.io"); + records.add(new TestDnsServer.TestResourceRecord( + question.getDomainName(), RecordType.CNAME, map)); + + Map map1 = new HashMap(); + map1.put(DnsAttribute.IP_ADDRESS.toLowerCase(), "10.0.0.2"); + records.add(new TestDnsServer.TestResourceRecord( + question.getDomainName(), RecordType.A, map1)); + return records; + } + } + }); + dnsServer2.start(); + DnsNameResolver resolver = null; + try { + DnsNameResolverBuilder builder = newResolver() + .recursionDesired(true) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress())); + + resolver = builder.build(); + List resolvedAddresses = + resolver.resolveAll("somehost.netty.io").syncUninterruptibly().getNow(); + assertEquals(2, resolvedAddresses.size()); + assertTrue(resolvedAddresses.contains(InetAddress.getByAddress(new byte[] { 10, 0, 0, 99 }))); + assertTrue(resolvedAddresses.contains(InetAddress.getByAddress(new byte[] { 10, 0, 0, 2 }))); + } finally { + dnsServer2.stop(); + if (resolver != null) { + resolver.close(); + } + } + } }