Fix race-condition when using DnsCache in DnsNameResolver

Motivation:

The usage of DnsCache in DnsNameResolver was racy in general. First of the isEmpty() was not called in a synchronized block while we depended on synchronized. The other problem was that this whole synchronization only worked if the DefaultDnsCache was used and the returned List was not wrapped by the user.

Modifications:

- Rewrite DefaultDnsCache to not depend on synchronization on the returned List by using a CoW approach.

Result:

Fixes [#7583] and other races.
This commit is contained in:
Norman Maurer 2018-01-16 12:35:34 +01:00
parent 46dac128f7
commit e305488c5a
2 changed files with 155 additions and 123 deletions

View File

@ -23,11 +23,13 @@ import io.netty.util.internal.UnstableApi;
import java.net.InetAddress; import java.net.InetAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static io.netty.util.internal.ObjectUtil.checkNotNull; import static io.netty.util.internal.ObjectUtil.checkNotNull;
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
@ -39,8 +41,8 @@ import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
@UnstableApi @UnstableApi
public class DefaultDnsCache implements DnsCache { public class DefaultDnsCache implements DnsCache {
private final ConcurrentMap<String, List<DefaultDnsCacheEntry>> resolveCache = private final ConcurrentMap<String, Entries> resolveCache = PlatformDependent.newConcurrentHashMap();
PlatformDependent.newConcurrentHashMap();
private final int minTtl; private final int minTtl;
private final int maxTtl; private final int maxTtl;
private final int negativeTtl; private final int negativeTtl;
@ -97,28 +99,21 @@ public class DefaultDnsCache implements DnsCache {
@Override @Override
public void clear() { public void clear() {
for (Iterator<Map.Entry<String, List<DefaultDnsCacheEntry>>> i = resolveCache.entrySet().iterator(); while (!resolveCache.isEmpty()) {
i.hasNext();) { for (Iterator<Map.Entry<String, Entries>> i = resolveCache.entrySet().iterator(); i.hasNext();) {
final Map.Entry<String, List<DefaultDnsCacheEntry>> e = i.next(); Map.Entry<String, Entries> e = i.next();
i.remove(); i.remove();
cancelExpiration(e.getValue());
e.getValue().cancelExpiration();
}
} }
} }
@Override @Override
public boolean clear(String hostname) { public boolean clear(String hostname) {
checkNotNull(hostname, "hostname"); checkNotNull(hostname, "hostname");
boolean removed = false; Entries entries = resolveCache.remove(hostname);
for (Iterator<Map.Entry<String, List<DefaultDnsCacheEntry>>> i = resolveCache.entrySet().iterator(); return entries != null && entries.cancelExpiration();
i.hasNext();) {
final Map.Entry<String, List<DefaultDnsCacheEntry>> e = i.next();
if (e.getKey().equals(hostname)) {
i.remove();
cancelExpiration(e.getValue());
removed = true;
}
}
return removed;
} }
private static boolean emptyAdditionals(DnsRecord[] additionals) { private static boolean emptyAdditionals(DnsRecord[] additionals) {
@ -129,22 +124,11 @@ public class DefaultDnsCache implements DnsCache {
public List<? extends DnsCacheEntry> get(String hostname, DnsRecord[] additionals) { public List<? extends DnsCacheEntry> get(String hostname, DnsRecord[] additionals) {
checkNotNull(hostname, "hostname"); checkNotNull(hostname, "hostname");
if (!emptyAdditionals(additionals)) { if (!emptyAdditionals(additionals)) {
return null; return Collections.<DnsCacheEntry>emptyList();
} }
return resolveCache.get(hostname);
}
private List<DefaultDnsCacheEntry> cachedEntries(String hostname) { Entries entries = resolveCache.get(hostname);
List<DefaultDnsCacheEntry> oldEntries = resolveCache.get(hostname); return entries == null ? null : entries.get();
final List<DefaultDnsCacheEntry> entries;
if (oldEntries == null) {
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(8);
oldEntries = resolveCache.putIfAbsent(hostname, newEntries);
entries = oldEntries != null? oldEntries : newEntries;
} else {
entries = oldEntries;
}
return entries;
} }
@Override @Override
@ -157,22 +141,7 @@ public class DefaultDnsCache implements DnsCache {
if (maxTtl == 0 || !emptyAdditionals(additionals)) { if (maxTtl == 0 || !emptyAdditionals(additionals)) {
return e; return e;
} }
final int ttl = Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)); cache0(e, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop);
final List<DefaultDnsCacheEntry> entries = cachedEntries(hostname);
synchronized (entries) {
if (!entries.isEmpty()) {
final DefaultDnsCacheEntry firstEntry = entries.get(0);
if (firstEntry.cause() != null) {
assert entries.size() == 1;
firstEntry.cancelExpiration();
entries.clear();
}
}
entries.add(e);
}
scheduleCacheExpiration(entries, e, ttl, loop);
return e; return e;
} }
@ -186,40 +155,34 @@ public class DefaultDnsCache implements DnsCache {
if (negativeTtl == 0 || !emptyAdditionals(additionals)) { if (negativeTtl == 0 || !emptyAdditionals(additionals)) {
return e; return e;
} }
final List<DefaultDnsCacheEntry> entries = cachedEntries(hostname);
synchronized (entries) { cache0(e, negativeTtl, loop);
final int numEntries = entries.size();
for (int i = 0; i < numEntries; i ++) {
entries.get(i).cancelExpiration();
}
entries.clear();
entries.add(e);
}
scheduleCacheExpiration(entries, e, negativeTtl, loop);
return e; return e;
} }
private static void cancelExpiration(List<DefaultDnsCacheEntry> entries) { private void cache0(DefaultDnsCacheEntry e, int ttl, EventLoop loop) {
final int numEntries = entries.size(); Entries entries = resolveCache.get(e.hostname());
for (int i = 0; i < numEntries; i++) { if (entries == null) {
entries.get(i).cancelExpiration(); entries = new Entries(e);
Entries oldEntries = resolveCache.putIfAbsent(e.hostname(), entries);
if (oldEntries != null) {
entries = oldEntries;
entries.add(e);
}
} }
scheduleCacheExpiration(entries, e, ttl, loop);
} }
private void scheduleCacheExpiration(final List<DefaultDnsCacheEntry> entries, private void scheduleCacheExpiration(final Entries entries,
final DefaultDnsCacheEntry e, final DefaultDnsCacheEntry e,
int ttl, int ttl,
EventLoop loop) { EventLoop loop) {
e.scheduleExpiration(loop, new Runnable() { e.scheduleExpiration(loop, new Runnable() {
@Override @Override
public void run() { public void run() {
synchronized (entries) { if (entries.remove(e)) {
entries.remove(e); resolveCache.remove(e.hostname);
if (entries.isEmpty()) {
resolveCache.remove(e.hostname());
}
} }
} }
}, ttl, TimeUnit.SECONDS); }, ttl, TimeUnit.SECONDS);
@ -289,4 +252,98 @@ public class DefaultDnsCache implements DnsCache {
} }
} }
} }
// Directly extend AtomicReference for intrinsics and also to keep memory overhead low.
private static final class Entries extends AtomicReference<List<DefaultDnsCacheEntry>> {
Entries(DefaultDnsCacheEntry entry) {
super(Collections.singletonList(entry));
}
void add(DefaultDnsCacheEntry e) {
if (e.cause() == null) {
for (;;) {
List<DefaultDnsCacheEntry> entries = get();
if (!entries.isEmpty()) {
final DefaultDnsCacheEntry firstEntry = entries.get(0);
if (firstEntry.cause() != null) {
assert entries.size() == 1;
if (compareAndSet(entries, Collections.singletonList(e))) {
firstEntry.cancelExpiration();
return;
} else {
// Need to try again as CAS failed
continue;
}
}
// Create a new List for COW semantics
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(entries.size() + 1);
newEntries.addAll(entries);
newEntries.add(e);
if (compareAndSet(entries, newEntries)) {
return;
}
} else if (compareAndSet(entries, Collections.singletonList(e))) {
return;
}
}
} else {
List<DefaultDnsCacheEntry> entries = getAndSet(Collections.singletonList(e));
cancelExpiration(entries);
}
}
boolean remove(DefaultDnsCacheEntry entry) {
for (;;) {
List<DefaultDnsCacheEntry> entries = get();
int size = entries.size();
if (size == 0) {
return false;
} else if (size == 1) {
if (entries.get(0).equals(entry)) {
// If the list is empty we just return early and so not allocate a new ArrayList.
if (compareAndSet(entries, Collections.<DefaultDnsCacheEntry>emptyList())) {
return true;
}
} else {
return false;
}
} else {
// Just size the new ArrayList as before as we may not find the entry we are looking for and not
// want to cause an extra allocation / memory copy in this case.
//
// Its very likely we find the entry we are looking for so we directly create a new ArrayList
// and fill it.
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(size);
for (int i = 0; i < size; i++) {
DefaultDnsCacheEntry e = entries.get(i);
if (!e.equals(entry)) {
newEntries.add(e);
}
}
if (compareAndSet(entries, Collections.unmodifiableList(newEntries))) {
// This will return true if an entry was removed.
return newEntries.size() != size;
}
}
}
}
boolean cancelExpiration() {
List<DefaultDnsCacheEntry> entries = getAndSet(Collections.<DefaultDnsCacheEntry>emptyList());
if (entries.isEmpty()) {
return false;
}
cancelExpiration(entries);
return true;
}
private static void cancelExpiration(List<DefaultDnsCacheEntry> entryList) {
final int numEntries = entryList.size();
for (int i = 0; i < numEntries; i++) {
entryList.get(i).cancelExpiration();
}
}
}
} }

View File

@ -600,43 +600,28 @@ public class DnsNameResolver extends InetNameResolver {
Promise<InetAddress> promise, Promise<InetAddress> promise,
DnsCache resolveCache) { DnsCache resolveCache) {
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals); final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
if (cachedEntries == null) { if (cachedEntries == null || cachedEntries.isEmpty()) {
return false; return false;
} }
InetAddress address = null; Throwable cause = cachedEntries.get(0).cause();
Throwable cause = null; if (cause == null) {
synchronized (cachedEntries) {
final int numEntries = cachedEntries.size(); final int numEntries = cachedEntries.size();
if (numEntries == 0) { // Find the first entry with the preferred address type.
return false; for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
} for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (cachedEntries.get(0).cause() != null) { if (f.addressType().isInstance(e.address())) {
cause = cachedEntries.get(0).cause(); trySuccess(promise, e.address());
} else { return true;
// Find the first entry with the preferred address type.
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
address = e.address();
break;
}
} }
} }
} }
} return false;
} else {
if (address != null) {
trySuccess(promise, address);
return true;
}
if (cause != null) {
tryFailure(promise, cause); tryFailure(promise, cause);
return true; return true;
} }
return false;
} }
private static <T> void trySuccess(Promise<T> promise, T result) { private static <T> void trySuccess(Promise<T> promise, T result) {
@ -732,44 +717,34 @@ public class DnsNameResolver extends InetNameResolver {
Promise<List<InetAddress>> promise, Promise<List<InetAddress>> promise,
DnsCache resolveCache) { DnsCache resolveCache) {
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals); final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
if (cachedEntries == null) { if (cachedEntries == null || cachedEntries.isEmpty()) {
return false; return false;
} }
List<InetAddress> result = null; Throwable cause = cachedEntries.get(0).cause();
Throwable cause = null; if (cause == null) {
synchronized (cachedEntries) { List<InetAddress> result = null;
final int numEntries = cachedEntries.size(); final int numEntries = cachedEntries.size();
if (numEntries == 0) { for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
return false; for (int i = 0; i < numEntries; i++) {
} final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
if (cachedEntries.get(0).cause() != null) { if (result == null) {
cause = cachedEntries.get(0).cause(); result = new ArrayList<InetAddress>(numEntries);
} else {
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
for (int i = 0; i < numEntries; i++) {
final DnsCacheEntry e = cachedEntries.get(i);
if (f.addressType().isInstance(e.address())) {
if (result == null) {
result = new ArrayList<InetAddress>(numEntries);
}
result.add(e.address());
} }
result.add(e.address());
} }
} }
} }
} if (result != null) {
trySuccess(promise, result);
if (result != null) { return true;
trySuccess(promise, result); }
return true; return false;
} } else {
if (cause != null) {
tryFailure(promise, cause); tryFailure(promise, cause);
return true; return true;
} }
return false;
} }
static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> { static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {