Fix race-condition when using DnsCache in DnsNameResolver
Motivation: The usage of DnsCache in DnsNameResolver was racy in general. First of the isEmpty() was not called in a synchronized block while we depended on synchronized. The other problem was that this whole synchronization only worked if the DefaultDnsCache was used and the returned List was not wrapped by the user. Modifications: - Rewrite DefaultDnsCache to not depend on synchronization on the returned List by using a CoW approach. Result: Fixes [#7583] and other races.
This commit is contained in:
parent
46dac128f7
commit
e305488c5a
@ -23,11 +23,13 @@ import io.netty.util.internal.UnstableApi;
|
||||
|
||||
import java.net.InetAddress;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
||||
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
||||
@ -39,8 +41,8 @@ import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
||||
@UnstableApi
|
||||
public class DefaultDnsCache implements DnsCache {
|
||||
|
||||
private final ConcurrentMap<String, List<DefaultDnsCacheEntry>> resolveCache =
|
||||
PlatformDependent.newConcurrentHashMap();
|
||||
private final ConcurrentMap<String, Entries> resolveCache = PlatformDependent.newConcurrentHashMap();
|
||||
|
||||
private final int minTtl;
|
||||
private final int maxTtl;
|
||||
private final int negativeTtl;
|
||||
@ -97,28 +99,21 @@ public class DefaultDnsCache implements DnsCache {
|
||||
|
||||
@Override
|
||||
public void clear() {
|
||||
for (Iterator<Map.Entry<String, List<DefaultDnsCacheEntry>>> i = resolveCache.entrySet().iterator();
|
||||
i.hasNext();) {
|
||||
final Map.Entry<String, List<DefaultDnsCacheEntry>> e = i.next();
|
||||
i.remove();
|
||||
cancelExpiration(e.getValue());
|
||||
while (!resolveCache.isEmpty()) {
|
||||
for (Iterator<Map.Entry<String, Entries>> i = resolveCache.entrySet().iterator(); i.hasNext();) {
|
||||
Map.Entry<String, Entries> e = i.next();
|
||||
i.remove();
|
||||
|
||||
e.getValue().cancelExpiration();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean clear(String hostname) {
|
||||
checkNotNull(hostname, "hostname");
|
||||
boolean removed = false;
|
||||
for (Iterator<Map.Entry<String, List<DefaultDnsCacheEntry>>> i = resolveCache.entrySet().iterator();
|
||||
i.hasNext();) {
|
||||
final Map.Entry<String, List<DefaultDnsCacheEntry>> e = i.next();
|
||||
if (e.getKey().equals(hostname)) {
|
||||
i.remove();
|
||||
cancelExpiration(e.getValue());
|
||||
removed = true;
|
||||
}
|
||||
}
|
||||
return removed;
|
||||
Entries entries = resolveCache.remove(hostname);
|
||||
return entries != null && entries.cancelExpiration();
|
||||
}
|
||||
|
||||
private static boolean emptyAdditionals(DnsRecord[] additionals) {
|
||||
@ -129,22 +124,11 @@ public class DefaultDnsCache implements DnsCache {
|
||||
public List<? extends DnsCacheEntry> get(String hostname, DnsRecord[] additionals) {
|
||||
checkNotNull(hostname, "hostname");
|
||||
if (!emptyAdditionals(additionals)) {
|
||||
return null;
|
||||
return Collections.<DnsCacheEntry>emptyList();
|
||||
}
|
||||
return resolveCache.get(hostname);
|
||||
}
|
||||
|
||||
private List<DefaultDnsCacheEntry> cachedEntries(String hostname) {
|
||||
List<DefaultDnsCacheEntry> oldEntries = resolveCache.get(hostname);
|
||||
final List<DefaultDnsCacheEntry> entries;
|
||||
if (oldEntries == null) {
|
||||
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(8);
|
||||
oldEntries = resolveCache.putIfAbsent(hostname, newEntries);
|
||||
entries = oldEntries != null? oldEntries : newEntries;
|
||||
} else {
|
||||
entries = oldEntries;
|
||||
}
|
||||
return entries;
|
||||
Entries entries = resolveCache.get(hostname);
|
||||
return entries == null ? null : entries.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -157,22 +141,7 @@ public class DefaultDnsCache implements DnsCache {
|
||||
if (maxTtl == 0 || !emptyAdditionals(additionals)) {
|
||||
return e;
|
||||
}
|
||||
final int ttl = Math.max(minTtl, (int) Math.min(maxTtl, originalTtl));
|
||||
final List<DefaultDnsCacheEntry> entries = cachedEntries(hostname);
|
||||
|
||||
synchronized (entries) {
|
||||
if (!entries.isEmpty()) {
|
||||
final DefaultDnsCacheEntry firstEntry = entries.get(0);
|
||||
if (firstEntry.cause() != null) {
|
||||
assert entries.size() == 1;
|
||||
firstEntry.cancelExpiration();
|
||||
entries.clear();
|
||||
}
|
||||
}
|
||||
entries.add(e);
|
||||
}
|
||||
|
||||
scheduleCacheExpiration(entries, e, ttl, loop);
|
||||
cache0(e, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop);
|
||||
return e;
|
||||
}
|
||||
|
||||
@ -186,40 +155,34 @@ public class DefaultDnsCache implements DnsCache {
|
||||
if (negativeTtl == 0 || !emptyAdditionals(additionals)) {
|
||||
return e;
|
||||
}
|
||||
final List<DefaultDnsCacheEntry> entries = cachedEntries(hostname);
|
||||
|
||||
synchronized (entries) {
|
||||
final int numEntries = entries.size();
|
||||
for (int i = 0; i < numEntries; i ++) {
|
||||
entries.get(i).cancelExpiration();
|
||||
}
|
||||
entries.clear();
|
||||
entries.add(e);
|
||||
}
|
||||
|
||||
scheduleCacheExpiration(entries, e, negativeTtl, loop);
|
||||
cache0(e, negativeTtl, loop);
|
||||
return e;
|
||||
}
|
||||
|
||||
private static void cancelExpiration(List<DefaultDnsCacheEntry> entries) {
|
||||
final int numEntries = entries.size();
|
||||
for (int i = 0; i < numEntries; i++) {
|
||||
entries.get(i).cancelExpiration();
|
||||
private void cache0(DefaultDnsCacheEntry e, int ttl, EventLoop loop) {
|
||||
Entries entries = resolveCache.get(e.hostname());
|
||||
if (entries == null) {
|
||||
entries = new Entries(e);
|
||||
Entries oldEntries = resolveCache.putIfAbsent(e.hostname(), entries);
|
||||
if (oldEntries != null) {
|
||||
entries = oldEntries;
|
||||
entries.add(e);
|
||||
}
|
||||
}
|
||||
|
||||
scheduleCacheExpiration(entries, e, ttl, loop);
|
||||
}
|
||||
|
||||
private void scheduleCacheExpiration(final List<DefaultDnsCacheEntry> entries,
|
||||
private void scheduleCacheExpiration(final Entries entries,
|
||||
final DefaultDnsCacheEntry e,
|
||||
int ttl,
|
||||
EventLoop loop) {
|
||||
e.scheduleExpiration(loop, new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
synchronized (entries) {
|
||||
entries.remove(e);
|
||||
if (entries.isEmpty()) {
|
||||
resolveCache.remove(e.hostname());
|
||||
}
|
||||
if (entries.remove(e)) {
|
||||
resolveCache.remove(e.hostname);
|
||||
}
|
||||
}
|
||||
}, ttl, TimeUnit.SECONDS);
|
||||
@ -289,4 +252,98 @@ public class DefaultDnsCache implements DnsCache {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Directly extend AtomicReference for intrinsics and also to keep memory overhead low.
|
||||
private static final class Entries extends AtomicReference<List<DefaultDnsCacheEntry>> {
|
||||
|
||||
Entries(DefaultDnsCacheEntry entry) {
|
||||
super(Collections.singletonList(entry));
|
||||
}
|
||||
|
||||
void add(DefaultDnsCacheEntry e) {
|
||||
if (e.cause() == null) {
|
||||
for (;;) {
|
||||
List<DefaultDnsCacheEntry> entries = get();
|
||||
if (!entries.isEmpty()) {
|
||||
final DefaultDnsCacheEntry firstEntry = entries.get(0);
|
||||
if (firstEntry.cause() != null) {
|
||||
assert entries.size() == 1;
|
||||
if (compareAndSet(entries, Collections.singletonList(e))) {
|
||||
firstEntry.cancelExpiration();
|
||||
return;
|
||||
} else {
|
||||
// Need to try again as CAS failed
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Create a new List for COW semantics
|
||||
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(entries.size() + 1);
|
||||
newEntries.addAll(entries);
|
||||
newEntries.add(e);
|
||||
if (compareAndSet(entries, newEntries)) {
|
||||
return;
|
||||
}
|
||||
} else if (compareAndSet(entries, Collections.singletonList(e))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
List<DefaultDnsCacheEntry> entries = getAndSet(Collections.singletonList(e));
|
||||
cancelExpiration(entries);
|
||||
}
|
||||
}
|
||||
|
||||
boolean remove(DefaultDnsCacheEntry entry) {
|
||||
for (;;) {
|
||||
List<DefaultDnsCacheEntry> entries = get();
|
||||
int size = entries.size();
|
||||
if (size == 0) {
|
||||
return false;
|
||||
} else if (size == 1) {
|
||||
if (entries.get(0).equals(entry)) {
|
||||
// If the list is empty we just return early and so not allocate a new ArrayList.
|
||||
if (compareAndSet(entries, Collections.<DefaultDnsCacheEntry>emptyList())) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Just size the new ArrayList as before as we may not find the entry we are looking for and not
|
||||
// want to cause an extra allocation / memory copy in this case.
|
||||
//
|
||||
// Its very likely we find the entry we are looking for so we directly create a new ArrayList
|
||||
// and fill it.
|
||||
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(size);
|
||||
for (int i = 0; i < size; i++) {
|
||||
DefaultDnsCacheEntry e = entries.get(i);
|
||||
if (!e.equals(entry)) {
|
||||
newEntries.add(e);
|
||||
}
|
||||
}
|
||||
if (compareAndSet(entries, Collections.unmodifiableList(newEntries))) {
|
||||
// This will return true if an entry was removed.
|
||||
return newEntries.size() != size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
boolean cancelExpiration() {
|
||||
List<DefaultDnsCacheEntry> entries = getAndSet(Collections.<DefaultDnsCacheEntry>emptyList());
|
||||
if (entries.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
cancelExpiration(entries);
|
||||
return true;
|
||||
}
|
||||
|
||||
private static void cancelExpiration(List<DefaultDnsCacheEntry> entryList) {
|
||||
final int numEntries = entryList.size();
|
||||
for (int i = 0; i < numEntries; i++) {
|
||||
entryList.get(i).cancelExpiration();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -600,43 +600,28 @@ public class DnsNameResolver extends InetNameResolver {
|
||||
Promise<InetAddress> promise,
|
||||
DnsCache resolveCache) {
|
||||
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
|
||||
if (cachedEntries == null) {
|
||||
if (cachedEntries == null || cachedEntries.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
InetAddress address = null;
|
||||
Throwable cause = null;
|
||||
synchronized (cachedEntries) {
|
||||
Throwable cause = cachedEntries.get(0).cause();
|
||||
if (cause == null) {
|
||||
final int numEntries = cachedEntries.size();
|
||||
if (numEntries == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cachedEntries.get(0).cause() != null) {
|
||||
cause = cachedEntries.get(0).cause();
|
||||
} else {
|
||||
// Find the first entry with the preferred address type.
|
||||
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
||||
for (int i = 0; i < numEntries; i++) {
|
||||
final DnsCacheEntry e = cachedEntries.get(i);
|
||||
if (f.addressType().isInstance(e.address())) {
|
||||
address = e.address();
|
||||
break;
|
||||
}
|
||||
// Find the first entry with the preferred address type.
|
||||
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
||||
for (int i = 0; i < numEntries; i++) {
|
||||
final DnsCacheEntry e = cachedEntries.get(i);
|
||||
if (f.addressType().isInstance(e.address())) {
|
||||
trySuccess(promise, e.address());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (address != null) {
|
||||
trySuccess(promise, address);
|
||||
return true;
|
||||
}
|
||||
if (cause != null) {
|
||||
return false;
|
||||
} else {
|
||||
tryFailure(promise, cause);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static <T> void trySuccess(Promise<T> promise, T result) {
|
||||
@ -732,44 +717,34 @@ public class DnsNameResolver extends InetNameResolver {
|
||||
Promise<List<InetAddress>> promise,
|
||||
DnsCache resolveCache) {
|
||||
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
|
||||
if (cachedEntries == null) {
|
||||
if (cachedEntries == null || cachedEntries.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
List<InetAddress> result = null;
|
||||
Throwable cause = null;
|
||||
synchronized (cachedEntries) {
|
||||
Throwable cause = cachedEntries.get(0).cause();
|
||||
if (cause == null) {
|
||||
List<InetAddress> result = null;
|
||||
final int numEntries = cachedEntries.size();
|
||||
if (numEntries == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cachedEntries.get(0).cause() != null) {
|
||||
cause = cachedEntries.get(0).cause();
|
||||
} else {
|
||||
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
||||
for (int i = 0; i < numEntries; i++) {
|
||||
final DnsCacheEntry e = cachedEntries.get(i);
|
||||
if (f.addressType().isInstance(e.address())) {
|
||||
if (result == null) {
|
||||
result = new ArrayList<InetAddress>(numEntries);
|
||||
}
|
||||
result.add(e.address());
|
||||
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
||||
for (int i = 0; i < numEntries; i++) {
|
||||
final DnsCacheEntry e = cachedEntries.get(i);
|
||||
if (f.addressType().isInstance(e.address())) {
|
||||
if (result == null) {
|
||||
result = new ArrayList<InetAddress>(numEntries);
|
||||
}
|
||||
result.add(e.address());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (result != null) {
|
||||
trySuccess(promise, result);
|
||||
return true;
|
||||
}
|
||||
if (cause != null) {
|
||||
if (result != null) {
|
||||
trySuccess(promise, result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
tryFailure(promise, cause);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {
|
||||
|
Loading…
Reference in New Issue
Block a user