Provide more control over DnsNameResolver.query() / Add NameResolver.resolveAll()

Related issues:
- #3971
- #3973
- #3976
- #4035

Motivation:

1. Previously, DnsNameResolver.query() retried the request query by its
own. It prevents a user from deciding when to retry or stop. It is also
impossible to get the response object whose code is not NOERROR.

2. NameResolver does not have an operation that resolves a host name
into multiple addresses, like InetAddress.getAllByName()

Modifications:

- Changes related with DnsNameResolver.query()
  - Make query() not retry
    - Move the retry logic to DnsNameResolver.resolve() instead.
  - Make query() fail the promise only when I/O error occurred or it
    failed to get a response
  - Add DnsNameResolverException and use it when query() fails so that
    the resolver can give more information about the failure
  - query() does not cache anymore.

- Changes related with NameResolver.resolveAll()
  - Add NameResolver.resolveAll()
  - Add SimpleNameResolver.doResolveAll()

- Changes related with DnsNameResolver.resolve() and resolveAll()
  - Make DnsNameResolveContext abstract so that DnsNameResolver can
    decide to get single or multiple addresses from it
  - Re-implement cache so that the cache works for resolve() and
    resolveAll()
  - Add 'traceEnabled' property to enable/disable trace information

- Miscellaneous changes
  - Use ObjectUtil.checkNotNull() wherever possible
  - Add InternetProtocolFamily.addressType() to remove repetitive
    switch-case blocks in DnsNameResolver(Context)
  - Do not raise an exception when decoding a truncated DNS response

Result:

- Full control over query()
- A user can now retrieve all addresses via (Dns)NameResolver.resolveAll()
- DNS cache works only for resolve() and resolveAll() now.
This commit is contained in:
Trustin Lee 2015-07-12 19:34:05 +09:00
parent fd9ca8bbb3
commit fdfe3149ba
15 changed files with 866 additions and 483 deletions

View File

@ -105,7 +105,13 @@ public class DatagramDnsResponseDecoder extends MessageToMessageDecoder<Datagram
private void decodeRecords(
DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception {
for (int i = count; i > 0; i --) {
response.addRecord(section, recordDecoder.decodeRecord(buf));
final DnsRecord r = recordDecoder.decodeRecord(buf);
if (r == null) {
// Truncated response
break;
}
response.addRecord(section, r);
}
}
}

View File

@ -42,13 +42,28 @@ public class DefaultDnsRecordDecoder implements DnsRecordDecoder {
@Override
public final <T extends DnsRecord> T decodeRecord(ByteBuf in) throws Exception {
final int startOffset = in.readerIndex();
final String name = decodeName(in);
final int endOffset = in.writerIndex();
if (endOffset - startOffset < 10) {
// Not enough data
in.readerIndex(startOffset);
return null;
}
final DnsRecordType type = DnsRecordType.valueOf(in.readUnsignedShort());
final int aClass = in.readUnsignedShort();
final long ttl = in.readUnsignedInt();
final int length = in.readUnsignedShort();
final int offset = in.readerIndex();
if (endOffset - offset < length) {
// Not enough data
in.readerIndex(startOffset);
return null;
}
@SuppressWarnings("unchecked")
T record = (T) decodeRecord(name, type, aClass, ttl, in, offset, length);
in.readerIndex(offset + length);

View File

@ -37,6 +37,8 @@ public interface DnsRecordDecoder {
* Decodes a DNS record into its object representation.
*
* @param in the input buffer which contains a DNS record at its reader index
*
* @return the decoded record, or {@code null} if there are not enough data in the input buffer
*/
<T extends DnsRecord> T decodeRecord(ByteBuf in) throws Exception;
}

View File

@ -16,73 +16,61 @@
package io.netty.resolver.dns;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent;
import java.net.InetSocketAddress;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
final class DnsCacheEntry {
private enum State {
INIT,
SCHEDULED_EXPIRATION,
RELEASED
}
private final AddressedEnvelope<DnsResponse, InetSocketAddress> response;
private final String hostname;
private final InetAddress address;
private final Throwable cause;
private volatile ScheduledFuture<?> expirationFuture;
private boolean released;
@SuppressWarnings("unchecked")
DnsCacheEntry(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> response) {
this.response = (AddressedEnvelope<DnsResponse, InetSocketAddress>) response.retain();
DnsCacheEntry(String hostname, InetAddress address) {
this.hostname = hostname;
this.address = address;
cause = null;
}
DnsCacheEntry(Throwable cause) {
DnsCacheEntry(String hostname, Throwable cause) {
this.hostname = hostname;
this.cause = cause;
response = null;
address = null;
}
String hostname() {
return hostname;
}
InetAddress address() {
return address;
}
Throwable cause() {
return cause;
}
synchronized AddressedEnvelope<DnsResponse, InetSocketAddress> retainedResponse() {
if (released) {
// Released by other thread via either the expiration task or clearCache()
return null;
}
return response.retain();
}
void scheduleExpiration(EventLoop loop, Runnable task, long delay, TimeUnit unit) {
assert expirationFuture == null: "expiration task scheduled already";
expirationFuture = loop.schedule(task, delay, unit);
}
void release() {
synchronized (this) {
if (released) {
return;
}
released = true;
ReferenceCountUtil.safeRelease(response);
}
void cancelExpiration() {
ScheduledFuture<?> expirationFuture = this.expirationFuture;
if (expirationFuture != null) {
expirationFuture.cancel(false);
}
}
@Override
public String toString() {
if (cause != null) {
return hostname + '/' + cause;
} else {
return address.toString();
}
}
}

View File

@ -31,19 +31,17 @@ import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.handler.codec.dns.DatagramDnsQueryEncoder;
import io.netty.handler.codec.dns.DatagramDnsResponse;
import io.netty.handler.codec.dns.DatagramDnsResponseDecoder;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsResponseCode;
import io.netty.resolver.NameResolver;
import io.netty.resolver.SimpleNameResolver;
import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.concurrent.FastThreadLocal;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.InternalThreadLocalMap;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.SystemPropertyUtil;
@ -56,6 +54,7 @@ import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
@ -63,6 +62,8 @@ import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReferenceArray;
import static io.netty.util.internal.ObjectUtil.*;
/**
* A DNS-based {@link NameResolver}.
*/
@ -102,9 +103,17 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
final AtomicReferenceArray<DnsQueryContext> promises = new AtomicReferenceArray<DnsQueryContext>(65536);
/**
* The cache for {@link #query(DnsQuestion)}
* Cache for {@link #doResolve(InetSocketAddress, Promise)} and {@link #doResolveAll(InetSocketAddress, Promise)}.
*/
final ConcurrentMap<DnsQuestion, DnsCacheEntry> queryCache = PlatformDependent.newConcurrentHashMap();
final ConcurrentMap<String, List<DnsCacheEntry>> resolveCache = PlatformDependent.newConcurrentHashMap();
private final FastThreadLocal<Iterator<InetSocketAddress>> nameServerAddrIterator =
new FastThreadLocal<Iterator<InetSocketAddress>>() {
@Override
protected Iterator<InetSocketAddress> initialValue() throws Exception {
return nameServerAddresses.iterator();
}
};
private final DnsResponseHandler responseHandler = new DnsResponseHandler();
@ -114,11 +123,11 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
private volatile int minTtl;
private volatile int maxTtl = Integer.MAX_VALUE;
private volatile int negativeTtl;
private volatile int maxTriesPerQuery = 2;
private volatile int maxQueriesPerResolve = 3;
private volatile boolean traceEnabled = true;
private volatile InternetProtocolFamily[] resolveAddressTypes = DEFAULT_RESOLVE_ADDRESS_TYPES;
private volatile boolean recursionDesired = true;
private volatile int maxQueriesPerResolve = 8;
private volatile int maxPayloadSize;
@ -238,17 +247,12 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
super(eventLoop);
if (channelFactory == null) {
throw new NullPointerException("channelFactory");
}
if (nameServerAddresses == null) {
throw new NullPointerException("nameServerAddresses");
}
checkNotNull(channelFactory, "channelFactory");
checkNotNull(nameServerAddresses, "nameServerAddresses");
checkNotNull(localAddress, "localAddress");
if (!nameServerAddresses.iterator().hasNext()) {
throw new NullPointerException("nameServerAddresses is empty");
}
if (localAddress == null) {
throw new NullPointerException("localAddress");
throw new IllegalArgumentException("nameServerAddresses is empty");
}
this.nameServerAddresses = nameServerAddresses;
@ -386,32 +390,6 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
return this;
}
/**
* Returns the maximum number of tries for each query. The default value is 2 times.
*
* @see #setMaxTriesPerQuery(int)
*/
public int maxTriesPerQuery() {
return maxTriesPerQuery;
}
/**
* Sets the maximum number of tries for each query.
*
* @return {@code this}
*
* @see #maxTriesPerQuery()
*/
public DnsNameResolver setMaxTriesPerQuery(int maxTriesPerQuery) {
if (maxTriesPerQuery < 1) {
throw new IllegalArgumentException("maxTries: " + maxTriesPerQuery + " (expected: > 0)");
}
this.maxTriesPerQuery = maxTriesPerQuery;
return this;
}
/**
* Returns the list of the protocol families of the address resolved by {@link #resolve(SocketAddress)}
* in the order of preference.
@ -438,9 +416,7 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
* @see #resolveAddressTypes()
*/
public DnsNameResolver setResolveAddressTypes(InternetProtocolFamily... resolveAddressTypes) {
if (resolveAddressTypes == null) {
throw new NullPointerException("resolveAddressTypes");
}
checkNotNull(resolveAddressTypes, "resolveAddressTypes");
final List<InternetProtocolFamily> list =
new ArrayList<InternetProtocolFamily>(InternetProtocolFamily.values().length);
@ -478,9 +454,7 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
* @see #resolveAddressTypes()
*/
public DnsNameResolver setResolveAddressTypes(Iterable<InternetProtocolFamily> resolveAddressTypes) {
if (resolveAddressTypes == null) {
throw new NullPointerException("resolveAddressTypes");
}
checkNotNull(resolveAddressTypes, "resolveAddressTypes");
final List<InternetProtocolFamily> list =
new ArrayList<InternetProtocolFamily>(InternetProtocolFamily.values().length);
@ -556,6 +530,23 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
return this;
}
/**
* Returns if this resolver should generate the detailed trace information in an exception message so that
* it is easier to understand the cause of resolution failure. The default value if {@code true}.
*/
public boolean isTraceEnabled() {
return traceEnabled;
}
/**
* Sets if this resolver should generate the detailed trace information in an exception message so that
* it is easier to understand the cause of resolution failure.
*/
public DnsNameResolver setTraceEnabled(boolean traceEnabled) {
this.traceEnabled = traceEnabled;
return this;
}
/**
* Returns the capacity of the datagram packet buffer (in bytes). The default value is {@code 4096} bytes.
*
@ -589,32 +580,45 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
}
/**
* Clears all the DNS resource records cached by this resolver.
* Clears all the resolved addresses cached by this resolver.
*
* @return {@code this}
*
* @see #clearCache(DnsQuestion)
* @see #clearCache(String)
*/
public DnsNameResolver clearCache() {
for (Iterator<Entry<DnsQuestion, DnsCacheEntry>> i = queryCache.entrySet().iterator(); i.hasNext();) {
Entry<DnsQuestion, DnsCacheEntry> e = i.next();
for (Iterator<Entry<String, List<DnsCacheEntry>>> i = resolveCache.entrySet().iterator(); i.hasNext();) {
final Entry<String, List<DnsCacheEntry>> e = i.next();
i.remove();
e.getValue().release();
cancelExpiration(e);
}
return this;
}
/**
* Clears the DNS resource record of the specified DNS question from the cache of this resolver.
* Clears the resolved addresses of the specified host name from the cache of this resolver.
*
* @return {@code true} if and only if there was an entry for the specified host name in the cache and
* it has been removed by this method
*/
public boolean clearCache(DnsQuestion question) {
DnsCacheEntry e = queryCache.remove(question);
if (e != null) {
e.release();
return true;
} else {
return false;
public boolean clearCache(String hostname) {
boolean removed = false;
for (Iterator<Entry<String, List<DnsCacheEntry>>> i = resolveCache.entrySet().iterator(); i.hasNext();) {
final Entry<String, List<DnsCacheEntry>> e = i.next();
if (e.getKey().equals(hostname)) {
i.remove();
cancelExpiration(e);
removed = true;
}
}
return removed;
}
private static void cancelExpiration(Entry<String, List<DnsCacheEntry>> e) {
final List<DnsCacheEntry> entries = e.getValue();
final int numEntries = entries.size();
for (int i = 0; i < numEntries; i++) {
entries.get(i).cancelExpiration();
}
}
@ -640,34 +644,273 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
@Override
protected void doResolve(InetSocketAddress unresolvedAddress, Promise<InetSocketAddress> promise) throws Exception {
byte[] bytes = NetUtil.createByteArrayFromIpAddressString(unresolvedAddress.getHostName());
if (bytes == null) {
final String hostname = IDN.toASCII(hostname(unresolvedAddress));
final int port = unresolvedAddress.getPort();
final DnsNameResolverContext ctx = new DnsNameResolverContext(this, hostname, port, promise);
ctx.resolve();
} else {
final byte[] bytes = NetUtil.createByteArrayFromIpAddressString(unresolvedAddress.getHostName());
if (bytes != null) {
// The unresolvedAddress was created via a String that contains an ipaddress.
promise.setSuccess(new InetSocketAddress(InetAddress.getByAddress(bytes), unresolvedAddress.getPort()));
return;
}
final String hostname = hostname(unresolvedAddress);
final int port = unresolvedAddress.getPort();
if (!doResolveCached(hostname, port, promise)) {
doResolveUncached(hostname, port, promise);
}
}
private boolean doResolveCached(String hostname, int port, Promise<InetSocketAddress> promise) {
final List<DnsCacheEntry> cachedEntries = resolveCache.get(hostname);
if (cachedEntries == null) {
return false;
}
InetAddress address = null;
Throwable cause = null;
synchronized (cachedEntries) {
final int numEntries = cachedEntries.size();
assert numEntries > 0;
if (cachedEntries.get(0).cause() != null) {
cause = cachedEntries.get(0).cause();
} else {
// Find the first entry with the preferred address type.
for (InternetProtocolFamily f : resolveAddressTypes) {
for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
address = e.address();
break;
}
}
}
}
}
if (address != null) {
setSuccess(promise, new InetSocketAddress(address, port));
} else if (cause != null) {
if (!promise.tryFailure(cause)) {
logger.warn("Failed to notify failure to a promise: {}", promise, cause);
}
} else {
return false;
}
return true;
}
private static void setSuccess(Promise<InetSocketAddress> promise, InetSocketAddress result) {
if (!promise.trySuccess(result)) {
logger.warn("Failed to notify success ({}) to a promise: {}", result, promise);
}
}
private void doResolveUncached(String hostname, final int port, Promise<InetSocketAddress> promise) {
final DnsNameResolverContext<InetSocketAddress> ctx =
new DnsNameResolverContext<InetSocketAddress>(this, hostname, promise) {
@Override
protected boolean finishResolve(
Class<? extends InetAddress> addressType, List<DnsCacheEntry> resolvedEntries) {
final int numEntries = resolvedEntries.size();
for (int i = 0; i < numEntries; i++) {
final InetAddress a = resolvedEntries.get(i).address();
if (addressType.isInstance(a)) {
setSuccess(promise(), new InetSocketAddress(a, port));
return true;
}
}
return false;
}
};
ctx.resolve();
}
@Override
protected void doResolveAll(
InetSocketAddress unresolvedAddress, Promise<List<InetSocketAddress>> promise) throws Exception {
final byte[] bytes = NetUtil.createByteArrayFromIpAddressString(unresolvedAddress.getHostName());
if (bytes != null) {
// The unresolvedAddress was created via a String that contains an ipaddress.
promise.setSuccess(Collections.singletonList(
new InetSocketAddress(InetAddress.getByAddress(bytes), unresolvedAddress.getPort())));
return;
}
final String hostname = hostname(unresolvedAddress);
final int port = unresolvedAddress.getPort();
if (!doResolveAllCached(hostname, port, promise)) {
doResolveAllUncached(hostname, port, promise);
}
}
private boolean doResolveAllCached(String hostname, int port, Promise<List<InetSocketAddress>> promise) {
final List<DnsCacheEntry> cachedEntries = resolveCache.get(hostname);
if (cachedEntries == null) {
return false;
}
List<InetSocketAddress> result = null;
Throwable cause = null;
synchronized (cachedEntries) {
final int numEntries = cachedEntries.size();
assert numEntries > 0;
if (cachedEntries.get(0).cause() != null) {
cause = cachedEntries.get(0).cause();
} else {
for (InternetProtocolFamily f : resolveAddressTypes) {
for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
if (result == null) {
result = new ArrayList<InetSocketAddress>(numEntries);
}
result.add(new InetSocketAddress(e.address(), port));
}
}
}
}
}
if (result != null) {
promise.trySuccess(result);
} else if (cause != null) {
promise.tryFailure(cause);
} else {
return false;
}
return true;
}
private void doResolveAllUncached(final String hostname, final int port,
final Promise<List<InetSocketAddress>> promise) {
final DnsNameResolverContext<List<InetSocketAddress>> ctx =
new DnsNameResolverContext<List<InetSocketAddress>>(this, hostname, promise) {
@Override
protected boolean finishResolve(
Class<? extends InetAddress> addressType, List<DnsCacheEntry> resolvedEntries) {
List<InetSocketAddress> result = null;
final int numEntries = resolvedEntries.size();
for (int i = 0; i < numEntries; i++) {
final InetAddress a = resolvedEntries.get(i).address();
if (addressType.isInstance(a)) {
if (result == null) {
result = new ArrayList<InetSocketAddress>(numEntries);
}
result.add(new InetSocketAddress(a, port));
}
}
if (result != null) {
promise().trySuccess(result);
return true;
}
return false;
}
};
ctx.resolve();
}
private static String hostname(InetSocketAddress addr) {
// InetSocketAddress.getHostString() is available since Java 7.
final String hostname;
if (PlatformDependent.javaVersion() < 7) {
return addr.getHostName();
hostname = addr.getHostName();
} else {
return addr.getHostString();
hostname = addr.getHostString();
}
return IDN.toASCII(hostname);
}
void cache(String hostname, InetAddress address, long originalTtl) {
final int maxTtl = maxTtl();
if (maxTtl == 0) {
return;
}
final int ttl = Math.max(minTtl(), (int) Math.min(maxTtl, originalTtl));
final List<DnsCacheEntry> entries = cachedEntries(hostname);
final DnsCacheEntry e = new DnsCacheEntry(hostname, address);
synchronized (entries) {
if (!entries.isEmpty()) {
final DnsCacheEntry firstEntry = entries.get(0);
if (firstEntry.cause() != null) {
assert entries.size() == 1;
firstEntry.cancelExpiration();
entries.clear();
}
}
entries.add(e);
}
scheduleCacheExpiration(entries, e, ttl);
}
void cache(String hostname, Throwable cause) {
final int negativeTtl = negativeTtl();
if (negativeTtl == 0) {
return;
}
final List<DnsCacheEntry> entries = cachedEntries(hostname);
final DnsCacheEntry e = new DnsCacheEntry(hostname, cause);
synchronized (entries) {
final int numEntries = entries.size();
for (int i = 0; i < numEntries; i ++) {
entries.get(i).cancelExpiration();
}
entries.clear();
entries.add(e);
}
scheduleCacheExpiration(entries, e, negativeTtl);
}
private List<DnsCacheEntry> cachedEntries(String hostname) {
List<DnsCacheEntry> oldEntries = resolveCache.get(hostname);
final List<DnsCacheEntry> entries;
if (oldEntries == null) {
List<DnsCacheEntry> newEntries = new ArrayList<DnsCacheEntry>();
oldEntries = resolveCache.putIfAbsent(hostname, newEntries);
entries = oldEntries != null? oldEntries : newEntries;
} else {
entries = oldEntries;
}
return entries;
}
private void scheduleCacheExpiration(final List<DnsCacheEntry> entries, final DnsCacheEntry e, int ttl) {
e.scheduleExpiration(
ch.eventLoop(),
new OneTimeTask() {
@Override
public void run() {
synchronized (entries) {
entries.remove(e);
if (entries.isEmpty()) {
resolveCache.remove(e.hostname());
}
}
}
}, ttl, TimeUnit.SECONDS);
}
/**
* Sends a DNS query with the specified question.
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(DnsQuestion question) {
return query(nameServerAddresses, question);
return query(nextNameServerAddress(), question);
}
/**
@ -675,132 +918,64 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
DnsQuestion question, Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
return query(nameServerAddresses, question, promise);
return query(nextNameServerAddress(), question, promise);
}
private InetSocketAddress nextNameServerAddress() {
final InternalThreadLocalMap tlm = InternalThreadLocalMap.get();
Iterator<InetSocketAddress> i = nameServerAddrIterator.get(tlm);
if (i.hasNext()) {
return i.next();
}
// The iterator has reached at its end, create a new iterator.
// We should not reach here if a user created nameServerAddresses via DnsServerAddresses, but just in case ..
i = nameServerAddresses.iterator();
nameServerAddrIterator.set(tlm, i);
return i.next();
}
/**
* Sends a DNS query with the specified question using the specified name server list.
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
Iterable<InetSocketAddress> nameServerAddresses, DnsQuestion question) {
if (nameServerAddresses == null) {
throw new NullPointerException("nameServerAddresses");
}
if (question == null) {
throw new NullPointerException("question");
}
InetSocketAddress nameServerAddr, DnsQuestion question) {
final EventLoop eventLoop = ch.eventLoop();
final DnsCacheEntry cachedResult = queryCache.get(question);
if (cachedResult != null) {
AddressedEnvelope<DnsResponse, InetSocketAddress> response = cachedResult.retainedResponse();
if (response != null) {
return eventLoop.newSucceededFuture(response);
} else {
Throwable cause = cachedResult.cause();
if (cause != null) {
return eventLoop.newFailedFuture(cause);
}
}
}
return query0(
nameServerAddresses, question,
eventLoop.<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
return query0(checkNotNull(nameServerAddr, "nameServerAddr"),
checkNotNull(question, "question"),
ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
}
/**
* Sends a DNS query with the specified question using the specified name server list.
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
Iterable<InetSocketAddress> nameServerAddresses, DnsQuestion question,
InetSocketAddress nameServerAddr, DnsQuestion question,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
if (nameServerAddresses == null) {
throw new NullPointerException("nameServerAddresses");
}
if (question == null) {
throw new NullPointerException("question");
}
if (promise == null) {
throw new NullPointerException("promise");
}
final DnsCacheEntry cachedResult = queryCache.get(question);
if (cachedResult != null) {
AddressedEnvelope<DnsResponse, InetSocketAddress> response = cachedResult.retainedResponse();
if (response != null) {
return cast(promise).setSuccess(response);
} else {
Throwable cause = cachedResult.cause();
if (cause != null) {
return cast(promise).setFailure(cause);
}
}
}
return query0(nameServerAddresses, question, promise);
return query0(checkNotNull(nameServerAddr, "nameServerAddr"),
checkNotNull(question, "question"),
checkNotNull(promise, "promise"));
}
private Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
Iterable<InetSocketAddress> nameServerAddresses, DnsQuestion question,
InetSocketAddress nameServerAddr, DnsQuestion question,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> castPromise = cast(promise);
try {
new DnsQueryContext(this, nameServerAddresses, question, castPromise).query();
new DnsQueryContext(this, nameServerAddr, question, castPromise).query();
return castPromise;
} catch (Exception e) {
return castPromise.setFailure(e);
}
}
void cacheSuccess(
DnsQuestion question, AddressedEnvelope<? extends DnsResponse, InetSocketAddress> res, long delaySeconds) {
cache(question, new DnsCacheEntry(res), delaySeconds);
}
void cacheFailure(DnsQuestion question, Throwable cause, long delaySeconds) {
cache(question, new DnsCacheEntry(cause), delaySeconds);
}
private void cache(final DnsQuestion question, DnsCacheEntry entry, long delaySeconds) {
DnsCacheEntry oldEntry = queryCache.put(question, entry);
if (oldEntry != null) {
oldEntry.release();
}
boolean scheduled = false;
try {
entry.scheduleExpiration(
ch.eventLoop(),
new OneTimeTask() {
@Override
public void run() {
clearCache(question);
}
},
delaySeconds, TimeUnit.SECONDS);
scheduled = true;
} finally {
if (!scheduled) {
// If failed to schedule the expiration task,
// remove the entry from the cache so that it does not leak.
clearCache(question);
entry.release();
}
}
}
@SuppressWarnings("unchecked")
private static Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> cast(Promise<?> promise) {
return (Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>>) promise;
}
@SuppressWarnings("unchecked")
private static AddressedEnvelope<DnsResponse, InetSocketAddress> cast(AddressedEnvelope<?, ?> envelope) {
return (AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope;
}
private final class DnsResponseHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
@ -813,78 +988,22 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
}
final DnsQueryContext qCtx = promises.get(queryId);
if (qCtx == null) {
if (logger.isWarnEnabled()) {
logger.warn("Received a DNS response with an unknown ID: {}", queryId);
logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId);
}
return;
}
if (res.count(DnsSection.QUESTION) != 1) {
logger.warn("Received a DNS response with invalid number of questions: {}", res);
return;
}
final DnsQuestion q = qCtx.question();
if (!q.equals(res.recordAt(DnsSection.QUESTION))) {
logger.warn("Received a mismatching DNS response: {}", res);
return;
}
// Cancel the timeout task.
final ScheduledFuture<?> timeoutFuture = qCtx.timeoutFuture();
if (timeoutFuture != null) {
timeoutFuture.cancel(false);
}
if (res.code() == DnsResponseCode.NOERROR) {
cache(q, res);
promises.set(queryId, null);
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> qPromise = qCtx.promise();
if (qPromise.setUncancellable()) {
qPromise.setSuccess(cast(res.retain()));
}
} else {
qCtx.retry(res.sender(),
"response code: " + res.code() +
" with " + res.count(DnsSection.ANSWER) + " answer(s) and " +
res.count(DnsSection.AUTHORITY) + " authority resource(s)");
}
qCtx.finish(res);
} finally {
ReferenceCountUtil.safeRelease(msg);
}
}
private void cache(DnsQuestion question, AddressedEnvelope<? extends DnsResponse, InetSocketAddress> res) {
final int maxTtl = maxTtl();
if (maxTtl == 0) {
return;
}
long ttl = Long.MAX_VALUE;
// Find the smallest TTL value returned by the server.
final DnsResponse resc = res.content();
final int answerCount = resc.count(DnsSection.ANSWER);
for (int i = 0; i < answerCount; i ++) {
final DnsRecord r = resc.recordAt(DnsSection.ANSWER, i);
final long rTtl = r.timeToLive();
if (ttl > rTtl) {
ttl = rTtl;
}
}
// Ensure that the found TTL is between minTtl and maxTtl.
ttl = Math.max(minTtl(), Math.min(maxTtl, ttl));
DnsNameResolver.this.cacheSuccess(question, res, ttl);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.warn("Unexpected exception: ", cause);
logger.warn("{} Unexpected exception: ", ch, cause);
}
}
}

View File

@ -22,6 +22,7 @@ import io.netty.channel.AddressedEnvelope;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
import io.netty.handler.codec.dns.DnsResponseCode;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRawRecord;
@ -50,7 +51,7 @@ import java.util.Locale;
import java.util.Map;
import java.util.Set;
final class DnsNameResolverContext {
abstract class DnsNameResolverContext<T> {
private static final int INADDRSZ4 = 4;
private static final int INADDRSZ6 = 16;
@ -66,9 +67,10 @@ final class DnsNameResolverContext {
};
private final DnsNameResolver parent;
private final Promise<InetSocketAddress> promise;
private final Iterator<InetSocketAddress> nameServerAddrs;
private final Promise<T> promise;
private final String hostname;
private final int port;
private final boolean traceEnabled;
private final int maxAllowedQueries;
private final InternetProtocolFamily[] resolveAddressTypes;
@ -76,22 +78,27 @@ final class DnsNameResolverContext {
Collections.newSetFromMap(
new IdentityHashMap<Future<AddressedEnvelope<DnsResponse, InetSocketAddress>>, Boolean>());
private List<InetAddress> resolvedAddresses;
private List<DnsCacheEntry> resolvedEntries;
private StringBuilder trace;
private int allowedQueries;
private boolean triedCNAME;
DnsNameResolverContext(DnsNameResolver parent, String hostname, int port, Promise<InetSocketAddress> promise) {
protected DnsNameResolverContext(DnsNameResolver parent, String hostname, Promise<T> promise) {
this.parent = parent;
this.promise = promise;
this.hostname = hostname;
this.port = port;
nameServerAddrs = parent.nameServerAddresses.iterator();
maxAllowedQueries = parent.maxQueriesPerResolve();
resolveAddressTypes = parent.resolveAddressTypesUnsafe();
traceEnabled = parent.isTraceEnabled();
allowedQueries = maxAllowedQueries;
}
protected Promise<T> promise() {
return promise;
}
void resolve() {
for (InternetProtocolFamily f: resolveAddressTypes) {
final DnsRecordType type;
@ -106,18 +113,18 @@ final class DnsNameResolverContext {
throw new Error();
}
query(parent.nameServerAddresses, new DefaultDnsQuestion(hostname, type));
query(nameServerAddrs.next(), new DefaultDnsQuestion(hostname, type));
}
}
private void query(Iterable<InetSocketAddress> nameServerAddresses, final DnsQuestion question) {
private void query(InetSocketAddress nameServerAddr, final DnsQuestion question) {
if (allowedQueries == 0 || promise.isCancelled()) {
return;
}
allowedQueries --;
final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> f = parent.query(nameServerAddresses, question);
final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> f = parent.query(nameServerAddr, question);
queriesInProgress.add(f);
f.addListener(new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@ -133,7 +140,11 @@ final class DnsNameResolverContext {
if (future.isSuccess()) {
onResponse(question, future.getNow());
} else {
addTrace(future.cause());
// Server did not respond or I/O error occurred; try again.
if (traceEnabled) {
addTrace(future.cause());
}
query(nameServerAddrs.next(), question);
}
} finally {
tryToFinishResolve();
@ -142,16 +153,32 @@ final class DnsNameResolverContext {
});
}
void onResponse(final DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> response) {
final DnsRecordType type = question.type();
void onResponse(final DnsQuestion question, AddressedEnvelope<DnsResponse, InetSocketAddress> envelope) {
try {
if (type == DnsRecordType.A || type == DnsRecordType.AAAA) {
onResponseAorAAAA(type, question, response);
} else if (type == DnsRecordType.CNAME) {
onResponseCNAME(question, response);
final DnsResponse res = envelope.content();
final DnsResponseCode code = res.code();
if (code == DnsResponseCode.NOERROR) {
final DnsRecordType type = question.type();
if (type == DnsRecordType.A || type == DnsRecordType.AAAA) {
onResponseAorAAAA(type, question, envelope);
} else if (type == DnsRecordType.CNAME) {
onResponseCNAME(question, envelope);
}
return;
}
if (traceEnabled) {
addTrace(envelope.sender(),
"response code: " + code + " with " + res.count(DnsSection.ANSWER) + " answer(s) and " +
res.count(DnsSection.AUTHORITY) + " authority resource(s)");
}
// Retry with the next server if the server did not tell us that the domain does not exist.
if (code != DnsResponseCode.NXDOMAIN) {
query(nameServerAddrs.next(), question);
}
} finally {
ReferenceCountUtil.safeRelease(response);
ReferenceCountUtil.safeRelease(envelope);
}
}
@ -203,24 +230,33 @@ final class DnsNameResolverContext {
final byte[] addrBytes = new byte[contentLen];
content.getBytes(content.readerIndex(), addrBytes);
final InetAddress resolved;
try {
InetAddress resolved = InetAddress.getByAddress(hostname, addrBytes);
if (resolvedAddresses == null) {
resolvedAddresses = new ArrayList<InetAddress>();
}
resolvedAddresses.add(resolved);
found = true;
resolved = InetAddress.getByAddress(hostname, addrBytes);
} catch (UnknownHostException e) {
// Should never reach here.
throw new Error(e);
}
if (resolvedEntries == null) {
resolvedEntries = new ArrayList<DnsCacheEntry>(8);
}
final DnsCacheEntry e = new DnsCacheEntry(hostname, resolved);
parent.cache(hostname, resolved, r.timeToLive());
resolvedEntries.add(e);
found = true;
// Note that we do not break from the loop here, so we decode/cache all A/AAAA records.
}
if (found) {
return;
}
addTrace(envelope.sender(), "no matching " + qType + " record found");
if (traceEnabled) {
addTrace(envelope.sender(), "no matching " + qType + " record found");
}
// We aked for A/AAAA but we got only CNAME.
if (!cnames.isEmpty()) {
@ -252,7 +288,7 @@ final class DnsNameResolverContext {
if (found) {
followCname(response.sender(), name, resolved);
} else if (trace) {
} else if (trace && traceEnabled) {
addTrace(response.sender(), "no matching CNAME record found");
}
}
@ -300,12 +336,12 @@ final class DnsNameResolverContext {
}
// There are no queries left to try.
if (resolvedAddresses == null) {
if (resolvedEntries == null) {
// .. and we could not find any A/AAAA records.
if (!triedCNAME) {
// As the last resort, try to query CNAME, just in case the name server has it.
triedCNAME = true;
query(parent.nameServerAddresses, new DefaultDnsQuestion(hostname, DnsRecordType.CNAME));
query(nameServerAddrs.next(), new DefaultDnsQuestion(hostname, DnsRecordType.CNAME));
return;
}
}
@ -315,22 +351,22 @@ final class DnsNameResolverContext {
}
private boolean gotPreferredAddress() {
if (resolvedAddresses == null) {
if (resolvedEntries == null) {
return false;
}
final int size = resolvedAddresses.size();
final int size = resolvedEntries.size();
switch (resolveAddressTypes[0]) {
case IPv4:
for (int i = 0; i < size; i ++) {
if (resolvedAddresses.get(i) instanceof Inet4Address) {
if (resolvedEntries.get(i).address() instanceof Inet4Address) {
return true;
}
}
break;
case IPv6:
for (int i = 0; i < size; i ++) {
if (resolvedAddresses.get(i) instanceof Inet6Address) {
if (resolvedEntries.get(i).address() instanceof Inet6Address) {
return true;
}
}
@ -354,67 +390,46 @@ final class DnsNameResolverContext {
}
}
if (resolvedAddresses != null) {
if (resolvedEntries != null) {
// Found at least one resolved address.
for (InternetProtocolFamily f: resolveAddressTypes) {
switch (f) {
case IPv4:
if (finishResolveWithIPv4()) {
return;
}
break;
case IPv6:
if (finishResolveWithIPv6()) {
return;
}
break;
if (finishResolve(f.addressType(), resolvedEntries)) {
return;
}
}
}
// No resolved address found.
int tries = maxAllowedQueries - allowedQueries;
UnknownHostException cause;
final int tries = maxAllowedQueries - allowedQueries;
final StringBuilder buf = new StringBuilder(64);
buf.append("failed to resolve ");
buf.append(hostname);
if (tries > 1) {
cause = new UnknownHostException(
"failed to resolve " + hostname + " after " + tries + " queries:" +
trace);
buf.append(" after ");
buf.append(tries);
if (trace != null) {
buf.append(" queries:");
buf.append(trace);
} else {
buf.append(" queries");
}
} else {
cause = new UnknownHostException("failed to resolve " + hostname + ':' + trace);
if (trace != null) {
buf.append(':');
buf.append(trace);
}
}
final UnknownHostException cause = new UnknownHostException(buf.toString());
parent.cache(hostname, cause);
promise.tryFailure(cause);
}
private boolean finishResolveWithIPv4() {
final List<InetAddress> resolvedAddresses = this.resolvedAddresses;
final int size = resolvedAddresses.size();
for (int i = 0; i < size; i ++) {
InetAddress a = resolvedAddresses.get(i);
if (a instanceof Inet4Address) {
promise.trySuccess(new InetSocketAddress(a, port));
return true;
}
}
return false;
}
private boolean finishResolveWithIPv6() {
final List<InetAddress> resolvedAddresses = this.resolvedAddresses;
final int size = resolvedAddresses.size();
for (int i = 0; i < size; i ++) {
InetAddress a = resolvedAddresses.get(i);
if (a instanceof Inet6Address) {
promise.trySuccess(new InetSocketAddress(a, port));
return true;
}
}
return false;
}
protected abstract boolean finishResolve(
Class<? extends InetAddress> addressType, List<DnsCacheEntry> resolvedEntries);
/**
* Adapted from {@link DefaultDnsRecordDecoder#decodeName(ByteBuf)}.
@ -466,26 +481,30 @@ final class DnsNameResolverContext {
}
}
private void followCname(
InetSocketAddress nameServerAddr, String name, String cname) {
private void followCname(InetSocketAddress nameServerAddr, String name, String cname) {
if (trace == null) {
trace = new StringBuilder(128);
if (traceEnabled) {
if (trace == null) {
trace = new StringBuilder(128);
}
trace.append(StringUtil.NEWLINE);
trace.append("\tfrom ");
trace.append(nameServerAddr);
trace.append(": ");
trace.append(name);
trace.append(" CNAME ");
trace.append(cname);
}
trace.append(StringUtil.NEWLINE);
trace.append("\tfrom ");
trace.append(nameServerAddr);
trace.append(": ");
trace.append(name);
trace.append(" CNAME ");
trace.append(cname);
query(parent.nameServerAddresses, new DefaultDnsQuestion(cname, DnsRecordType.A));
query(parent.nameServerAddresses, new DefaultDnsQuestion(cname, DnsRecordType.AAAA));
final InetSocketAddress nextAddr = nameServerAddrs.next();
query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.A));
query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.AAAA));
}
private void addTrace(InetSocketAddress nameServerAddr, String msg) {
assert traceEnabled;
if (trace == null) {
trace = new StringBuilder(128);
}
@ -498,6 +517,8 @@ final class DnsNameResolverContext {
}
private void addTrace(Throwable cause) {
assert traceEnabled;
if (trace == null) {
trace = new StringBuilder(128);
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.ObjectUtil;
import java.net.InetSocketAddress;
/**
* A {@link RuntimeException} raised when {@link DnsNameResolver} failed to perform a successful query.
*/
public final class DnsNameResolverException extends RuntimeException {
private static final long serialVersionUID = -8826717909627131850L;
private final InetSocketAddress remoteAddress;
private final DnsQuestion question;
public DnsNameResolverException(InetSocketAddress remoteAddress, DnsQuestion question) {
this.remoteAddress = validateRemoteAddress(remoteAddress);
this.question = validateQuestion(question);
}
public DnsNameResolverException(InetSocketAddress remoteAddress, DnsQuestion question, String message) {
super(message);
this.remoteAddress = validateRemoteAddress(remoteAddress);
this.question = validateQuestion(question);
}
public DnsNameResolverException(
InetSocketAddress remoteAddress, DnsQuestion question, String message, Throwable cause) {
super(message, cause);
this.remoteAddress = validateRemoteAddress(remoteAddress);
this.question = validateQuestion(question);
}
public DnsNameResolverException(InetSocketAddress remoteAddress, DnsQuestion question, Throwable cause) {
super(cause);
this.remoteAddress = validateRemoteAddress(remoteAddress);
this.question = validateQuestion(question);
}
private static InetSocketAddress validateRemoteAddress(InetSocketAddress remoteAddress) {
return ObjectUtil.checkNotNull(remoteAddress, "remoteAddress");
}
private static DnsQuestion validateQuestion(DnsQuestion question) {
return ObjectUtil.checkNotNull(question, "question");
}
/**
* Returns the {@link DnsQuestion} of the DNS query that has failed.
*/
public DnsQuestion question() {
return question;
}
@Override
public Throwable fillInStackTrace() {
setStackTrace(EmptyArrays.EMPTY_STACK_TRACE);
return this;
}
}

View File

@ -13,7 +13,6 @@
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import io.netty.buffer.Unpooled;
@ -22,12 +21,12 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.dns.DnsQuery;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.OneTimeTask;
@ -37,8 +36,6 @@ import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.Iterator;
import java.util.concurrent.TimeUnit;
final class DnsQueryContext {
@ -50,30 +47,24 @@ final class DnsQueryContext {
private final int id;
private final DnsQuestion question;
private final DnsRecord optResource;
private final Iterator<InetSocketAddress> nameServerAddresses;
private final InetSocketAddress nameServerAddr;
private final boolean recursionDesired;
private final int maxTries;
private int remainingTries;
private volatile ScheduledFuture<?> timeoutFuture;
private StringBuilder trace;
DnsQueryContext(DnsNameResolver parent,
Iterable<InetSocketAddress> nameServerAddresses,
InetSocketAddress nameServerAddr,
DnsQuestion question, Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
this.parent = parent;
this.promise = promise;
this.nameServerAddr = nameServerAddr;
this.question = question;
this.promise = promise;
id = allocateId();
recursionDesired = parent.isRecursionDesired();
maxTries = parent.maxTriesPerQuery();
remainingTries = maxTries;
optResource = new DefaultDnsRawRecord(
StringUtil.EMPTY_STRING, DnsRecordType.OPT, parent.maxPayloadSize(), 0, Unpooled.EMPTY_BUFFER);
this.nameServerAddresses = nameServerAddresses.iterator();
}
private int allocateId() {
@ -93,42 +84,8 @@ final class DnsQueryContext {
}
}
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise() {
return promise;
}
DnsQuestion question() {
return question;
}
ScheduledFuture<?> timeoutFuture() {
return timeoutFuture;
}
void query() {
final DnsQuestion question = this.question;
if (remainingTries <= 0 || !nameServerAddresses.hasNext()) {
parent.promises.set(id, null);
int tries = maxTries - remainingTries;
UnknownHostException cause;
if (tries > 1) {
cause = new UnknownHostException(
"failed to resolve " + question + " after " + tries + " attempts:" +
trace);
} else {
cause = new UnknownHostException("failed to resolve " + question + ':' + trace);
}
cache(question, cause);
promise.tryFailure(cause);
return;
}
remainingTries --;
final InetSocketAddress nameServerAddr = nameServerAddresses.next();
final DatagramDnsQuery query = new DatagramDnsQuery(null, nameServerAddr, id);
query.setRecursionDesired(recursionDesired);
query.setRecord(DnsSection.QUESTION, question);
@ -138,43 +95,43 @@ final class DnsQueryContext {
logger.debug("{} WRITE: [{}: {}], {}", parent.ch, id, nameServerAddr, question);
}
sendQuery(query, nameServerAddr);
sendQuery(query);
}
private void sendQuery(final DnsQuery query, final InetSocketAddress nameServerAddr) {
private void sendQuery(final DnsQuery query) {
if (parent.bindFuture.isDone()) {
writeQuery(query, nameServerAddr);
writeQuery(query);
} else {
parent.bindFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
writeQuery(query, nameServerAddr);
writeQuery(query);
} else {
promise.tryFailure(future.cause());
}
}
});
}
}
private void writeQuery(final DnsQuery query, final InetSocketAddress nameServerAddr) {
final ChannelFuture writeFuture = parent.ch.writeAndFlush(query);
if (writeFuture.isDone()) {
onQueryWriteCompletion(writeFuture, nameServerAddr);
} else {
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
onQueryWriteCompletion(writeFuture, nameServerAddr);
}
});
}
}
private void onQueryWriteCompletion(ChannelFuture writeFuture, final InetSocketAddress nameServerAddr) {
private void writeQuery(final DnsQuery query) {
final ChannelFuture writeFuture = parent.ch.writeAndFlush(query);
if (writeFuture.isDone()) {
onQueryWriteCompletion(writeFuture);
} else {
writeFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
onQueryWriteCompletion(writeFuture);
}
});
}
}
private void onQueryWriteCompletion(ChannelFuture writeFuture) {
if (!writeFuture.isSuccess()) {
retry(nameServerAddr, "failed to send a query: " + writeFuture.cause());
setFailure("failed to send a query", writeFuture.cause());
return;
}
@ -189,35 +146,62 @@ final class DnsQueryContext {
return;
}
retry(nameServerAddr, "query timed out after " + queryTimeoutMillis + " milliseconds");
setFailure("query timed out after " + queryTimeoutMillis + " milliseconds", null);
}
}, queryTimeoutMillis, TimeUnit.MILLISECONDS);
}
}
void retry(InetSocketAddress nameServerAddr, String message) {
if (promise.isCancelled()) {
void finish(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
DnsResponse res = envelope.content();
if (res.count(DnsSection.QUESTION) != 1) {
logger.warn("Received a DNS response with invalid number of questions: {}", envelope);
return;
}
if (trace == null) {
trace = new StringBuilder(128);
if (!question.equals(res.recordAt(DnsSection.QUESTION))) {
logger.warn("Received a mismatching DNS response: {}", envelope);
return;
}
trace.append(StringUtil.NEWLINE);
trace.append("\tfrom ");
trace.append(nameServerAddr);
trace.append(": ");
trace.append(message);
query();
setSuccess(envelope);
}
private void cache(final DnsQuestion question, Throwable cause) {
final int negativeTtl = parent.negativeTtl();
if (negativeTtl == 0) {
return;
private void setSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
parent.promises.set(id, null);
// Cancel the timeout task.
final ScheduledFuture<?> timeoutFuture = this.timeoutFuture;
if (timeoutFuture != null) {
timeoutFuture.cancel(false);
}
parent.cacheFailure(question, cause, negativeTtl);
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise = this.promise;
if (promise.setUncancellable()) {
@SuppressWarnings("unchecked")
AddressedEnvelope<DnsResponse, InetSocketAddress> castResponse =
(AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope.retain();
promise.setSuccess(castResponse);
}
}
private void setFailure(String message, Throwable cause) {
parent.promises.set(id, null);
final StringBuilder buf = new StringBuilder(message.length() + 64);
buf.append('[')
.append(nameServerAddr)
.append("] ")
.append(message)
.append(" (no stack trace available)");
final DnsNameResolverException e;
if (cause != null) {
e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause);
} else {
e = new DnsNameResolverException(nameServerAddr, question, buf.toString());
}
promise.tryFailure(e);
}
}

View File

@ -247,7 +247,7 @@ public class DnsNameResolverTest {
group.next(), NioDatagramChannel.class, DnsServerAddresses.shuffled(SERVERS));
static {
resolver.setMaxTriesPerQuery(SERVERS.size());
resolver.setMaxQueriesPerResolve(SERVERS.size());
}
@AfterClass
@ -293,6 +293,10 @@ public class DnsNameResolverTest {
for (Entry<String, InetAddress> e: resultA.entrySet()) {
InetAddress expected = e.getValue();
InetAddress actual = resultB.get(e.getKey());
if (!actual.equals(expected)) {
// Print the content of the cache when test failure is expected.
System.err.println("Cache for " + e.getKey() + ": " + resolver.resolveAll(e.getKey(), 0).getNow());
}
assertThat(actual, is(expected));
}
} finally {
@ -342,17 +346,8 @@ public class DnsNameResolverTest {
boolean typeMatches = false;
for (InternetProtocolFamily f: famililies) {
Class<?> resolvedType = resolved.getAddress().getClass();
switch (f) {
case IPv4:
if (Inet4Address.class.isAssignableFrom(resolvedType)) {
typeMatches = true;
}
break;
case IPv6:
if (Inet6Address.class.isAssignableFrom(resolvedType)) {
typeMatches = true;
}
break;
if (f.addressType().isAssignableFrom(resolvedType)) {
typeMatches = true;
}
}
@ -383,9 +378,18 @@ public class DnsNameResolverTest {
for (Entry<String, Future<AddressedEnvelope<DnsResponse, InetSocketAddress>>> e: futures.entrySet()) {
String hostname = e.getKey();
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = e.getValue().sync().getNow();
DnsResponse response = envelope.content();
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> f = e.getValue().awaitUninterruptibly();
if (!f.isSuccess()) {
// Try again a couple more times because the DNS servers might be throttling us down.
for (int i = 0; i < 2; i++) {
f = queryMx(hostname).awaitUninterruptibly();
if (f.isSuccess()) {
break;
}
}
}
DnsResponse response = f.getNow().content();
assertThat(response.code(), is(DnsResponseCode.NOERROR));
final int answerCount = response.count(DnsSection.ANSWER);
@ -441,4 +445,8 @@ public class DnsNameResolverTest {
String hostname) throws Exception {
futures.put(hostname, resolver.query(new DefaultDnsQuestion(hostname, DnsRecordType.MX)));
}
private static Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> queryMx(String hostname) throws Exception {
return resolver.query(new DefaultDnsQuestion(hostname, DnsRecordType.MX));
}
}

View File

@ -22,6 +22,8 @@ import io.netty.util.concurrent.Promise;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
/**
* A {@link NameResolver} that resolves an {@link InetSocketAddress} using JDK's built-in domain name lookup mechanism.
@ -49,4 +51,22 @@ public class DefaultNameResolver extends SimpleNameResolver<InetSocketAddress> {
promise.setFailure(e);
}
}
@Override
protected void doResolveAll(
InetSocketAddress unresolvedAddress, Promise<List<InetSocketAddress>> promise) throws Exception {
try {
// Note that InetSocketAddress.getHostName() will never incur a reverse lookup here,
// because an unresolved address always has a host name.
final InetAddress[] resolved = InetAddress.getAllByName(unresolvedAddress.getHostName());
final List<InetSocketAddress> result = new ArrayList<InetSocketAddress>(resolved.length);
for (InetAddress a: resolved) {
result.add(new InetSocketAddress(a, unresolvedAddress.getPort()));
}
promise.setSuccess(result);
} catch (UnknownHostException e) {
promise.setFailure(e);
}
}
}

View File

@ -22,6 +22,7 @@ import io.netty.util.concurrent.Promise;
import java.io.Closeable;
import java.net.SocketAddress;
import java.nio.channels.UnsupportedAddressTypeException;
import java.util.List;
/**
* Resolves an arbitrary string that represents the name of an endpoint into a {@link SocketAddress}.
@ -82,6 +83,48 @@ public interface NameResolver<T extends SocketAddress> extends Closeable {
*/
Future<T> resolve(SocketAddress address, Promise<T> promise);
/**
* Resolves the specified host name and port into a list of {@link SocketAddress}es.
*
* @param inetHost the name to resolve
* @param inetPort the port number
*
* @return the list of the {@link SocketAddress}es as the result of the resolution
*/
Future<List<T>> resolveAll(String inetHost, int inetPort);
/**
* Resolves the specified host name and port into a list of {@link SocketAddress}es.
*
* @param inetHost the name to resolve
* @param inetPort the port number
* @param promise the {@link Promise} which will be fulfilled when the name resolution is finished
*
* @return the list of the {@link SocketAddress}es as the result of the resolution
*/
Future<List<T>> resolveAll(String inetHost, int inetPort, Promise<List<T>> promise);
/**
* Resolves the specified address. If the specified address is resolved already, this method does nothing
* but returning the original address.
*
* @param address the address to resolve
*
* @return the list of the {@link SocketAddress}es as the result of the resolution
*/
Future<List<T>> resolveAll(SocketAddress address);
/**
* Resolves the specified address. If the specified address is resolved already, this method does nothing
* but returning the original address.
*
* @param address the address to resolve
* @param promise the {@link Promise} which will be fulfilled when the name resolution is finished
*
* @return the list of the {@link SocketAddress}es as the result of the resolution
*/
Future<List<T>> resolveAll(SocketAddress address, Promise<List<T>> promise);
/**
* Closes all the resources allocated and used by this resolver.
*/

View File

@ -20,6 +20,8 @@ import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Promise;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.List;
/**
* A {@link NameResolver} that does not perform any resolution but always reports successful resolution.
@ -40,4 +42,10 @@ public class NoopNameResolver extends SimpleNameResolver<SocketAddress> {
protected void doResolve(SocketAddress unresolvedAddress, Promise<SocketAddress> promise) throws Exception {
promise.setSuccess(unresolvedAddress);
}
@Override
protected void doResolveAll(
SocketAddress unresolvedAddress, Promise<List<SocketAddress>> promise) throws Exception {
promise.setSuccess(Collections.singletonList(unresolvedAddress));
}
}

View File

@ -24,6 +24,10 @@ import io.netty.util.internal.TypeParameterMatcher;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.UnsupportedAddressTypeException;
import java.util.Collections;
import java.util.List;
import static io.netty.util.internal.ObjectUtil.*;
/**
* A skeletal {@link NameResolver} implementation.
@ -92,29 +96,17 @@ public abstract class SimpleNameResolver<T extends SocketAddress> implements Nam
@Override
public final Future<T> resolve(String inetHost, int inetPort) {
if (inetHost == null) {
throw new NullPointerException("inetHost");
}
return resolve(InetSocketAddress.createUnresolved(inetHost, inetPort));
return resolve(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort));
}
@Override
public Future<T> resolve(String inetHost, int inetPort, Promise<T> promise) {
if (inetHost == null) {
throw new NullPointerException("inetHost");
}
return resolve(InetSocketAddress.createUnresolved(inetHost, inetPort), promise);
return resolve(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort), promise);
}
@Override
public final Future<T> resolve(SocketAddress address) {
if (address == null) {
throw new NullPointerException("unresolvedAddress");
}
if (!isSupported(address)) {
if (!isSupported(checkNotNull(address, "address"))) {
// Address type not supported by the resolver
return executor().newFailedFuture(new UnsupportedAddressTypeException());
}
@ -139,12 +131,8 @@ public abstract class SimpleNameResolver<T extends SocketAddress> implements Nam
@Override
public final Future<T> resolve(SocketAddress address, Promise<T> promise) {
if (address == null) {
throw new NullPointerException("unresolvedAddress");
}
if (promise == null) {
throw new NullPointerException("promise");
}
checkNotNull(address, "address");
checkNotNull(promise, "promise");
if (!isSupported(address)) {
// Address type not supported by the resolver
@ -168,12 +156,80 @@ public abstract class SimpleNameResolver<T extends SocketAddress> implements Nam
}
}
@Override
public final Future<List<T>> resolveAll(String inetHost, int inetPort) {
return resolveAll(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort));
}
@Override
public Future<List<T>> resolveAll(String inetHost, int inetPort, Promise<List<T>> promise) {
return resolveAll(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort), promise);
}
@Override
public final Future<List<T>> resolveAll(SocketAddress address) {
if (!isSupported(checkNotNull(address, "address"))) {
// Address type not supported by the resolver
return executor().newFailedFuture(new UnsupportedAddressTypeException());
}
if (isResolved(address)) {
// Resolved already; no need to perform a lookup
@SuppressWarnings("unchecked")
final T cast = (T) address;
return executor.newSucceededFuture(Collections.singletonList(cast));
}
try {
@SuppressWarnings("unchecked")
final T cast = (T) address;
final Promise<List<T>> promise = executor().newPromise();
doResolveAll(cast, promise);
return promise;
} catch (Exception e) {
return executor().newFailedFuture(e);
}
}
@Override
public final Future<List<T>> resolveAll(SocketAddress address, Promise<List<T>> promise) {
checkNotNull(address, "address");
checkNotNull(promise, "promise");
if (!isSupported(address)) {
// Address type not supported by the resolver
return promise.setFailure(new UnsupportedAddressTypeException());
}
if (isResolved(address)) {
// Resolved already; no need to perform a lookup
@SuppressWarnings("unchecked")
final T cast = (T) address;
return promise.setSuccess(Collections.singletonList(cast));
}
try {
@SuppressWarnings("unchecked")
final T cast = (T) address;
doResolveAll(cast, promise);
return promise;
} catch (Exception e) {
return promise.setFailure(e);
}
}
/**
* Invoked by {@link #resolve(SocketAddress)} and {@link #resolve(String, int)} to perform the actual name
* resolution.
*/
protected abstract void doResolve(T unresolvedAddress, Promise<T> promise) throws Exception;
/**
* Invoked by {@link #resolveAll(SocketAddress)} and {@link #resolveAll(String, int)} to perform the actual name
* resolution.
*/
protected abstract void doResolveAll(T unresolvedAddress, Promise<List<T>> promise) throws Exception;
@Override
public void close() { }
}

View File

@ -15,10 +15,27 @@
*/
package io.netty.channel.socket;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
/**
* Internet Protocol (IP) families used byte the {@link DatagramChannel}
*/
public enum InternetProtocolFamily {
IPv4,
IPv6
IPv4(Inet4Address.class),
IPv6(Inet6Address.class);
private final Class<? extends InetAddress> addressType;
InternetProtocolFamily(Class<? extends InetAddress> addressType) {
this.addressType = addressType;
}
/**
* Returns the address type of this protocol family.
*/
public Class<? extends InetAddress> addressType() {
return addressType;
}
}

View File

@ -36,6 +36,7 @@ import io.netty.resolver.SimpleNameResolver;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.OneTimeTask;
import org.junit.AfterClass;
import org.junit.Test;
@ -43,6 +44,7 @@ import java.net.SocketAddress;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
@ -302,13 +304,29 @@ public class BootstrapTest {
@Override
protected void doResolve(
final SocketAddress unresolvedAddress, final Promise<SocketAddress> promise) {
executor().execute(new Runnable() {
executor().execute(new OneTimeTask() {
@Override
public void run() {
if (success) {
promise.setSuccess(unresolvedAddress);
} else {
promise.setFailure(new UnknownHostException());
promise.setFailure(new UnknownHostException(unresolvedAddress.toString()));
}
}
});
}
@Override
protected void doResolveAll(
final SocketAddress unresolvedAddress, final Promise<List<SocketAddress>> promise)
throws Exception {
executor().execute(new OneTimeTask() {
@Override
public void run() {
if (success) {
promise.setSuccess(Collections.singletonList(unresolvedAddress));
} else {
promise.setFailure(new UnknownHostException(unresolvedAddress.toString()));
}
}
});