Fix race-condition when using DnsCache in DnsNameResolver
Motivation: The usage of DnsCache in DnsNameResolver was racy in general. First of the isEmpty() was not called in a synchronized block while we depended on synchronized. The other problem was that this whole synchronization only worked if the DefaultDnsCache was used and the returned List was not wrapped by the user. Modifications: - Rewrite DefaultDnsCache to not depend on synchronization on the returned List by using a CoW approach. Result: Fixes [#7583] and other races.
This commit is contained in:
parent
46dac128f7
commit
e305488c5a
@ -23,11 +23,13 @@ import io.netty.util.internal.UnstableApi;
|
|||||||
|
|
||||||
import java.net.InetAddress;
|
import java.net.InetAddress;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentMap;
|
import java.util.concurrent.ConcurrentMap;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
import static io.netty.util.internal.ObjectUtil.checkNotNull;
|
||||||
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
||||||
@ -39,8 +41,8 @@ import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
|
|||||||
@UnstableApi
|
@UnstableApi
|
||||||
public class DefaultDnsCache implements DnsCache {
|
public class DefaultDnsCache implements DnsCache {
|
||||||
|
|
||||||
private final ConcurrentMap<String, List<DefaultDnsCacheEntry>> resolveCache =
|
private final ConcurrentMap<String, Entries> resolveCache = PlatformDependent.newConcurrentHashMap();
|
||||||
PlatformDependent.newConcurrentHashMap();
|
|
||||||
private final int minTtl;
|
private final int minTtl;
|
||||||
private final int maxTtl;
|
private final int maxTtl;
|
||||||
private final int negativeTtl;
|
private final int negativeTtl;
|
||||||
@ -97,28 +99,21 @@ public class DefaultDnsCache implements DnsCache {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void clear() {
|
public void clear() {
|
||||||
for (Iterator<Map.Entry<String, List<DefaultDnsCacheEntry>>> i = resolveCache.entrySet().iterator();
|
while (!resolveCache.isEmpty()) {
|
||||||
i.hasNext();) {
|
for (Iterator<Map.Entry<String, Entries>> i = resolveCache.entrySet().iterator(); i.hasNext();) {
|
||||||
final Map.Entry<String, List<DefaultDnsCacheEntry>> e = i.next();
|
Map.Entry<String, Entries> e = i.next();
|
||||||
i.remove();
|
i.remove();
|
||||||
cancelExpiration(e.getValue());
|
|
||||||
|
e.getValue().cancelExpiration();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean clear(String hostname) {
|
public boolean clear(String hostname) {
|
||||||
checkNotNull(hostname, "hostname");
|
checkNotNull(hostname, "hostname");
|
||||||
boolean removed = false;
|
Entries entries = resolveCache.remove(hostname);
|
||||||
for (Iterator<Map.Entry<String, List<DefaultDnsCacheEntry>>> i = resolveCache.entrySet().iterator();
|
return entries != null && entries.cancelExpiration();
|
||||||
i.hasNext();) {
|
|
||||||
final Map.Entry<String, List<DefaultDnsCacheEntry>> e = i.next();
|
|
||||||
if (e.getKey().equals(hostname)) {
|
|
||||||
i.remove();
|
|
||||||
cancelExpiration(e.getValue());
|
|
||||||
removed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return removed;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean emptyAdditionals(DnsRecord[] additionals) {
|
private static boolean emptyAdditionals(DnsRecord[] additionals) {
|
||||||
@ -129,22 +124,11 @@ public class DefaultDnsCache implements DnsCache {
|
|||||||
public List<? extends DnsCacheEntry> get(String hostname, DnsRecord[] additionals) {
|
public List<? extends DnsCacheEntry> get(String hostname, DnsRecord[] additionals) {
|
||||||
checkNotNull(hostname, "hostname");
|
checkNotNull(hostname, "hostname");
|
||||||
if (!emptyAdditionals(additionals)) {
|
if (!emptyAdditionals(additionals)) {
|
||||||
return null;
|
return Collections.<DnsCacheEntry>emptyList();
|
||||||
}
|
|
||||||
return resolveCache.get(hostname);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<DefaultDnsCacheEntry> cachedEntries(String hostname) {
|
Entries entries = resolveCache.get(hostname);
|
||||||
List<DefaultDnsCacheEntry> oldEntries = resolveCache.get(hostname);
|
return entries == null ? null : entries.get();
|
||||||
final List<DefaultDnsCacheEntry> entries;
|
|
||||||
if (oldEntries == null) {
|
|
||||||
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(8);
|
|
||||||
oldEntries = resolveCache.putIfAbsent(hostname, newEntries);
|
|
||||||
entries = oldEntries != null? oldEntries : newEntries;
|
|
||||||
} else {
|
|
||||||
entries = oldEntries;
|
|
||||||
}
|
|
||||||
return entries;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -157,22 +141,7 @@ public class DefaultDnsCache implements DnsCache {
|
|||||||
if (maxTtl == 0 || !emptyAdditionals(additionals)) {
|
if (maxTtl == 0 || !emptyAdditionals(additionals)) {
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
final int ttl = Math.max(minTtl, (int) Math.min(maxTtl, originalTtl));
|
cache0(e, Math.max(minTtl, (int) Math.min(maxTtl, originalTtl)), loop);
|
||||||
final List<DefaultDnsCacheEntry> entries = cachedEntries(hostname);
|
|
||||||
|
|
||||||
synchronized (entries) {
|
|
||||||
if (!entries.isEmpty()) {
|
|
||||||
final DefaultDnsCacheEntry firstEntry = entries.get(0);
|
|
||||||
if (firstEntry.cause() != null) {
|
|
||||||
assert entries.size() == 1;
|
|
||||||
firstEntry.cancelExpiration();
|
|
||||||
entries.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
entries.add(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
scheduleCacheExpiration(entries, e, ttl, loop);
|
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,40 +155,34 @@ public class DefaultDnsCache implements DnsCache {
|
|||||||
if (negativeTtl == 0 || !emptyAdditionals(additionals)) {
|
if (negativeTtl == 0 || !emptyAdditionals(additionals)) {
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
final List<DefaultDnsCacheEntry> entries = cachedEntries(hostname);
|
|
||||||
|
|
||||||
synchronized (entries) {
|
cache0(e, negativeTtl, loop);
|
||||||
final int numEntries = entries.size();
|
|
||||||
for (int i = 0; i < numEntries; i ++) {
|
|
||||||
entries.get(i).cancelExpiration();
|
|
||||||
}
|
|
||||||
entries.clear();
|
|
||||||
entries.add(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
scheduleCacheExpiration(entries, e, negativeTtl, loop);
|
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void cancelExpiration(List<DefaultDnsCacheEntry> entries) {
|
private void cache0(DefaultDnsCacheEntry e, int ttl, EventLoop loop) {
|
||||||
final int numEntries = entries.size();
|
Entries entries = resolveCache.get(e.hostname());
|
||||||
for (int i = 0; i < numEntries; i++) {
|
if (entries == null) {
|
||||||
entries.get(i).cancelExpiration();
|
entries = new Entries(e);
|
||||||
|
Entries oldEntries = resolveCache.putIfAbsent(e.hostname(), entries);
|
||||||
|
if (oldEntries != null) {
|
||||||
|
entries = oldEntries;
|
||||||
|
entries.add(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void scheduleCacheExpiration(final List<DefaultDnsCacheEntry> entries,
|
scheduleCacheExpiration(entries, e, ttl, loop);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void scheduleCacheExpiration(final Entries entries,
|
||||||
final DefaultDnsCacheEntry e,
|
final DefaultDnsCacheEntry e,
|
||||||
int ttl,
|
int ttl,
|
||||||
EventLoop loop) {
|
EventLoop loop) {
|
||||||
e.scheduleExpiration(loop, new Runnable() {
|
e.scheduleExpiration(loop, new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
synchronized (entries) {
|
if (entries.remove(e)) {
|
||||||
entries.remove(e);
|
resolveCache.remove(e.hostname);
|
||||||
if (entries.isEmpty()) {
|
|
||||||
resolveCache.remove(e.hostname());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, ttl, TimeUnit.SECONDS);
|
}, ttl, TimeUnit.SECONDS);
|
||||||
@ -289,4 +252,98 @@ public class DefaultDnsCache implements DnsCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Directly extend AtomicReference for intrinsics and also to keep memory overhead low.
|
||||||
|
private static final class Entries extends AtomicReference<List<DefaultDnsCacheEntry>> {
|
||||||
|
|
||||||
|
Entries(DefaultDnsCacheEntry entry) {
|
||||||
|
super(Collections.singletonList(entry));
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(DefaultDnsCacheEntry e) {
|
||||||
|
if (e.cause() == null) {
|
||||||
|
for (;;) {
|
||||||
|
List<DefaultDnsCacheEntry> entries = get();
|
||||||
|
if (!entries.isEmpty()) {
|
||||||
|
final DefaultDnsCacheEntry firstEntry = entries.get(0);
|
||||||
|
if (firstEntry.cause() != null) {
|
||||||
|
assert entries.size() == 1;
|
||||||
|
if (compareAndSet(entries, Collections.singletonList(e))) {
|
||||||
|
firstEntry.cancelExpiration();
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
// Need to try again as CAS failed
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Create a new List for COW semantics
|
||||||
|
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(entries.size() + 1);
|
||||||
|
newEntries.addAll(entries);
|
||||||
|
newEntries.add(e);
|
||||||
|
if (compareAndSet(entries, newEntries)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if (compareAndSet(entries, Collections.singletonList(e))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
List<DefaultDnsCacheEntry> entries = getAndSet(Collections.singletonList(e));
|
||||||
|
cancelExpiration(entries);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean remove(DefaultDnsCacheEntry entry) {
|
||||||
|
for (;;) {
|
||||||
|
List<DefaultDnsCacheEntry> entries = get();
|
||||||
|
int size = entries.size();
|
||||||
|
if (size == 0) {
|
||||||
|
return false;
|
||||||
|
} else if (size == 1) {
|
||||||
|
if (entries.get(0).equals(entry)) {
|
||||||
|
// If the list is empty we just return early and so not allocate a new ArrayList.
|
||||||
|
if (compareAndSet(entries, Collections.<DefaultDnsCacheEntry>emptyList())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Just size the new ArrayList as before as we may not find the entry we are looking for and not
|
||||||
|
// want to cause an extra allocation / memory copy in this case.
|
||||||
|
//
|
||||||
|
// Its very likely we find the entry we are looking for so we directly create a new ArrayList
|
||||||
|
// and fill it.
|
||||||
|
List<DefaultDnsCacheEntry> newEntries = new ArrayList<DefaultDnsCacheEntry>(size);
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
DefaultDnsCacheEntry e = entries.get(i);
|
||||||
|
if (!e.equals(entry)) {
|
||||||
|
newEntries.add(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (compareAndSet(entries, Collections.unmodifiableList(newEntries))) {
|
||||||
|
// This will return true if an entry was removed.
|
||||||
|
return newEntries.size() != size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean cancelExpiration() {
|
||||||
|
List<DefaultDnsCacheEntry> entries = getAndSet(Collections.<DefaultDnsCacheEntry>emptyList());
|
||||||
|
if (entries.isEmpty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
cancelExpiration(entries);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void cancelExpiration(List<DefaultDnsCacheEntry> entryList) {
|
||||||
|
final int numEntries = entryList.size();
|
||||||
|
for (int i = 0; i < numEntries; i++) {
|
||||||
|
entryList.get(i).cancelExpiration();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -600,43 +600,28 @@ public class DnsNameResolver extends InetNameResolver {
|
|||||||
Promise<InetAddress> promise,
|
Promise<InetAddress> promise,
|
||||||
DnsCache resolveCache) {
|
DnsCache resolveCache) {
|
||||||
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
|
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
|
||||||
if (cachedEntries == null) {
|
if (cachedEntries == null || cachedEntries.isEmpty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
InetAddress address = null;
|
Throwable cause = cachedEntries.get(0).cause();
|
||||||
Throwable cause = null;
|
if (cause == null) {
|
||||||
synchronized (cachedEntries) {
|
|
||||||
final int numEntries = cachedEntries.size();
|
final int numEntries = cachedEntries.size();
|
||||||
if (numEntries == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cachedEntries.get(0).cause() != null) {
|
|
||||||
cause = cachedEntries.get(0).cause();
|
|
||||||
} else {
|
|
||||||
// Find the first entry with the preferred address type.
|
// Find the first entry with the preferred address type.
|
||||||
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
||||||
for (int i = 0; i < numEntries; i++) {
|
for (int i = 0; i < numEntries; i++) {
|
||||||
final DnsCacheEntry e = cachedEntries.get(i);
|
final DnsCacheEntry e = cachedEntries.get(i);
|
||||||
if (f.addressType().isInstance(e.address())) {
|
if (f.addressType().isInstance(e.address())) {
|
||||||
address = e.address();
|
trySuccess(promise, e.address());
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (address != null) {
|
|
||||||
trySuccess(promise, address);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (cause != null) {
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
tryFailure(promise, cause);
|
tryFailure(promise, cause);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static <T> void trySuccess(Promise<T> promise, T result) {
|
private static <T> void trySuccess(Promise<T> promise, T result) {
|
||||||
@ -732,21 +717,14 @@ public class DnsNameResolver extends InetNameResolver {
|
|||||||
Promise<List<InetAddress>> promise,
|
Promise<List<InetAddress>> promise,
|
||||||
DnsCache resolveCache) {
|
DnsCache resolveCache) {
|
||||||
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
|
final List<? extends DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
|
||||||
if (cachedEntries == null) {
|
if (cachedEntries == null || cachedEntries.isEmpty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Throwable cause = cachedEntries.get(0).cause();
|
||||||
|
if (cause == null) {
|
||||||
List<InetAddress> result = null;
|
List<InetAddress> result = null;
|
||||||
Throwable cause = null;
|
|
||||||
synchronized (cachedEntries) {
|
|
||||||
final int numEntries = cachedEntries.size();
|
final int numEntries = cachedEntries.size();
|
||||||
if (numEntries == 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cachedEntries.get(0).cause() != null) {
|
|
||||||
cause = cachedEntries.get(0).cause();
|
|
||||||
} else {
|
|
||||||
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
for (InternetProtocolFamily f : resolvedInternetProtocolFamilies) {
|
||||||
for (int i = 0; i < numEntries; i++) {
|
for (int i = 0; i < numEntries; i++) {
|
||||||
final DnsCacheEntry e = cachedEntries.get(i);
|
final DnsCacheEntry e = cachedEntries.get(i);
|
||||||
@ -758,18 +736,15 @@ public class DnsNameResolver extends InetNameResolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
trySuccess(promise, result);
|
trySuccess(promise, result);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (cause != null) {
|
return false;
|
||||||
|
} else {
|
||||||
tryFailure(promise, cause);
|
tryFailure(promise, cause);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {
|
static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {
|
||||||
|
Loading…
Reference in New Issue
Block a user