From 661ff2538e48977b974a78a870fec0712452dd93 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 25 Jan 2017 09:28:22 +0100 Subject: [PATCH] Implement correct handling of recursive DNS Motivation: DnsNameResolver does not handle recursive DNS and so fails if you query a DNS server (for example a ROOT dns server) which provides the correct redirect for a domain. Modification: Add support for redirects (a.k.a. handling of AUTHORITY section'). Result: Its now possible to use a DNS server that redirects. --- .../netty/resolver/dns/DnsNameResolver.java | 56 ++- .../resolver/dns/DnsNameResolverBuilder.java | 28 +- .../resolver/dns/DnsNameResolverContext.java | 371 ++++++++++++++++-- .../resolver/dns/DnsServerAddresses.java | 2 +- .../resolver/dns/DnsNameResolverTest.java | 170 ++++++++ .../io/netty/resolver/dns/TestDnsServer.java | 16 +- 6 files changed, 572 insertions(+), 71 deletions(-) diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 10c0065e54..b61e7aa6d4 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -132,6 +132,7 @@ public class DnsNameResolver extends InetNameResolver { * Cache for {@link #doResolve(String, Promise)} and {@link #doResolveAll(String, Promise)}. */ private final DnsCache resolveCache; + private final DnsCache authoritativeDnsServerCache; private final FastThreadLocal nameServerAddrStream = new FastThreadLocal() { @@ -151,8 +152,8 @@ public class DnsNameResolver extends InetNameResolver { private final HostsFileEntriesResolver hostsFileEntriesResolver; private final String[] searchDomains; private final int ndots; - private final boolean cnameFollowARecords; - private final boolean cnameFollowAAAARecords; + private final boolean supportsAAAARecords; + private final boolean supportsARecords; private final InternetProtocolFamily preferredAddressType; private final DnsRecordType[] resolveRecordTypes; private final boolean decodeIdn; @@ -176,9 +177,9 @@ public class DnsNameResolver extends InetNameResolver { * @param hostsFileEntriesResolver the {@link HostsFileEntriesResolver} used to check for local aliases * @param searchDomains the list of search domain * @param ndots the ndots value - * @deprecated use {@link #DnsNameResolver(EventLoop, ChannelFactory, DnsServerAddresses, DnsCache, long, - * InternetProtocolFamily[], boolean, int, boolean, int, - * boolean, HostsFileEntriesResolver, String[], int, boolean)} + * @deprecated use {@link DnsNameResolver#DnsNameResolver(EventLoop, ChannelFactory, DnsServerAddresses, DnsCache, + * DnsCache, long, InternetProtocolFamily[], boolean, int, boolean, int, boolean, + * HostsFileEntriesResolver, String[], int, boolean)} */ @Deprecated public DnsNameResolver( @@ -196,11 +197,10 @@ public class DnsNameResolver extends InetNameResolver { HostsFileEntriesResolver hostsFileEntriesResolver, String[] searchDomains, int ndots) { - this(eventLoop, channelFactory, nameServerAddresses, resolveCache, queryTimeoutMillis, resolvedAddressTypes, - recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, - hostsFileEntriesResolver, searchDomains, ndots, true); + this(eventLoop, channelFactory, nameServerAddresses, resolveCache, NoopDnsCache.INSTANCE, queryTimeoutMillis, + resolvedAddressTypes, recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, + optResourceEnabled, hostsFileEntriesResolver, searchDomains, ndots, true); } - /** * Creates a new DNS-based name resolver that communicates with the specified list of DNS servers. * @@ -210,6 +210,7 @@ public class DnsNameResolver extends InetNameResolver { * this to determine which DNS server should be contacted for the next retry in case * of failure. * @param resolveCache the DNS resolved entries cache + * @param authoritativeDnsServerCache the cache used to find the authoritative DNS server for a domain * @param queryTimeoutMillis timeout of each DNS query in millis * @param resolvedAddressTypes list of the protocol families * @param recursionDesired if recursion desired flag must be set @@ -228,6 +229,7 @@ public class DnsNameResolver extends InetNameResolver { ChannelFactory channelFactory, DnsServerAddresses nameServerAddresses, final DnsCache resolveCache, + DnsCache authoritativeDnsServerCache, long queryTimeoutMillis, InternetProtocolFamily[] resolvedAddressTypes, boolean recursionDesired, @@ -239,7 +241,6 @@ public class DnsNameResolver extends InetNameResolver { String[] searchDomains, int ndots, boolean decodeIdn) { - super(eventLoop); checkNotNull(channelFactory, "channelFactory"); this.nameServerAddresses = checkNotNull(nameServerAddresses, "nameServerAddresses"); @@ -252,22 +253,23 @@ public class DnsNameResolver extends InetNameResolver { this.optResourceEnabled = optResourceEnabled; this.hostsFileEntriesResolver = checkNotNull(hostsFileEntriesResolver, "hostsFileEntriesResolver"); this.resolveCache = checkNotNull(resolveCache, "resolveCache"); + this.authoritativeDnsServerCache = checkNotNull(authoritativeDnsServerCache, "authoritativeDnsServerCache"); this.searchDomains = checkNotNull(searchDomains, "searchDomains").clone(); this.ndots = checkPositiveOrZero(ndots, "ndots"); this.decodeIdn = decodeIdn; - boolean cnameFollowARecords = false; - boolean cnameFollowAAAARecords = false; + boolean supportsARecords = false; + boolean supportsAAAARecords = false; // Use LinkedHashSet to maintain correct ordering. Set recordTypes = new LinkedHashSet(resolvedAddressTypes.length); for (InternetProtocolFamily family: resolvedAddressTypes) { switch (family) { case IPv4: - cnameFollowARecords = true; + supportsARecords = true; recordTypes.add(DnsRecordType.A); break; case IPv6: - cnameFollowAAAARecords = true; + supportsAAAARecords = true; recordTypes.add(DnsRecordType.AAAA); break; default: @@ -276,9 +278,9 @@ public class DnsNameResolver extends InetNameResolver { } // One of both must be always true. - assert cnameFollowARecords || cnameFollowAAAARecords; - this.cnameFollowAAAARecords = cnameFollowAAAARecords; - this.cnameFollowARecords = cnameFollowARecords; + assert supportsARecords || supportsAAAARecords; + this.supportsAAAARecords = supportsAAAARecords; + this.supportsARecords = supportsARecords; resolveRecordTypes = recordTypes.toArray(new DnsRecordType[recordTypes.size()]); preferredAddressType = resolvedAddressTypes[0]; @@ -306,6 +308,11 @@ public class DnsNameResolver extends InetNameResolver { }); } + // Only here to override in unit tests. + int dnsRedirectPort(@SuppressWarnings("unused") InetAddress server) { + return DnsServerAddresses.DNS_PORT; + } + /** * Returns the resolution cache. */ @@ -313,6 +320,13 @@ public class DnsNameResolver extends InetNameResolver { return resolveCache; } + /** + * Returns the cache used for authoritative DNS servers for a domain. + */ + public DnsCache authoritativeDnsServerCache() { + return authoritativeDnsServerCache; + } + /** * Returns the timeout of each DNS query performed by this resolver (in milliseconds). * The default value is 5 seconds. @@ -342,12 +356,12 @@ public class DnsNameResolver extends InetNameResolver { return ndots; } - final boolean isCnameFollowAAAARecords() { - return cnameFollowAAAARecords; + final boolean supportsAAAARecords() { + return supportsAAAARecords; } - final boolean isCnameFollowARecords() { - return cnameFollowARecords; + final boolean supportsARecords() { + return supportsARecords; } final InternetProtocolFamily preferredAddressType() { diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java index 1b4561d904..2eeae7ee1c 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java @@ -40,6 +40,7 @@ public final class DnsNameResolverBuilder { private ChannelFactory channelFactory; private DnsServerAddresses nameServerAddresses = DnsServerAddresses.defaultAddresses(); private DnsCache resolveCache; + private DnsCache authoritativeDnsServerCache; private Integer minTtl; private Integer maxTtl; private Integer negativeTtl; @@ -109,6 +110,17 @@ public final class DnsNameResolverBuilder { return this; } + /** + * Sets the cache for authoritive NS servers + * + * @param authoritativeDnsServerCache the authoritive NS servers cache + * @return {@code this} + */ + public DnsNameResolverBuilder authoritativeDnsServerCache(DnsCache authoritativeDnsServerCache) { + this.authoritativeDnsServerCache = authoritativeDnsServerCache; + return this; + } + /** * Sets the minimum and maximum TTL of the cached DNS resource records (in seconds). If the TTL of the DNS * resource record returned by the DNS server is less than the minimum TTL or greater than the maximum TTL, @@ -331,6 +343,10 @@ public final class DnsNameResolverBuilder { return this; } + private DnsCache newCache() { + return new DefaultDnsCache(intValue(minTtl, 0), intValue(maxTtl, Integer.MAX_VALUE), intValue(negativeTtl, 0)); + } + /** * Set if domain / host names should be decoded to unicode when received. * See rfc3492. @@ -349,19 +365,23 @@ public final class DnsNameResolverBuilder { * @return a {@link DnsNameResolver} */ public DnsNameResolver build() { - if (resolveCache != null && (minTtl != null || maxTtl != null || negativeTtl != null)) { throw new IllegalStateException("resolveCache and TTLs are mutually exclusive"); } - DnsCache cache = resolveCache != null ? resolveCache : - new DefaultDnsCache(intValue(minTtl, 0), intValue(maxTtl, Integer.MAX_VALUE), intValue(negativeTtl, 0)); + if (authoritativeDnsServerCache != null && (minTtl != null || maxTtl != null || negativeTtl != null)) { + throw new IllegalStateException("authoritativeDnsServerCache and TTLs are mutually exclusive"); + } + DnsCache resolveCache = this.resolveCache != null ? this.resolveCache : newCache(); + DnsCache authoritativeDnsServerCache = this.authoritativeDnsServerCache != null ? + this.authoritativeDnsServerCache : newCache(); return new DnsNameResolver( eventLoop, channelFactory, nameServerAddresses, - cache, + resolveCache, + authoritativeDnsServerCache, queryTimeoutMillis, resolvedAddressTypes, recursionDesired, diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java index 5884513d84..258b2f3586 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java @@ -145,15 +145,103 @@ abstract class DnsNameResolverContext { } private void internalResolve(Promise promise) { - InetSocketAddress nameServerAddrToTry = nameServerAddrs.next(); + DnsServerAddressStream nameServerAddressStream = getNameServers(hostname); + for (DnsRecordType type: parent.resolveRecordTypes()) { - if (!query(hostname, type, nameServerAddrToTry, promise)) { + if (!query(hostname, type, nameServerAddressStream, promise)) { return; } } } - private void query(InetSocketAddress nameServerAddr, final DnsQuestion question, final Promise promise) { + /** + * Add an authoritative nameserver to the cache if its not a root server. + */ + private void addNameServerToCache( + AuthoritativeNameServer name, InetAddress resolved, long ttl) { + if (!name.isRootServer()) { + // Cache NS record if not for a root server as we should never cache for root servers. + parent.authoritativeDnsServerCache().cache(name.domainName(), + additionals, resolved, ttl, parent.ch.eventLoop()); + } + } + + /** + * Returns the {@link DnsServerAddressStream} that was cached for the given hostname or {@code null} if non + * could be found. + */ + private DnsServerAddressStream getNameServersFromCache(String hostname) { + int len = hostname.length(); + + if (len == 0) { + // We never cache for root servers. + return null; + } + + // We always store in the cache with a trailing '.'. + if (hostname.charAt(len - 1) != '.') { + hostname += "."; + } + + int idx = hostname.indexOf('.'); + if (idx == hostname.length() - 1) { + // We are not interested in handling '.' as we should never serve the root servers from cache. + return null; + } + + // We start from the closed match and then move down. + for (;;) { + // Skip '.' as well. + hostname = hostname.substring(idx + 1); + + int idx2 = hostname.indexOf('.'); + if (idx2 <= 0 || idx2 == hostname.length() - 1) { + // We are not interested in handling '.TLD.' as we should never serve the root servers from cache. + return null; + } + idx = idx2; + + List entries = parent.authoritativeDnsServerCache().get(hostname, additionals); + if (entries != null && !entries.isEmpty()) { + // Found a match in the cache... Also shuffle them so we not always use the same order for lookup. + return DnsServerAddresses.shuffled(new DnsCacheIterable(entries)).stream(); + } + } + } + + private final class DnsCacheIterable implements Iterable { + private final List entries; + + DnsCacheIterable(List entries) { + this.entries = entries; + } + + @Override + public Iterator iterator() { + return new Iterator() { + Iterator entryIterator = entries.iterator(); + + @Override + public boolean hasNext() { + return entryIterator.hasNext(); + } + + @Override + public InetSocketAddress next() { + InetAddress address = entryIterator.next().address(); + return new InetSocketAddress(address, parent.dnsRedirectPort(address)); + } + + @Override + public void remove() { + entryIterator.remove(); + } + }; + } + } + + private void query(final DnsServerAddressStream nameServerAddrStream, final DnsQuestion question, + final Promise promise) { if (allowedQueries == 0 || promise.isCancelled()) { tryToFinishResolve(promise); return; @@ -162,7 +250,7 @@ abstract class DnsNameResolverContext { allowedQueries --; final Future> f = parent.query0( - nameServerAddr, question, additionals, + nameServerAddrStream.next(), question, additionals, parent.ch.eventLoop().>newPromise()); queriesInProgress.add(f); @@ -177,13 +265,13 @@ abstract class DnsNameResolverContext { try { if (future.isSuccess()) { - onResponse(question, future.getNow(), promise); + onResponse(nameServerAddrStream, question, future.getNow(), promise); } else { // Server did not respond or I/O error occurred; try again. if (traceEnabled) { addTrace(future.cause()); } - query(nameServerAddrs.next(), question, promise); + query(nameServerAddrStream, question, promise); } } finally { tryToFinishResolve(promise); @@ -192,13 +280,18 @@ abstract class DnsNameResolverContext { }); } - void onResponse(final DnsQuestion question, AddressedEnvelope envelope, - Promise promise) { + void onResponse(final DnsServerAddressStream nameServerAddrStream, final DnsQuestion question, + AddressedEnvelope envelope, Promise promise) { try { final DnsResponse res = envelope.content(); final DnsResponseCode code = res.code(); if (code == DnsResponseCode.NOERROR) { + if (handleRedirect(question, envelope, promise)) { + // Was a redirect so return here as everything else is handled in handleRedirect(...) + return; + } final DnsRecordType type = question.type(); + if (type == DnsRecordType.A || type == DnsRecordType.AAAA) { onResponseAorAAAA(type, question, envelope, promise); } else if (type == DnsRecordType.CNAME) { @@ -215,13 +308,86 @@ abstract class DnsNameResolverContext { // Retry with the next server if the server did not tell us that the domain does not exist. if (code != DnsResponseCode.NXDOMAIN) { - query(nameServerAddrs.next(), question, promise); + query(nameServerAddrStream, question, promise); } } finally { ReferenceCountUtil.safeRelease(envelope); } } + /** + * Handle redirects if needed and returns {@code true} if there was a redirect handled. If {@code true} is returned + * this method took ownership of the {@code promise} otherwise {@code false} is returned. + */ + private boolean handleRedirect( + DnsQuestion question, AddressedEnvelope envelope, Promise promise) { + final DnsResponse res = envelope.content(); + + // Check if we have answers, if not this may be an non authority NS and so redirects must be handled. + if (res.count(DnsSection.ANSWER) == 0) { + AuthoritativeNameServerList serverNames = extractAuthoritativeNameServers(question.name(), res); + + if (serverNames != null) { + List nameServers = new ArrayList(serverNames.size()); + int additionalCount = res.count(DnsSection.ADDITIONAL); + + for (int i = 0; i < additionalCount; i++) { + final DnsRecord r = res.recordAt(DnsSection.ADDITIONAL, i); + + if ((r.type() == DnsRecordType.A && !parent.supportsARecords()) || + r.type() == DnsRecordType.AAAA && !parent.supportsAAAARecords()) { + continue; + } + + final String recordName = r.name(); + AuthoritativeNameServer authoritativeNameServer = + serverNames.remove(recordName); + + if (authoritativeNameServer == null) { + // Not a server we are interested in. + continue; + } + + InetAddress resolved = parseAddress(r, recordName); + if (resolved == null) { + // Could not parse it, move to the next. + continue; + } + + nameServers.add(new InetSocketAddress(resolved, parent.dnsRedirectPort(resolved))); + addNameServerToCache(authoritativeNameServer, resolved, r.timeToLive()); + } + + if (nameServers.isEmpty()) { + promise.tryFailure( + new UnknownHostException("Unable to find correct name server for " + hostname)); + } else { + // Shuffle as we want to re-distribute the load across name servers. + query(DnsServerAddresses.shuffled(nameServers).stream(), question, promise); + } + return true; + } + } + return false; + } + + /** + * Returns the {@code {@link AuthoritativeNameServerList} which were included in {@link DnsSection#AUTHORITY} + * or {@code null} if non are found. + */ + private static AuthoritativeNameServerList extractAuthoritativeNameServers(String questionName, DnsResponse res) { + int authorityCount = res.count(DnsSection.AUTHORITY); + if (authorityCount == 0) { + return null; + } + + AuthoritativeNameServerList serverNames = new AuthoritativeNameServerList(questionName); + for (int i = 0; i < authorityCount; i++) { + serverNames.add(res.recordAt(DnsSection.AUTHORITY, i)); + } + return serverNames; + } + private void onResponseAorAAAA( DnsRecordType qType, DnsQuestion question, AddressedEnvelope envelope, Promise promise) { @@ -239,16 +405,16 @@ abstract class DnsNameResolverContext { continue; } - final String qName = question.name().toLowerCase(Locale.US); - final String rName = r.name().toLowerCase(Locale.US); + final String questionName = question.name().toLowerCase(Locale.US); + final String recordName = r.name().toLowerCase(Locale.US); // Make sure the record is for the questioned domain. - if (!rName.equals(qName)) { + if (!recordName.equals(questionName)) { // Even if the record's name is not exactly same, it might be an alias defined in the CNAME records. - String resolved = qName; + String resolved = questionName; do { resolved = cnames.get(resolved); - if (rName.equals(resolved)) { + if (recordName.equals(resolved)) { break; } } while (resolved != null); @@ -258,28 +424,11 @@ abstract class DnsNameResolverContext { } } - if (!(r instanceof DnsRawRecord)) { + InetAddress resolved = parseAddress(r, hostname); + if (resolved == null) { continue; } - final ByteBuf content = ((ByteBufHolder) r).content(); - final int contentLen = content.readableBytes(); - if (contentLen != INADDRSZ4 && contentLen != INADDRSZ6) { - continue; - } - - final byte[] addrBytes = new byte[contentLen]; - content.getBytes(content.readerIndex(), addrBytes); - - final InetAddress resolved; - try { - resolved = InetAddress.getByAddress( - parent.isDecodeIdn() ? IDN.toUnicode(hostname) : hostname, addrBytes); - } catch (UnknownHostException e) { - // Should never reach here. - throw new Error(e); - } - if (resolvedEntries == null) { resolvedEntries = new ArrayList(8); } @@ -306,6 +455,28 @@ abstract class DnsNameResolverContext { } } + private InetAddress parseAddress(DnsRecord r, String name) { + if (!(r instanceof DnsRawRecord)) { + return null; + } + final ByteBuf content = ((ByteBufHolder) r).content(); + final int contentLen = content.readableBytes(); + if (contentLen != INADDRSZ4 && contentLen != INADDRSZ6) { + return null; + } + + final byte[] addrBytes = new byte[contentLen]; + content.getBytes(content.readerIndex(), addrBytes); + + try { + return InetAddress.getByAddress( + parent.isDecodeIdn() ? IDN.toUnicode(name) : name, addrBytes); + } catch (UnknownHostException e) { + // Should never reach here. + throw new Error(e); + } + } + private void onResponseCNAME(DnsQuestion question, AddressedEnvelope envelope, Promise promise) { onResponseCNAME(question, envelope, buildAliasMap(envelope.content()), true, promise); @@ -386,7 +557,8 @@ abstract class DnsNameResolverContext { if (!triedCNAME) { // As the last resort, try to query CNAME, just in case the name server has it. triedCNAME = true; - query(hostname, DnsRecordType.CNAME, nameServerAddrs.next(), promise); + + query(hostname, DnsRecordType.CNAME, getNameServers(hostname), promise); return; } } @@ -496,6 +668,11 @@ abstract class DnsNameResolverContext { } } + private DnsServerAddressStream getNameServers(String hostame) { + DnsServerAddressStream stream = getNameServersFromCache(hostame); + return stream == null ? nameServerAddrs : stream; + } + private void followCname(InetSocketAddress nameServerAddr, String name, String cname, Promise promise) { if (traceEnabled) { @@ -512,16 +689,18 @@ abstract class DnsNameResolverContext { trace.append(cname); } - final InetSocketAddress nextAddr = nameServerAddrs.next(); - if (parent.isCnameFollowARecords() && !query(hostname, DnsRecordType.A, nextAddr, promise)) { + // Use the same server for both CNAME queries + DnsServerAddressStream stream = DnsServerAddresses.singleton(getNameServers(cname).next()).stream(); + + if (parent.supportsARecords() && !query(hostname, DnsRecordType.A, stream, promise)) { return; } - if (parent.isCnameFollowAAAARecords()) { - query(hostname, DnsRecordType.AAAA, nextAddr, promise); + if (parent.supportsAAAARecords()) { + query(hostname, DnsRecordType.AAAA, stream, promise); } } - private boolean query(String hostname, DnsRecordType type, final InetSocketAddress nextAddr, Promise promise) { + private boolean query(String hostname, DnsRecordType type, DnsServerAddressStream nextAddr, Promise promise) { final DnsQuestion question; try { question = new DefaultDnsQuestion(hostname, type); @@ -559,4 +738,118 @@ abstract class DnsNameResolverContext { trace.append("Caused by: "); trace.append(cause); } + + /** + * Holds the closed DNS Servers for a domain. + */ + private static final class AuthoritativeNameServerList { + + private final String questionName; + + // We not expect the linked-list to be very long so a double-linked-list is overkill. + private AuthoritativeNameServer head; + private int count; + + AuthoritativeNameServerList(String questionName) { + this.questionName = questionName.toLowerCase(Locale.US); + } + + void add(DnsRecord r) { + if (r.type() != DnsRecordType.NS || !(r instanceof DnsRawRecord)) { + return; + } + + // Only include servers that serve the correct domain. + if (questionName.length() < r.name().length()) { + return; + } + + String recordName = r.name().toLowerCase(Locale.US); + + int dots = 0; + for (int a = recordName.length() - 1, b = questionName.length() - 1; a >= 0; a--, b--) { + char c = recordName.charAt(a); + if (questionName.charAt(b) != c) { + return; + } + if (c == '.') { + dots++; + } + } + + if (head != null && head.dots > dots) { + // We already have a closer match so ignore this one, no need to parse the domainName etc. + return; + } + + final ByteBuf recordContent = ((ByteBufHolder) r).content(); + final String domainName = decodeDomainName(recordContent); + if (domainName == null) { + // Could not be parsed, ignore. + return; + } + + // We are only interested in preserving the nameservers which are the closest to our qName, so ensure + // we drop servers that have a smaller dots count. + if (head == null || head.dots < dots) { + count = 1; + head = new AuthoritativeNameServer(dots, recordName, domainName); + } else if (head.dots == dots) { + AuthoritativeNameServer serverName = head; + while (serverName.next != null) { + serverName = serverName.next; + } + serverName.next = new AuthoritativeNameServer(dots, recordName, domainName); + count++; + } + } + + // Just walk the linked-list and mark the entry as removed when matched, so next lookup will need to process + // one node less. + AuthoritativeNameServer remove(String nsName) { + AuthoritativeNameServer serverName = head; + + while (serverName != null) { + if (!serverName.removed && serverName.nsName.equalsIgnoreCase(nsName)) { + serverName.removed = true; + return serverName; + } + serverName = serverName.next; + } + return null; + } + + int size() { + return count; + } + } + + static final class AuthoritativeNameServer { + final int dots; + final String nsName; + final String domainName; + + AuthoritativeNameServer next; + boolean removed; + + AuthoritativeNameServer(int dots, String domainName, String nsName) { + this.dots = dots; + this.nsName = nsName; + this.domainName = domainName; + } + + /** + * Returns {@code true} if its a root server. + */ + boolean isRootServer() { + return dots == 1; + } + + /** + * The domain for which the {@link AuthoritativeNameServer} is responsible. + */ + String domainName() { + return domainName; + } + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java index 56c9070a1f..ceb87fe36b 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsServerAddresses.java @@ -40,9 +40,9 @@ public abstract class DnsServerAddresses { private static final List DEFAULT_NAME_SERVER_LIST; private static final InetSocketAddress[] DEFAULT_NAME_SERVER_ARRAY; private static final DnsServerAddresses DEFAULT_NAME_SERVERS; + static final int DNS_PORT = 53; static { - final int DNS_PORT = 53; final List defaultNameServers = new ArrayList(2); try { Class configClass = Class.forName("sun.net.dns.ResolverConfiguration"); diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index d5fd45d79c..693685049d 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -18,8 +18,11 @@ package io.netty.resolver.dns; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufHolder; import io.netty.channel.AddressedEnvelope; +import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; +import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.handler.codec.dns.DefaultDnsQuestion; @@ -29,10 +32,17 @@ import io.netty.handler.codec.dns.DnsResponse; import io.netty.handler.codec.dns.DnsResponseCode; import io.netty.handler.codec.dns.DnsSection; import io.netty.util.NetUtil; +import io.netty.resolver.HostsFileEntriesResolver; import io.netty.util.concurrent.Future; import io.netty.util.internal.StringUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import org.apache.directory.server.dns.messages.DnsMessage; +import org.apache.directory.server.dns.messages.RecordClass; +import org.apache.directory.server.dns.messages.RecordType; +import org.apache.directory.server.dns.messages.ResourceRecord; +import org.apache.directory.server.dns.messages.ResourceRecordModifier; +import org.apache.directory.server.dns.store.DnsAttribute; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -50,12 +60,16 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.hasToString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -599,6 +613,72 @@ public class DnsNameResolverTest { } } + @Test + public void testRecursiveResolveNoCache() throws Exception { + testRecursiveResolveCache(false); + } + + @Test + public void testRecursiveResolveCache() throws Exception { + testRecursiveResolveCache(true); + } + + private static void testRecursiveResolveCache(boolean cache) throws Exception { + final String hostname = "some.record.netty.io"; + final String hostname2 = "some2.record.netty.io"; + + final TestDnsServer dnsServerAuthority = new TestDnsServer(new HashSet( + Arrays.asList(hostname, hostname2))); + dnsServerAuthority.start(); + + TestDnsServer dnsServer = new RedirectingTestDnsServer(hostname, + dnsServerAuthority.localAddress().getAddress().getHostAddress()); + dnsServer.start(); + + TestDnsCache nsCache = new TestDnsCache(cache ? new DefaultDnsCache() : NoopDnsCache.INSTANCE); + EventLoopGroup group = new NioEventLoopGroup(1); + DnsNameResolver resolver = new DnsNameResolver( + group.next(), new ReflectiveChannelFactory(NioDatagramChannel.class), + DnsServerAddresses.singleton(dnsServer.localAddress()), NoopDnsCache.INSTANCE, nsCache, + 3000, new InternetProtocolFamily[] { InternetProtocolFamily.IPv4 }, true, 10, + true, 4096, false, HostsFileEntriesResolver.DEFAULT, + DnsNameResolver.DEFAULT_SEACH_DOMAINS, 0, true) { + @Override + int dnsRedirectPort(InetAddress server) { + return server.equals(dnsServerAuthority.localAddress().getAddress()) ? + dnsServerAuthority.localAddress().getPort() : DnsServerAddresses.DNS_PORT; + } + }; + + try { + resolver.resolveAll(hostname).syncUninterruptibly(); + + if (cache) { + assertNull(nsCache.cache.get("io.", null)); + assertNull(nsCache.cache.get("netty.io.", null)); + List entries = nsCache.cache.get("record.netty.io.", null); + assertEquals(1, entries.size()); + + assertNull(nsCache.cache.get(hostname, null)); + + // Test again via cache. + resolver.resolveAll(hostname).syncUninterruptibly(); + resolver.resolveAll(hostname2).syncUninterruptibly(); + + // Check that it only queried the cache for record.netty.io. + assertNull(nsCache.cacheHits.get("io.")); + assertNull(nsCache.cacheHits.get("netty.io.")); + assertNotNull(nsCache.cacheHits.get("record.netty.io.")); + assertNull(nsCache.cacheHits.get("some.record.netty.io.")); + } + } finally { + resolver.close(); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS); + dnsServer.stop(); + dnsServerAuthority.stop(); + } + } + private static void resolve(DnsNameResolver resolver, Map> futures, String hostname) { futures.put(hostname, resolver.resolve(hostname)); } @@ -609,4 +689,94 @@ public class DnsNameResolverTest { String hostname) throws Exception { futures.put(hostname, resolver.query(new DefaultDnsQuestion(hostname, DnsRecordType.MX))); } + + private static final class TestDnsCache implements DnsCache { + private final DnsCache cache; + final Map> cacheHits = new HashMap>(); + + TestDnsCache(DnsCache cache) { + this.cache = cache; + } + + @Override + public void clear() { + cache.clear(); + } + + @Override + public boolean clear(String hostname) { + return cache.clear(hostname); + } + + @Override + public List get(String hostname, DnsRecord[] additionals) { + List cacheEntries = cache.get(hostname, additionals); + cacheHits.put(hostname, cacheEntries); + return cacheEntries; + } + + @Override + public void cache( + String hostname, DnsRecord[] additionals, InetAddress address, long originalTtl, EventLoop loop) { + cache.cache(hostname, additionals, address, originalTtl, loop); + } + + @Override + public void cache( + String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop) { + cache.cache(hostname, additionals, cause, loop); + } + } + + private static class RedirectingTestDnsServer extends TestDnsServer { + + private final String dnsAddress; + private final String domain; + + RedirectingTestDnsServer(String domain, String dnsAddress) { + super(Collections.singleton(domain)); + this.domain = domain; + this.dnsAddress = dnsAddress; + } + + @Override + protected DnsMessage filterMessage(DnsMessage message) { + // Clear the answers as we want to add our own stuff to test dns redirects. + message.getAnswerRecords().clear(); + + String name = domain; + for (int i = 0 ;; i++) { + int idx = name.indexOf('.'); + if (idx <= 0) { + break; + } + name = name.substring(idx + 1); // skip the '.' as well. + String dnsName = "dns" + idx + '.' + domain; + message.getAuthorityRecords().add(newNsRecord(name, dnsName)); + message.getAdditionalRecords().add(newARecord(dnsName, i == 0 ? dnsAddress : "1.2.3." + idx)); + } + + return message; + } + + private static ResourceRecord newARecord(String dnsname, String ipAddress) { + ResourceRecordModifier rm = new ResourceRecordModifier(); + rm.setDnsClass(RecordClass.IN); + rm.setDnsName(dnsname); + rm.setDnsTtl(100); + rm.setDnsType(RecordType.A); + rm.put(DnsAttribute.IP_ADDRESS, ipAddress); + return rm.getEntry(); + } + + private static ResourceRecord newNsRecord(String dnsname, String domainName) { + ResourceRecordModifier rm = new ResourceRecordModifier(); + rm.setDnsClass(RecordClass.IN); + rm.setDnsName(dnsname); + rm.setDnsTtl(100); + rm.setDnsType(RecordType.NS); + rm.put(DnsAttribute.DOMAIN_NAME, domainName); + return rm.getEntry(); + } + } } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java b/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java index ade7b729ed..e410c00090 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/TestDnsServer.java @@ -54,7 +54,7 @@ import java.util.List; import java.util.Map; import java.util.Set; -final class TestDnsServer extends DnsServer { +class TestDnsServer extends DnsServer { private static final Map BYTES = new HashMap(); private static final String[] IPV6_ADDRESSES; @@ -74,7 +74,7 @@ final class TestDnsServer extends DnsServer { private final RecordStore store; TestDnsServer(Set domains) { - this.store = new TestRecordStore(domains); + this(new TestRecordStore(domains)); } TestDnsServer(RecordStore store) { @@ -83,7 +83,7 @@ final class TestDnsServer extends DnsServer { @Override public void start() throws IOException { - InetSocketAddress address = new InetSocketAddress(NetUtil.LOCALHOST4, 50000); + InetSocketAddress address = new InetSocketAddress(NetUtil.LOCALHOST4, 0); UdpTransport transport = new UdpTransport(address.getHostName(), address.getPort()); setTransports(transport); @@ -108,10 +108,14 @@ final class TestDnsServer extends DnsServer { return (InetSocketAddress) getTransports()[0].getAcceptor().getLocalAddress(); } + protected DnsMessage filterMessage(DnsMessage message) { + return message; + } + /** * {@link ProtocolCodecFactory} which allows to test AAAA resolution. */ - private static final class TestDnsProtocolUdpCodecFactory implements ProtocolCodecFactory { + private final class TestDnsProtocolUdpCodecFactory implements ProtocolCodecFactory { private final DnsMessageEncoder encoder = new DnsMessageEncoder(); private final TestAAAARecordEncoder recordEncoder = new TestAAAARecordEncoder(); @@ -122,7 +126,7 @@ final class TestDnsServer extends DnsServer { @Override public void encode(IoSession session, Object message, ProtocolEncoderOutput out) { IoBuffer buf = IoBuffer.allocate(1024); - DnsMessage dnsMessage = (DnsMessage) message; + DnsMessage dnsMessage = filterMessage((DnsMessage) message); encoder.encode(buf, dnsMessage); for (ResourceRecord record : dnsMessage.getAnswerRecords()) { // This is a hack to allow to also test for AAAA resolution as DnsMessageEncoder @@ -150,7 +154,7 @@ final class TestDnsServer extends DnsServer { return new DnsUdpDecoder(); } - private static final class TestAAAARecordEncoder extends ResourceRecordEncoder { + private final class TestAAAARecordEncoder extends ResourceRecordEncoder { @Override protected void putResourceRecordData(IoBuffer ioBuffer, ResourceRecord resourceRecord) {