From fdfe3149ba3f68678fa7f73c574c7c61aa33fa2f Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Sun, 12 Jul 2015 19:34:05 +0900 Subject: [PATCH] Provide more control over DnsNameResolver.query() / Add NameResolver.resolveAll() Related issues: - #3971 - #3973 - #3976 - #4035 Motivation: 1. Previously, DnsNameResolver.query() retried the request query by its own. It prevents a user from deciding when to retry or stop. It is also impossible to get the response object whose code is not NOERROR. 2. NameResolver does not have an operation that resolves a host name into multiple addresses, like InetAddress.getAllByName() Modifications: - Changes related with DnsNameResolver.query() - Make query() not retry - Move the retry logic to DnsNameResolver.resolve() instead. - Make query() fail the promise only when I/O error occurred or it failed to get a response - Add DnsNameResolverException and use it when query() fails so that the resolver can give more information about the failure - query() does not cache anymore. - Changes related with NameResolver.resolveAll() - Add NameResolver.resolveAll() - Add SimpleNameResolver.doResolveAll() - Changes related with DnsNameResolver.resolve() and resolveAll() - Make DnsNameResolveContext abstract so that DnsNameResolver can decide to get single or multiple addresses from it - Re-implement cache so that the cache works for resolve() and resolveAll() - Add 'traceEnabled' property to enable/disable trace information - Miscellaneous changes - Use ObjectUtil.checkNotNull() wherever possible - Add InternetProtocolFamily.addressType() to remove repetitive switch-case blocks in DnsNameResolver(Context) - Do not raise an exception when decoding a truncated DNS response Result: - Full control over query() - A user can now retrieve all addresses via (Dns)NameResolver.resolveAll() - DNS cache works only for resolve() and resolveAll() now. --- .../codec/dns/DatagramDnsResponseDecoder.java | 8 +- .../codec/dns/DefaultDnsRecordDecoder.java | 15 + .../handler/codec/dns/DnsRecordDecoder.java | 2 + .../io/netty/resolver/dns/DnsCacheEntry.java | 66 +- .../netty/resolver/dns/DnsNameResolver.java | 575 +++++++++++------- .../resolver/dns/DnsNameResolverContext.java | 205 ++++--- .../dns/DnsNameResolverException.java | 78 +++ .../netty/resolver/dns/DnsQueryContext.java | 152 +++-- .../resolver/dns/DnsNameResolverTest.java | 36 +- .../netty/resolver/DefaultNameResolver.java | 20 + .../java/io/netty/resolver/NameResolver.java | 43 ++ .../io/netty/resolver/NoopNameResolver.java | 8 + .../io/netty/resolver/SimpleNameResolver.java | 98 ++- .../socket/InternetProtocolFamily.java | 21 +- .../io/netty/bootstrap/BootstrapTest.java | 22 +- 15 files changed, 866 insertions(+), 483 deletions(-) create mode 100644 resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverException.java diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java index e9872015b1..b4c1fd09e0 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java @@ -105,7 +105,13 @@ public class DatagramDnsResponseDecoder extends MessageToMessageDecoder 0; i --) { - response.addRecord(section, recordDecoder.decodeRecord(buf)); + final DnsRecord r = recordDecoder.decodeRecord(buf); + if (r == null) { + // Truncated response + break; + } + + response.addRecord(section, r); } } } diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java index 2e35aebf5f..6d6beea4e4 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java @@ -42,13 +42,28 @@ public class DefaultDnsRecordDecoder implements DnsRecordDecoder { @Override public final T decodeRecord(ByteBuf in) throws Exception { + final int startOffset = in.readerIndex(); final String name = decodeName(in); + + final int endOffset = in.writerIndex(); + if (endOffset - startOffset < 10) { + // Not enough data + in.readerIndex(startOffset); + return null; + } + final DnsRecordType type = DnsRecordType.valueOf(in.readUnsignedShort()); final int aClass = in.readUnsignedShort(); final long ttl = in.readUnsignedInt(); final int length = in.readUnsignedShort(); final int offset = in.readerIndex(); + if (endOffset - offset < length) { + // Not enough data + in.readerIndex(startOffset); + return null; + } + @SuppressWarnings("unchecked") T record = (T) decodeRecord(name, type, aClass, ttl, in, offset, length); in.readerIndex(offset + length); diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsRecordDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsRecordDecoder.java index 23886487ee..a2b6315acc 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsRecordDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsRecordDecoder.java @@ -37,6 +37,8 @@ public interface DnsRecordDecoder { * Decodes a DNS record into its object representation. * * @param in the input buffer which contains a DNS record at its reader index + * + * @return the decoded record, or {@code null} if there are not enough data in the input buffer */ T decodeRecord(ByteBuf in) throws Exception; } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCacheEntry.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCacheEntry.java index dc7b9e99f2..eaeb91dd23 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCacheEntry.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsCacheEntry.java @@ -16,73 +16,61 @@ package io.netty.resolver.dns; -import io.netty.channel.AddressedEnvelope; import io.netty.channel.EventLoop; -import io.netty.handler.codec.dns.DnsResponse; -import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.ScheduledFuture; -import io.netty.util.internal.OneTimeTask; -import io.netty.util.internal.PlatformDependent; -import java.net.InetSocketAddress; +import java.net.InetAddress; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; final class DnsCacheEntry { - private enum State { - INIT, - SCHEDULED_EXPIRATION, - RELEASED - } - - private final AddressedEnvelope response; + private final String hostname; + private final InetAddress address; private final Throwable cause; private volatile ScheduledFuture expirationFuture; - private boolean released; - @SuppressWarnings("unchecked") - DnsCacheEntry(AddressedEnvelope response) { - this.response = (AddressedEnvelope) response.retain(); + DnsCacheEntry(String hostname, InetAddress address) { + this.hostname = hostname; + this.address = address; cause = null; } - DnsCacheEntry(Throwable cause) { + DnsCacheEntry(String hostname, Throwable cause) { + this.hostname = hostname; this.cause = cause; - response = null; + address = null; + } + + String hostname() { + return hostname; + } + + InetAddress address() { + return address; } Throwable cause() { return cause; } - synchronized AddressedEnvelope retainedResponse() { - if (released) { - // Released by other thread via either the expiration task or clearCache() - return null; - } - - return response.retain(); - } - void scheduleExpiration(EventLoop loop, Runnable task, long delay, TimeUnit unit) { assert expirationFuture == null: "expiration task scheduled already"; expirationFuture = loop.schedule(task, delay, unit); } - void release() { - synchronized (this) { - if (released) { - return; - } - - released = true; - ReferenceCountUtil.safeRelease(response); - } - + void cancelExpiration() { ScheduledFuture expirationFuture = this.expirationFuture; if (expirationFuture != null) { expirationFuture.cancel(false); } } + + @Override + public String toString() { + if (cause != null) { + return hostname + '/' + cause; + } else { + return address.toString(); + } + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 4419b97151..62a4d893e5 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -31,19 +31,17 @@ import io.netty.channel.socket.InternetProtocolFamily; import io.netty.handler.codec.dns.DatagramDnsQueryEncoder; import io.netty.handler.codec.dns.DatagramDnsResponse; import io.netty.handler.codec.dns.DatagramDnsResponseDecoder; -import io.netty.handler.codec.dns.DnsSection; import io.netty.handler.codec.dns.DnsQuestion; -import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsResponse; -import io.netty.handler.codec.dns.DnsResponseCode; import io.netty.resolver.NameResolver; import io.netty.resolver.SimpleNameResolver; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; -import io.netty.util.concurrent.ScheduledFuture; +import io.netty.util.internal.InternalThreadLocalMap; import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.SystemPropertyUtil; @@ -56,6 +54,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map.Entry; @@ -63,6 +62,8 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReferenceArray; +import static io.netty.util.internal.ObjectUtil.*; + /** * A DNS-based {@link NameResolver}. */ @@ -102,9 +103,17 @@ public class DnsNameResolver extends SimpleNameResolver { final AtomicReferenceArray promises = new AtomicReferenceArray(65536); /** - * The cache for {@link #query(DnsQuestion)} + * Cache for {@link #doResolve(InetSocketAddress, Promise)} and {@link #doResolveAll(InetSocketAddress, Promise)}. */ - final ConcurrentMap queryCache = PlatformDependent.newConcurrentHashMap(); + final ConcurrentMap> resolveCache = PlatformDependent.newConcurrentHashMap(); + + private final FastThreadLocal> nameServerAddrIterator = + new FastThreadLocal>() { + @Override + protected Iterator initialValue() throws Exception { + return nameServerAddresses.iterator(); + } + }; private final DnsResponseHandler responseHandler = new DnsResponseHandler(); @@ -114,11 +123,11 @@ public class DnsNameResolver extends SimpleNameResolver { private volatile int minTtl; private volatile int maxTtl = Integer.MAX_VALUE; private volatile int negativeTtl; - private volatile int maxTriesPerQuery = 2; + private volatile int maxQueriesPerResolve = 3; + private volatile boolean traceEnabled = true; private volatile InternetProtocolFamily[] resolveAddressTypes = DEFAULT_RESOLVE_ADDRESS_TYPES; private volatile boolean recursionDesired = true; - private volatile int maxQueriesPerResolve = 8; private volatile int maxPayloadSize; @@ -238,17 +247,12 @@ public class DnsNameResolver extends SimpleNameResolver { super(eventLoop); - if (channelFactory == null) { - throw new NullPointerException("channelFactory"); - } - if (nameServerAddresses == null) { - throw new NullPointerException("nameServerAddresses"); - } + checkNotNull(channelFactory, "channelFactory"); + checkNotNull(nameServerAddresses, "nameServerAddresses"); + checkNotNull(localAddress, "localAddress"); + if (!nameServerAddresses.iterator().hasNext()) { - throw new NullPointerException("nameServerAddresses is empty"); - } - if (localAddress == null) { - throw new NullPointerException("localAddress"); + throw new IllegalArgumentException("nameServerAddresses is empty"); } this.nameServerAddresses = nameServerAddresses; @@ -386,32 +390,6 @@ public class DnsNameResolver extends SimpleNameResolver { return this; } - /** - * Returns the maximum number of tries for each query. The default value is 2 times. - * - * @see #setMaxTriesPerQuery(int) - */ - public int maxTriesPerQuery() { - return maxTriesPerQuery; - } - - /** - * Sets the maximum number of tries for each query. - * - * @return {@code this} - * - * @see #maxTriesPerQuery() - */ - public DnsNameResolver setMaxTriesPerQuery(int maxTriesPerQuery) { - if (maxTriesPerQuery < 1) { - throw new IllegalArgumentException("maxTries: " + maxTriesPerQuery + " (expected: > 0)"); - } - - this.maxTriesPerQuery = maxTriesPerQuery; - - return this; - } - /** * Returns the list of the protocol families of the address resolved by {@link #resolve(SocketAddress)} * in the order of preference. @@ -438,9 +416,7 @@ public class DnsNameResolver extends SimpleNameResolver { * @see #resolveAddressTypes() */ public DnsNameResolver setResolveAddressTypes(InternetProtocolFamily... resolveAddressTypes) { - if (resolveAddressTypes == null) { - throw new NullPointerException("resolveAddressTypes"); - } + checkNotNull(resolveAddressTypes, "resolveAddressTypes"); final List list = new ArrayList(InternetProtocolFamily.values().length); @@ -478,9 +454,7 @@ public class DnsNameResolver extends SimpleNameResolver { * @see #resolveAddressTypes() */ public DnsNameResolver setResolveAddressTypes(Iterable resolveAddressTypes) { - if (resolveAddressTypes == null) { - throw new NullPointerException("resolveAddressTypes"); - } + checkNotNull(resolveAddressTypes, "resolveAddressTypes"); final List list = new ArrayList(InternetProtocolFamily.values().length); @@ -556,6 +530,23 @@ public class DnsNameResolver extends SimpleNameResolver { return this; } + /** + * Returns if this resolver should generate the detailed trace information in an exception message so that + * it is easier to understand the cause of resolution failure. The default value if {@code true}. + */ + public boolean isTraceEnabled() { + return traceEnabled; + } + + /** + * Sets if this resolver should generate the detailed trace information in an exception message so that + * it is easier to understand the cause of resolution failure. + */ + public DnsNameResolver setTraceEnabled(boolean traceEnabled) { + this.traceEnabled = traceEnabled; + return this; + } + /** * Returns the capacity of the datagram packet buffer (in bytes). The default value is {@code 4096} bytes. * @@ -589,32 +580,45 @@ public class DnsNameResolver extends SimpleNameResolver { } /** - * Clears all the DNS resource records cached by this resolver. + * Clears all the resolved addresses cached by this resolver. * * @return {@code this} * - * @see #clearCache(DnsQuestion) + * @see #clearCache(String) */ public DnsNameResolver clearCache() { - for (Iterator> i = queryCache.entrySet().iterator(); i.hasNext();) { - Entry e = i.next(); + for (Iterator>> i = resolveCache.entrySet().iterator(); i.hasNext();) { + final Entry> e = i.next(); i.remove(); - e.getValue().release(); + cancelExpiration(e); } - return this; } /** - * Clears the DNS resource record of the specified DNS question from the cache of this resolver. + * Clears the resolved addresses of the specified host name from the cache of this resolver. + * + * @return {@code true} if and only if there was an entry for the specified host name in the cache and + * it has been removed by this method */ - public boolean clearCache(DnsQuestion question) { - DnsCacheEntry e = queryCache.remove(question); - if (e != null) { - e.release(); - return true; - } else { - return false; + public boolean clearCache(String hostname) { + boolean removed = false; + for (Iterator>> i = resolveCache.entrySet().iterator(); i.hasNext();) { + final Entry> e = i.next(); + if (e.getKey().equals(hostname)) { + i.remove(); + cancelExpiration(e); + removed = true; + } + } + return removed; + } + + private static void cancelExpiration(Entry> e) { + final List entries = e.getValue(); + final int numEntries = entries.size(); + for (int i = 0; i < numEntries; i++) { + entries.get(i).cancelExpiration(); } } @@ -640,34 +644,273 @@ public class DnsNameResolver extends SimpleNameResolver { @Override protected void doResolve(InetSocketAddress unresolvedAddress, Promise promise) throws Exception { - byte[] bytes = NetUtil.createByteArrayFromIpAddressString(unresolvedAddress.getHostName()); - if (bytes == null) { - final String hostname = IDN.toASCII(hostname(unresolvedAddress)); - final int port = unresolvedAddress.getPort(); - - final DnsNameResolverContext ctx = new DnsNameResolverContext(this, hostname, port, promise); - - ctx.resolve(); - } else { + final byte[] bytes = NetUtil.createByteArrayFromIpAddressString(unresolvedAddress.getHostName()); + if (bytes != null) { // The unresolvedAddress was created via a String that contains an ipaddress. promise.setSuccess(new InetSocketAddress(InetAddress.getByAddress(bytes), unresolvedAddress.getPort())); + return; } + + final String hostname = hostname(unresolvedAddress); + final int port = unresolvedAddress.getPort(); + + if (!doResolveCached(hostname, port, promise)) { + doResolveUncached(hostname, port, promise); + } + } + + private boolean doResolveCached(String hostname, int port, Promise promise) { + final List cachedEntries = resolveCache.get(hostname); + if (cachedEntries == null) { + return false; + } + + InetAddress address = null; + Throwable cause = null; + synchronized (cachedEntries) { + final int numEntries = cachedEntries.size(); + assert numEntries > 0; + + if (cachedEntries.get(0).cause() != null) { + cause = cachedEntries.get(0).cause(); + } else { + // Find the first entry with the preferred address type. + for (InternetProtocolFamily f : resolveAddressTypes) { + for (int i = 0; i < numEntries; i++) { + final DnsCacheEntry e = cachedEntries.get(i); + if (f.addressType().isInstance(e.address())) { + address = e.address(); + break; + } + } + } + } + } + + if (address != null) { + setSuccess(promise, new InetSocketAddress(address, port)); + } else if (cause != null) { + if (!promise.tryFailure(cause)) { + logger.warn("Failed to notify failure to a promise: {}", promise, cause); + } + } else { + return false; + } + + return true; + } + + private static void setSuccess(Promise promise, InetSocketAddress result) { + if (!promise.trySuccess(result)) { + logger.warn("Failed to notify success ({}) to a promise: {}", result, promise); + } + } + + private void doResolveUncached(String hostname, final int port, Promise promise) { + final DnsNameResolverContext ctx = + new DnsNameResolverContext(this, hostname, promise) { + @Override + protected boolean finishResolve( + Class addressType, List resolvedEntries) { + + final int numEntries = resolvedEntries.size(); + for (int i = 0; i < numEntries; i++) { + final InetAddress a = resolvedEntries.get(i).address(); + if (addressType.isInstance(a)) { + setSuccess(promise(), new InetSocketAddress(a, port)); + return true; + } + } + return false; + } + }; + + ctx.resolve(); + } + + @Override + protected void doResolveAll( + InetSocketAddress unresolvedAddress, Promise> promise) throws Exception { + + final byte[] bytes = NetUtil.createByteArrayFromIpAddressString(unresolvedAddress.getHostName()); + if (bytes != null) { + // The unresolvedAddress was created via a String that contains an ipaddress. + promise.setSuccess(Collections.singletonList( + new InetSocketAddress(InetAddress.getByAddress(bytes), unresolvedAddress.getPort()))); + return; + } + + final String hostname = hostname(unresolvedAddress); + final int port = unresolvedAddress.getPort(); + + if (!doResolveAllCached(hostname, port, promise)) { + doResolveAllUncached(hostname, port, promise); + } + } + + private boolean doResolveAllCached(String hostname, int port, Promise> promise) { + final List cachedEntries = resolveCache.get(hostname); + if (cachedEntries == null) { + return false; + } + + List result = null; + Throwable cause = null; + synchronized (cachedEntries) { + final int numEntries = cachedEntries.size(); + assert numEntries > 0; + + if (cachedEntries.get(0).cause() != null) { + cause = cachedEntries.get(0).cause(); + } else { + for (InternetProtocolFamily f : resolveAddressTypes) { + for (int i = 0; i < numEntries; i++) { + final DnsCacheEntry e = cachedEntries.get(i); + if (f.addressType().isInstance(e.address())) { + if (result == null) { + result = new ArrayList(numEntries); + } + result.add(new InetSocketAddress(e.address(), port)); + } + } + } + } + } + + if (result != null) { + promise.trySuccess(result); + } else if (cause != null) { + promise.tryFailure(cause); + } else { + return false; + } + + return true; + } + + private void doResolveAllUncached(final String hostname, final int port, + final Promise> promise) { + final DnsNameResolverContext> ctx = + new DnsNameResolverContext>(this, hostname, promise) { + @Override + protected boolean finishResolve( + Class addressType, List resolvedEntries) { + + List result = null; + final int numEntries = resolvedEntries.size(); + for (int i = 0; i < numEntries; i++) { + final InetAddress a = resolvedEntries.get(i).address(); + if (addressType.isInstance(a)) { + if (result == null) { + result = new ArrayList(numEntries); + } + result.add(new InetSocketAddress(a, port)); + } + } + + if (result != null) { + promise().trySuccess(result); + return true; + } + return false; + } + }; + + ctx.resolve(); } private static String hostname(InetSocketAddress addr) { // InetSocketAddress.getHostString() is available since Java 7. + final String hostname; if (PlatformDependent.javaVersion() < 7) { - return addr.getHostName(); + hostname = addr.getHostName(); } else { - return addr.getHostString(); + hostname = addr.getHostString(); } + + return IDN.toASCII(hostname); + } + + void cache(String hostname, InetAddress address, long originalTtl) { + final int maxTtl = maxTtl(); + if (maxTtl == 0) { + return; + } + + final int ttl = Math.max(minTtl(), (int) Math.min(maxTtl, originalTtl)); + final List entries = cachedEntries(hostname); + final DnsCacheEntry e = new DnsCacheEntry(hostname, address); + + synchronized (entries) { + if (!entries.isEmpty()) { + final DnsCacheEntry firstEntry = entries.get(0); + if (firstEntry.cause() != null) { + assert entries.size() == 1; + firstEntry.cancelExpiration(); + entries.clear(); + } + } + entries.add(e); + } + + scheduleCacheExpiration(entries, e, ttl); + } + + void cache(String hostname, Throwable cause) { + final int negativeTtl = negativeTtl(); + if (negativeTtl == 0) { + return; + } + + final List entries = cachedEntries(hostname); + final DnsCacheEntry e = new DnsCacheEntry(hostname, cause); + + synchronized (entries) { + final int numEntries = entries.size(); + for (int i = 0; i < numEntries; i ++) { + entries.get(i).cancelExpiration(); + } + entries.clear(); + entries.add(e); + } + + scheduleCacheExpiration(entries, e, negativeTtl); + } + + private List cachedEntries(String hostname) { + List oldEntries = resolveCache.get(hostname); + final List entries; + if (oldEntries == null) { + List newEntries = new ArrayList(); + oldEntries = resolveCache.putIfAbsent(hostname, newEntries); + entries = oldEntries != null? oldEntries : newEntries; + } else { + entries = oldEntries; + } + return entries; + } + + private void scheduleCacheExpiration(final List entries, final DnsCacheEntry e, int ttl) { + e.scheduleExpiration( + ch.eventLoop(), + new OneTimeTask() { + @Override + public void run() { + synchronized (entries) { + entries.remove(e); + if (entries.isEmpty()) { + resolveCache.remove(e.hostname()); + } + } + } + }, ttl, TimeUnit.SECONDS); } /** * Sends a DNS query with the specified question. */ public Future> query(DnsQuestion question) { - return query(nameServerAddresses, question); + return query(nextNameServerAddress(), question); } /** @@ -675,132 +918,64 @@ public class DnsNameResolver extends SimpleNameResolver { */ public Future> query( DnsQuestion question, Promise> promise) { - return query(nameServerAddresses, question, promise); + return query(nextNameServerAddress(), question, promise); + } + + private InetSocketAddress nextNameServerAddress() { + final InternalThreadLocalMap tlm = InternalThreadLocalMap.get(); + Iterator i = nameServerAddrIterator.get(tlm); + if (i.hasNext()) { + return i.next(); + } + + // The iterator has reached at its end, create a new iterator. + // We should not reach here if a user created nameServerAddresses via DnsServerAddresses, but just in case .. + i = nameServerAddresses.iterator(); + nameServerAddrIterator.set(tlm, i); + return i.next(); } /** * Sends a DNS query with the specified question using the specified name server list. */ public Future> query( - Iterable nameServerAddresses, DnsQuestion question) { - if (nameServerAddresses == null) { - throw new NullPointerException("nameServerAddresses"); - } - if (question == null) { - throw new NullPointerException("question"); - } + InetSocketAddress nameServerAddr, DnsQuestion question) { - final EventLoop eventLoop = ch.eventLoop(); - final DnsCacheEntry cachedResult = queryCache.get(question); - if (cachedResult != null) { - AddressedEnvelope response = cachedResult.retainedResponse(); - if (response != null) { - return eventLoop.newSucceededFuture(response); - } else { - Throwable cause = cachedResult.cause(); - if (cause != null) { - return eventLoop.newFailedFuture(cause); - } - } - } - return query0( - nameServerAddresses, question, - eventLoop.>newPromise()); + return query0(checkNotNull(nameServerAddr, "nameServerAddr"), + checkNotNull(question, "question"), + ch.eventLoop().>newPromise()); } /** * Sends a DNS query with the specified question using the specified name server list. */ public Future> query( - Iterable nameServerAddresses, DnsQuestion question, + InetSocketAddress nameServerAddr, DnsQuestion question, Promise> promise) { - if (nameServerAddresses == null) { - throw new NullPointerException("nameServerAddresses"); - } - if (question == null) { - throw new NullPointerException("question"); - } - if (promise == null) { - throw new NullPointerException("promise"); - } - - final DnsCacheEntry cachedResult = queryCache.get(question); - if (cachedResult != null) { - AddressedEnvelope response = cachedResult.retainedResponse(); - if (response != null) { - return cast(promise).setSuccess(response); - } else { - Throwable cause = cachedResult.cause(); - if (cause != null) { - return cast(promise).setFailure(cause); - } - } - } - - return query0(nameServerAddresses, question, promise); + return query0(checkNotNull(nameServerAddr, "nameServerAddr"), + checkNotNull(question, "question"), + checkNotNull(promise, "promise")); } private Future> query0( - Iterable nameServerAddresses, DnsQuestion question, + InetSocketAddress nameServerAddr, DnsQuestion question, Promise> promise) { final Promise> castPromise = cast(promise); try { - new DnsQueryContext(this, nameServerAddresses, question, castPromise).query(); + new DnsQueryContext(this, nameServerAddr, question, castPromise).query(); return castPromise; } catch (Exception e) { return castPromise.setFailure(e); } } - void cacheSuccess( - DnsQuestion question, AddressedEnvelope res, long delaySeconds) { - cache(question, new DnsCacheEntry(res), delaySeconds); - } - - void cacheFailure(DnsQuestion question, Throwable cause, long delaySeconds) { - cache(question, new DnsCacheEntry(cause), delaySeconds); - } - - private void cache(final DnsQuestion question, DnsCacheEntry entry, long delaySeconds) { - DnsCacheEntry oldEntry = queryCache.put(question, entry); - if (oldEntry != null) { - oldEntry.release(); - } - - boolean scheduled = false; - try { - entry.scheduleExpiration( - ch.eventLoop(), - new OneTimeTask() { - @Override - public void run() { - clearCache(question); - } - }, - delaySeconds, TimeUnit.SECONDS); - scheduled = true; - } finally { - if (!scheduled) { - // If failed to schedule the expiration task, - // remove the entry from the cache so that it does not leak. - clearCache(question); - entry.release(); - } - } - } - @SuppressWarnings("unchecked") private static Promise> cast(Promise promise) { return (Promise>) promise; } - @SuppressWarnings("unchecked") - private static AddressedEnvelope cast(AddressedEnvelope envelope) { - return (AddressedEnvelope) envelope; - } - private final class DnsResponseHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { @@ -813,78 +988,22 @@ public class DnsNameResolver extends SimpleNameResolver { } final DnsQueryContext qCtx = promises.get(queryId); - if (qCtx == null) { if (logger.isWarnEnabled()) { - logger.warn("Received a DNS response with an unknown ID: {}", queryId); + logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId); } return; } - if (res.count(DnsSection.QUESTION) != 1) { - logger.warn("Received a DNS response with invalid number of questions: {}", res); - return; - } - - final DnsQuestion q = qCtx.question(); - if (!q.equals(res.recordAt(DnsSection.QUESTION))) { - logger.warn("Received a mismatching DNS response: {}", res); - return; - } - - // Cancel the timeout task. - final ScheduledFuture timeoutFuture = qCtx.timeoutFuture(); - if (timeoutFuture != null) { - timeoutFuture.cancel(false); - } - - if (res.code() == DnsResponseCode.NOERROR) { - cache(q, res); - promises.set(queryId, null); - - Promise> qPromise = qCtx.promise(); - if (qPromise.setUncancellable()) { - qPromise.setSuccess(cast(res.retain())); - } - } else { - qCtx.retry(res.sender(), - "response code: " + res.code() + - " with " + res.count(DnsSection.ANSWER) + " answer(s) and " + - res.count(DnsSection.AUTHORITY) + " authority resource(s)"); - } + qCtx.finish(res); } finally { ReferenceCountUtil.safeRelease(msg); } } - private void cache(DnsQuestion question, AddressedEnvelope res) { - final int maxTtl = maxTtl(); - if (maxTtl == 0) { - return; - } - - long ttl = Long.MAX_VALUE; - // Find the smallest TTL value returned by the server. - final DnsResponse resc = res.content(); - final int answerCount = resc.count(DnsSection.ANSWER); - for (int i = 0; i < answerCount; i ++) { - final DnsRecord r = resc.recordAt(DnsSection.ANSWER, i); - final long rTtl = r.timeToLive(); - if (ttl > rTtl) { - ttl = rTtl; - } - } - - // Ensure that the found TTL is between minTtl and maxTtl. - ttl = Math.max(minTtl(), Math.min(maxTtl, ttl)); - - DnsNameResolver.this.cacheSuccess(question, res, ttl); - } - @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.warn("Unexpected exception: ", cause); + logger.warn("{} Unexpected exception: ", ch, cause); } } - } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java index 5b1b8531ba..0b0e542427 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverContext.java @@ -22,6 +22,7 @@ import io.netty.channel.AddressedEnvelope; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.handler.codec.dns.DefaultDnsQuestion; import io.netty.handler.codec.dns.DefaultDnsRecordDecoder; +import io.netty.handler.codec.dns.DnsResponseCode; import io.netty.handler.codec.dns.DnsSection; import io.netty.handler.codec.dns.DnsQuestion; import io.netty.handler.codec.dns.DnsRawRecord; @@ -50,7 +51,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; -final class DnsNameResolverContext { +abstract class DnsNameResolverContext { private static final int INADDRSZ4 = 4; private static final int INADDRSZ6 = 16; @@ -66,9 +67,10 @@ final class DnsNameResolverContext { }; private final DnsNameResolver parent; - private final Promise promise; + private final Iterator nameServerAddrs; + private final Promise promise; private final String hostname; - private final int port; + private final boolean traceEnabled; private final int maxAllowedQueries; private final InternetProtocolFamily[] resolveAddressTypes; @@ -76,22 +78,27 @@ final class DnsNameResolverContext { Collections.newSetFromMap( new IdentityHashMap>, Boolean>()); - private List resolvedAddresses; + private List resolvedEntries; private StringBuilder trace; private int allowedQueries; private boolean triedCNAME; - DnsNameResolverContext(DnsNameResolver parent, String hostname, int port, Promise promise) { + protected DnsNameResolverContext(DnsNameResolver parent, String hostname, Promise promise) { this.parent = parent; this.promise = promise; this.hostname = hostname; - this.port = port; + nameServerAddrs = parent.nameServerAddresses.iterator(); maxAllowedQueries = parent.maxQueriesPerResolve(); resolveAddressTypes = parent.resolveAddressTypesUnsafe(); + traceEnabled = parent.isTraceEnabled(); allowedQueries = maxAllowedQueries; } + protected Promise promise() { + return promise; + } + void resolve() { for (InternetProtocolFamily f: resolveAddressTypes) { final DnsRecordType type; @@ -106,18 +113,18 @@ final class DnsNameResolverContext { throw new Error(); } - query(parent.nameServerAddresses, new DefaultDnsQuestion(hostname, type)); + query(nameServerAddrs.next(), new DefaultDnsQuestion(hostname, type)); } } - private void query(Iterable nameServerAddresses, final DnsQuestion question) { + private void query(InetSocketAddress nameServerAddr, final DnsQuestion question) { if (allowedQueries == 0 || promise.isCancelled()) { return; } allowedQueries --; - final Future> f = parent.query(nameServerAddresses, question); + final Future> f = parent.query(nameServerAddr, question); queriesInProgress.add(f); f.addListener(new FutureListener>() { @@ -133,7 +140,11 @@ final class DnsNameResolverContext { if (future.isSuccess()) { onResponse(question, future.getNow()); } else { - addTrace(future.cause()); + // Server did not respond or I/O error occurred; try again. + if (traceEnabled) { + addTrace(future.cause()); + } + query(nameServerAddrs.next(), question); } } finally { tryToFinishResolve(); @@ -142,16 +153,32 @@ final class DnsNameResolverContext { }); } - void onResponse(final DnsQuestion question, AddressedEnvelope response) { - final DnsRecordType type = question.type(); + void onResponse(final DnsQuestion question, AddressedEnvelope envelope) { try { - if (type == DnsRecordType.A || type == DnsRecordType.AAAA) { - onResponseAorAAAA(type, question, response); - } else if (type == DnsRecordType.CNAME) { - onResponseCNAME(question, response); + final DnsResponse res = envelope.content(); + final DnsResponseCode code = res.code(); + if (code == DnsResponseCode.NOERROR) { + final DnsRecordType type = question.type(); + if (type == DnsRecordType.A || type == DnsRecordType.AAAA) { + onResponseAorAAAA(type, question, envelope); + } else if (type == DnsRecordType.CNAME) { + onResponseCNAME(question, envelope); + } + return; + } + + if (traceEnabled) { + addTrace(envelope.sender(), + "response code: " + code + " with " + res.count(DnsSection.ANSWER) + " answer(s) and " + + res.count(DnsSection.AUTHORITY) + " authority resource(s)"); + } + + // Retry with the next server if the server did not tell us that the domain does not exist. + if (code != DnsResponseCode.NXDOMAIN) { + query(nameServerAddrs.next(), question); } } finally { - ReferenceCountUtil.safeRelease(response); + ReferenceCountUtil.safeRelease(envelope); } } @@ -203,24 +230,33 @@ final class DnsNameResolverContext { final byte[] addrBytes = new byte[contentLen]; content.getBytes(content.readerIndex(), addrBytes); + final InetAddress resolved; try { - InetAddress resolved = InetAddress.getByAddress(hostname, addrBytes); - if (resolvedAddresses == null) { - resolvedAddresses = new ArrayList(); - } - resolvedAddresses.add(resolved); - found = true; + resolved = InetAddress.getByAddress(hostname, addrBytes); } catch (UnknownHostException e) { // Should never reach here. throw new Error(e); } + + if (resolvedEntries == null) { + resolvedEntries = new ArrayList(8); + } + + final DnsCacheEntry e = new DnsCacheEntry(hostname, resolved); + parent.cache(hostname, resolved, r.timeToLive()); + resolvedEntries.add(e); + found = true; + + // Note that we do not break from the loop here, so we decode/cache all A/AAAA records. } if (found) { return; } - addTrace(envelope.sender(), "no matching " + qType + " record found"); + if (traceEnabled) { + addTrace(envelope.sender(), "no matching " + qType + " record found"); + } // We aked for A/AAAA but we got only CNAME. if (!cnames.isEmpty()) { @@ -252,7 +288,7 @@ final class DnsNameResolverContext { if (found) { followCname(response.sender(), name, resolved); - } else if (trace) { + } else if (trace && traceEnabled) { addTrace(response.sender(), "no matching CNAME record found"); } } @@ -300,12 +336,12 @@ final class DnsNameResolverContext { } // There are no queries left to try. - if (resolvedAddresses == null) { + if (resolvedEntries == null) { // .. and we could not find any A/AAAA records. if (!triedCNAME) { // As the last resort, try to query CNAME, just in case the name server has it. triedCNAME = true; - query(parent.nameServerAddresses, new DefaultDnsQuestion(hostname, DnsRecordType.CNAME)); + query(nameServerAddrs.next(), new DefaultDnsQuestion(hostname, DnsRecordType.CNAME)); return; } } @@ -315,22 +351,22 @@ final class DnsNameResolverContext { } private boolean gotPreferredAddress() { - if (resolvedAddresses == null) { + if (resolvedEntries == null) { return false; } - final int size = resolvedAddresses.size(); + final int size = resolvedEntries.size(); switch (resolveAddressTypes[0]) { case IPv4: for (int i = 0; i < size; i ++) { - if (resolvedAddresses.get(i) instanceof Inet4Address) { + if (resolvedEntries.get(i).address() instanceof Inet4Address) { return true; } } break; case IPv6: for (int i = 0; i < size; i ++) { - if (resolvedAddresses.get(i) instanceof Inet6Address) { + if (resolvedEntries.get(i).address() instanceof Inet6Address) { return true; } } @@ -354,67 +390,46 @@ final class DnsNameResolverContext { } } - if (resolvedAddresses != null) { + if (resolvedEntries != null) { // Found at least one resolved address. for (InternetProtocolFamily f: resolveAddressTypes) { - switch (f) { - case IPv4: - if (finishResolveWithIPv4()) { - return; - } - break; - case IPv6: - if (finishResolveWithIPv6()) { - return; - } - break; + if (finishResolve(f.addressType(), resolvedEntries)) { + return; } } } // No resolved address found. - int tries = maxAllowedQueries - allowedQueries; - UnknownHostException cause; + final int tries = maxAllowedQueries - allowedQueries; + final StringBuilder buf = new StringBuilder(64); + + buf.append("failed to resolve "); + buf.append(hostname); + if (tries > 1) { - cause = new UnknownHostException( - "failed to resolve " + hostname + " after " + tries + " queries:" + - trace); + buf.append(" after "); + buf.append(tries); + if (trace != null) { + buf.append(" queries:"); + buf.append(trace); + } else { + buf.append(" queries"); + } } else { - cause = new UnknownHostException("failed to resolve " + hostname + ':' + trace); + if (trace != null) { + buf.append(':'); + buf.append(trace); + } } + final UnknownHostException cause = new UnknownHostException(buf.toString()); + + parent.cache(hostname, cause); promise.tryFailure(cause); } - private boolean finishResolveWithIPv4() { - final List resolvedAddresses = this.resolvedAddresses; - final int size = resolvedAddresses.size(); - - for (int i = 0; i < size; i ++) { - InetAddress a = resolvedAddresses.get(i); - if (a instanceof Inet4Address) { - promise.trySuccess(new InetSocketAddress(a, port)); - return true; - } - } - - return false; - } - - private boolean finishResolveWithIPv6() { - final List resolvedAddresses = this.resolvedAddresses; - final int size = resolvedAddresses.size(); - - for (int i = 0; i < size; i ++) { - InetAddress a = resolvedAddresses.get(i); - if (a instanceof Inet6Address) { - promise.trySuccess(new InetSocketAddress(a, port)); - return true; - } - } - - return false; - } + protected abstract boolean finishResolve( + Class addressType, List resolvedEntries); /** * Adapted from {@link DefaultDnsRecordDecoder#decodeName(ByteBuf)}. @@ -466,26 +481,30 @@ final class DnsNameResolverContext { } } - private void followCname( - InetSocketAddress nameServerAddr, String name, String cname) { + private void followCname(InetSocketAddress nameServerAddr, String name, String cname) { - if (trace == null) { - trace = new StringBuilder(128); + if (traceEnabled) { + if (trace == null) { + trace = new StringBuilder(128); + } + + trace.append(StringUtil.NEWLINE); + trace.append("\tfrom "); + trace.append(nameServerAddr); + trace.append(": "); + trace.append(name); + trace.append(" CNAME "); + trace.append(cname); } - trace.append(StringUtil.NEWLINE); - trace.append("\tfrom "); - trace.append(nameServerAddr); - trace.append(": "); - trace.append(name); - trace.append(" CNAME "); - trace.append(cname); - - query(parent.nameServerAddresses, new DefaultDnsQuestion(cname, DnsRecordType.A)); - query(parent.nameServerAddresses, new DefaultDnsQuestion(cname, DnsRecordType.AAAA)); + final InetSocketAddress nextAddr = nameServerAddrs.next(); + query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.A)); + query(nextAddr, new DefaultDnsQuestion(cname, DnsRecordType.AAAA)); } private void addTrace(InetSocketAddress nameServerAddr, String msg) { + assert traceEnabled; + if (trace == null) { trace = new StringBuilder(128); } @@ -498,6 +517,8 @@ final class DnsNameResolverContext { } private void addTrace(Throwable cause) { + assert traceEnabled; + if (trace == null) { trace = new StringBuilder(128); } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverException.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverException.java new file mode 100644 index 0000000000..3393ac2b1f --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverException.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.handler.codec.dns.DnsQuestion; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; + +import java.net.InetSocketAddress; + +/** + * A {@link RuntimeException} raised when {@link DnsNameResolver} failed to perform a successful query. + */ +public final class DnsNameResolverException extends RuntimeException { + + private static final long serialVersionUID = -8826717909627131850L; + + private final InetSocketAddress remoteAddress; + private final DnsQuestion question; + + public DnsNameResolverException(InetSocketAddress remoteAddress, DnsQuestion question) { + this.remoteAddress = validateRemoteAddress(remoteAddress); + this.question = validateQuestion(question); + } + + public DnsNameResolverException(InetSocketAddress remoteAddress, DnsQuestion question, String message) { + super(message); + this.remoteAddress = validateRemoteAddress(remoteAddress); + this.question = validateQuestion(question); + } + + public DnsNameResolverException( + InetSocketAddress remoteAddress, DnsQuestion question, String message, Throwable cause) { + super(message, cause); + this.remoteAddress = validateRemoteAddress(remoteAddress); + this.question = validateQuestion(question); + } + + public DnsNameResolverException(InetSocketAddress remoteAddress, DnsQuestion question, Throwable cause) { + super(cause); + this.remoteAddress = validateRemoteAddress(remoteAddress); + this.question = validateQuestion(question); + } + + private static InetSocketAddress validateRemoteAddress(InetSocketAddress remoteAddress) { + return ObjectUtil.checkNotNull(remoteAddress, "remoteAddress"); + } + + private static DnsQuestion validateQuestion(DnsQuestion question) { + return ObjectUtil.checkNotNull(question, "question"); + } + + /** + * Returns the {@link DnsQuestion} of the DNS query that has failed. + */ + public DnsQuestion question() { + return question; + } + + @Override + public Throwable fillInStackTrace() { + setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); + return this; + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index 8cc10b3b9f..f7ab7696ff 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -13,7 +13,6 @@ * License for the specific language governing permissions and limitations * under the License. */ - package io.netty.resolver.dns; import io.netty.buffer.Unpooled; @@ -22,12 +21,12 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.handler.codec.dns.DatagramDnsQuery; import io.netty.handler.codec.dns.DefaultDnsRawRecord; -import io.netty.handler.codec.dns.DnsSection; import io.netty.handler.codec.dns.DnsQuery; import io.netty.handler.codec.dns.DnsQuestion; import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsResponse; +import io.netty.handler.codec.dns.DnsSection; import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.internal.OneTimeTask; @@ -37,8 +36,6 @@ import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.net.InetSocketAddress; -import java.net.UnknownHostException; -import java.util.Iterator; import java.util.concurrent.TimeUnit; final class DnsQueryContext { @@ -50,30 +47,24 @@ final class DnsQueryContext { private final int id; private final DnsQuestion question; private final DnsRecord optResource; - private final Iterator nameServerAddresses; + private final InetSocketAddress nameServerAddr; private final boolean recursionDesired; - private final int maxTries; - private int remainingTries; private volatile ScheduledFuture timeoutFuture; - private StringBuilder trace; DnsQueryContext(DnsNameResolver parent, - Iterable nameServerAddresses, + InetSocketAddress nameServerAddr, DnsQuestion question, Promise> promise) { this.parent = parent; - this.promise = promise; + this.nameServerAddr = nameServerAddr; this.question = question; + this.promise = promise; id = allocateId(); recursionDesired = parent.isRecursionDesired(); - maxTries = parent.maxTriesPerQuery(); - remainingTries = maxTries; optResource = new DefaultDnsRawRecord( StringUtil.EMPTY_STRING, DnsRecordType.OPT, parent.maxPayloadSize(), 0, Unpooled.EMPTY_BUFFER); - - this.nameServerAddresses = nameServerAddresses.iterator(); } private int allocateId() { @@ -93,42 +84,8 @@ final class DnsQueryContext { } } - Promise> promise() { - return promise; - } - - DnsQuestion question() { - return question; - } - - ScheduledFuture timeoutFuture() { - return timeoutFuture; - } - void query() { final DnsQuestion question = this.question; - - if (remainingTries <= 0 || !nameServerAddresses.hasNext()) { - parent.promises.set(id, null); - - int tries = maxTries - remainingTries; - UnknownHostException cause; - if (tries > 1) { - cause = new UnknownHostException( - "failed to resolve " + question + " after " + tries + " attempts:" + - trace); - } else { - cause = new UnknownHostException("failed to resolve " + question + ':' + trace); - } - - cache(question, cause); - promise.tryFailure(cause); - return; - } - - remainingTries --; - - final InetSocketAddress nameServerAddr = nameServerAddresses.next(); final DatagramDnsQuery query = new DatagramDnsQuery(null, nameServerAddr, id); query.setRecursionDesired(recursionDesired); query.setRecord(DnsSection.QUESTION, question); @@ -138,43 +95,43 @@ final class DnsQueryContext { logger.debug("{} WRITE: [{}: {}], {}", parent.ch, id, nameServerAddr, question); } - sendQuery(query, nameServerAddr); + sendQuery(query); } - private void sendQuery(final DnsQuery query, final InetSocketAddress nameServerAddr) { + private void sendQuery(final DnsQuery query) { if (parent.bindFuture.isDone()) { - writeQuery(query, nameServerAddr); + writeQuery(query); } else { parent.bindFuture.addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { - writeQuery(query, nameServerAddr); + writeQuery(query); } else { promise.tryFailure(future.cause()); } - } - }); - } - } - - private void writeQuery(final DnsQuery query, final InetSocketAddress nameServerAddr) { - final ChannelFuture writeFuture = parent.ch.writeAndFlush(query); - if (writeFuture.isDone()) { - onQueryWriteCompletion(writeFuture, nameServerAddr); - } else { - writeFuture.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - onQueryWriteCompletion(writeFuture, nameServerAddr); } }); } } - private void onQueryWriteCompletion(ChannelFuture writeFuture, final InetSocketAddress nameServerAddr) { + private void writeQuery(final DnsQuery query) { + final ChannelFuture writeFuture = parent.ch.writeAndFlush(query); + if (writeFuture.isDone()) { + onQueryWriteCompletion(writeFuture); + } else { + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + onQueryWriteCompletion(writeFuture); + } + }); + } + } + + private void onQueryWriteCompletion(ChannelFuture writeFuture) { if (!writeFuture.isSuccess()) { - retry(nameServerAddr, "failed to send a query: " + writeFuture.cause()); + setFailure("failed to send a query", writeFuture.cause()); return; } @@ -189,35 +146,62 @@ final class DnsQueryContext { return; } - retry(nameServerAddr, "query timed out after " + queryTimeoutMillis + " milliseconds"); + setFailure("query timed out after " + queryTimeoutMillis + " milliseconds", null); } }, queryTimeoutMillis, TimeUnit.MILLISECONDS); } } - void retry(InetSocketAddress nameServerAddr, String message) { - if (promise.isCancelled()) { + void finish(AddressedEnvelope envelope) { + DnsResponse res = envelope.content(); + if (res.count(DnsSection.QUESTION) != 1) { + logger.warn("Received a DNS response with invalid number of questions: {}", envelope); return; } - if (trace == null) { - trace = new StringBuilder(128); + if (!question.equals(res.recordAt(DnsSection.QUESTION))) { + logger.warn("Received a mismatching DNS response: {}", envelope); + return; } - trace.append(StringUtil.NEWLINE); - trace.append("\tfrom "); - trace.append(nameServerAddr); - trace.append(": "); - trace.append(message); - query(); + setSuccess(envelope); } - private void cache(final DnsQuestion question, Throwable cause) { - final int negativeTtl = parent.negativeTtl(); - if (negativeTtl == 0) { - return; + private void setSuccess(AddressedEnvelope envelope) { + parent.promises.set(id, null); + + // Cancel the timeout task. + final ScheduledFuture timeoutFuture = this.timeoutFuture; + if (timeoutFuture != null) { + timeoutFuture.cancel(false); } - parent.cacheFailure(question, cause, negativeTtl); + Promise> promise = this.promise; + if (promise.setUncancellable()) { + @SuppressWarnings("unchecked") + AddressedEnvelope castResponse = + (AddressedEnvelope) envelope.retain(); + promise.setSuccess(castResponse); + } + } + + private void setFailure(String message, Throwable cause) { + parent.promises.set(id, null); + + final StringBuilder buf = new StringBuilder(message.length() + 64); + buf.append('[') + .append(nameServerAddr) + .append("] ") + .append(message) + .append(" (no stack trace available)"); + + final DnsNameResolverException e; + if (cause != null) { + e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause); + } else { + e = new DnsNameResolverException(nameServerAddr, question, buf.toString()); + } + + promise.tryFailure(e); } } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index 6cbf8e70d1..58c658c308 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -247,7 +247,7 @@ public class DnsNameResolverTest { group.next(), NioDatagramChannel.class, DnsServerAddresses.shuffled(SERVERS)); static { - resolver.setMaxTriesPerQuery(SERVERS.size()); + resolver.setMaxQueriesPerResolve(SERVERS.size()); } @AfterClass @@ -293,6 +293,10 @@ public class DnsNameResolverTest { for (Entry e: resultA.entrySet()) { InetAddress expected = e.getValue(); InetAddress actual = resultB.get(e.getKey()); + if (!actual.equals(expected)) { + // Print the content of the cache when test failure is expected. + System.err.println("Cache for " + e.getKey() + ": " + resolver.resolveAll(e.getKey(), 0).getNow()); + } assertThat(actual, is(expected)); } } finally { @@ -342,17 +346,8 @@ public class DnsNameResolverTest { boolean typeMatches = false; for (InternetProtocolFamily f: famililies) { Class resolvedType = resolved.getAddress().getClass(); - switch (f) { - case IPv4: - if (Inet4Address.class.isAssignableFrom(resolvedType)) { - typeMatches = true; - } - break; - case IPv6: - if (Inet6Address.class.isAssignableFrom(resolvedType)) { - typeMatches = true; - } - break; + if (f.addressType().isAssignableFrom(resolvedType)) { + typeMatches = true; } } @@ -383,9 +378,18 @@ public class DnsNameResolverTest { for (Entry>> e: futures.entrySet()) { String hostname = e.getKey(); - AddressedEnvelope envelope = e.getValue().sync().getNow(); - DnsResponse response = envelope.content(); + Future> f = e.getValue().awaitUninterruptibly(); + if (!f.isSuccess()) { + // Try again a couple more times because the DNS servers might be throttling us down. + for (int i = 0; i < 2; i++) { + f = queryMx(hostname).awaitUninterruptibly(); + if (f.isSuccess()) { + break; + } + } + } + DnsResponse response = f.getNow().content(); assertThat(response.code(), is(DnsResponseCode.NOERROR)); final int answerCount = response.count(DnsSection.ANSWER); @@ -441,4 +445,8 @@ public class DnsNameResolverTest { String hostname) throws Exception { futures.put(hostname, resolver.query(new DefaultDnsQuestion(hostname, DnsRecordType.MX))); } + + private static Future> queryMx(String hostname) throws Exception { + return resolver.query(new DefaultDnsQuestion(hostname, DnsRecordType.MX)); + } } diff --git a/resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java b/resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java index fde957c95c..a34d86918b 100644 --- a/resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java +++ b/resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java @@ -22,6 +22,8 @@ import io.netty.util.concurrent.Promise; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; /** * A {@link NameResolver} that resolves an {@link InetSocketAddress} using JDK's built-in domain name lookup mechanism. @@ -49,4 +51,22 @@ public class DefaultNameResolver extends SimpleNameResolver { promise.setFailure(e); } } + + @Override + protected void doResolveAll( + InetSocketAddress unresolvedAddress, Promise> promise) throws Exception { + + try { + // Note that InetSocketAddress.getHostName() will never incur a reverse lookup here, + // because an unresolved address always has a host name. + final InetAddress[] resolved = InetAddress.getAllByName(unresolvedAddress.getHostName()); + final List result = new ArrayList(resolved.length); + for (InetAddress a: resolved) { + result.add(new InetSocketAddress(a, unresolvedAddress.getPort())); + } + promise.setSuccess(result); + } catch (UnknownHostException e) { + promise.setFailure(e); + } + } } diff --git a/resolver/src/main/java/io/netty/resolver/NameResolver.java b/resolver/src/main/java/io/netty/resolver/NameResolver.java index 7760d1f460..4938a43815 100644 --- a/resolver/src/main/java/io/netty/resolver/NameResolver.java +++ b/resolver/src/main/java/io/netty/resolver/NameResolver.java @@ -22,6 +22,7 @@ import io.netty.util.concurrent.Promise; import java.io.Closeable; import java.net.SocketAddress; import java.nio.channels.UnsupportedAddressTypeException; +import java.util.List; /** * Resolves an arbitrary string that represents the name of an endpoint into a {@link SocketAddress}. @@ -82,6 +83,48 @@ public interface NameResolver extends Closeable { */ Future resolve(SocketAddress address, Promise promise); + /** + * Resolves the specified host name and port into a list of {@link SocketAddress}es. + * + * @param inetHost the name to resolve + * @param inetPort the port number + * + * @return the list of the {@link SocketAddress}es as the result of the resolution + */ + Future> resolveAll(String inetHost, int inetPort); + + /** + * Resolves the specified host name and port into a list of {@link SocketAddress}es. + * + * @param inetHost the name to resolve + * @param inetPort the port number + * @param promise the {@link Promise} which will be fulfilled when the name resolution is finished + * + * @return the list of the {@link SocketAddress}es as the result of the resolution + */ + Future> resolveAll(String inetHost, int inetPort, Promise> promise); + + /** + * Resolves the specified address. If the specified address is resolved already, this method does nothing + * but returning the original address. + * + * @param address the address to resolve + * + * @return the list of the {@link SocketAddress}es as the result of the resolution + */ + Future> resolveAll(SocketAddress address); + + /** + * Resolves the specified address. If the specified address is resolved already, this method does nothing + * but returning the original address. + * + * @param address the address to resolve + * @param promise the {@link Promise} which will be fulfilled when the name resolution is finished + * + * @return the list of the {@link SocketAddress}es as the result of the resolution + */ + Future> resolveAll(SocketAddress address, Promise> promise); + /** * Closes all the resources allocated and used by this resolver. */ diff --git a/resolver/src/main/java/io/netty/resolver/NoopNameResolver.java b/resolver/src/main/java/io/netty/resolver/NoopNameResolver.java index 8605aae894..8d5999727d 100644 --- a/resolver/src/main/java/io/netty/resolver/NoopNameResolver.java +++ b/resolver/src/main/java/io/netty/resolver/NoopNameResolver.java @@ -20,6 +20,8 @@ import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Promise; import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; /** * A {@link NameResolver} that does not perform any resolution but always reports successful resolution. @@ -40,4 +42,10 @@ public class NoopNameResolver extends SimpleNameResolver { protected void doResolve(SocketAddress unresolvedAddress, Promise promise) throws Exception { promise.setSuccess(unresolvedAddress); } + + @Override + protected void doResolveAll( + SocketAddress unresolvedAddress, Promise> promise) throws Exception { + promise.setSuccess(Collections.singletonList(unresolvedAddress)); + } } diff --git a/resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java b/resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java index 9bc5329bc6..3c48d1e3d6 100644 --- a/resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java +++ b/resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java @@ -24,6 +24,10 @@ import io.netty.util.internal.TypeParameterMatcher; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.channels.UnsupportedAddressTypeException; +import java.util.Collections; +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.*; /** * A skeletal {@link NameResolver} implementation. @@ -92,29 +96,17 @@ public abstract class SimpleNameResolver implements Nam @Override public final Future resolve(String inetHost, int inetPort) { - if (inetHost == null) { - throw new NullPointerException("inetHost"); - } - - return resolve(InetSocketAddress.createUnresolved(inetHost, inetPort)); + return resolve(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort)); } @Override public Future resolve(String inetHost, int inetPort, Promise promise) { - if (inetHost == null) { - throw new NullPointerException("inetHost"); - } - - return resolve(InetSocketAddress.createUnresolved(inetHost, inetPort), promise); + return resolve(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort), promise); } @Override public final Future resolve(SocketAddress address) { - if (address == null) { - throw new NullPointerException("unresolvedAddress"); - } - - if (!isSupported(address)) { + if (!isSupported(checkNotNull(address, "address"))) { // Address type not supported by the resolver return executor().newFailedFuture(new UnsupportedAddressTypeException()); } @@ -139,12 +131,8 @@ public abstract class SimpleNameResolver implements Nam @Override public final Future resolve(SocketAddress address, Promise promise) { - if (address == null) { - throw new NullPointerException("unresolvedAddress"); - } - if (promise == null) { - throw new NullPointerException("promise"); - } + checkNotNull(address, "address"); + checkNotNull(promise, "promise"); if (!isSupported(address)) { // Address type not supported by the resolver @@ -168,12 +156,80 @@ public abstract class SimpleNameResolver implements Nam } } + @Override + public final Future> resolveAll(String inetHost, int inetPort) { + return resolveAll(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort)); + } + + @Override + public Future> resolveAll(String inetHost, int inetPort, Promise> promise) { + return resolveAll(InetSocketAddress.createUnresolved(checkNotNull(inetHost, "inetHost"), inetPort), promise); + } + + @Override + public final Future> resolveAll(SocketAddress address) { + if (!isSupported(checkNotNull(address, "address"))) { + // Address type not supported by the resolver + return executor().newFailedFuture(new UnsupportedAddressTypeException()); + } + + if (isResolved(address)) { + // Resolved already; no need to perform a lookup + @SuppressWarnings("unchecked") + final T cast = (T) address; + return executor.newSucceededFuture(Collections.singletonList(cast)); + } + + try { + @SuppressWarnings("unchecked") + final T cast = (T) address; + final Promise> promise = executor().newPromise(); + doResolveAll(cast, promise); + return promise; + } catch (Exception e) { + return executor().newFailedFuture(e); + } + } + + @Override + public final Future> resolveAll(SocketAddress address, Promise> promise) { + checkNotNull(address, "address"); + checkNotNull(promise, "promise"); + + if (!isSupported(address)) { + // Address type not supported by the resolver + return promise.setFailure(new UnsupportedAddressTypeException()); + } + + if (isResolved(address)) { + // Resolved already; no need to perform a lookup + @SuppressWarnings("unchecked") + final T cast = (T) address; + return promise.setSuccess(Collections.singletonList(cast)); + } + + try { + @SuppressWarnings("unchecked") + final T cast = (T) address; + doResolveAll(cast, promise); + return promise; + } catch (Exception e) { + return promise.setFailure(e); + } + } + /** * Invoked by {@link #resolve(SocketAddress)} and {@link #resolve(String, int)} to perform the actual name * resolution. */ protected abstract void doResolve(T unresolvedAddress, Promise promise) throws Exception; + /** + * Invoked by {@link #resolveAll(SocketAddress)} and {@link #resolveAll(String, int)} to perform the actual name + * resolution. + */ + protected abstract void doResolveAll(T unresolvedAddress, Promise> promise) throws Exception; + @Override public void close() { } } diff --git a/transport/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java b/transport/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java index d276abc095..1caeb5e8a2 100644 --- a/transport/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java +++ b/transport/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java @@ -15,10 +15,27 @@ */ package io.netty.channel.socket; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; + /** * Internet Protocol (IP) families used byte the {@link DatagramChannel} */ public enum InternetProtocolFamily { - IPv4, - IPv6 + IPv4(Inet4Address.class), + IPv6(Inet6Address.class); + + private final Class addressType; + + InternetProtocolFamily(Class addressType) { + this.addressType = addressType; + } + + /** + * Returns the address type of this protocol family. + */ + public Class addressType() { + return addressType; + } } diff --git a/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java b/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java index 39d5864221..ede4be1fa2 100644 --- a/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java +++ b/transport/src/test/java/io/netty/bootstrap/BootstrapTest.java @@ -36,6 +36,7 @@ import io.netty.resolver.SimpleNameResolver; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; +import io.netty.util.internal.OneTimeTask; import org.junit.AfterClass; import org.junit.Test; @@ -43,6 +44,7 @@ import java.net.SocketAddress; import java.net.SocketException; import java.net.UnknownHostException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; @@ -302,13 +304,29 @@ public class BootstrapTest { @Override protected void doResolve( final SocketAddress unresolvedAddress, final Promise promise) { - executor().execute(new Runnable() { + executor().execute(new OneTimeTask() { @Override public void run() { if (success) { promise.setSuccess(unresolvedAddress); } else { - promise.setFailure(new UnknownHostException()); + promise.setFailure(new UnknownHostException(unresolvedAddress.toString())); + } + } + }); + } + + @Override + protected void doResolveAll( + final SocketAddress unresolvedAddress, final Promise> promise) + throws Exception { + executor().execute(new OneTimeTask() { + @Override + public void run() { + if (success) { + promise.setSuccess(Collections.singletonList(unresolvedAddress)); + } else { + promise.setFailure(new UnknownHostException(unresolvedAddress.toString())); } } });