We should fail fast when a CNAME loop is detected (#10305)

Motivation:

Once a CNAME loop was detected we can just fail fast and so reduce the number of queries.

Modifications:

Fail fast once a CNAME loop is detected

Result:

Fail fast
This commit is contained in:
Norman Maurer 2020-05-20 07:10:16 +02:00
parent 9fe5f6f958
commit 3f7dbd8790
2 changed files with 32 additions and 18 deletions

View File

@ -268,7 +268,7 @@ abstract class DnsResolveContext<T> {
// guards against loops in the cache but early return once a loop is detected. // guards against loops in the cache but early return once a loop is detected.
// //
// Visible for testing only // Visible for testing only
static String cnameResolveFromCache(DnsCnameCache cnameCache, String name) { static String cnameResolveFromCache(DnsCnameCache cnameCache, String name) throws UnknownHostException {
String first = cnameCache.get(hostnameWithDot(name)); String first = cnameCache.get(hostnameWithDot(name));
if (first == null) { if (first == null) {
// Nothing in the cache at all // Nothing in the cache at all
@ -281,15 +281,12 @@ abstract class DnsResolveContext<T> {
return first; return first;
} }
if (first.equals(second)) { checkCnameLoop(name, first, second);
// Loop detected.... early return. return cnameResolveFromCacheLoop(cnameCache, name, first, second);
return first;
} }
return cnameResolveFromCacheLoop(cnameCache, first, second); private static String cnameResolveFromCacheLoop(
} DnsCnameCache cnameCache, String hostname, String first, String mapping) throws UnknownHostException {
static String cnameResolveFromCacheLoop(DnsCnameCache cnameCache, String first, String mapping) {
// Detect loops by advance only every other iteration. // Detect loops by advance only every other iteration.
// See https://en.wikipedia.org/wiki/Cycle_detection#Floyd's_Tortoise_and_Hare // See https://en.wikipedia.org/wiki/Cycle_detection#Floyd's_Tortoise_and_Hare
boolean advance = false; boolean advance = false;
@ -297,10 +294,7 @@ abstract class DnsResolveContext<T> {
String name = mapping; String name = mapping;
// Resolve from cnameCache() until there is no more cname entry cached. // Resolve from cnameCache() until there is no more cname entry cached.
while ((mapping = cnameCache.get(hostnameWithDot(name))) != null) { while ((mapping = cnameCache.get(hostnameWithDot(name))) != null) {
if (first.equals(mapping)) { checkCnameLoop(hostname, first, mapping);
// Follow CNAME from cache would loop. Lets break here.
break;
}
name = mapping; name = mapping;
if (advance) { if (advance) {
first = cnameCache.get(first); first = cnameCache.get(first);
@ -310,9 +304,20 @@ abstract class DnsResolveContext<T> {
return name; return name;
} }
private static void checkCnameLoop(String hostname, String first, String second) throws UnknownHostException {
if (first.equals(second)) {
// Follow CNAME from cache would loop. Lets throw and so fail the resolution.
throw new UnknownHostException("CNAME loop detected for '" + hostname + '\'');
}
}
private void internalResolve(String name, Promise<List<T>> promise) { private void internalResolve(String name, Promise<List<T>> promise) {
try {
// Resolve from cnameCache() until there is no more cname entry cached. // Resolve from cnameCache() until there is no more cname entry cached.
name = cnameResolveFromCache(cnameCache(), name); name = cnameResolveFromCache(cnameCache(), name);
} catch (Throwable cause) {
promise.tryFailure(cause);
return;
}
try { try {
DnsServerAddressStream nameServerAddressStream = getNameServers(name); DnsServerAddressStream nameServerAddressStream = getNameServers(name);
@ -977,11 +982,11 @@ abstract class DnsResolveContext<T> {
private void followCname(DnsQuestion question, String cname, DnsQueryLifecycleObserver queryLifecycleObserver, private void followCname(DnsQuestion question, String cname, DnsQueryLifecycleObserver queryLifecycleObserver,
Promise<List<T>> promise) { Promise<List<T>> promise) {
cname = cnameResolveFromCache(cnameCache(), cname);
DnsServerAddressStream stream = getNameServers(cname);
final DnsQuestion cnameQuestion; final DnsQuestion cnameQuestion;
final DnsServerAddressStream stream;
try { try {
cname = cnameResolveFromCache(cnameCache(), cname);
stream = getNameServers(cname);
cnameQuestion = new DefaultDnsQuestion(cname, question.type(), dnsClass); cnameQuestion = new DefaultDnsQuestion(cname, question.type(), dnsClass);
} catch (Throwable cause) { } catch (Throwable cause) {
queryLifecycleObserver.queryFailed(cause); queryLifecycleObserver.queryFailed(cause);

View File

@ -18,6 +18,10 @@ package io.netty.resolver.dns;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.Test; import org.junit.Test;
import java.net.UnknownHostException;
import static org.junit.Assert.fail;
public class DnsResolveContextTest { public class DnsResolveContextTest {
private static final String HOSTNAME = "netty.io."; private static final String HOSTNAME = "netty.io.";
@ -25,7 +29,12 @@ public class DnsResolveContextTest {
@Test @Test
public void testCnameLoop() { public void testCnameLoop() {
for (int i = 1; i < 128; i++) { for (int i = 1; i < 128; i++) {
try {
DnsResolveContext.cnameResolveFromCache(buildCache(i), HOSTNAME); DnsResolveContext.cnameResolveFromCache(buildCache(i), HOSTNAME);
fail();
} catch (UnknownHostException expected) {
// expected
}
} }
} }