diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index c69dfb52ab..7ba7149a9f 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -37,7 +37,6 @@ import io.netty.resolver.NameResolver; import io.netty.resolver.SimpleNameResolver; import io.netty.util.NetUtil; import io.netty.util.ReferenceCountUtil; -import io.netty.util.collection.IntObjectHashMap; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; @@ -59,9 +58,8 @@ import java.util.List; import java.util.Map.Entry; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReferenceArray; -import static io.netty.util.internal.ObjectUtil.*; +import static io.netty.util.internal.ObjectUtil.checkNotNull; /** * A DNS-based {@link NameResolver}. @@ -95,11 +93,9 @@ public class DnsNameResolver extends SimpleNameResolver { final DatagramChannel ch; /** - * An array whose index is the ID of a DNS query and whose value is the promise of the corresponsing response. We - * don't use {@link IntObjectHashMap} or map-like data structure here because 64k elements are fairly small, which - * is only about 512KB. + * Manages the {@link DnsQueryContext}s in progress and their query IDs. */ - final AtomicReferenceArray promises = new AtomicReferenceArray(65536); + final DnsQueryContextManager queryContextManager = new DnsQueryContextManager(); /** * Cache for {@link #doResolve(InetSocketAddress, Promise)} and {@link #doResolveAll(InetSocketAddress, Promise)}. @@ -937,7 +933,7 @@ public class DnsNameResolver extends SimpleNameResolver { logger.debug("{} RECEIVED: [{}: {}], {}", ch, queryId, res.sender(), res); } - final DnsQueryContext qCtx = promises.get(queryId); + final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId); if (qCtx == null) { if (logger.isWarnEnabled()) { logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId); diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index 1d1d96c35c..8217c82ea6 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -31,7 +31,6 @@ import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.StringUtil; -import io.netty.util.internal.ThreadLocalRandom; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; @@ -60,9 +59,9 @@ final class DnsQueryContext { this.nameServerAddr = nameServerAddr; this.question = question; this.promise = promise; - - id = allocateId(); recursionDesired = parent.isRecursionDesired(); + id = parent.queryContextManager.add(this); + if (parent.isOptResourceEnabled()) { optResource = new DefaultDnsRawRecord( StringUtil.EMPTY_STRING, DnsRecordType.OPT, parent.maxPayloadSize(), 0, Unpooled.EMPTY_BUFFER); @@ -71,25 +70,17 @@ final class DnsQueryContext { } } - private int allocateId() { - int id = ThreadLocalRandom.current().nextInt(parent.promises.length()); - final int maxTries = parent.promises.length() << 1; - int tries = 0; - for (;;) { - if (parent.promises.compareAndSet(id, null, this)) { - return id; - } + InetSocketAddress nameServerAddr() { + return nameServerAddr; + } - id = id + 1 & 0xFFFF; - - if (++ tries >= maxTries) { - throw new IllegalStateException("query ID space exhausted: " + question); - } - } + DnsQuestion question() { + return question; } void query() { - final DnsQuestion question = this.question; + final DnsQuestion question = question(); + final InetSocketAddress nameServerAddr = nameServerAddr(); final DatagramDnsQuery query = new DatagramDnsQuery(null, nameServerAddr, id); query.setRecursionDesired(recursionDesired); query.setRecord(DnsSection.QUESTION, question); @@ -159,13 +150,13 @@ final class DnsQueryContext { } void finish(AddressedEnvelope envelope) { - DnsResponse res = envelope.content(); + final DnsResponse res = envelope.content(); if (res.count(DnsSection.QUESTION) != 1) { logger.warn("Received a DNS response with invalid number of questions: {}", envelope); return; } - if (!question.equals(res.recordAt(DnsSection.QUESTION))) { + if (!question().equals(res.recordAt(DnsSection.QUESTION))) { logger.warn("Received a mismatching DNS response: {}", envelope); return; } @@ -174,7 +165,7 @@ final class DnsQueryContext { } private void setSuccess(AddressedEnvelope envelope) { - parent.promises.set(id, null); + parent.queryContextManager.remove(nameServerAddr(), id); // Cancel the timeout task. final ScheduledFuture timeoutFuture = this.timeoutFuture; @@ -192,7 +183,8 @@ final class DnsQueryContext { } private void setFailure(String message, Throwable cause) { - parent.promises.set(id, null); + final InetSocketAddress nameServerAddr = nameServerAddr(); + parent.queryContextManager.remove(nameServerAddr, id); final StringBuilder buf = new StringBuilder(message.length() + 64); buf.append('[') @@ -203,9 +195,9 @@ final class DnsQueryContext { final DnsNameResolverException e; if (cause != null) { - e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause); + e = new DnsNameResolverException(nameServerAddr, question(), buf.toString(), cause); } else { - e = new DnsNameResolverException(nameServerAddr, question, buf.toString()); + e = new DnsNameResolverException(nameServerAddr, question(), buf.toString()); } promise.tryFailure(e); diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContextManager.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContextManager.java new file mode 100644 index 0000000000..9c3946c72f --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContextManager.java @@ -0,0 +1,148 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver.dns; + +import io.netty.util.NetUtil; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.netty.util.internal.ThreadLocalRandom; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.HashMap; +import java.util.Map; + +final class DnsQueryContextManager { + + /** + * A map whose key is the DNS server address and value is the map of the DNS query ID and its corresponding + * {@link DnsQueryContext}. + */ + final Map> map = + new HashMap>(); + + int add(DnsQueryContext qCtx) { + final IntObjectMap contexts = getOrCreateContextMap(qCtx.nameServerAddr()); + + int id = ThreadLocalRandom.current().nextInt(1, 65536); + final int maxTries = 65535 << 1; + int tries = 0; + + synchronized (contexts) { + for (;;) { + if (!contexts.containsKey(id)) { + contexts.put(id, qCtx); + return id; + } + + id = id + 1 & 0xFFFF; + + if (++tries >= maxTries) { + throw new IllegalStateException("query ID space exhausted: " + qCtx.question()); + } + } + } + } + + DnsQueryContext get(InetSocketAddress nameServerAddr, int id) { + final IntObjectMap contexts = getContextMap(nameServerAddr); + final DnsQueryContext qCtx; + if (contexts != null) { + synchronized (contexts) { + qCtx = contexts.get(id); + } + } else { + qCtx = null; + } + + return qCtx; + } + + DnsQueryContext remove(InetSocketAddress nameServerAddr, int id) { + final IntObjectMap contexts = getContextMap(nameServerAddr); + if (contexts == null) { + return null; + } + + synchronized (contexts) { + return contexts.remove(id); + } + } + + private IntObjectMap getContextMap(InetSocketAddress nameServerAddr) { + synchronized (map) { + return map.get(nameServerAddr); + } + } + + private IntObjectMap getOrCreateContextMap(InetSocketAddress nameServerAddr) { + synchronized (map) { + final IntObjectMap contexts = map.get(nameServerAddr); + if (contexts != null) { + return contexts; + } + + final IntObjectMap newContexts = new IntObjectHashMap(); + final InetAddress a = nameServerAddr.getAddress(); + final int port = nameServerAddr.getPort(); + map.put(nameServerAddr, newContexts); + + if (a instanceof Inet4Address) { + // Also add the mapping for the IPv4-compatible IPv6 address. + final Inet4Address a4 = (Inet4Address) a; + if (a4.isLoopbackAddress()) { + map.put(new InetSocketAddress(NetUtil.LOCALHOST6, port), newContexts); + } else { + map.put(new InetSocketAddress(toCompatAddress(a4), port), newContexts); + } + } else if (a instanceof Inet6Address) { + // Also add the mapping for the IPv4 address if this IPv6 address is compatible. + final Inet6Address a6 = (Inet6Address) a; + if (a6.isLoopbackAddress()) { + map.put(new InetSocketAddress(NetUtil.LOCALHOST4, port), newContexts); + } else if (a6.isIPv4CompatibleAddress()) { + map.put(new InetSocketAddress(toIPv4Address(a6), port), newContexts); + } + } + + return newContexts; + } + } + + private static Inet6Address toCompatAddress(Inet4Address a4) { + byte[] b4 = a4.getAddress(); + byte[] b6 = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, b4[0], b4[1], b4[2], b4[3] }; + try { + return (Inet6Address) InetAddress.getByAddress(b6); + } catch (UnknownHostException e) { + throw new Error(e); + } + } + + private static Inet4Address toIPv4Address(Inet6Address a6) { + byte[] b6 = a6.getAddress(); + byte[] b4 = { b6[12], b6[13], b6[14], b6[15] }; + try { + return (Inet4Address) InetAddress.getByAddress(b4); + } catch (UnknownHostException e) { + throw new Error(e); + } + } +}