From e305488c5ab9eb29abfb65ca103c55a4680f4c72 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Tue, 16 Jan 2018 12:35:34 +0100 Subject: [PATCH] Fix race-condition when using DnsCache in DnsNameResolver Motivation: The usage of DnsCache in DnsNameResolver was racy in general. First of the isEmpty() was not called in a synchronized block while we depended on synchronized. The other problem was that this whole synchronization only worked if the DefaultDnsCache was used and the returned List was not wrapped by the user. Modifications: - Rewrite DefaultDnsCache to not depend on synchronization on the returned List by using a CoW approach. Result: Fixes [#7583] and other races. --- .../netty/resolver/dns/DefaultDnsCache.java | 195 +++++++++++------- .../netty/resolver/dns/DnsNameResolver.java | 83 +++----- 2 files changed, 155 insertions(+), 123 deletions(-) diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java index f2acd3fac7..2b3477bd38 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DefaultDnsCache.java @@ -23,11 +23,13 @@ import io.netty.util.internal.UnstableApi; import java.net.InetAddress; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; @@ -39,8 +41,8 @@ import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; @UnstableApi public class DefaultDnsCache implements DnsCache { - private final ConcurrentMap> resolveCache = - PlatformDependent.newConcurrentHashMap(); + private final ConcurrentMap resolveCache = PlatformDependent.newConcurrentHashMap(); + private final int minTtl; private final int maxTtl; private final int negativeTtl; @@ -97,28 +99,21 @@ public class DefaultDnsCache implements DnsCache { @Override public void clear() { - for (Iterator>> i = resolveCache.entrySet().iterator(); - i.hasNext();) { - final Map.Entry> e = i.next(); - i.remove(); - cancelExpiration(e.getValue()); + while (!resolveCache.isEmpty()) { + for (Iterator> i = resolveCache.entrySet().iterator(); i.hasNext();) { + Map.Entry e = i.next(); + i.remove(); + + e.getValue().cancelExpiration(); + } } } @Override public boolean clear(String hostname) { checkNotNull(hostname, "hostname"); - boolean removed = false; - for (Iterator>> i = resolveCache.entrySet().iterator(); - i.hasNext();) { - final Map.Entry> e = i.next(); - if (e.getKey().equals(hostname)) { - i.remove(); - cancelExpiration(e.getValue()); - removed = true; - } - } - return removed; + Entries entries = resolveCache.remove(hostname); + return entries != null && entries.cancelExpiration(); } private static boolean emptyAdditionals(DnsRecord[] additionals) { @@ -129,22 +124,11 @@ public class DefaultDnsCache implements DnsCache { public List get(String hostname, DnsRecord[] additionals) { checkNotNull(hostname, "hostname"); if (!emptyAdditionals(additionals)) { - return null; + return Collections.emptyList(); } - return resolveCache.get(hostname); - } - private List cachedEntries(String hostname) { - List oldEntries = resolveCache.get(hostname); - final List entries; - if (oldEntries == null) { - List newEntries = new ArrayList(8); - oldEntries = resolveCache.putIfAbsent(hostname, newEntries); - entries = oldEntries != null? oldEntries : newEntries; - } else { - entries = oldEntries; - } - return entries; + Entries entries = resolveCache.get(hostname); + return entries == null ? null : entries.get(); } @Override @@ -157,22 +141,7 @@ public class DefaultDnsCache implements DnsCache { if (maxTtl == 0 || !emptyAdditionals(additionals)) { return e; } - final int ttl = Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)); - final List entries = cachedEntries(hostname); - - synchronized (entries) { - if (!entries.isEmpty()) { - final DefaultDnsCacheEntry firstEntry = entries.get(0); - if (firstEntry.cause() != null) { - assert entries.size() == 1; - firstEntry.cancelExpiration(); - entries.clear(); - } - } - entries.add(e); - } - - scheduleCacheExpiration(entries, e, ttl, loop); + cache0(e, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop); return e; } @@ -186,40 +155,34 @@ public class DefaultDnsCache implements DnsCache { if (negativeTtl == 0 || !emptyAdditionals(additionals)) { return e; } - final List entries = cachedEntries(hostname); - synchronized (entries) { - final int numEntries = entries.size(); - for (int i = 0; i < numEntries; i ++) { - entries.get(i).cancelExpiration(); - } - entries.clear(); - entries.add(e); - } - - scheduleCacheExpiration(entries, e, negativeTtl, loop); + cache0(e, negativeTtl, loop); return e; } - private static void cancelExpiration(List entries) { - final int numEntries = entries.size(); - for (int i = 0; i < numEntries; i++) { - entries.get(i).cancelExpiration(); + private void cache0(DefaultDnsCacheEntry e, int ttl, EventLoop loop) { + Entries entries = resolveCache.get(e.hostname()); + if (entries == null) { + entries = new Entries(e); + Entries oldEntries = resolveCache.putIfAbsent(e.hostname(), entries); + if (oldEntries != null) { + entries = oldEntries; + entries.add(e); + } } + + scheduleCacheExpiration(entries, e, ttl, loop); } - private void scheduleCacheExpiration(final List entries, + private void scheduleCacheExpiration(final Entries entries, final DefaultDnsCacheEntry e, int ttl, EventLoop loop) { e.scheduleExpiration(loop, new Runnable() { @Override public void run() { - synchronized (entries) { - entries.remove(e); - if (entries.isEmpty()) { - resolveCache.remove(e.hostname()); - } + if (entries.remove(e)) { + resolveCache.remove(e.hostname); } } }, ttl, TimeUnit.SECONDS); @@ -289,4 +252,98 @@ public class DefaultDnsCache implements DnsCache { } } } + + // Directly extend AtomicReference for intrinsics and also to keep memory overhead low. + private static final class Entries extends AtomicReference> { + + Entries(DefaultDnsCacheEntry entry) { + super(Collections.singletonList(entry)); + } + + void add(DefaultDnsCacheEntry e) { + if (e.cause() == null) { + for (;;) { + List entries = get(); + if (!entries.isEmpty()) { + final DefaultDnsCacheEntry firstEntry = entries.get(0); + if (firstEntry.cause() != null) { + assert entries.size() == 1; + if (compareAndSet(entries, Collections.singletonList(e))) { + firstEntry.cancelExpiration(); + return; + } else { + // Need to try again as CAS failed + continue; + } + } + // Create a new List for COW semantics + List newEntries = new ArrayList(entries.size() + 1); + newEntries.addAll(entries); + newEntries.add(e); + if (compareAndSet(entries, newEntries)) { + return; + } + } else if (compareAndSet(entries, Collections.singletonList(e))) { + return; + } + } + } else { + List entries = getAndSet(Collections.singletonList(e)); + cancelExpiration(entries); + } + } + + boolean remove(DefaultDnsCacheEntry entry) { + for (;;) { + List entries = get(); + int size = entries.size(); + if (size == 0) { + return false; + } else if (size == 1) { + if (entries.get(0).equals(entry)) { + // If the list is empty we just return early and so not allocate a new ArrayList. + if (compareAndSet(entries, Collections.emptyList())) { + return true; + } + } else { + return false; + } + } else { + // Just size the new ArrayList as before as we may not find the entry we are looking for and not + // want to cause an extra allocation / memory copy in this case. + // + // Its very likely we find the entry we are looking for so we directly create a new ArrayList + // and fill it. + List newEntries = new ArrayList(size); + for (int i = 0; i < size; i++) { + DefaultDnsCacheEntry e = entries.get(i); + if (!e.equals(entry)) { + newEntries.add(e); + } + } + if (compareAndSet(entries, Collections.unmodifiableList(newEntries))) { + // This will return true if an entry was removed. + return newEntries.size() != size; + } + } + } + } + + boolean cancelExpiration() { + List entries = getAndSet(Collections.emptyList()); + if (entries.isEmpty()) { + return false; + } + + cancelExpiration(entries); + return true; + } + + private static void cancelExpiration(List entryList) { + final int numEntries = entryList.size(); + for (int i = 0; i < numEntries; i++) { + entryList.get(i).cancelExpiration(); + } + } + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 1a928f8c23..eaa714026d 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -600,43 +600,28 @@ public class DnsNameResolver extends InetNameResolver { Promise promise, DnsCache resolveCache) { final List cachedEntries = resolveCache.get(hostname, additionals); - if (cachedEntries == null) { + if (cachedEntries == null || cachedEntries.isEmpty()) { return false; } - InetAddress address = null; - Throwable cause = null; - synchronized (cachedEntries) { + Throwable cause = cachedEntries.get(0).cause(); + if (cause == null) { final int numEntries = cachedEntries.size(); - if (numEntries == 0) { - return false; - } - - if (cachedEntries.get(0).cause() != null) { - cause = cachedEntries.get(0).cause(); - } else { - // Find the first entry with the preferred address type. - for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) { - for (int i = 0; i < numEntries; i++) { - final DnsCacheEntry e = cachedEntries.get(i); - if (f.addressType().isInstance(e.address())) { - address = e.address(); - break; - } + // Find the first entry with the preferred address type. + for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) { + for (int i = 0; i < numEntries; i++) { + final DnsCacheEntry e = cachedEntries.get(i); + if (f.addressType().isInstance(e.address())) { + trySuccess(promise, e.address()); + return true; } } } - } - - if (address != null) { - trySuccess(promise, address); - return true; - } - if (cause != null) { + return false; + } else { tryFailure(promise, cause); return true; } - return false; } private static void trySuccess(Promise promise, T result) { @@ -732,44 +717,34 @@ public class DnsNameResolver extends InetNameResolver { Promise> promise, DnsCache resolveCache) { final List cachedEntries = resolveCache.get(hostname, additionals); - if (cachedEntries == null) { + if (cachedEntries == null || cachedEntries.isEmpty()) { return false; } - List result = null; - Throwable cause = null; - synchronized (cachedEntries) { + Throwable cause = cachedEntries.get(0).cause(); + if (cause == null) { + List result = null; final int numEntries = cachedEntries.size(); - if (numEntries == 0) { - return false; - } - - if (cachedEntries.get(0).cause() != null) { - cause = cachedEntries.get(0).cause(); - } else { - for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) { - for (int i = 0; i < numEntries; i++) { - final DnsCacheEntry e = cachedEntries.get(i); - if (f.addressType().isInstance(e.address())) { - if (result == null) { - result = new ArrayList(numEntries); - } - result.add(e.address()); + for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) { + for (int i = 0; i < numEntries; i++) { + final DnsCacheEntry e = cachedEntries.get(i); + if (f.addressType().isInstance(e.address())) { + if (result == null) { + result = new ArrayList(numEntries); } + result.add(e.address()); } } } - } - - if (result != null) { - trySuccess(promise, result); - return true; - } - if (cause != null) { + if (result != null) { + trySuccess(promise, result); + return true; + } + return false; + } else { tryFailure(promise, cause); return true; } - return false; } static final class ListResolverContext extends DnsNameResolverContext> {