Add support for Client Subnet in DNS Queries (RFC7871)

Motivation:

RFC7871 defines an extension which allows to request responses for a given subset.

Modifications:

- Add DnsOptPseudoRrRecord which can act as base class for extensions based on EDNS(0) as defined in RFC6891
- Add DnsOptEcsRecord to support the Client Subnet in DNS Queries extension
- Add tests

Result:

Client Subnet in DNS Queries extension is now supported.
This commit is contained in:
Norman Maurer 2016-07-29 10:30:08 +02:00
parent bd7806dd4f
commit dfa3bbbf00
14 changed files with 783 additions and 92 deletions

View File

@ -0,0 +1,78 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.UnstableApi;
/**
* An <a href="https://tools.ietf.org/html/rfc6891#section-6.1">OPT RR</a> record.
*
* This is used for <a href="https://tools.ietf.org/html/rfc6891#section-6.1.3">
* Extension Mechanisms for DNS (EDNS(0))</a>.
*/
@UnstableApi
public abstract class AbstractDnsOptPseudoRrRecord extends AbstractDnsRecord implements DnsOptPseudoRecord {
protected AbstractDnsOptPseudoRrRecord(int maxPayloadSize, int extendedRcode, int version) {
super(StringUtil.EMPTY_STRING, DnsRecordType.OPT, maxPayloadSize, packIntoLong(extendedRcode, version));
}
protected AbstractDnsOptPseudoRrRecord(int maxPayloadSize) {
super(StringUtil.EMPTY_STRING, DnsRecordType.OPT, maxPayloadSize, 0);
}
// See https://tools.ietf.org/html/rfc6891#section-6.1.3
private static long packIntoLong(int val, int val2) {
// We are currently not support DO and Z fields, just use 0.
return ((val & 0xff) << 24 | (val2 & 0xff) << 16 | (0 & 0xff) << 8 | 0 & 0xff) & 0xFFFFFFFFL;
}
@Override
public int extendedRcode() {
return (short) (((int) timeToLive() >> 24) & 0xff);
}
@Override
public int version() {
return (short) (((int) timeToLive() >> 16) & 0xff);
}
@Override
public int flags() {
return (short) ((short) timeToLive() & 0xff);
}
@Override
public String toString() {
return toStringBuilder().toString();
}
final StringBuilder toStringBuilder() {
return new StringBuilder(64)
.append(StringUtil.simpleClassName(this))
.append('(')
.append("OPT flags:")
.append(flags())
.append(" version:")
.append(version())
.append(" extendedRecode:")
.append(extendedRcode())
.append(" udp:")
.append(dnsClass())
.append(')');
}
}

View File

@ -0,0 +1,104 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.util.internal.UnstableApi;
import java.net.InetAddress;
import java.util.Arrays;
/**
* Default {@link DnsOptEcsRecord} implementation.
*/
@UnstableApi
public final class DefaultDnsOptEcsRecord extends AbstractDnsOptPseudoRrRecord implements DnsOptEcsRecord {
private final int srcPrefixLength;
private final byte[] address;
/**
* Creates a new instance.
*
* @param maxPayloadSize the suggested max payload size in bytes
* @param extendedRcode the extended rcode
* @param version the version
* @param srcPrefixLength the prefix length
* @param address the bytes of the {@link InetAddress} to use
*/
public DefaultDnsOptEcsRecord(int maxPayloadSize, int extendedRcode, int version,
int srcPrefixLength, byte[] address) {
super(maxPayloadSize, extendedRcode, version);
this.srcPrefixLength = srcPrefixLength;
this.address = verifyAddress(address).clone();
}
/**
* Creates a new instance.
*
* @param maxPayloadSize the suggested max payload size in bytes
* @param srcPrefixLength the prefix length
* @param address the bytes of the {@link InetAddress} to use
*/
public DefaultDnsOptEcsRecord(int maxPayloadSize, int srcPrefixLength, byte[] address) {
this(maxPayloadSize, 0, 0, srcPrefixLength, address);
}
/**
* Creates a new instance.
*
* @param maxPayloadSize the suggested max payload size in bytes
* @param protocolFamily the {@link InternetProtocolFamily} to use. This should be the same as the one used to
* send the query.
*/
public DefaultDnsOptEcsRecord(int maxPayloadSize, InternetProtocolFamily protocolFamily) {
this(maxPayloadSize, 0, 0, 0, protocolFamily.localhost().getAddress());
}
private static byte[] verifyAddress(byte[] bytes) {
if (bytes.length == 4 || bytes.length == 16) {
return bytes;
}
throw new IllegalArgumentException("bytes.length must either 4 or 16");
}
@Override
public int sourcePrefixLength() {
return srcPrefixLength;
}
@Override
public int scopePrefixLength() {
return 0;
}
@Override
public byte[] address() {
return address.clone();
}
@Override
public String toString() {
StringBuilder sb = toStringBuilder();
sb.setLength(sb.length() - 1);
return sb.append(" address:")
.append(Arrays.toString(address))
.append(" sourcePrefixLength:")
.append(sourcePrefixLength())
.append(" scopePrefixLength:")
.append(scopePrefixLength())
.append(')').toString();
}
}

View File

@ -17,6 +17,7 @@ package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.UnstableApi;
@ -30,6 +31,7 @@ import static io.netty.handler.codec.dns.DefaultDnsRecordDecoder.ROOT;
*/
@UnstableApi
public class DefaultDnsRecordEncoder implements DnsRecordEncoder {
private static final int PREFIX_MASK = Byte.SIZE - 1;
/**
* Creates a new instance.
@ -49,6 +51,10 @@ public class DefaultDnsRecordEncoder implements DnsRecordEncoder {
encodeQuestion((DnsQuestion) record, out);
} else if (record instanceof DnsPtrRecord) {
encodePtrRecord((DnsPtrRecord) record, out);
} else if (record instanceof DnsOptEcsRecord) {
encodeOptEcsRecord((DnsOptEcsRecord) record, out);
} else if (record instanceof DnsOptPseudoRecord) {
encodeOptPseudoRecord((DnsOptPseudoRecord) record, out);
} else if (record instanceof DnsRawRecord) {
encodeRawRecord((DnsRawRecord) record, out);
} else {
@ -56,22 +62,76 @@ public class DefaultDnsRecordEncoder implements DnsRecordEncoder {
}
}
private void encodePtrRecord(DnsPtrRecord record, ByteBuf out) throws Exception {
private void encodeRecord0(DnsRecord record, ByteBuf out) throws Exception {
encodeName(record.name(), out);
out.writeShort(record.type().intValue());
out.writeShort(record.dnsClass());
out.writeInt((int) record.timeToLive());
}
private void encodePtrRecord(DnsPtrRecord record, ByteBuf out) throws Exception {
encodeRecord0(record, out);
encodeName(record.hostname(), out);
}
private void encodeRawRecord(DnsRawRecord record, ByteBuf out) throws Exception {
encodeName(record.name(), out);
private void encodeOptPseudoRecord(DnsOptPseudoRecord record, ByteBuf out) throws Exception {
encodeRecord0(record, out);
out.writeShort(0);
}
out.writeShort(record.type().intValue());
out.writeShort(record.dnsClass());
out.writeInt((int) record.timeToLive());
private void encodeOptEcsRecord(DnsOptEcsRecord record, ByteBuf out) throws Exception {
encodeRecord0(record, out);
int sourcePrefixLength = record.sourcePrefixLength();
int scopePrefixLength = record.scopePrefixLength();
int lowOrderBitsToPreserve = sourcePrefixLength & PREFIX_MASK;
byte[] bytes = record.address();
int addressBits = bytes.length << 3;
if (addressBits < sourcePrefixLength || sourcePrefixLength < 0) {
throw new IllegalArgumentException(sourcePrefixLength + ": " +
sourcePrefixLength + " (expected: 0 >= " + addressBits + ')');
}
// See http://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml
final short addressNumber = (short) (bytes.length == 4 ?
InternetProtocolFamily.IPv4.addressNumber() : InternetProtocolFamily.IPv6.addressNumber());
int payloadLength = calculateEcsAddressLength(sourcePrefixLength, lowOrderBitsToPreserve);
int fullPayloadLength = 2 + // OPTION-CODE
2 + // OPTION-LENGTH
2 + // FAMILY
1 + // SOURCE PREFIX-LENGTH
1 + // SCOPE PREFIX-LENGTH
payloadLength; // ADDRESS...
out.writeShort(fullPayloadLength);
out.writeShort(8); // This is the defined type for ECS.
out.writeShort(fullPayloadLength - 4); // Not include OPTION-CODE and OPTION-LENGTH
out.writeShort(addressNumber);
out.writeByte(sourcePrefixLength);
out.writeByte(scopePrefixLength); // Must be 0 in queries.
if (lowOrderBitsToPreserve > 0) {
int bytesLength = payloadLength - 1;
out.writeBytes(bytes, 0, bytesLength);
// Pad the leftover of the last byte with zeros.
out.writeByte(padWithZeros(bytes[bytesLength], lowOrderBitsToPreserve));
} else {
// The sourcePrefixLength align with Byte so just copy in the bytes directly.
out.writeBytes(bytes, 0, payloadLength);
}
}
// Package-Private for testing
static int calculateEcsAddressLength(int sourcePrefixLength, int lowOrderBitsToPreserve) {
return (sourcePrefixLength >>> 3) + (lowOrderBitsToPreserve != 0 ? 1 : 0);
}
private void encodeRawRecord(DnsRawRecord record, ByteBuf out) throws Exception {
encodeRecord0(record, out);
ByteBuf content = record.content();
int contentLen = content.readableBytes();
@ -101,4 +161,30 @@ public class DefaultDnsRecordEncoder implements DnsRecordEncoder {
buf.writeByte(0); // marks end of name field
}
// Package private so it can be reused in the test.
static byte padWithZeros(byte b, int lowOrderBitsToPreserve) {
switch (lowOrderBitsToPreserve) {
case 0:
return 0;
case 1:
return (byte) (0x01 & b);
case 2:
return (byte) (0x03 & b);
case 3:
return (byte) (0x07 & b);
case 4:
return (byte) (0x0F & b);
case 5:
return (byte) (0x1F & b);
case 6:
return (byte) (0x3F & b);
case 7:
return (byte) (0x7F & b);
case 8:
return b;
default:
throw new IllegalArgumentException("lowOrderBitsToPreserve: " + lowOrderBitsToPreserve);
}
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.util.internal.UnstableApi;
import java.net.InetAddress;
/**
* An ECS record as defined in <a href="https://tools.ietf.org/html/rfc7871#section-6">Client Subnet in DNS Queries</a>.
*/
@UnstableApi
public interface DnsOptEcsRecord extends DnsOptPseudoRecord {
/**
* Returns the leftmost number of significant bits of ADDRESS to be used for the lookup.
*/
int sourcePrefixLength();
/**
* Returns the leftmost number of significant bits of ADDRESS that the response covers.
* In queries, it MUST be 0.
*/
int scopePrefixLength();
/**
* Retuns the bytes of the {@link InetAddress} to use.
*/
byte[] address();
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.util.internal.UnstableApi;
/**
* An <a href="https://tools.ietf.org/html/rfc6891#section-6.1">OPT RR</a> record.
* <p>
* This is used for <a href="https://tools.ietf.org/html/rfc6891#section-6.1.3">Extension
* Mechanisms for DNS (EDNS(0))</a>.
*/
@UnstableApi
public interface DnsOptPseudoRecord extends DnsRecord {
/**
* Returns the {@code EXTENDED-RCODE} which is encoded into {@link DnsOptPseudoRecord#timeToLive()}.
*/
int extendedRcode();
/**
* Returns the {@code VERSION} which is encoded into {@link DnsOptPseudoRecord#timeToLive()}.
*/
int version();
/**
* Returns the {@code flags} which includes {@code DO} and {@code Z} which is encoded
* into {@link DnsOptPseudoRecord#timeToLive()}.
*/
int flags();
}

View File

@ -16,11 +16,14 @@
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test;
import java.net.InetAddress;
import static org.junit.Assert.assertEquals;
public class DefaultDnsRecordEncoderTest {
@ -63,4 +66,82 @@ public class DefaultDnsRecordEncoderTest {
expectedBuf.release();
}
}
@Test
public void testOptEcsRecordIpv4() throws Exception {
testOptEcsRecordIp(InetAddress.getByName("1.2.3.4"));
}
@Test
public void testOptEcsRecordIpv6() throws Exception {
testOptEcsRecordIp(InetAddress.getByName("::0"));
}
private static void testOptEcsRecordIp(InetAddress address) throws Exception {
int addressBits = address.getAddress().length * Byte.SIZE;
for (int i = 0; i <= addressBits; ++i) {
testIp(address, i);
}
}
private static void testIp(InetAddress address, int prefix) throws Exception {
int lowOrderBitsToPreserve = prefix % Byte.SIZE;
ByteBuf addressPart = Unpooled.wrappedBuffer(address.getAddress(), 0,
DefaultDnsRecordEncoder.calculateEcsAddressLength(prefix, lowOrderBitsToPreserve));
if (lowOrderBitsToPreserve > 0) {
// Pad the leftover of the last byte with zeros.
int idx = addressPart.writerIndex() - 1;
byte lastByte = addressPart.getByte(idx);
addressPart.setByte(idx, DefaultDnsRecordEncoder.padWithZeros(lastByte, lowOrderBitsToPreserve));
}
int payloadSize = nextInt(Short.MAX_VALUE);
int extendedRcode = nextInt(Byte.MAX_VALUE * 2); // Unsigned
int version = nextInt(Byte.MAX_VALUE * 2); // Unsigned
DefaultDnsRecordEncoder encoder = new DefaultDnsRecordEncoder();
ByteBuf out = Unpooled.buffer();
try {
DnsOptEcsRecord record = new DefaultDnsOptEcsRecord(
payloadSize, extendedRcode, version, prefix, address.getAddress());
encoder.encodeRecord(record, out);
assertEquals(0, out.readByte()); // Name
assertEquals(DnsRecordType.OPT.intValue(), out.readUnsignedShort()); // Opt
assertEquals(payloadSize, out.readUnsignedShort()); // payload
assertEquals(record.timeToLive(), out.getUnsignedInt(out.readerIndex()));
// Read unpacked TTL.
assertEquals(extendedRcode, out.readUnsignedByte());
assertEquals(version, out.readUnsignedByte());
assertEquals(extendedRcode, record.extendedRcode());
assertEquals(version, record.version());
assertEquals(0, record.flags());
assertEquals(0, out.readShort());
int payloadLength = out.readUnsignedShort();
assertEquals(payloadLength, out.readableBytes());
assertEquals(8, out.readShort()); // As defined by RFC.
int rdataLength = out.readUnsignedShort();
assertEquals(rdataLength, out.readableBytes());
assertEquals((short) InternetProtocolFamily.of(address).addressNumber(), out.readShort());
assertEquals(prefix, out.readUnsignedByte());
assertEquals(0, out.readUnsignedByte()); // This must be 0 for requests.
assertEquals(addressPart, out);
} finally {
addressPart.release();
out.release();
}
}
private static int nextInt(int max) {
return ThreadLocalRandom.current().nextInt(0, max);
}
}

View File

@ -16,6 +16,7 @@
package io.netty.resolver.dns;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.UnstableApi;
@ -32,6 +33,7 @@ import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
/**
* Default implementation of {@link DnsCache}, backed by a {@link ConcurrentMap}.
* If any additional {@link DnsRecord} is used, no caching takes place.
*/
@UnstableApi
public class DefaultDnsCache implements DnsCache {
@ -115,9 +117,16 @@ public class DefaultDnsCache implements DnsCache {
return removed;
}
private static boolean emptyAdditionals(DnsRecord[] additionals) {
return additionals == null || additionals.length == 0;
}
@Override
public List<DnsCacheEntry> get(String hostname) {
public List<DnsCacheEntry> get(String hostname, DnsRecord[] additionals) {
checkNotNull(hostname, "hostname");
if (!emptyAdditionals(additionals)) {
return null;
}
return resolveCache.get(hostname);
}
@ -135,14 +144,14 @@ public class DefaultDnsCache implements DnsCache {
}
@Override
public void cache(String hostname, InetAddress address, long originalTtl, EventLoop loop) {
if (maxTtl == 0) {
return;
}
public void cache(String hostname, DnsRecord[] additionals,
InetAddress address, long originalTtl, EventLoop loop) {
checkNotNull(hostname, "hostname");
checkNotNull(address, "address");
checkNotNull(loop, "loop");
if (maxTtl == 0 || !emptyAdditionals(additionals)) {
return;
}
final int ttl = Math.max(minTtl, (int) Math.min(maxTtl, originalTtl));
final List<DnsCacheEntry> entries = cachedEntries(hostname);
final DnsCacheEntry e = new DnsCacheEntry(hostname, address);
@ -163,14 +172,14 @@ public class DefaultDnsCache implements DnsCache {
}
@Override
public void cache(String hostname, Throwable cause, EventLoop loop) {
if (negativeTtl == 0) {
return;
}
public void cache(String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop) {
checkNotNull(hostname, "hostname");
checkNotNull(cause, "cause");
checkNotNull(loop, "loop");
if (negativeTtl == 0 || !emptyAdditionals(additionals)) {
return;
}
final List<DnsCacheEntry> entries = cachedEntries(hostname);
final DnsCacheEntry e = new DnsCacheEntry(hostname, cause);

View File

@ -16,6 +16,7 @@
package io.netty.resolver.dns;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.util.internal.UnstableApi;
import java.net.InetAddress;
@ -45,24 +46,27 @@ public interface DnsCache {
/**
* Return the cached entries for the given hostname.
* @param hostname the hostname
* @param additionals the additional records
* @return the cached entries
*/
List<DnsCacheEntry> get(String hostname);
List<DnsCacheEntry> get(String hostname, DnsRecord[] additionals);
/**
* Cache a resolved address for a given hostname.
* @param hostname the hostname
* @param additionals the additional records
* @param address the resolved adresse
* @param originalTtl the TLL as returned by the DNS server
* @param loop the {@link EventLoop} used to register the TTL timeout
*/
void cache(String hostname, InetAddress address, long originalTtl, EventLoop loop);
void cache(String hostname, DnsRecord[] additionals, InetAddress address, long originalTtl, EventLoop loop);
/**
* Cache the resolution failure for a given hostname.
* @param hostname the hostname
* @param additionals the additional records
* @param cause the resolution failure
* @param loop the {@link EventLoop} used to register the TTL timeout
*/
void cache(String hostname, Throwable cause, EventLoop loop);
void cache(String hostname, DnsRecord[] additionals, Throwable cause, EventLoop loop);
}

View File

@ -31,6 +31,7 @@ import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.handler.codec.dns.DatagramDnsQueryEncoder;
import io.netty.handler.codec.dns.DatagramDnsResponse;
import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DatagramDnsResponseDecoder;
import io.netty.handler.codec.dns.DnsQuestion;
@ -55,7 +56,9 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import static io.netty.util.internal.ObjectUtil.*;
@ -69,6 +72,7 @@ public class DnsNameResolver extends InetNameResolver {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsNameResolver.class);
private static final String LOCALHOST = "localhost";
private static final InetAddress LOCALHOST_ADDRESS;
private static final DnsRecord[] EMTPY_ADDITIONALS = new DnsRecord[0];
static final InternetProtocolFamily[] DEFAULT_RESOLVE_ADDRESS_TYPES;
static final String[] DEFAULT_SEACH_DOMAINS;
@ -335,9 +339,108 @@ public class DnsNameResolver extends InetNameResolver {
}
}
/**
* Resolves the specified name into an address.
*
* @param inetHost the name to resolve
* @param additionals additional records ({@code OPT})
*
* @return the address as the result of the resolution
*/
public final Future<InetAddress> resolve(String inetHost, Iterable<DnsRecord> additionals) {
return resolve(inetHost, additionals, executor().<InetAddress>newPromise());
}
/**
* Resolves the specified name into an address.
*
* @param inetHost the name to resolve
* @param additionals additional records ({@code OPT})
* @param promise the {@link Promise} which will be fulfilled when the name resolution is finished
*
* @return the address as the result of the resolution
*/
public final Future<InetAddress> resolve(String inetHost, Iterable<DnsRecord> additionals,
Promise<InetAddress> promise) {
checkNotNull(inetHost, "inetHost");
checkNotNull(promise, "promise");
DnsRecord[] additionalsArray = toArray(additionals, true);
try {
doResolve(inetHost, additionalsArray, promise, resolveCache);
return promise;
} catch (Exception e) {
return promise.setFailure(e);
}
}
/**
* Resolves the specified host name and port into a list of address.
*
* @param inetHost the name to resolve
* @param additionals additional records ({@code OPT})
*
* @return the list of the address as the result of the resolution
*/
public final Future<List<InetAddress>> resolveAll(String inetHost, Iterable<DnsRecord> additionals) {
return resolveAll(inetHost, additionals, executor().<List<InetAddress>>newPromise());
}
/**
* Resolves the specified host name and port into a list of address.
*
* @param inetHost the name to resolve
* @param additionals additional records ({@code OPT})
* @param promise the {@link Promise} which will be fulfilled when the name resolution is finished
*
* @return the list of the address as the result of the resolution
*/
public final Future<List<InetAddress>> resolveAll(String inetHost, Iterable<DnsRecord> additionals,
Promise<List<InetAddress>> promise) {
checkNotNull(inetHost, "inetHost");
checkNotNull(promise, "promise");
DnsRecord[] additionalsArray = toArray(additionals, true);
try {
doResolveAll(inetHost, additionalsArray, promise, resolveCache);
return promise;
} catch (Exception e) {
return promise.setFailure(e);
}
}
@Override
protected void doResolve(String inetHost, Promise<InetAddress> promise) throws Exception {
doResolve(inetHost, promise, resolveCache);
doResolve(inetHost, EMTPY_ADDITIONALS, promise, resolveCache);
}
private static DnsRecord[] toArray(Iterable<DnsRecord> additionals, boolean validateType) {
checkNotNull(additionals, "additionals");
if (additionals instanceof Collection) {
Collection<DnsRecord> records = (Collection<DnsRecord>) additionals;
for (DnsRecord r: additionals) {
validateAdditional(r, validateType);
}
return records.toArray(new DnsRecord[records.size()]);
}
Iterator<DnsRecord> additionalsIt = additionals.iterator();
if (!additionalsIt.hasNext()) {
return EMTPY_ADDITIONALS;
}
List<DnsRecord> records = new ArrayList<DnsRecord>();
do {
DnsRecord r = additionalsIt.next();
validateAdditional(r, validateType);
records.add(r);
} while (additionalsIt.hasNext());
return records.toArray(new DnsRecord[records.size()]);
}
private static void validateAdditional(DnsRecord record, boolean validateType) {
checkNotNull(record, "record");
if (validateType && record instanceof DnsRawRecord) {
throw new IllegalArgumentException("DnsRawRecord implementations not allowed: " + record);
}
}
/**
@ -345,6 +448,7 @@ public class DnsNameResolver extends InetNameResolver {
* instead of using the global one.
*/
protected void doResolve(String inetHost,
DnsRecord[] additionals,
Promise<InetAddress> promise,
DnsCache resolveCache) throws Exception {
final byte[] bytes = NetUtil.createByteArrayFromIpAddressString(inetHost);
@ -362,15 +466,16 @@ public class DnsNameResolver extends InetNameResolver {
return;
}
if (!doResolveCached(hostname, promise, resolveCache)) {
doResolveUncached(hostname, promise, resolveCache);
if (!doResolveCached(hostname, additionals, promise, resolveCache)) {
doResolveUncached(hostname, additionals, promise, resolveCache);
}
}
private boolean doResolveCached(String hostname,
DnsRecord[] additionals,
Promise<InetAddress> promise,
DnsCache resolveCache) {
final List<DnsCacheEntry> cachedEntries = resolveCache.get(hostname);
final List<DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
if (cachedEntries == null || cachedEntries.isEmpty()) {
return false;
}
@ -398,41 +503,47 @@ public class DnsNameResolver extends InetNameResolver {
}
if (address != null) {
setSuccess(promise, address);
} else if (cause != null) {
if (!promise.tryFailure(cause)) {
logger.warn("Failed to notify failure to a promise: {}", promise, cause);
}
} else {
return false;
trySuccess(promise, address);
return true;
}
return true;
if (cause != null) {
tryFailure(promise, cause);
return true;
}
return false;
}
private static void setSuccess(Promise<InetAddress> promise, InetAddress result) {
private static <T> void trySuccess(Promise<T> promise, T result) {
if (!promise.trySuccess(result)) {
logger.warn("Failed to notify success ({}) to a promise: {}", result, promise);
}
}
private static void tryFailure(Promise<?> promise, Throwable cause) {
if (!promise.tryFailure(cause)) {
logger.warn("Failed to notify failure to a promise: {}", promise, cause);
}
}
private void doResolveUncached(String hostname,
DnsRecord[] additionals,
Promise<InetAddress> promise,
DnsCache resolveCache) {
SingleResolverContext ctx = new SingleResolverContext(this, hostname, resolveCache);
SingleResolverContext ctx = new SingleResolverContext(this, hostname, additionals, resolveCache);
ctx.resolve(promise);
}
final class SingleResolverContext extends DnsNameResolverContext<InetAddress> {
static final class SingleResolverContext extends DnsNameResolverContext<InetAddress> {
SingleResolverContext(DnsNameResolver parent, String hostname, DnsCache resolveCache) {
super(parent, hostname, resolveCache);
SingleResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache) {
super(parent, hostname, additionals, resolveCache);
}
@Override
DnsNameResolverContext<InetAddress> newResolverContext(DnsNameResolver parent,
String hostname, DnsCache resolveCache) {
return new SingleResolverContext(parent, hostname, resolveCache);
DnsNameResolverContext<InetAddress> newResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache) {
return new SingleResolverContext(parent, hostname, additionals, resolveCache);
}
@Override
@ -444,7 +555,7 @@ public class DnsNameResolver extends InetNameResolver {
for (int i = 0; i < numEntries; i++) {
final InetAddress a = resolvedEntries.get(i).address();
if (addressType.isInstance(a)) {
setSuccess(promise, a);
trySuccess(promise, a);
return true;
}
}
@ -454,7 +565,7 @@ public class DnsNameResolver extends InetNameResolver {
@Override
protected void doResolveAll(String inetHost, Promise<List<InetAddress>> promise) throws Exception {
doResolveAll(inetHost, promise, resolveCache);
doResolveAll(inetHost, EMTPY_ADDITIONALS, promise, resolveCache);
}
/**
@ -462,9 +573,9 @@ public class DnsNameResolver extends InetNameResolver {
* instead of using the global one.
*/
protected void doResolveAll(String inetHost,
DnsRecord[] additionals,
Promise<List<InetAddress>> promise,
DnsCache resolveCache) throws Exception {
final byte[] bytes = NetUtil.createByteArrayFromIpAddressString(inetHost);
if (bytes != null) {
// The unresolvedAddress was created via a String that contains an ipaddress.
@ -480,15 +591,16 @@ public class DnsNameResolver extends InetNameResolver {
return;
}
if (!doResolveAllCached(hostname, promise, resolveCache)) {
doResolveAllUncached(hostname, promise, resolveCache);
if (!doResolveAllCached(hostname, additionals, promise, resolveCache)) {
doResolveAllUncached(hostname, additionals, promise, resolveCache);
}
}
private boolean doResolveAllCached(String hostname,
DnsRecord[] additionals,
Promise<List<InetAddress>> promise,
DnsCache resolveCache) {
final List<DnsCacheEntry> cachedEntries = resolveCache.get(hostname);
final List<DnsCacheEntry> cachedEntries = resolveCache.get(hostname, additionals);
if (cachedEntries == null || cachedEntries.isEmpty()) {
return false;
}
@ -517,25 +629,26 @@ public class DnsNameResolver extends InetNameResolver {
}
if (result != null) {
promise.trySuccess(result);
} else if (cause != null) {
promise.tryFailure(cause);
} else {
return false;
trySuccess(promise, result);
return true;
}
return true;
if (cause != null) {
tryFailure(promise, cause);
return true;
}
return false;
}
final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {
ListResolverContext(DnsNameResolver parent, String hostname, DnsCache resolveCache) {
super(parent, hostname, resolveCache);
static final class ListResolverContext extends DnsNameResolverContext<List<InetAddress>> {
ListResolverContext(DnsNameResolver parent, String hostname,
DnsRecord[] additionals, DnsCache resolveCache) {
super(parent, hostname, additionals, resolveCache);
}
@Override
DnsNameResolverContext<List<InetAddress>> newResolverContext(DnsNameResolver parent, String hostname,
DnsCache resolveCache) {
return new ListResolverContext(parent, hostname, resolveCache);
DnsNameResolverContext<List<InetAddress>> newResolverContext(
DnsNameResolver parent, String hostname, DnsRecord[] additionals, DnsCache resolveCache) {
return new ListResolverContext(parent, hostname, additionals, resolveCache);
}
@Override
@ -564,9 +677,11 @@ public class DnsNameResolver extends InetNameResolver {
}
private void doResolveAllUncached(String hostname,
DnsRecord[] additionals,
Promise<List<InetAddress>> promise,
DnsCache resolveCache) {
DnsNameResolverContext<List<InetAddress>> ctx = new ListResolverContext(this, hostname, resolveCache);
DnsNameResolverContext<List<InetAddress>> ctx = new ListResolverContext(
this, hostname, additionals, resolveCache);
ctx.resolve(promise);
}
@ -590,8 +705,8 @@ public class DnsNameResolver extends InetNameResolver {
* Sends a DNS query with the specified question with additional records.
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
DnsQuestion question, Iterable<DnsRecord> additional) {
return query(nextNameServerAddress(), question, additional);
DnsQuestion question, Iterable<DnsRecord> additionals) {
return query(nextNameServerAddress(), question, additionals);
}
/**
@ -612,7 +727,7 @@ public class DnsNameResolver extends InetNameResolver {
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
InetSocketAddress nameServerAddr, DnsQuestion question) {
return query0(nameServerAddr, question, Collections.<DnsRecord>emptyList(),
return query0(nameServerAddr, question, EMTPY_ADDITIONALS,
ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
}
@ -620,9 +735,9 @@ public class DnsNameResolver extends InetNameResolver {
* Sends a DNS query with the specified question with additional records using the specified name server list.
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
InetSocketAddress nameServerAddr, DnsQuestion question, Iterable<DnsRecord> additional) {
InetSocketAddress nameServerAddr, DnsQuestion question, Iterable<DnsRecord> additionals) {
return query0(nameServerAddr, question, additional,
return query0(nameServerAddr, question, toArray(additionals, false),
ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
}
@ -633,7 +748,7 @@ public class DnsNameResolver extends InetNameResolver {
InetSocketAddress nameServerAddr, DnsQuestion question,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
return query0(nameServerAddr, question, Collections.<DnsRecord>emptyList(), promise);
return query0(nameServerAddr, question, EMTPY_ADDITIONALS, promise);
}
/**
@ -641,21 +756,21 @@ public class DnsNameResolver extends InetNameResolver {
*/
public Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query(
InetSocketAddress nameServerAddr, DnsQuestion question,
Iterable<DnsRecord> additional,
Iterable<DnsRecord> additionals,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
return query0(nameServerAddr, question, additional, promise);
return query0(nameServerAddr, question, toArray(additionals, false), promise);
}
private Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
InetSocketAddress nameServerAddr, DnsQuestion question,
Iterable<DnsRecord> additional,
DnsRecord[] additionals,
Promise<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>> promise) {
final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> castPromise = cast(
checkNotNull(promise, "promise"));
try {
new DnsQueryContext(this, nameServerAddr, question, additional, castPromise).query();
new DnsQueryContext(this, nameServerAddr, question, additionals, castPromise).query();
return castPromise;
} catch (Exception e) {
return castPromise.setFailure(e);

View File

@ -74,6 +74,7 @@ abstract class DnsNameResolverContext<T> {
private final boolean traceEnabled;
private final int maxAllowedQueries;
private final InternetProtocolFamily[] resolveAddressTypes;
private final DnsRecord[] additionals;
private final Set<Future<AddressedEnvelope<DnsResponse, InetSocketAddress>>> queriesInProgress =
Collections.newSetFromMap(
@ -86,9 +87,11 @@ abstract class DnsNameResolverContext<T> {
protected DnsNameResolverContext(DnsNameResolver parent,
String hostname,
DnsRecord[] additionals,
DnsCache resolveCache) {
this.parent = parent;
this.hostname = hostname;
this.additionals = additionals;
this.resolveCache = resolveCache;
nameServerAddrs = parent.nameServerAddresses.stream();
@ -114,9 +117,9 @@ abstract class DnsNameResolverContext<T> {
} else if (count < parent.searchDomains().length) {
String searchDomain = parent.searchDomains()[count++];
Promise<T> nextPromise = parent.executor().newPromise();
String nextHostname = DnsNameResolverContext.this.hostname + "." + searchDomain;
String nextHostname = hostname + '.' + searchDomain;
DnsNameResolverContext<T> nextContext = newResolverContext(parent,
nextHostname, resolveCache);
nextHostname, additionals, resolveCache);
nextContext.pristineHostname = hostname;
nextContext.internalResolve(nextPromise);
nextPromise.addListener(this);
@ -167,7 +170,9 @@ abstract class DnsNameResolverContext<T> {
allowedQueries --;
final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> f = parent.query(nameServerAddr, question);
final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> f = parent.query0(
nameServerAddr, question, additionals,
parent.ch.eventLoop().<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
queriesInProgress.add(f);
f.addListener(new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@ -288,7 +293,7 @@ abstract class DnsNameResolverContext<T> {
}
final DnsCacheEntry e = new DnsCacheEntry(hostname, resolved);
resolveCache.cache(hostname, resolved, r.timeToLive(), parent.ch.eventLoop());
resolveCache.cache(hostname, additionals, resolved, r.timeToLive(), parent.ch.eventLoop());
resolvedEntries.add(e);
found = true;
@ -475,7 +480,7 @@ abstract class DnsNameResolverContext<T> {
}
final UnknownHostException cause = new UnknownHostException(buf.toString());
resolveCache.cache(hostname, cause, parent.ch.eventLoop());
resolveCache.cache(hostname, additionals, cause, parent.ch.eventLoop());
promise.tryFailure(cause);
}
@ -483,7 +488,7 @@ abstract class DnsNameResolverContext<T> {
Promise<T> promise);
abstract DnsNameResolverContext<T> newResolverContext(DnsNameResolver parent, String hostname,
DnsCache resolveCache);
DnsRecord[] additionals, DnsCache resolveCache);
static String decodeDomainName(ByteBuf in) {
in.markReaderIndex();

View File

@ -15,24 +15,21 @@
*/
package io.netty.resolver.dns;
import io.netty.buffer.Unpooled;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
import io.netty.handler.codec.dns.DnsQuery;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
@ -49,7 +46,7 @@ final class DnsQueryContext {
private final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise;
private final int id;
private final DnsQuestion question;
private final Iterable<DnsRecord> additional;
private final DnsRecord[] additionals;
private final DnsRecord optResource;
private final InetSocketAddress nameServerAddr;
@ -59,20 +56,21 @@ final class DnsQueryContext {
DnsQueryContext(DnsNameResolver parent,
InetSocketAddress nameServerAddr,
DnsQuestion question,
Iterable<DnsRecord> additional,
DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
this.parent = checkNotNull(parent, "parent");
this.nameServerAddr = checkNotNull(nameServerAddr, "nameServerAddr");
this.question = checkNotNull(question, "question");
this.additional = checkNotNull(additional, "additional");
this.additionals = checkNotNull(additionals, "additionals");
this.promise = checkNotNull(promise, "promise");
recursionDesired = parent.isRecursionDesired();
id = parent.queryContextManager.add(this);
if (parent.isOptResourceEnabled()) {
optResource = new DefaultDnsRawRecord(
StringUtil.EMPTY_STRING, DnsRecordType.OPT, parent.maxPayloadSize(), 0, Unpooled.EMPTY_BUFFER);
optResource = new AbstractDnsOptPseudoRrRecord(parent.maxPayloadSize(), 0, 0) {
// We may want to remove this in the future and let the user just specify the opt record in the query.
};
} else {
optResource = null;
}
@ -95,9 +93,10 @@ final class DnsQueryContext {
query.addRecord(DnsSection.QUESTION, question);
for (DnsRecord record:additional) {
for (DnsRecord record: additionals) {
query.addRecord(DnsSection.ADDITIONAL, record);
}
if (optResource != null) {
query.addRecord(DnsSection.ADDITIONAL, optResource);
}

View File

@ -16,6 +16,7 @@
package io.netty.resolver.dns;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.util.internal.UnstableApi;
import java.net.InetAddress;
@ -46,16 +47,17 @@ public final class NoopDnsCache implements DnsCache {
}
@Override
public List<DnsCacheEntry> get(String hostname) {
public List<DnsCacheEntry> get(String hostname, DnsRecord[] additionals) {
return Collections.emptyList();
}
@Override
public void cache(String hostname, InetAddress address, long originalTtl, EventLoop loop) {
public void cache(String hostname, DnsRecord[] additional,
InetAddress address, long originalTtl, EventLoop loop) {
}
@Override
public void cache(String hostname, Throwable cause, EventLoop loop) {
public void cache(String hostname, DnsRecord[] additional, Throwable cause, EventLoop loop) {
}
@Override

View File

@ -0,0 +1,66 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.handler.codec.dns.DefaultDnsOptEcsRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.util.concurrent.Future;
import org.junit.Ignore;
import org.junit.Test;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
public class DnsNameResolverClientSubnetTest {
// See https://www.gsic.uva.es/~jnisigl/dig-edns-client-subnet.html
// Ignore as this needs to query real DNS servers.
@Ignore
@Test
public void testSubnetQuery() throws Exception {
EventLoopGroup group = new NioEventLoopGroup(1);
DnsNameResolver resolver = newResolver(group).build();
try {
// Same as:
// # /.bind-9.9.3-edns/bin/dig @ns1.google.com www.google.es +client=157.88.0.0/24
Future<List<InetAddress>> future = resolver.resolveAll("www.google.es",
Collections.<DnsRecord>singleton(
// Suggest max payload size of 1024
// 157.88.0.0 / 24
new DefaultDnsOptEcsRecord(1024, 24, InetAddress.getByName("157.88.0.0").getAddress())));
for (InetAddress address: future.syncUninterruptibly().getNow()) {
System.err.println(address);
}
} finally {
resolver.close();
group.shutdownGracefully(0, 0, TimeUnit.SECONDS);
}
}
private static DnsNameResolverBuilder newResolver(EventLoopGroup group) {
return new DnsNameResolverBuilder(group.next())
.channelType(NioDatagramChannel.class)
.nameServerAddresses(DnsServerAddresses.singleton(new InetSocketAddress("8.8.8.8", 53)))
.maxQueriesPerResolve(1)
.optResourceEnabled(false);
}
}

View File

@ -15,6 +15,8 @@
*/
package io.netty.channel.socket;
import io.netty.util.NetUtil;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
@ -27,9 +29,13 @@ public enum InternetProtocolFamily {
IPv6(Inet6Address.class);
private final Class<? extends InetAddress> addressType;
private final int addressNumber;
private final InetAddress localHost;
InternetProtocolFamily(Class<? extends InetAddress> addressType) {
this.addressType = addressType;
addressNumber = addressNumber(addressType);
localHost = localhost(addressType);
}
/**
@ -38,4 +44,53 @@ public enum InternetProtocolFamily {
public Class<? extends InetAddress> addressType() {
return addressType;
}
/**
* Returns the
* <a href="http://www.iana.org/assignments/address-family-numbers/address-family-numbers.xhtml">address number</a>
* of the family.
*/
public int addressNumber() {
return addressNumber;
}
/**
* Returns the {@link InetAddress} that represent the {@code LOCALHOST} for the family.
*/
public InetAddress localhost() {
return localHost;
}
private static InetAddress localhost(Class<? extends InetAddress> addressType) {
if (addressType.isAssignableFrom(Inet4Address.class)) {
return NetUtil.LOCALHOST4;
}
if (addressType.isAssignableFrom(Inet6Address.class)) {
return NetUtil.LOCALHOST6;
}
throw new Error();
}
private static int addressNumber(Class<? extends InetAddress> addressType) {
if (addressType.isAssignableFrom(Inet4Address.class)) {
return 1;
}
if (addressType.isAssignableFrom(Inet6Address.class)) {
return 2;
}
throw new IllegalArgumentException("addressType " + addressType + " not supported");
}
/**
* Returns the {@link InternetProtocolFamily} for the given {@link InetAddress}.
*/
public static InternetProtocolFamily of(InetAddress address) {
if (address instanceof Inet4Address) {
return IPv4;
}
if (address instanceof Inet6Address) {
return IPv6;
}
throw new IllegalArgumentException("address " + address + " not supported");
}
}