Add DnsNameResolver.resolveAll(DnsQuestion) (#7803)

* Add DnsNameResolver.resolveAll(DnsQuestion)

Motivation:

A user is currently expected to use DnsNameResolver.query() when he or
she wants to look up the full DNS records rather than just InetAddres.

However, query() only performs a single query. It does not handle
/etc/hosts file, redirection, CNAMEs or multiple name servers.

As a result, such a user has to duplicate all the logic in
DnsNameResolverContext.

Modifications:

- Refactor DnsNameResolverContext so that it can send queries for
  arbitrary record types.
  - Rename DnsNameResolverContext to DnsResolveContext
  - Add DnsAddressResolveContext which extends DnsResolveContext for
    A/AAAA lookup
  - Add DnsRecordResolveContext which extends DnsResolveContext for
    arbitrary lookup
- Add DnsNameResolverContext.resolveAll(DnsQuestion) and its variants
- Change DnsNameResolverContext.resolve() delegates the resolve request
  to resolveAll() for simplicity
- Move the code that decodes A/AAAA record content to DnsAddressDecoder

Result:

- Fixes #7795
- A user does not have to duplicate DnsNameResolverContext in his or her
  own code to implement the usual DNS resolver behavior.
This commit is contained in:
Trustin Lee 2018-03-30 05:01:25 +09:00 committed by Norman Maurer
parent 965734a1eb
commit cd4594d292
6 changed files with 536 additions and 209 deletions

View File

@ -0,0 +1,67 @@
/*
* Copyright 2018 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import java.net.IDN;
import java.net.InetAddress;
import java.net.UnknownHostException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
/**
* Decodes an {@link InetAddress} from an A or AAAA {@link DnsRawRecord}.
*/
final class DnsAddressDecoder {
private static final int INADDRSZ4 = 4;
private static final int INADDRSZ6 = 16;
/**
* Decodes an {@link InetAddress} from an A or AAAA {@link DnsRawRecord}.
*
* @param record the {@link DnsRecord}, most likely a {@link DnsRawRecord}
* @param name the host name of the decoded address
* @param decodeIdn whether to convert {@code name} to a unicode host name
*
* @return the {@link InetAddress}, or {@code null} if {@code record} is not a {@link DnsRawRecord} or
* its content is malformed
*/
static InetAddress decodeAddress(DnsRecord record, String name, boolean decodeIdn) {
if (!(record instanceof DnsRawRecord)) {
return null;
}
final ByteBuf content = ((ByteBufHolder) record).content();
final int contentLen = content.readableBytes();
if (contentLen != INADDRSZ4 && contentLen != INADDRSZ6) {
return null;
}
final byte[] addrBytes = new byte[contentLen];
content.getBytes(content.readerIndex(), addrBytes);
try {
return InetAddress.getByAddress(decodeIdn ? IDN.toUnicode(name) : name, addrBytes);
} catch (UnknownHostException e) {
// Should never reach here.
throw new Error(e);
}
}
private DnsAddressDecoder() { }
}

View File

@ -0,0 +1,74 @@
/*
* Copyright 2018 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import static io.netty.resolver.dns.DnsAddressDecoder.decodeAddress;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.List;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
final class DnsAddressResolveContext extends DnsResolveContext<InetAddress> {
private final DnsCache resolveCache;
DnsAddressResolveContext(DnsNameResolver parent, String hostname, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs, DnsCache resolveCache) {
super(parent, hostname, DnsRecord.CLASS_IN, parent.resolveRecordTypes(), additionals, nameServerAddrs);
this.resolveCache = resolveCache;
}
@Override
DnsResolveContext<InetAddress> newResolverContext(DnsNameResolver parent, String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
return new DnsAddressResolveContext(parent, hostname, additionals, nameServerAddrs, resolveCache);
}
@Override
InetAddress convertRecord(DnsRecord record, String hostname, DnsRecord[] additionals, EventLoop eventLoop) {
return decodeAddress(record, hostname, parent.isDecodeIdn());
}
@Override
boolean containsExpectedResult(List<InetAddress> finalResult) {
final int size = finalResult.size();
final Class<? extends InetAddress> inetAddressType = parent.preferredAddressType().addressType();
for (int i = 0; i < size; i++) {
InetAddress address = finalResult.get(i);
if (inetAddressType.isInstance(address)) {
return true;
}
}
return false;
}
@Override
void cache(String hostname, DnsRecord[] additionals,
DnsRecord result, InetAddress convertedResult) {
resolveCache.cache(hostname, additionals, convertedResult, result.timeToLive(), parent.ch.eventLoop());
}
@Override
void cache(String hostname, DnsRecord[] additionals, UnknownHostException cause) {
resolveCache.cache(hostname, additionals, cause, parent.ch.eventLoop());
}
}

View File

@ -16,6 +16,8 @@
package io.netty.resolver.dns;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFactory;
@ -33,11 +35,13 @@ import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.handler.codec.dns.DatagramDnsQueryEncoder;
import io.netty.handler.codec.dns.DatagramDnsResponse;
import io.netty.handler.codec.dns.DatagramDnsResponseDecoder;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.resolver.HostsFileEntries;
import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.InetNameResolver;
import io.netty.resolver.ResolvedAddressTypes;
@ -45,6 +49,7 @@ import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.FastThreadLocal;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PlatformDependent;
@ -55,6 +60,8 @@ import io.netty.util.internal.logging.InternalLoggerFactory;
import java.lang.reflect.Method;
import java.net.IDN;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.ArrayList;
@ -527,6 +534,96 @@ public class DnsNameResolver extends InetNameResolver {
doResolve(inetHost, EMPTY_ADDITIONALS, promise, resolveCache);
}
/**
* Resolves the {@link DnsRecord}s that are matched by the specified {@link DnsQuestion}. Unlike
* {@link #query(DnsQuestion)}, this method handles redirection, CNAMEs and multiple name servers.
* If the specified {@link DnsQuestion} is {@code A} or {@code AAAA}, this method looks up the configured
* {@link HostsFileEntries} before sending a query to the name servers. If a match is found in the
* {@link HostsFileEntries}, a synthetic {@code A} or {@code AAAA} record will be returned.
*
* @param question the question
*
* @return the list of the {@link DnsRecord}s as the result of the resolution
*/
public final Future<List<DnsRecord>> resolveAll(DnsQuestion question) {
return resolveAll(question, EMPTY_ADDITIONALS, executor().<List<DnsRecord>>newPromise());
}
/**
* Resolves the {@link DnsRecord}s that are matched by the specified {@link DnsQuestion}. Unlike
* {@link #query(DnsQuestion)}, this method handles redirection, CNAMEs and multiple name servers.
* If the specified {@link DnsQuestion} is {@code A} or {@code AAAA}, this method looks up the configured
* {@link HostsFileEntries} before sending a query to the name servers. If a match is found in the
* {@link HostsFileEntries}, a synthetic {@code A} or {@code AAAA} record will be returned.
*
* @param question the question
* @param additionals additional records ({@code OPT})
*
* @return the list of the {@link DnsRecord}s as the result of the resolution
*/
public final Future<List<DnsRecord>> resolveAll(DnsQuestion question, Iterable<DnsRecord> additionals) {
return resolveAll(question, additionals, executor().<List<DnsRecord>>newPromise());
}
/**
* Resolves the {@link DnsRecord}s that are matched by the specified {@link DnsQuestion}. Unlike
* {@link #query(DnsQuestion)}, this method handles redirection, CNAMEs and multiple name servers.
* If the specified {@link DnsQuestion} is {@code A} or {@code AAAA}, this method looks up the configured
* {@link HostsFileEntries} before sending a query to the name servers. If a match is found in the
* {@link HostsFileEntries}, a synthetic {@code A} or {@code AAAA} record will be returned.
*
* @param question the question
* @param additionals additional records ({@code OPT})
* @param promise the {@link Promise} which will be fulfilled when the resolution is finished
*
* @return the list of the {@link DnsRecord}s as the result of the resolution
*/
public final Future<List<DnsRecord>> resolveAll(DnsQuestion question, Iterable<DnsRecord> additionals,
Promise<List<DnsRecord>> promise) {
final DnsRecord[] additionalsArray = toArray(additionals, true);
return resolveAll(question, additionalsArray, promise);
}
private Future<List<DnsRecord>> resolveAll(DnsQuestion question, DnsRecord[] additionals,
Promise<List<DnsRecord>> promise) {
checkNotNull(question, "question");
checkNotNull(promise, "promise");
// Respect /etc/hosts as well if the record type is A or AAAA.
final DnsRecordType type = question.type();
final String hostname = question.name();
if (type == DnsRecordType.A || type == DnsRecordType.AAAA) {
final InetAddress hostsFileEntry = resolveHostsFileEntry(hostname);
if (hostsFileEntry != null) {
ByteBuf content = null;
if (hostsFileEntry instanceof Inet4Address) {
if (type == DnsRecordType.A) {
content = Unpooled.wrappedBuffer(hostsFileEntry.getAddress());
}
} else if (hostsFileEntry instanceof Inet6Address) {
if (type == DnsRecordType.AAAA) {
content = Unpooled.wrappedBuffer(hostsFileEntry.getAddress());
}
}
if (content != null) {
// Our current implementation does not support reloading the hosts file,
// so use a fairly large TTL (1 day, i.e. 86400 seconds).
trySuccess(promise, Collections.<DnsRecord>singletonList(
new DefaultDnsRawRecord(hostname, type, 86400, content)));
return promise;
}
}
}
// It was not A/AAAA question or there was no entry in /etc/hosts.
final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsRecordResolveContext(this, question, additionals, nameServerAddrs).resolve(promise);
return promise;
}
private static DnsRecord[] toArray(Iterable<DnsRecord> additionals, boolean validateType) {
checkNotNull(additionals, "additionals");
if (additionals instanceof Collection) {
@ -624,7 +721,7 @@ public class DnsNameResolver extends InetNameResolver {
}
}
private static <T> void trySuccess(Promise<T> promise, T result) {
static <T> void trySuccess(Promise<T> promise, T result) {
if (!promise.trySuccess(result)) {
logger.warn("Failed to notify success ({}) to a promise: {}", result, promise);
}
@ -638,40 +735,20 @@ public class DnsNameResolver extends InetNameResolver {
private void doResolveUncached(String hostname,
DnsRecord[] additionals,
Promise<InetAddress> promise,
final Promise<InetAddress> promise,
DnsCache resolveCache) {
new SingleResolverContext(this, hostname, additionals, resolveCache,
dnsServerAddressStreamProvider.nameServerAddressStream(hostname)).resolve(promise);
}
static final class SingleResolverContext extends DnsNameResolverContext<InetAddress> {
SingleResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache, DnsServerAddressStream nameServerAddrs) {
super(parent, hostname, additionals, resolveCache, nameServerAddrs);
}
@Override
DnsNameResolverContext<InetAddress> newResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache,
DnsServerAddressStream nameServerAddrs) {
return new SingleResolverContext(parent, hostname, additionals, resolveCache, nameServerAddrs);
}
@Override
boolean finishResolve(
Class<? extends InetAddress> addressType, List<DnsCacheEntry> resolvedEntries,
Promise<InetAddress> promise) {
final int numEntries = resolvedEntries.size();
for (int i = 0; i < numEntries; i++) {
final InetAddress a = resolvedEntries.get(i).address();
if (addressType.isInstance(a)) {
trySuccess(promise, a);
return true;
final Promise<List<InetAddress>> allPromise = executor().newPromise();
doResolveAllUncached(hostname, additionals, allPromise, resolveCache);
allPromise.addListener(new FutureListener<List<InetAddress>>() {
@Override
public void operationComplete(Future<List<InetAddress>> future) {
if (future.isSuccess()) {
trySuccess(promise, future.getNow().get(0));
} else {
tryFailure(promise, future.cause());
}
}
return false;
}
});
}
@Override
@ -747,50 +824,13 @@ public class DnsNameResolver extends InetNameResolver {
}
}
static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {
ListResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache, DnsServerAddressStream nameServerAddrs) {
super(parent, hostname, additionals, resolveCache, nameServerAddrs);
}
@Override
DnsNameResolverContext<List<InetAddress>> newResolverContext(
DnsNameResolver parent, String hostname, DnsRecord[] additionals, DnsCache resolveCache,
DnsServerAddressStream nameServerAddrs) {
return new ListResolverContext(parent, hostname, additionals, resolveCache, nameServerAddrs);
}
@Override
boolean finishResolve(
Class<? extends InetAddress> addressType, List<DnsCacheEntry> resolvedEntries,
Promise<List<InetAddress>> promise) {
List<InetAddress> result = null;
final int numEntries = resolvedEntries.size();
for (int i = 0; i < numEntries; i++) {
final InetAddress a = resolvedEntries.get(i).address();
if (addressType.isInstance(a)) {
if (result == null) {
result = new ArrayList<InetAddress>(numEntries);
}
result.add(a);
}
}
if (result != null) {
promise.trySuccess(result);
return true;
}
return false;
}
}
private void doResolveAllUncached(String hostname,
DnsRecord[] additionals,
Promise<List<InetAddress>> promise,
DnsCache resolveCache) {
new ListResolverContext(this, hostname, additionals, resolveCache,
dnsServerAddressStreamProvider.nameServerAddressStream(hostname)).resolve(promise);
final DnsServerAddressStream nameServerAddrs =
dnsServerAddressStreamProvider.nameServerAddressStream(hostname);
new DnsAddressResolveContext(this, hostname, additionals, nameServerAddrs, resolveCache).resolve(promise);
}
private static String hostname(String inetHost) {

View File

@ -0,0 +1,75 @@
/*
* Copyright 2018 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import static io.netty.resolver.dns.DnsAddressDecoder.decodeAddress;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.List;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.util.ReferenceCountUtil;
final class DnsRecordResolveContext extends DnsResolveContext<DnsRecord> {
DnsRecordResolveContext(DnsNameResolver parent, DnsQuestion question, DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
this(parent, question.name(), question.dnsClass(),
new DnsRecordType[] { question.type() },
additionals, nameServerAddrs);
}
private DnsRecordResolveContext(DnsNameResolver parent, String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
super(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs);
}
@Override
DnsResolveContext<DnsRecord> newResolverContext(DnsNameResolver parent, String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs) {
return new DnsRecordResolveContext(parent, hostname, dnsClass, expectedTypes, additionals, nameServerAddrs);
}
@Override
DnsRecord convertRecord(DnsRecord record, String hostname, DnsRecord[] additionals, EventLoop eventLoop) {
return ReferenceCountUtil.retain(record);
}
@Override
boolean containsExpectedResult(List<DnsRecord> finalResult) {
return true;
}
@Override
void cache(String hostname, DnsRecord[] additionals, DnsRecord result, DnsRecord convertedResult) {
// Do not cache.
// XXX: When we implement cache, we would need to retain the reference count of the result record.
}
@Override
void cache(String hostname, DnsRecord[] additionals, UnknownHostException cause) {
// Do not cache.
// XXX: When we implement cache, we would need to retain the reference count of the result record.
}
}

View File

@ -20,7 +20,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
@ -40,7 +40,6 @@ import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.ThrowableUtil;
import java.net.IDN;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
@ -54,13 +53,12 @@ import java.util.Locale;
import java.util.Map;
import java.util.Set;
import static io.netty.resolver.dns.DnsAddressDecoder.decodeAddress;
import static io.netty.resolver.dns.DnsNameResolver.trySuccess;
import static java.lang.Math.min;
import static java.util.Collections.unmodifiableList;
abstract class DnsNameResolverContext<T> {
private static final int INADDRSZ4 = 4;
private static final int INADDRSZ6 = 16;
abstract class DnsResolveContext<T> {
private static final FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>> RELEASE_RESPONSE =
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@ -73,58 +71,90 @@ abstract class DnsNameResolverContext<T> {
};
private static final RuntimeException NXDOMAIN_QUERY_FAILED_EXCEPTION = ThrowableUtil.unknownStackTrace(
new RuntimeException("No answer found and NXDOMAIN response code returned"),
DnsNameResolverContext.class,
DnsResolveContext.class,
"onResponse(..)");
private static final RuntimeException CNAME_NOT_FOUND_QUERY_FAILED_EXCEPTION = ThrowableUtil.unknownStackTrace(
new RuntimeException("No matching CNAME record found"),
DnsNameResolverContext.class,
DnsResolveContext.class,
"onResponseCNAME(..)");
private static final RuntimeException NO_MATCHING_RECORD_QUERY_FAILED_EXCEPTION = ThrowableUtil.unknownStackTrace(
new RuntimeException("No matching record type found"),
DnsNameResolverContext.class,
DnsResolveContext.class,
"onResponseAorAAAA(..)");
private static final RuntimeException UNRECOGNIZED_TYPE_QUERY_FAILED_EXCEPTION = ThrowableUtil.unknownStackTrace(
new RuntimeException("Response type was unrecognized"),
DnsNameResolverContext.class,
DnsResolveContext.class,
"onResponse(..)");
private static final RuntimeException NAME_SERVERS_EXHAUSTED_EXCEPTION = ThrowableUtil.unknownStackTrace(
new RuntimeException("No name servers returned an answer"),
DnsNameResolverContext.class,
DnsResolveContext.class,
"tryToFinishResolve(..)");
private final DnsNameResolver parent;
final DnsNameResolver parent;
private final DnsServerAddressStream nameServerAddrs;
private final String hostname;
private final DnsCache resolveCache;
private final int dnsClass;
private final DnsRecordType[] expectedTypes;
private final int maxAllowedQueries;
private final InternetProtocolFamily[] resolvedInternetProtocolFamilies;
private final DnsRecord[] additionals;
private final Set<Future<AddressedEnvelope<DnsResponse, InetSocketAddress>>> queriesInProgress =
Collections.newSetFromMap(
new IdentityHashMap<Future<AddressedEnvelope<DnsResponse, InetSocketAddress>>, Boolean>());
private List<DnsCacheEntry> resolvedEntries;
private List<T> finalResult;
private int allowedQueries;
private boolean triedCNAME;
DnsNameResolverContext(DnsNameResolver parent,
String hostname,
DnsRecord[] additionals,
DnsCache resolveCache,
DnsServerAddressStream nameServerAddrs) {
DnsResolveContext(DnsNameResolver parent,
String hostname, int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals, DnsServerAddressStream nameServerAddrs) {
assert expectedTypes.length > 0;
this.parent = parent;
this.hostname = hostname;
this.dnsClass = dnsClass;
this.expectedTypes = expectedTypes;
this.additionals = additionals;
this.resolveCache = resolveCache;
this.nameServerAddrs = ObjectUtil.checkNotNull(nameServerAddrs, "nameServerAddrs");
maxAllowedQueries = parent.maxQueriesPerResolve();
resolvedInternetProtocolFamilies = parent.resolvedInternetProtocolFamiliesUnsafe();
allowedQueries = maxAllowedQueries;
}
void resolve(final Promise<T> promise) {
/**
* Creates a new context with the given parameters.
*/
abstract DnsResolveContext<T> newResolverContext(DnsNameResolver parent, String hostname,
int dnsClass, DnsRecordType[] expectedTypes,
DnsRecord[] additionals,
DnsServerAddressStream nameServerAddrs);
/**
* Converts the given {@link DnsRecord} into {@code T}.
*/
abstract T convertRecord(DnsRecord record, String hostname, DnsRecord[] additionals, EventLoop eventLoop);
/**
* Returns {@code true} if the given list contains any expected records. {@code finalResult} always contains
* at least one element.
*/
abstract boolean containsExpectedResult(List<T> finalResult);
/**
* Caches a successful resolution.
*/
abstract void cache(String hostname, DnsRecord[] additionals,
DnsRecord result, T convertedResult);
/**
* Caches a failed resolution.
*/
abstract void cache(String hostname, DnsRecord[] additionals,
UnknownHostException cause);
void resolve(final Promise<List<T>> promise) {
final String[] searchDomains = parent.searchDomains();
if (searchDomains.length == 0 || parent.ndots() == 0 || StringUtil.endsWith(hostname, '.')) {
internalResolve(promise);
@ -133,10 +163,10 @@ abstract class DnsNameResolverContext<T> {
final String initialHostname = startWithoutSearchDomain ? hostname : hostname + '.' + searchDomains[0];
final int initialSearchDomainIdx = startWithoutSearchDomain ? 0 : 1;
doSearchDomainQuery(initialHostname, new FutureListener<T>() {
doSearchDomainQuery(initialHostname, new FutureListener<List<T>>() {
private int searchDomainIdx = initialSearchDomainIdx;
@Override
public void operationComplete(Future<T> future) throws Exception {
public void operationComplete(Future<List<T>> future) throws Exception {
Throwable cause = future.cause();
if (cause == null) {
promise.trySuccess(future.getNow());
@ -182,26 +212,24 @@ abstract class DnsNameResolverContext<T> {
}
}
private void doSearchDomainQuery(String hostname, FutureListener<T> listener) {
DnsNameResolverContext<T> nextContext = newResolverContext(parent, hostname, additionals, resolveCache,
nameServerAddrs);
Promise<T> nextPromise = parent.executor().newPromise();
private void doSearchDomainQuery(String hostname, FutureListener<List<T>> listener) {
DnsResolveContext<T> nextContext = newResolverContext(parent, hostname, dnsClass, expectedTypes,
additionals, nameServerAddrs);
Promise<List<T>> nextPromise = parent.executor().newPromise();
nextContext.internalResolve(nextPromise);
nextPromise.addListener(listener);
}
private void internalResolve(Promise<T> promise) {
private void internalResolve(Promise<List<T>> promise) {
DnsServerAddressStream nameServerAddressStream = getNameServers(hostname);
DnsRecordType[] recordTypes = parent.resolveRecordTypes();
assert recordTypes.length > 0;
final int end = recordTypes.length - 1;
final int end = expectedTypes.length - 1;
for (int i = 0; i < end; ++i) {
if (!query(hostname, recordTypes[i], nameServerAddressStream.duplicate(), promise, null)) {
if (!query(hostname, expectedTypes[i], nameServerAddressStream.duplicate(), promise, null)) {
return;
}
}
query(hostname, recordTypes[end], nameServerAddressStream, promise, null);
query(hostname, expectedTypes[end], nameServerAddressStream, promise, null);
}
/**
@ -291,7 +319,7 @@ abstract class DnsNameResolverContext<T> {
private void query(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex,
final DnsQuestion question,
final Promise<T> promise, Throwable cause) {
final Promise<List<T>> promise, Throwable cause) {
query(nameServerAddrStream, nameServerAddrStreamIndex, question,
parent.dnsQueryLifecycleObserverFactory().newDnsQueryLifecycleObserver(question), promise, cause);
}
@ -300,7 +328,7 @@ abstract class DnsNameResolverContext<T> {
final int nameServerAddrStreamIndex,
final DnsQuestion question,
final DnsQueryLifecycleObserver queryLifecycleObserver,
final Promise<T> promise,
final Promise<List<T>> promise,
final Throwable cause) {
if (nameServerAddrStreamIndex >= nameServerAddrStream.size() || allowedQueries == 0 || promise.isCancelled()) {
tryToFinishResolve(nameServerAddrStream, nameServerAddrStreamIndex, question, queryLifecycleObserver,
@ -359,7 +387,7 @@ abstract class DnsNameResolverContext<T> {
void onResponse(final DnsServerAddressStream nameServerAddrStream, final int nameServerAddrStreamIndex,
final DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> envelope,
final DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<T> promise) {
Promise<List<T>> promise) {
try {
final DnsResponse res = envelope.content();
final DnsResponseCode code = res.code();
@ -370,13 +398,19 @@ abstract class DnsNameResolverContext<T> {
}
final DnsRecordType type = question.type();
if (type == DnsRecordType.A || type == DnsRecordType.AAAA) {
onResponseAorAAAA(type, question, envelope, queryLifecycleObserver, promise);
} else if (type == DnsRecordType.CNAME) {
if (type == DnsRecordType.CNAME) {
onResponseCNAME(question, envelope, queryLifecycleObserver, promise);
} else {
queryLifecycleObserver.queryFailed(UNRECOGNIZED_TYPE_QUERY_FAILED_EXCEPTION);
return;
}
for (DnsRecordType expectedType : expectedTypes) {
if (type == expectedType) {
onExpectedResponse(question, envelope, queryLifecycleObserver, promise);
return;
}
}
queryLifecycleObserver.queryFailed(UNRECOGNIZED_TYPE_QUERY_FAILED_EXCEPTION);
return;
}
@ -397,7 +431,7 @@ abstract class DnsNameResolverContext<T> {
*/
private boolean handleRedirect(
DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> envelope,
final DnsQueryLifecycleObserver queryLifecycleObserver, Promise<T> promise) {
final DnsQueryLifecycleObserver queryLifecycleObserver, Promise<List<T>> promise) {
final DnsResponse res = envelope.content();
// Check if we have answers, if not this may be an non authority NS and so redirects must be handled.
@ -417,15 +451,14 @@ abstract class DnsNameResolverContext<T> {
}
final String recordName = r.name();
AuthoritativeNameServer authoritativeNameServer =
serverNames.remove(recordName);
final AuthoritativeNameServer authoritativeNameServer = serverNames.remove(recordName);
if (authoritativeNameServer == null) {
// Not a server we are interested in.
continue;
}
InetAddress resolved = parseAddress(r, recordName);
InetAddress resolved = decodeAddress(r, recordName, parent.isDecodeIdn());
if (resolved == null) {
// Could not parse it, move to the next.
continue;
@ -462,10 +495,9 @@ abstract class DnsNameResolverContext<T> {
return serverNames;
}
private void onResponseAorAAAA(
DnsRecordType qType, DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> envelope,
final DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<T> promise) {
private void onExpectedResponse(
DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> envelope,
final DnsQueryLifecycleObserver queryLifecycleObserver, Promise<List<T>> promise) {
// We often get a bunch of CNAMES as well when we asked for A/AAAA.
final DnsResponse response = envelope.content();
@ -476,7 +508,15 @@ abstract class DnsNameResolverContext<T> {
for (int i = 0; i < answerCount; i ++) {
final DnsRecord r = response.recordAt(DnsSection.ANSWER, i);
final DnsRecordType type = r.type();
if (type != DnsRecordType.A && type != DnsRecordType.AAAA) {
boolean matches = false;
for (DnsRecordType expectedType : expectedTypes) {
if (type == expectedType) {
matches = true;
break;
}
}
if (!matches) {
continue;
}
@ -499,17 +539,17 @@ abstract class DnsNameResolverContext<T> {
}
}
InetAddress resolved = parseAddress(r, hostname);
if (resolved == null) {
final T converted = convertRecord(r, hostname, additionals, parent.ch.eventLoop());
if (converted == null) {
continue;
}
if (resolvedEntries == null) {
resolvedEntries = new ArrayList<DnsCacheEntry>(8);
if (finalResult == null) {
finalResult = new ArrayList<T>(8);
}
finalResult.add(converted);
resolvedEntries.add(
resolveCache.cache(hostname, additionals, resolved, r.timeToLive(), parent.ch.eventLoop()));
cache(hostname, additionals, r, converted);
found = true;
// Note that we do not break from the loop here, so we decode/cache all A/AAAA records.
@ -523,47 +563,23 @@ abstract class DnsNameResolverContext<T> {
if (cnames.isEmpty()) {
queryLifecycleObserver.queryFailed(NO_MATCHING_RECORD_QUERY_FAILED_EXCEPTION);
} else {
// We asked for A/AAAA but we got only CNAME.
onResponseCNAME(question, envelope, cnames, queryLifecycleObserver, promise);
}
}
private InetAddress parseAddress(DnsRecord r, String name) {
if (!(r instanceof DnsRawRecord)) {
return null;
}
final ByteBuf content = ((ByteBufHolder) r).content();
final int contentLen = content.readableBytes();
if (contentLen != INADDRSZ4 && contentLen != INADDRSZ6) {
return null;
}
final byte[] addrBytes = new byte[contentLen];
content.getBytes(content.readerIndex(), addrBytes);
try {
return InetAddress.getByAddress(
parent.isDecodeIdn() ? IDN.toUnicode(name) : name, addrBytes);
} catch (UnknownHostException e) {
// Should never reach here.
throw new Error(e);
// We got only CNAME, not one of expectedTypes.
onResponseCNAME(question, cnames, queryLifecycleObserver, promise);
}
}
private void onResponseCNAME(DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> envelope,
final DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<T> promise) {
onResponseCNAME(question, envelope, buildAliasMap(envelope.content()), queryLifecycleObserver, promise);
final DnsQueryLifecycleObserver queryLifecycleObserver, Promise<List<T>> promise) {
onResponseCNAME(question, buildAliasMap(envelope.content()), queryLifecycleObserver, promise);
}
private void onResponseCNAME(
DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> response,
Map<String, String> cnames, final DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<T> promise) {
DnsQuestion question, Map<String, String> cnames,
final DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<List<T>> promise) {
// Resolve the host name in the question into the real host name.
final String name = question.name().toLowerCase(Locale.US);
String resolved = name;
String resolved = question.name().toLowerCase(Locale.US);
boolean found = false;
while (!cnames.isEmpty()) { // Do not attempt to call Map.remove() when the Map is empty
// because it can be Collections.emptyMap()
@ -618,24 +634,24 @@ abstract class DnsNameResolverContext<T> {
final int nameServerAddrStreamIndex,
final DnsQuestion question,
final DnsQueryLifecycleObserver queryLifecycleObserver,
final Promise<T> promise,
final Promise<List<T>> promise,
final Throwable cause) {
// There are no queries left to try.
if (!queriesInProgress.isEmpty()) {
queryLifecycleObserver.queryCancelled(allowedQueries);
// There are still some queries we did not receive responses for.
if (gotPreferredAddress()) {
// But it's OK to finish the resolution process if we got a resolved address of the preferred type.
if (finalResult != null && containsExpectedResult(finalResult)) {
// But it's OK to finish the resolution process if we got something expected.
finishResolve(promise, cause);
}
// We did not get any resolved address of the preferred type, so we can't finish the resolution process.
// We did not get an expected result yet, so we can't finish the resolution process.
return;
}
// There are no queries left to try.
if (resolvedEntries == null) {
if (finalResult == null) {
if (nameServerAddrStreamIndex < nameServerAddrStream.size()) {
if (queryLifecycleObserver == NoopDnsQueryLifecycleObserver.INSTANCE) {
// If the queryLifecycleObserver has already been terminated we should create a new one for this
@ -650,11 +666,11 @@ abstract class DnsNameResolverContext<T> {
queryLifecycleObserver.queryFailed(NAME_SERVERS_EXHAUSTED_EXCEPTION);
// .. and we could not find any A/AAAA records.
// .. and we could not find any expected records.
// If cause != null we know this was caused by a timeout / cancel / transport exception. In this case we
// won't try to resolve the CNAME as we only should do this if we could not get the A/AAAA records because
// these not exists and the DNS server did probably signal it.
// won't try to resolve the CNAME as we only should do this if we could not get the expected records
// because they do not exist and the DNS server did probably signal it.
if (cause == null && !triedCNAME) {
// As the last resort, try to query CNAME, just in case the name server has it.
triedCNAME = true;
@ -666,27 +682,11 @@ abstract class DnsNameResolverContext<T> {
queryLifecycleObserver.queryCancelled(allowedQueries);
}
// We have at least one resolved address or tried CNAME as the last resort..
// We have at least one resolved record or tried CNAME as the last resort..
finishResolve(promise, cause);
}
private boolean gotPreferredAddress() {
if (resolvedEntries == null) {
return false;
}
final int size = resolvedEntries.size();
final Class<? extends InetAddress> inetAddressType = parent.preferredAddressType().addressType();
for (int i = 0; i < size; i++) {
InetAddress address = resolvedEntries.get(i).address();
if (inetAddressType.isInstance(address)) {
return true;
}
}
return false;
}
private void finishResolve(Promise<T> promise, Throwable cause) {
private void finishResolve(Promise<List<T>> promise, Throwable cause) {
if (!queriesInProgress.isEmpty()) {
// If there are queries in progress, we should cancel it because we already finished the resolution.
for (Iterator<Future<AddressedEnvelope<DnsResponse, InetSocketAddress>>> i = queriesInProgress.iterator();
@ -700,13 +700,10 @@ abstract class DnsNameResolverContext<T> {
}
}
if (resolvedEntries != null) {
// Found at least one resolved address.
for (InternetProtocolFamily f: resolvedInternetProtocolFamilies) {
if (finishResolve(f.addressType(), resolvedEntries, promise)) {
return;
}
}
if (finalResult != null) {
// Found at least one resolved record.
trySuccess(promise, finalResult);
return;
}
// No resolved address found.
@ -729,20 +726,13 @@ abstract class DnsNameResolverContext<T> {
if (cause == null) {
// Only cache if the failure was not because of an IO error / timeout that was caused by the query
// itself.
resolveCache.cache(hostname, additionals, unknownHostException, parent.ch.eventLoop());
cache(hostname, additionals, unknownHostException);
} else {
unknownHostException.initCause(cause);
}
promise.tryFailure(unknownHostException);
}
abstract boolean finishResolve(Class<? extends InetAddress> addressType, List<DnsCacheEntry> resolvedEntries,
Promise<T> promise);
abstract DnsNameResolverContext<T> newResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache,
DnsServerAddressStream nameServerAddrs);
static String decodeDomainName(ByteBuf in) {
in.markReaderIndex();
try {
@ -760,8 +750,8 @@ abstract class DnsNameResolverContext<T> {
return stream == null ? nameServerAddrs.duplicate() : stream;
}
private void followCname(
DnsQuestion question, String cname, DnsQueryLifecycleObserver queryLifecycleObserver, Promise<T> promise) {
private void followCname(DnsQuestion question, String cname, DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<List<T>> promise) {
DnsServerAddressStream stream = getNameServers(cname);
final DnsQuestion cnameQuestion;
@ -776,7 +766,7 @@ abstract class DnsNameResolverContext<T> {
}
private boolean query(String hostname, DnsRecordType type, DnsServerAddressStream dnsServerAddressStream,
Promise<T> promise, Throwable cause) {
Promise<List<T>> promise, Throwable cause) {
final DnsQuestion question = newQuestion(hostname, type);
if (question == null) {
return false;
@ -785,9 +775,9 @@ abstract class DnsNameResolverContext<T> {
return true;
}
private static DnsQuestion newQuestion(String hostname, DnsRecordType type) {
private DnsQuestion newQuestion(String hostname, DnsRecordType type) {
try {
return new DefaultDnsQuestion(hostname, type);
return new DefaultDnsQuestion(hostname, type, dnsClass);
} catch (IllegalArgumentException e) {
// java.net.IDN.toASCII(...) may throw an IllegalArgumentException if it fails to parse the hostname
return null;

View File

@ -28,6 +28,7 @@ import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
@ -36,6 +37,7 @@ import io.netty.handler.codec.dns.DnsSection;
import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.SocketUtils;
@ -52,6 +54,7 @@ import org.apache.directory.server.dns.messages.ResourceRecordModifier;
import org.apache.directory.server.dns.messages.ResponseCode;
import org.apache.directory.server.dns.store.DnsAttribute;
import org.apache.directory.server.dns.store.RecordStore;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@ -608,7 +611,7 @@ public class DnsNameResolverTest {
buf.append(' ');
buf.append(recordContent.readUnsignedShort());
buf.append(' ');
buf.append(DnsNameResolverContext.decodeDomainName(recordContent));
buf.append(DnsResolveContext.decodeDomainName(recordContent));
}
logger.info("{} has the following MX records:{}", hostname, buf);
@ -945,6 +948,84 @@ public class DnsNameResolverTest {
}
}
@Test
public void testResolveAllMx() {
final DnsNameResolver resolver = newResolver().build();
try {
assertThat(resolver.isRecursionDesired(), is(true));
final Map<String, Future<List<DnsRecord>>> futures = new LinkedHashMap<String, Future<List<DnsRecord>>>();
for (String name: DOMAINS) {
if (EXCLUSIONS_QUERY_MX.contains(name)) {
continue;
}
futures.put(name, resolver.resolveAll(new DefaultDnsQuestion(name, DnsRecordType.MX)));
}
for (Entry<String, Future<List<DnsRecord>>> e: futures.entrySet()) {
String hostname = e.getKey();
Future<List<DnsRecord>> f = e.getValue().awaitUninterruptibly();
final List<DnsRecord> mxList = f.getNow();
assertThat(mxList.size(), is(greaterThan(0)));
StringBuilder buf = new StringBuilder();
for (DnsRecord r: mxList) {
ByteBuf recordContent = ((ByteBufHolder) r).content();
buf.append(StringUtil.NEWLINE);
buf.append('\t');
buf.append(r.name());
buf.append(' ');
buf.append(r.type().name());
buf.append(' ');
buf.append(recordContent.readUnsignedShort());
buf.append(' ');
buf.append(DnsResolveContext.decodeDomainName(recordContent));
ReferenceCountUtil.release(r);
}
logger.info("{} has the following MX records:{}", hostname, buf);
}
} finally {
resolver.close();
}
}
@Test
public void testResolveAllHostsFile() {
final DnsNameResolver resolver = new DnsNameResolverBuilder(group.next())
.channelType(NioDatagramChannel.class)
.hostsFileEntriesResolver(new HostsFileEntriesResolver() {
@Override
public InetAddress address(String inetHost, ResolvedAddressTypes resolvedAddressTypes) {
if ("foo.com.".equals(inetHost)) {
try {
return InetAddress.getByAddress("foo.com", new byte[] { 1, 2, 3, 4 });
} catch (UnknownHostException e) {
throw new Error(e);
}
}
return null;
}
}).build();
final List<DnsRecord> records = resolver.resolveAll(new DefaultDnsQuestion("foo.com.", A))
.syncUninterruptibly().getNow();
assertThat(records, Matchers.<DnsRecord>hasSize(1));
assertThat(records.get(0), Matchers.<DnsRecord>instanceOf(DnsRawRecord.class));
final DnsRawRecord record = (DnsRawRecord) records.get(0);
final ByteBuf content = record.content();
assertThat(record.name(), is("foo.com."));
assertThat(record.dnsClass(), is(DnsRecord.CLASS_IN));
assertThat(record.type(), is(A));
assertThat(content.readableBytes(), is(4));
assertThat(content.readInt(), is(0x01020304));
record.release();
}
@Test
public void testResolveDecodeUnicode() {
testResolveUnicode(true);