Respect resolvedAddressTypes when follow CNAME records.

Motivation:

When we follow CNAME records we should respect resolvedAddressTypes and only query A / AAAA depending on which address types are expected.

Modifications:

Check if we should query A / AAAA when follow CNAMEs depending on resolvedAddressTypes.

Result:

Correct behaviour when follow CNAMEs.
This commit is contained in:
Norman Maurer 2017-01-19 13:18:19 +01:00
parent 71efb0b39e
commit ead87b7df8
2 changed files with 59 additions and 16 deletions

View File

@ -35,6 +35,7 @@ import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DatagramDnsResponseDecoder; import io.netty.handler.codec.dns.DatagramDnsResponseDecoder;
import io.netty.handler.codec.dns.DnsQuestion; import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse; import io.netty.handler.codec.dns.DnsResponse;
import io.netty.resolver.HostsFileEntriesResolver; import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.InetNameResolver; import io.netty.resolver.InetNameResolver;
@ -59,7 +60,9 @@ import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import static io.netty.util.internal.ObjectUtil.*; import static io.netty.util.internal.ObjectUtil.*;
@ -148,6 +151,10 @@ public class DnsNameResolver extends InetNameResolver {
private final HostsFileEntriesResolver hostsFileEntriesResolver; private final HostsFileEntriesResolver hostsFileEntriesResolver;
private final String[] searchDomains; private final String[] searchDomains;
private final int ndots; private final int ndots;
private final boolean cnameFollowARecords;
private final boolean cnameFollowAAAARecords;
private final InternetProtocolFamily preferredAddressType;
private final DnsRecordType[] resolveRecordTypes;
/** /**
* Creates a new DNS-based name resolver that communicates with the specified list of DNS servers. * Creates a new DNS-based name resolver that communicates with the specified list of DNS servers.
@ -200,6 +207,32 @@ public class DnsNameResolver extends InetNameResolver {
this.searchDomains = checkNotNull(searchDomains, "searchDomains").clone(); this.searchDomains = checkNotNull(searchDomains, "searchDomains").clone();
this.ndots = checkPositiveOrZero(ndots, "ndots"); this.ndots = checkPositiveOrZero(ndots, "ndots");
boolean cnameFollowARecords = false;
boolean cnameFollowAAAARecords = false;
// Use LinkedHashSet to maintain correct ordering.
Set<DnsRecordType> recordTypes = new LinkedHashSet<DnsRecordType>(resolvedAddressTypes.length);
for (InternetProtocolFamily family: resolvedAddressTypes) {
switch (family) {
case IPv4:
cnameFollowARecords = true;
recordTypes.add(DnsRecordType.A);
break;
case IPv6:
cnameFollowAAAARecords = true;
recordTypes.add(DnsRecordType.AAAA);
break;
default:
throw new Error();
}
}
// One of both must be always true.
assert cnameFollowARecords || cnameFollowAAAARecords;
this.cnameFollowAAAARecords = cnameFollowAAAARecords;
this.cnameFollowARecords = cnameFollowARecords;
resolveRecordTypes = recordTypes.toArray(new DnsRecordType[recordTypes.size()]);
preferredAddressType = resolvedAddressTypes[0];
Bootstrap b = new Bootstrap(); Bootstrap b = new Bootstrap();
b.group(executor()); b.group(executor());
b.channelFactory(channelFactory); b.channelFactory(channelFactory);
@ -260,6 +293,22 @@ public class DnsNameResolver extends InetNameResolver {
return ndots; return ndots;
} }
final boolean isCnameFollowAAAARecords() {
return cnameFollowAAAARecords;
}
final boolean isCnameFollowARecords() {
return cnameFollowARecords;
}
final InternetProtocolFamily preferredAddressType() {
return preferredAddressType;
}
final DnsRecordType[] resolveRecordTypes() {
return resolveRecordTypes;
}
/** /**
* Returns {@code true} if and only if this resolver sends a DNS query with the RD (recursion desired) flag set. * Returns {@code true} if and only if this resolver sends a DNS query with the RD (recursion desired) flag set.
* The default value is {@code true}. * The default value is {@code true}.

View File

@ -145,19 +145,7 @@ abstract class DnsNameResolverContext<T> {
private void internalResolve(Promise<T> promise) { private void internalResolve(Promise<T> promise) {
InetSocketAddress nameServerAddrToTry = nameServerAddrs.next(); InetSocketAddress nameServerAddrToTry = nameServerAddrs.next();
for (InternetProtocolFamily f: resolveAddressTypes) { for (DnsRecordType type: parent.resolveRecordTypes()) {
final DnsRecordType type;
switch (f) {
case IPv4:
type = DnsRecordType.A;
break;
case IPv6:
type = DnsRecordType.AAAA;
break;
default:
throw new Error();
}
query(nameServerAddrToTry, new DefaultDnsQuestion(hostname, type), promise); query(nameServerAddrToTry, new DefaultDnsQuestion(hostname, type), promise);
} }
} }
@ -409,7 +397,7 @@ abstract class DnsNameResolverContext<T> {
} }
final int size = resolvedEntries.size(); final int size = resolvedEntries.size();
switch (resolveAddressTypes[0]) { switch (parent.preferredAddressType()) {
case IPv4: case IPv4:
for (int i = 0; i < size; i ++) { for (int i = 0; i < size; i ++) {
if (resolvedEntries.get(i).address() instanceof Inet4Address) { if (resolvedEntries.get(i).address() instanceof Inet4Address) {
@ -424,6 +412,8 @@ abstract class DnsNameResolverContext<T> {
} }
} }
break; break;
default:
throw new Error();
} }
return false; return false;
@ -519,8 +509,12 @@ abstract class DnsNameResolverContext<T> {
} }
final InetSocketAddress nextAddr = nameServerAddrs.next(); final InetSocketAddress nextAddr = nameServerAddrs.next();
query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.A), promise); if (parent.isCnameFollowARecords()) {
query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.AAAA), promise); query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.A), promise);
}
if (parent.isCnameFollowAAAARecords()) {
query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.AAAA), promise);
}
} }
private void addTrace(InetSocketAddress nameServerAddr, String msg) { private void addTrace(InetSocketAddress nameServerAddr, String msg) {