Add support for TCP fallback when we receive a truncated DnsResponse (#9139)

Motivation:

Sometimes DNS responses can be very large which mean they will not fit in a UDP packet. When this is happening the DNS server will set the TC flag (truncated flag) to tell the resolver that the response was truncated. When a truncated response was received we should allow to retry via TCP and use the received response (if possible) as a replacement for the truncated one.

See https://tools.ietf.org/html/rfc7766.

Modifications:

- Add support for TCP fallback by allow to specify a socketChannelFactory / socketChannelType on the DnsNameResolverBuilder. If this is set to something different then null we will try to fallback to TCP.
- Add decoder / encoder for TCP
- Add unit tests

Result:

Support for TCP fallback as defined by https://tools.ietf.org/html/rfc7766 when using DnsNameResolver.
This commit is contained in:
Norman Maurer 2019-05-17 14:37:11 +02:00 committed by GitHub
parent ccf56706f8
commit 1672b6d12c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 842 additions and 126 deletions

View File

@ -36,7 +36,7 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull;
@ChannelHandler.Sharable
public class DatagramDnsQueryEncoder extends MessageToMessageEncoder<AddressedEnvelope<DnsQuery, InetSocketAddress>> {
private final DnsRecordEncoder recordEncoder;
private final DnsQueryEncoder encoder;
/**
* Creates a new encoder with {@linkplain DnsRecordEncoder#DEFAULT the default record encoder}.
@ -49,7 +49,7 @@ public class DatagramDnsQueryEncoder extends MessageToMessageEncoder<AddressedEn
* Creates a new encoder with the specified {@code recordEncoder}.
*/
public DatagramDnsQueryEncoder(DnsRecordEncoder recordEncoder) {
this.recordEncoder = checkNotNull(recordEncoder, "recordEncoder");
this.encoder = new DnsQueryEncoder(recordEncoder);
}
@Override
@ -63,9 +63,7 @@ public class DatagramDnsQueryEncoder extends MessageToMessageEncoder<AddressedEn
boolean success = false;
try {
encodeHeader(query, buf);
encodeQuestions(query, buf);
encodeRecords(query, DnsSection.ADDITIONAL, buf);
encoder.encode(query, buf);
success = true;
} finally {
if (!success) {
@ -85,38 +83,4 @@ public class DatagramDnsQueryEncoder extends MessageToMessageEncoder<AddressedEn
@SuppressWarnings("unused") AddressedEnvelope<DnsQuery, InetSocketAddress> msg) throws Exception {
return ctx.alloc().ioBuffer(1024);
}
/**
* Encodes the header that is always 12 bytes long.
*
* @param query the query header being encoded
* @param buf the buffer the encoded data should be written to
*/
private static void encodeHeader(DnsQuery query, ByteBuf buf) {
buf.writeShort(query.id());
int flags = 0;
flags |= (query.opCode().byteValue() & 0xFF) << 14;
if (query.isRecursionDesired()) {
flags |= 1 << 8;
}
buf.writeShort(flags);
buf.writeShort(query.count(DnsSection.QUESTION));
buf.writeShort(0); // answerCount
buf.writeShort(0); // authorityResourceCount
buf.writeShort(query.count(DnsSection.ADDITIONAL));
}
private void encodeQuestions(DnsQuery query, ByteBuf buf) throws Exception {
final int count = query.count(DnsSection.QUESTION);
for (int i = 0; i < count; i++) {
recordEncoder.encodeQuestion((DnsQuestion) query.recordAt(DnsSection.QUESTION, i), buf);
}
}
private void encodeRecords(DnsQuery query, DnsSection section, ByteBuf buf) throws Exception {
final int count = query.count(section);
for (int i = 0; i < count; i++) {
recordEncoder.encodeRecord(query.recordAt(section, i), buf);
}
}
}

View File

@ -15,18 +15,15 @@
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.internal.UnstableApi;
import java.net.InetSocketAddress;
import java.util.List;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
/**
* Decodes a {@link DatagramPacket} into a {@link DatagramDnsResponse}.
*/
@ -34,7 +31,7 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull;
@ChannelHandler.Sharable
public class DatagramDnsResponseDecoder extends MessageToMessageDecoder<DatagramPacket> {
private final DnsRecordDecoder recordDecoder;
private final DnsResponseDecoder<InetSocketAddress> responseDecoder;
/**
* Creates a new decoder with {@linkplain DnsRecordDecoder#DEFAULT the default record decoder}.
@ -47,73 +44,17 @@ public class DatagramDnsResponseDecoder extends MessageToMessageDecoder<Datagram
* Creates a new decoder with the specified {@code recordDecoder}.
*/
public DatagramDnsResponseDecoder(DnsRecordDecoder recordDecoder) {
this.recordDecoder = checkNotNull(recordDecoder, "recordDecoder");
this.responseDecoder = new DnsResponseDecoder<InetSocketAddress>(recordDecoder) {
@Override
protected DnsResponse newResponse(InetSocketAddress sender, InetSocketAddress recipient,
int id, DnsOpCode opCode, DnsResponseCode responseCode) {
return new DatagramDnsResponse(sender, recipient, id, opCode, responseCode);
}
};
}
@Override
protected void decode(ChannelHandlerContext ctx, DatagramPacket packet, List<Object> out) throws Exception {
final ByteBuf buf = packet.content();
final DnsResponse response = newResponse(packet, buf);
boolean success = false;
try {
final int questionCount = buf.readUnsignedShort();
final int answerCount = buf.readUnsignedShort();
final int authorityRecordCount = buf.readUnsignedShort();
final int additionalRecordCount = buf.readUnsignedShort();
decodeQuestions(response, buf, questionCount);
decodeRecords(response, DnsSection.ANSWER, buf, answerCount);
decodeRecords(response, DnsSection.AUTHORITY, buf, authorityRecordCount);
decodeRecords(response, DnsSection.ADDITIONAL, buf, additionalRecordCount);
out.add(response);
success = true;
} finally {
if (!success) {
response.release();
}
}
}
private static DnsResponse newResponse(DatagramPacket packet, ByteBuf buf) {
final int id = buf.readUnsignedShort();
final int flags = buf.readUnsignedShort();
if (flags >> 15 == 0) {
throw new CorruptedFrameException("not a response");
}
final DnsResponse response = new DatagramDnsResponse(
packet.sender(),
packet.recipient(),
id,
DnsOpCode.valueOf((byte) (flags >> 11 & 0xf)), DnsResponseCode.valueOf((byte) (flags & 0xf)));
response.setRecursionDesired((flags >> 8 & 1) == 1);
response.setAuthoritativeAnswer((flags >> 10 & 1) == 1);
response.setTruncated((flags >> 9 & 1) == 1);
response.setRecursionAvailable((flags >> 7 & 1) == 1);
response.setZ(flags >> 4 & 0x7);
return response;
}
private void decodeQuestions(DnsResponse response, ByteBuf buf, int questionCount) throws Exception {
for (int i = questionCount; i > 0; i --) {
response.addRecord(DnsSection.QUESTION, recordDecoder.decodeQuestion(buf));
}
}
private void decodeRecords(
DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception {
for (int i = count; i > 0; i --) {
final DnsRecord r = recordDecoder.decodeRecord(buf);
if (r == null) {
// Truncated response
break;
}
response.addRecord(section, r);
}
out.add(responseDecoder.decode(packet.sender(), packet.recipient(), packet.content()));
}
}

View File

@ -0,0 +1,75 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
final class DnsQueryEncoder {
private final DnsRecordEncoder recordEncoder;
/**
* Creates a new encoder with the specified {@code recordEncoder}.
*/
DnsQueryEncoder(DnsRecordEncoder recordEncoder) {
this.recordEncoder = checkNotNull(recordEncoder, "recordEncoder");
}
/**
* Encodes the given {@link DnsQuery} into a {@link ByteBuf}.
*/
void encode(DnsQuery query, ByteBuf out) throws Exception {
encodeHeader(query, out);
encodeQuestions(query, out);
encodeRecords(query, DnsSection.ADDITIONAL, out);
}
/**
* Encodes the header that is always 12 bytes long.
*
* @param query the query header being encoded
* @param buf the buffer the encoded data should be written to
*/
private static void encodeHeader(DnsQuery query, ByteBuf buf) {
buf.writeShort(query.id());
int flags = 0;
flags |= (query.opCode().byteValue() & 0xFF) << 14;
if (query.isRecursionDesired()) {
flags |= 1 << 8;
}
buf.writeShort(flags);
buf.writeShort(query.count(DnsSection.QUESTION));
buf.writeShort(0); // answerCount
buf.writeShort(0); // authorityResourceCount
buf.writeShort(query.count(DnsSection.ADDITIONAL));
}
private void encodeQuestions(DnsQuery query, ByteBuf buf) throws Exception {
final int count = query.count(DnsSection.QUESTION);
for (int i = 0; i < count; i++) {
recordEncoder.encodeQuestion((DnsQuestion) query.recordAt(DnsSection.QUESTION, i), buf);
}
}
private void encodeRecords(DnsQuery query, DnsSection section, ByteBuf buf) throws Exception {
final int count = query.count(section);
for (int i = 0; i < count; i++) {
recordEncoder.encodeRecord(query.recordAt(section, i), buf);
}
}
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.CorruptedFrameException;
import java.net.SocketAddress;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
abstract class DnsResponseDecoder<A extends SocketAddress> {
private final DnsRecordDecoder recordDecoder;
/**
* Creates a new decoder with the specified {@code recordDecoder}.
*/
DnsResponseDecoder(DnsRecordDecoder recordDecoder) {
this.recordDecoder = checkNotNull(recordDecoder, "recordDecoder");
}
final DnsResponse decode(A sender, A recipient, ByteBuf buffer) throws Exception {
final int id = buffer.readUnsignedShort();
final int flags = buffer.readUnsignedShort();
if (flags >> 15 == 0) {
throw new CorruptedFrameException("not a response");
}
final DnsResponse response = newResponse(
sender,
recipient,
id,
DnsOpCode.valueOf((byte) (flags >> 11 & 0xf)), DnsResponseCode.valueOf((byte) (flags & 0xf)));
response.setRecursionDesired((flags >> 8 & 1) == 1);
response.setAuthoritativeAnswer((flags >> 10 & 1) == 1);
response.setTruncated((flags >> 9 & 1) == 1);
response.setRecursionAvailable((flags >> 7 & 1) == 1);
response.setZ(flags >> 4 & 0x7);
boolean success = false;
try {
final int questionCount = buffer.readUnsignedShort();
final int answerCount = buffer.readUnsignedShort();
final int authorityRecordCount = buffer.readUnsignedShort();
final int additionalRecordCount = buffer.readUnsignedShort();
decodeQuestions(response, buffer, questionCount);
decodeRecords(response, DnsSection.ANSWER, buffer, answerCount);
decodeRecords(response, DnsSection.AUTHORITY, buffer, authorityRecordCount);
decodeRecords(response, DnsSection.ADDITIONAL, buffer, additionalRecordCount);
success = true;
return response;
} finally {
if (!success) {
response.release();
}
}
}
protected abstract DnsResponse newResponse(A sender, A recipient, int id,
DnsOpCode opCode, DnsResponseCode responseCode) throws Exception;
private void decodeQuestions(DnsResponse response, ByteBuf buf, int questionCount) throws Exception {
for (int i = questionCount; i > 0; i --) {
response.addRecord(DnsSection.QUESTION, recordDecoder.decodeQuestion(buf));
}
}
private void decodeRecords(
DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception {
for (int i = count; i > 0; i --) {
final DnsRecord r = recordDecoder.decodeRecord(buf);
if (r == null) {
// Truncated response
break;
}
response.addRecord(section, r);
}
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.util.internal.UnstableApi;
@ChannelHandler.Sharable
@UnstableApi
public final class TcpDnsQueryEncoder extends MessageToByteEncoder<DnsQuery> {
private final DnsQueryEncoder encoder;
/**
* Creates a new encoder with {@linkplain DnsRecordEncoder#DEFAULT the default record encoder}.
*/
public TcpDnsQueryEncoder() {
this(DnsRecordEncoder.DEFAULT);
}
/**
* Creates a new encoder with the specified {@code recordEncoder}.
*/
public TcpDnsQueryEncoder(DnsRecordEncoder recordEncoder) {
this.encoder = new DnsQueryEncoder(recordEncoder);
}
@Override
protected void encode(ChannelHandlerContext ctx, DnsQuery msg, ByteBuf out) throws Exception {
// Length is two octets as defined by RFC-7766
// See https://tools.ietf.org/html/rfc7766#section-8
out.writerIndex(out.writerIndex() + 2);
encoder.encode(msg, out);
// Now fill in the correct length based on the amount of data that we wrote the ByteBuf.
out.setShort(0, out.readableBytes() - 2);
}
@Override
protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, @SuppressWarnings("unused") DnsQuery msg,
boolean preferDirect) {
if (preferDirect) {
return ctx.alloc().ioBuffer(1024);
} else {
return ctx.alloc().heapBuffer(1024);
}
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.internal.UnstableApi;
import java.net.SocketAddress;
@UnstableApi
public final class TcpDnsResponseDecoder extends LengthFieldBasedFrameDecoder {
private final DnsResponseDecoder<SocketAddress> responseDecoder;
/**
* Creates a new decoder with {@linkplain DnsRecordDecoder#DEFAULT the default record decoder}.
*/
public TcpDnsResponseDecoder() {
this(DnsRecordDecoder.DEFAULT, 64 * 1024);
}
/**
* Creates a new decoder with the specified {@code recordDecoder} and {@code maxFrameLength}
*/
public TcpDnsResponseDecoder(DnsRecordDecoder recordDecoder, int maxFrameLength) {
// Length is two octets as defined by RFC-7766
// See https://tools.ietf.org/html/rfc7766#section-8
super(maxFrameLength, 0, 2, 0, 2);
this.responseDecoder = new DnsResponseDecoder<SocketAddress>(recordDecoder) {
@Override
protected DnsResponse newResponse(SocketAddress sender, SocketAddress recipient,
int id, DnsOpCode opCode, DnsResponseCode responseCode) {
return new DefaultDnsResponse(id, opCode, responseCode);
}
};
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf frame = (ByteBuf) super.decode(ctx, in);
if (frame == null) {
return null;
}
try {
return responseDecoder.decode(ctx.channel().remoteAddress(), ctx.channel().localAddress(), frame.slice());
} finally {
frame.release();
}
}
@Override
protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) {
return buffer.copy(index, length);
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.handler.codec.dns.DnsQuery;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.util.concurrent.Promise;
import java.net.InetSocketAddress;
final class DatagramDnsQueryContext extends DnsQueryContext {
DatagramDnsQueryContext(DnsNameResolver parent, InetSocketAddress nameServerAddr, DnsQuestion question,
DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(parent, nameServerAddr, question, additionals, promise);
}
@Override
protected DnsQuery newQuery(int id) {
return new DatagramDnsQuery(null, nameServerAddr(), id);
}
@Override
protected Channel channel() {
return parent().ch;
}
@Override
protected String protocol() {
return "UDP";
}
}

View File

@ -32,6 +32,7 @@ import io.netty.channel.EventLoop;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.dns.DatagramDnsQueryEncoder;
import io.netty.handler.codec.dns.DatagramDnsResponse;
import io.netty.handler.codec.dns.DatagramDnsResponseDecoder;
@ -41,6 +42,8 @@ import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
import io.netty.resolver.HostsFileEntries;
import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.InetNameResolver;
@ -66,6 +69,7 @@ import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.SocketAddress;
import java.net.SocketException;
import java.util.ArrayList;
import java.util.Collection;
@ -182,8 +186,9 @@ public class DnsNameResolver extends InetNameResolver {
return (List<String>) nameservers.invoke(instance);
}
private static final DatagramDnsResponseDecoder DECODER = new DatagramDnsResponseDecoder();
private static final DatagramDnsQueryEncoder ENCODER = new DatagramDnsQueryEncoder();
private static final DatagramDnsResponseDecoder DATAGRAM_DECODER = new DatagramDnsResponseDecoder();
private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder();
private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();
final Future<Channel> channelFuture;
final Channel ch;
@ -228,6 +233,7 @@ public class DnsNameResolver extends InetNameResolver {
private final boolean decodeIdn;
private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory;
private final boolean completeOncePreferredResolved;
private final ChannelFactory<? extends SocketChannel> socketChannelFactory;
/**
* Creates a new DNS-based name resolver that communicates with the specified list of DNS servers.
@ -326,7 +332,7 @@ public class DnsNameResolver extends InetNameResolver {
String[] searchDomains,
int ndots,
boolean decodeIdn) {
this(eventLoop, channelFactory, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache,
this(eventLoop, channelFactory, null, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache,
dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired,
maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver,
dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false);
@ -335,6 +341,7 @@ public class DnsNameResolver extends InetNameResolver {
DnsNameResolver(
EventLoop eventLoop,
ChannelFactory<? extends DatagramChannel> channelFactory,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
final DnsCache resolveCache,
final DnsCnameCache cnameCache,
final AuthoritativeDnsServerCache authoritativeDnsServerCache,
@ -374,7 +381,7 @@ public class DnsNameResolver extends InetNameResolver {
this.ndots = ndots >= 0 ? ndots : DEFAULT_NDOTS;
this.decodeIdn = decodeIdn;
this.completeOncePreferredResolved = completeOncePreferredResolved;
this.socketChannelFactory = socketChannelFactory;
switch (this.resolvedAddressTypes) {
case IPV4_ONLY:
supportsAAAARecords = false;
@ -414,8 +421,8 @@ public class DnsNameResolver extends InetNameResolver {
final DnsResponseHandler responseHandler = new DnsResponseHandler(executor().<Channel>newPromise());
b.handler(new ChannelInitializer<DatagramChannel>() {
@Override
protected void initChannel(DatagramChannel ch) throws Exception {
ch.pipeline().addLast(DECODER, ENCODER, responseHandler);
protected void initChannel(DatagramChannel ch) {
ch.pipeline().addLast(DATAGRAM_ENCODER, DATAGRAM_DECODER, responseHandler);
}
});
@ -1141,7 +1148,7 @@ public class DnsNameResolver extends InetNameResolver {
final Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> castPromise = cast(
checkNotNull(promise, "promise"));
try {
new DnsQueryContext(this, nameServerAddr, question, additionals, castPromise)
new DatagramDnsQueryContext(this, nameServerAddr, question, additionals, castPromise)
.query(flush, writePromise);
return castPromise;
} catch (Exception e) {
@ -1173,7 +1180,7 @@ public class DnsNameResolver extends InetNameResolver {
final int queryId = res.id();
if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: [{}: {}], {}", ch, queryId, res.sender(), res);
logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res);
}
final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId);
@ -1182,7 +1189,112 @@ public class DnsNameResolver extends InetNameResolver {
return;
}
qCtx.finish(res);
// Check if the response was truncated and if we can fallback to TCP to retry.
if (res.isTruncated() && socketChannelFactory != null) {
// Let's retain as we may need it later on.
res.retain();
Bootstrap bs = new Bootstrap();
bs.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(TCP_ENCODER);
ch.pipeline().addLast(new TcpDnsResponseDecoder());
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private boolean finish;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
Channel channel = ctx.channel();
DnsResponse response = (DnsResponse) msg;
int queryId = response.id();
if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId,
channel.remoteAddress(), response);
}
DnsQueryContext tcpCtx = queryContextManager.get(res.sender(), queryId);
if (tcpCtx == null) {
logger.warn("{} Received a DNS response with an unknown ID: {}",
channel, queryId);
qCtx.finish(res);
return;
}
// Release the original response as we will use the response that we
// received via TCP fallback.
res.release();
tcpCtx.finish(new AddressedEnvelopeAdapter(
(InetSocketAddress) ctx.channel().remoteAddress(),
(InetSocketAddress) ctx.channel().localAddress(),
response));
finish = true;
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (!finish) {
if (logger.isDebugEnabled()) {
logger.debug("{} Error during processing response: TCP [{}: {}]",
ctx.channel(), queryId,
ctx.channel().remoteAddress(), cause);
}
// TCP fallback failed, just use the truncated response as
qCtx.finish(res);
}
}
});
}
});
bs.connect(res.sender()).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (future.isSuccess()) {
final Channel channel = future.channel();
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
channel.eventLoop().newPromise();
new TcpDnsQueryContext(DnsNameResolver.this, channel,
(InetSocketAddress) channel.remoteAddress(), qCtx.question(),
EMPTY_ADDITIONALS, promise).query(true, future.channel().newPromise());
promise.addListener(
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@Override
public void operationComplete(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
channel.close();
if (future.isSuccess()) {
qCtx.finish(future.getNow());
} else {
// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
}
}
});
} else {
if (logger.isDebugEnabled()) {
logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause());
}
// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
}
}
});
} else {
qCtx.finish(res);
}
} finally {
ReferenceCountUtil.safeRelease(msg);
}
@ -1196,7 +1308,116 @@ public class DnsNameResolver extends InetNameResolver {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.warn("{} Unexpected exception: ", ch, cause);
logger.warn("{} Unexpected exception: ", ctx.channel(), cause);
}
}
private final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
private final InetSocketAddress sender;
private final InetSocketAddress recipient;
private final DnsResponse response;
AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
this.sender = sender;
this.recipient = recipient;
this.response = response;
}
@Override
public DnsResponse content() {
return response;
}
@Override
public InetSocketAddress sender() {
return sender;
}
@Override
public InetSocketAddress recipient() {
return recipient;
}
@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
response.retain();
return this;
}
@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
response.retain(increment);
return this;
}
@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
response.touch();
return this;
}
@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
response.touch(hint);
return this;
}
@Override
public int refCnt() {
return response.refCnt();
}
@Override
public boolean release() {
return response.release();
}
@Override
public boolean release(int decrement) {
return response.release(decrement);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof AddressedEnvelope)) {
return false;
}
@SuppressWarnings("unchecked")
final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
if (sender() == null) {
if (that.sender() != null) {
return false;
}
} else if (!sender().equals(that.sender())) {
return false;
}
if (recipient() == null) {
if (that.recipient() != null) {
return false;
}
} else if (!recipient().equals(that.recipient())) {
return false;
}
return response.equals(obj);
}
@Override
public int hashCode() {
int hashCode = response.hashCode();
if (sender() != null) {
hashCode = hashCode * 31 + sender().hashCode();
}
if (recipient() != null) {
hashCode = hashCode * 31 + recipient().hashCode();
}
return hashCode;
}
}
}

View File

@ -20,6 +20,7 @@ import io.netty.channel.EventLoop;
import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketChannel;
import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.util.concurrent.Future;
@ -40,6 +41,7 @@ import static io.netty.util.internal.ObjectUtil.intValue;
public final class DnsNameResolverBuilder {
private EventLoop eventLoop;
private ChannelFactory<? extends DatagramChannel> channelFactory;
private ChannelFactory<? extends SocketChannel> socketChannelFactory;
private DnsCache resolveCache;
private DnsCnameCache cnameCache;
private AuthoritativeDnsServerCache authoritativeDnsServerCache;
@ -115,6 +117,35 @@ public final class DnsNameResolverBuilder {
return channelFactory(new ReflectiveChannelFactory<DatagramChannel>(channelType));
}
/**
* Sets the {@link ChannelFactory} that will create a {@link SocketChannel} for
* <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a> if needed.
*
* @param channelFactory the {@link ChannelFactory} or {@code null}
* if <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a> should not be supported.
* @return {@code this}
*/
public DnsNameResolverBuilder socketChannelFactory(ChannelFactory<? extends SocketChannel> channelFactory) {
this.socketChannelFactory = channelFactory;
return this;
}
/**
* Sets the {@link ChannelFactory} as a {@link ReflectiveChannelFactory} of this type for
* <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a> if needed.
* Use as an alternative to {@link #socketChannelFactory(ChannelFactory)}.
*
* @param channelType the type or {@code null} if <a href="https://tools.ietf.org/html/rfc7766">TCP fallback</a>
* should not be supported.
* @return {@code this}
*/
public DnsNameResolverBuilder socketChannelType(Class<? extends SocketChannel> channelType) {
if (channelType == null) {
return socketChannelFactory(null);
}
return socketChannelFactory(new ReflectiveChannelFactory<SocketChannel>(channelType));
}
/**
* Sets the cache for resolution results.
*
@ -444,6 +475,7 @@ public final class DnsNameResolverBuilder {
return new DnsNameResolver(
eventLoop,
channelFactory,
socketChannelFactory,
resolveCache,
cnameCache,
authoritativeDnsServerCache,
@ -479,6 +511,10 @@ public final class DnsNameResolverBuilder {
copiedBuilder.channelFactory(channelFactory);
}
if (socketChannelFactory != null) {
copiedBuilder.socketChannelFactory(socketChannelFactory);
}
if (resolveCache != null) {
copiedBuilder.resolveCache(resolveCache);
}

View File

@ -20,7 +20,6 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord;
import io.netty.handler.codec.dns.DnsQuery;
import io.netty.handler.codec.dns.DnsQuestion;
@ -40,7 +39,7 @@ import java.util.concurrent.TimeUnit;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
final class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>> {
abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>> {
private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class);
@ -89,10 +88,18 @@ final class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRespo
return question;
}
DnsNameResolver parent() {
return parent;
}
protected abstract DnsQuery newQuery(int id);
protected abstract Channel channel();
protected abstract String protocol();
void query(boolean flush, ChannelPromise writePromise) {
final DnsQuestion question = question();
final InetSocketAddress nameServerAddr = nameServerAddr();
final DatagramDnsQuery query = new DatagramDnsQuery(null, nameServerAddr, id);
final DnsQuery query = newQuery(id);
query.setRecursionDesired(recursionDesired);
@ -107,7 +114,7 @@ final class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRespo
}
if (logger.isDebugEnabled()) {
logger.debug("{} WRITE: [{}: {}], {}", parent.ch, id, nameServerAddr, question);
logger.debug("{} WRITE: {}, [{}: {}], {}", channel(), protocol(), id, nameServerAddr, question);
}
sendQuery(query, flush, writePromise);
@ -136,8 +143,8 @@ final class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRespo
}
private void writeQuery(final DnsQuery query, final boolean flush, final ChannelPromise writePromise) {
final ChannelFuture writeFuture = flush ? parent.ch.writeAndFlush(query, writePromise) :
parent.ch.write(query, writePromise);
final ChannelFuture writeFuture = flush ? channel().writeAndFlush(query, writePromise) :
channel().write(query, writePromise);
if (writeFuture.isDone()) {
onQueryWriteCompletion(writeFuture);
} else {
@ -152,7 +159,7 @@ final class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRespo
private void onQueryWriteCompletion(ChannelFuture writeFuture) {
if (!writeFuture.isSuccess()) {
setFailure("failed to send a query", writeFuture.cause());
setFailure("failed to send a query via " + protocol(), writeFuture.cause());
return;
}
@ -167,7 +174,8 @@ final class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRespo
return;
}
setFailure("query timed out after " + queryTimeoutMillis + " milliseconds", null);
setFailure("query via " + protocol() + " timed out after " +
queryTimeoutMillis + " milliseconds", null);
}
}, queryTimeoutMillis, TimeUnit.MILLISECONDS);
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2019 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.handler.codec.dns.DefaultDnsQuery;
import io.netty.handler.codec.dns.DnsQuery;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.util.concurrent.Promise;
import java.net.InetSocketAddress;
final class TcpDnsQueryContext extends DnsQueryContext {
private final Channel channel;
TcpDnsQueryContext(DnsNameResolver parent, Channel channel, InetSocketAddress nameServerAddr, DnsQuestion question,
DnsRecord[] additionals, Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
super(parent, nameServerAddr, question, additionals, promise);
this.channel = channel;
}
@Override
protected DnsQuery newQuery(int id) {
return new DefaultDnsQuery(id);
}
@Override
protected Channel channel() {
return channel;
}
@Override
protected String protocol() {
return "TCP";
}
}

View File

@ -27,6 +27,7 @@ import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRawRecord;
@ -47,7 +48,9 @@ import io.netty.util.internal.StringUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.apache.directory.server.dns.DnsException;
import org.apache.directory.server.dns.io.encoder.DnsMessageEncoder;
import org.apache.directory.server.dns.messages.DnsMessage;
import org.apache.directory.server.dns.messages.DnsMessageModifier;
import org.apache.directory.server.dns.messages.QuestionRecord;
import org.apache.directory.server.dns.messages.RecordClass;
import org.apache.directory.server.dns.messages.RecordType;
@ -56,6 +59,7 @@ import org.apache.directory.server.dns.messages.ResourceRecordModifier;
import org.apache.directory.server.dns.messages.ResponseCode;
import org.apache.directory.server.dns.store.DnsAttribute;
import org.apache.directory.server.dns.store.RecordStore;
import org.apache.mina.core.buffer.IoBuffer;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@ -68,7 +72,10 @@ import java.net.DatagramSocket;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
@ -2734,4 +2741,131 @@ public class DnsNameResolverTest {
}
}
}
@Test(timeout = 5000)
public void testTruncatedWithoutTcpFallback() throws IOException {
testTruncated0(false);
}
@Test(timeout = 5000)
public void testTruncatedWithTcpFallback() throws IOException {
testTruncated0(true);
}
private static void testTruncated0(boolean tcpFallback) throws IOException {
final String host = "somehost.netty.io";
final String txt = "this is a txt record";
final AtomicReference<DnsMessage> messageRef = new AtomicReference<DnsMessage>();
TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() {
@Override
public Set<ResourceRecord> getRecords(QuestionRecord question) {
String name = question.getDomainName();
if (name.equals(host)) {
return Collections.<ResourceRecord>singleton(
new TestDnsServer.TestResourceRecord(name, RecordType.TXT,
Collections.<String, Object>singletonMap(
DnsAttribute.CHARACTER_STRING.toLowerCase(), txt)));
}
return null;
}
}) {
@Override
protected DnsMessage filterMessage(DnsMessage message) {
// Store a original message so we can replay it later on.
messageRef.set(message);
// Create a copy of the message but set the truncated flag.
DnsMessageModifier modifier = new DnsMessageModifier();
modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData());
modifier.setAdditionalRecords(message.getAdditionalRecords());
modifier.setAnswerRecords(message.getAnswerRecords());
modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer());
modifier.setAuthorityRecords(message.getAuthorityRecords());
modifier.setMessageType(message.getMessageType());
modifier.setOpCode(message.getOpCode());
modifier.setQuestionRecords(message.getQuestionRecords());
modifier.setRecursionAvailable(message.isRecursionAvailable());
modifier.setRecursionDesired(message.isRecursionDesired());
modifier.setReserved(message.isReserved());
modifier.setResponseCode(message.getResponseCode());
modifier.setTransactionId(message.getTransactionId());
modifier.setTruncated(true);
return modifier.getDnsMessage();
}
};
dnsServer2.start();
DnsNameResolver resolver = null;
ServerSocket serverSocket = null;
try {
DnsNameResolverBuilder builder = newResolver()
.queryTimeoutMillis(10000)
.resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED)
.maxQueriesPerResolve(16)
.nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress()));
if (tcpFallback) {
// If we are configured to use TCP as a fallback also bind a TCP socket
serverSocket = new ServerSocket(dnsServer2.localAddress().getPort());
builder.socketChannelType(NioSocketChannel.class);
}
resolver = builder.build();
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> envelopeFuture = resolver.query(
new DefaultDnsQuestion(host, DnsRecordType.TXT));
if (tcpFallback) {
// If we are configured to use TCP as a fallback lets replay the dns message over TCP
Socket socket = serverSocket.accept();
IoBuffer ioBuffer = IoBuffer.allocate(1024);
new DnsMessageEncoder().encode(ioBuffer, messageRef.get());
ioBuffer.flip();
ByteBuffer lenBuffer = ByteBuffer.allocate(2);
lenBuffer.putShort((short) ioBuffer.remaining());
lenBuffer.flip();
while (lenBuffer.hasRemaining()) {
socket.getOutputStream().write(lenBuffer.get());
}
while (ioBuffer.hasRemaining()) {
socket.getOutputStream().write(ioBuffer.get());
}
socket.getOutputStream().flush();
socket.getOutputStream().close();
socket.close();
}
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = envelopeFuture.syncUninterruptibly().getNow();
assertNotNull(envelope.sender());
DnsResponse response = envelope.content();
assertNotNull(response);
assertEquals(DnsResponseCode.NOERROR, response.code());
int count = response.count(DnsSection.ANSWER);
assertEquals(1, count);
List<String> texts = decodeTxt(response.recordAt(DnsSection.ANSWER, 0));
assertEquals(1, texts.size());
assertEquals(txt, texts.get(0));
if (tcpFallback) {
assertFalse(envelope.content().isTruncated());
} else {
assertTrue(envelope.content().isTruncated());
}
envelope.release();
} finally {
dnsServer2.stop();
if (serverSocket != null) {
serverSocket.close();
}
if (resolver != null) {
resolver.close();
}
}
}
}