From 1672b6d12cb8f80dd1535ba5fcc2e8c85945cd07 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 17 May 2019 14:37:11 +0200 Subject: [PATCH] Add support for TCP fallback when we receive a truncated DnsResponse (#9139) Motivation: Sometimes DNS responses can be very large which mean they will not fit in a UDP packet. When this is happening the DNS server will set the TC flag (truncated flag) to tell the resolver that the response was truncated. When a truncated response was received we should allow to retry via TCP and use the received response (if possible) as a replacement for the truncated one. See https://tools.ietf.org/html/rfc7766. Modifications: - Add support for TCP fallback by allow to specify a socketChannelFactory / socketChannelType on the DnsNameResolverBuilder. If this is set to something different then null we will try to fallback to TCP. - Add decoder / encoder for TCP - Add unit tests Result: Support for TCP fallback as defined by https://tools.ietf.org/html/rfc7766 when using DnsNameResolver. --- .../codec/dns/DatagramDnsQueryEncoder.java | 42 +-- .../codec/dns/DatagramDnsResponseDecoder.java | 79 +----- .../handler/codec/dns/DnsQueryEncoder.java | 75 ++++++ .../handler/codec/dns/DnsResponseDecoder.java | 97 +++++++ .../handler/codec/dns/TcpDnsQueryEncoder.java | 64 +++++ .../codec/dns/TcpDnsResponseDecoder.java | 72 ++++++ .../resolver/dns/DatagramDnsQueryContext.java | 51 ++++ .../netty/resolver/dns/DnsNameResolver.java | 241 +++++++++++++++++- .../resolver/dns/DnsNameResolverBuilder.java | 36 +++ .../netty/resolver/dns/DnsQueryContext.java | 24 +- .../resolver/dns/TcpDnsQueryContext.java | 53 ++++ .../resolver/dns/DnsNameResolverTest.java | 134 ++++++++++ 12 files changed, 842 insertions(+), 126 deletions(-) create mode 100644 codec-dns/src/main/java/io/netty/handler/codec/dns/DnsQueryEncoder.java create mode 100644 codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java create mode 100644 codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsQueryEncoder.java create mode 100644 codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsResponseDecoder.java create mode 100644 resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java create mode 100644 resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsQueryEncoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsQueryEncoder.java index dbb562fd2b..18c6009a2d 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsQueryEncoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsQueryEncoder.java @@ -36,7 +36,7 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull; @ChannelHandler.Sharable public class DatagramDnsQueryEncoder extends MessageToMessageEncoder> { - private final DnsRecordEncoder recordEncoder; + private final DnsQueryEncoder encoder; /** * Creates a new encoder with {@linkplain DnsRecordEncoder#DEFAULT the default record encoder}. @@ -49,7 +49,7 @@ public class DatagramDnsQueryEncoder extends MessageToMessageEncoder msg) throws Exception { return ctx.alloc().ioBuffer(1024); } - - /** - * Encodes the header that is always 12 bytes long. - * - * @param query the query header being encoded - * @param buf the buffer the encoded data should be written to - */ - private static void encodeHeader(DnsQuery query, ByteBuf buf) { - buf.writeShort(query.id()); - int flags = 0; - flags |= (query.opCode().byteValue() & 0xFF) << 14; - if (query.isRecursionDesired()) { - flags |= 1 << 8; - } - buf.writeShort(flags); - buf.writeShort(query.count(DnsSection.QUESTION)); - buf.writeShort(0); // answerCount - buf.writeShort(0); // authorityResourceCount - buf.writeShort(query.count(DnsSection.ADDITIONAL)); - } - - private void encodeQuestions(DnsQuery query, ByteBuf buf) throws Exception { - final int count = query.count(DnsSection.QUESTION); - for (int i = 0; i < count; i++) { - recordEncoder.encodeQuestion((DnsQuestion) query.recordAt(DnsSection.QUESTION, i), buf); - } - } - - private void encodeRecords(DnsQuery query, DnsSection section, ByteBuf buf) throws Exception { - final int count = query.count(section); - for (int i = 0; i < count; i++) { - recordEncoder.encodeRecord(query.recordAt(section, i), buf); - } - } } diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java index 547d5aefac..da973263ac 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java @@ -15,18 +15,15 @@ */ package io.netty.handler.codec.dns; -import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.socket.DatagramPacket; -import io.netty.handler.codec.CorruptedFrameException; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.util.internal.UnstableApi; +import java.net.InetSocketAddress; import java.util.List; -import static io.netty.util.internal.ObjectUtil.checkNotNull; - /** * Decodes a {@link DatagramPacket} into a {@link DatagramDnsResponse}. */ @@ -34,7 +31,7 @@ import static io.netty.util.internal.ObjectUtil.checkNotNull; @ChannelHandler.Sharable public class DatagramDnsResponseDecoder extends MessageToMessageDecoder { - private final DnsRecordDecoder recordDecoder; + private final DnsResponseDecoder responseDecoder; /** * Creates a new decoder with {@linkplain DnsRecordDecoder#DEFAULT the default record decoder}. @@ -47,73 +44,17 @@ public class DatagramDnsResponseDecoder extends MessageToMessageDecoder(recordDecoder) { + @Override + protected DnsResponse newResponse(InetSocketAddress sender, InetSocketAddress recipient, + int id, DnsOpCode opCode, DnsResponseCode responseCode) { + return new DatagramDnsResponse(sender, recipient, id, opCode, responseCode); + } + }; } @Override protected void decode(ChannelHandlerContext ctx, DatagramPacket packet, List out) throws Exception { - final ByteBuf buf = packet.content(); - - final DnsResponse response = newResponse(packet, buf); - boolean success = false; - try { - final int questionCount = buf.readUnsignedShort(); - final int answerCount = buf.readUnsignedShort(); - final int authorityRecordCount = buf.readUnsignedShort(); - final int additionalRecordCount = buf.readUnsignedShort(); - - decodeQuestions(response, buf, questionCount); - decodeRecords(response, DnsSection.ANSWER, buf, answerCount); - decodeRecords(response, DnsSection.AUTHORITY, buf, authorityRecordCount); - decodeRecords(response, DnsSection.ADDITIONAL, buf, additionalRecordCount); - - out.add(response); - success = true; - } finally { - if (!success) { - response.release(); - } - } - } - - private static DnsResponse newResponse(DatagramPacket packet, ByteBuf buf) { - final int id = buf.readUnsignedShort(); - - final int flags = buf.readUnsignedShort(); - if (flags >> 15 == 0) { - throw new CorruptedFrameException("not a response"); - } - - final DnsResponse response = new DatagramDnsResponse( - packet.sender(), - packet.recipient(), - id, - DnsOpCode.valueOf((byte) (flags >> 11 & 0xf)), DnsResponseCode.valueOf((byte) (flags & 0xf))); - - response.setRecursionDesired((flags >> 8 & 1) == 1); - response.setAuthoritativeAnswer((flags >> 10 & 1) == 1); - response.setTruncated((flags >> 9 & 1) == 1); - response.setRecursionAvailable((flags >> 7 & 1) == 1); - response.setZ(flags >> 4 & 0x7); - return response; - } - - private void decodeQuestions(DnsResponse response, ByteBuf buf, int questionCount) throws Exception { - for (int i = questionCount; i > 0; i --) { - response.addRecord(DnsSection.QUESTION, recordDecoder.decodeQuestion(buf)); - } - } - - private void decodeRecords( - DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception { - for (int i = count; i > 0; i --) { - final DnsRecord r = recordDecoder.decodeRecord(buf); - if (r == null) { - // Truncated response - break; - } - - response.addRecord(section, r); - } + out.add(responseDecoder.decode(packet.sender(), packet.recipient(), packet.content())); } } diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsQueryEncoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsQueryEncoder.java new file mode 100644 index 0000000000..db2a730d25 --- /dev/null +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsQueryEncoder.java @@ -0,0 +1,75 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.dns; + +import io.netty.buffer.ByteBuf; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +final class DnsQueryEncoder { + + private final DnsRecordEncoder recordEncoder; + + /** + * Creates a new encoder with the specified {@code recordEncoder}. + */ + DnsQueryEncoder(DnsRecordEncoder recordEncoder) { + this.recordEncoder = checkNotNull(recordEncoder, "recordEncoder"); + } + + /** + * Encodes the given {@link DnsQuery} into a {@link ByteBuf}. + */ + void encode(DnsQuery query, ByteBuf out) throws Exception { + encodeHeader(query, out); + encodeQuestions(query, out); + encodeRecords(query, DnsSection.ADDITIONAL, out); + } + + /** + * Encodes the header that is always 12 bytes long. + * + * @param query the query header being encoded + * @param buf the buffer the encoded data should be written to + */ + private static void encodeHeader(DnsQuery query, ByteBuf buf) { + buf.writeShort(query.id()); + int flags = 0; + flags |= (query.opCode().byteValue() & 0xFF) << 14; + if (query.isRecursionDesired()) { + flags |= 1 << 8; + } + buf.writeShort(flags); + buf.writeShort(query.count(DnsSection.QUESTION)); + buf.writeShort(0); // answerCount + buf.writeShort(0); // authorityResourceCount + buf.writeShort(query.count(DnsSection.ADDITIONAL)); + } + + private void encodeQuestions(DnsQuery query, ByteBuf buf) throws Exception { + final int count = query.count(DnsSection.QUESTION); + for (int i = 0; i < count; i++) { + recordEncoder.encodeQuestion((DnsQuestion) query.recordAt(DnsSection.QUESTION, i), buf); + } + } + + private void encodeRecords(DnsQuery query, DnsSection section, ByteBuf buf) throws Exception { + final int count = query.count(section); + for (int i = 0; i < count; i++) { + recordEncoder.encodeRecord(query.recordAt(section, i), buf); + } + } +} diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java new file mode 100644 index 0000000000..fd5a1bed91 --- /dev/null +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java @@ -0,0 +1,97 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.dns; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.CorruptedFrameException; + +import java.net.SocketAddress; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +abstract class DnsResponseDecoder { + + private final DnsRecordDecoder recordDecoder; + + /** + * Creates a new decoder with the specified {@code recordDecoder}. + */ + DnsResponseDecoder(DnsRecordDecoder recordDecoder) { + this.recordDecoder = checkNotNull(recordDecoder, "recordDecoder"); + } + + final DnsResponse decode(A sender, A recipient, ByteBuf buffer) throws Exception { + final int id = buffer.readUnsignedShort(); + + final int flags = buffer.readUnsignedShort(); + if (flags >> 15 == 0) { + throw new CorruptedFrameException("not a response"); + } + + final DnsResponse response = newResponse( + sender, + recipient, + id, + DnsOpCode.valueOf((byte) (flags >> 11 & 0xf)), DnsResponseCode.valueOf((byte) (flags & 0xf))); + + response.setRecursionDesired((flags >> 8 & 1) == 1); + response.setAuthoritativeAnswer((flags >> 10 & 1) == 1); + response.setTruncated((flags >> 9 & 1) == 1); + response.setRecursionAvailable((flags >> 7 & 1) == 1); + response.setZ(flags >> 4 & 0x7); + + boolean success = false; + try { + final int questionCount = buffer.readUnsignedShort(); + final int answerCount = buffer.readUnsignedShort(); + final int authorityRecordCount = buffer.readUnsignedShort(); + final int additionalRecordCount = buffer.readUnsignedShort(); + + decodeQuestions(response, buffer, questionCount); + decodeRecords(response, DnsSection.ANSWER, buffer, answerCount); + decodeRecords(response, DnsSection.AUTHORITY, buffer, authorityRecordCount); + decodeRecords(response, DnsSection.ADDITIONAL, buffer, additionalRecordCount); + success = true; + return response; + } finally { + if (!success) { + response.release(); + } + } + } + + protected abstract DnsResponse newResponse(A sender, A recipient, int id, + DnsOpCode opCode, DnsResponseCode responseCode) throws Exception; + + private void decodeQuestions(DnsResponse response, ByteBuf buf, int questionCount) throws Exception { + for (int i = questionCount; i > 0; i --) { + response.addRecord(DnsSection.QUESTION, recordDecoder.decodeQuestion(buf)); + } + } + + private void decodeRecords( + DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception { + for (int i = count; i > 0; i --) { + final DnsRecord r = recordDecoder.decodeRecord(buf); + if (r == null) { + // Truncated response + break; + } + + response.addRecord(section, r); + } + } +} diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsQueryEncoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsQueryEncoder.java new file mode 100644 index 0000000000..2978fff900 --- /dev/null +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsQueryEncoder.java @@ -0,0 +1,64 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.dns; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.internal.UnstableApi; + +@ChannelHandler.Sharable +@UnstableApi +public final class TcpDnsQueryEncoder extends MessageToByteEncoder { + + private final DnsQueryEncoder encoder; + + /** + * Creates a new encoder with {@linkplain DnsRecordEncoder#DEFAULT the default record encoder}. + */ + public TcpDnsQueryEncoder() { + this(DnsRecordEncoder.DEFAULT); + } + + /** + * Creates a new encoder with the specified {@code recordEncoder}. + */ + public TcpDnsQueryEncoder(DnsRecordEncoder recordEncoder) { + this.encoder = new DnsQueryEncoder(recordEncoder); + } + + @Override + protected void encode(ChannelHandlerContext ctx, DnsQuery msg, ByteBuf out) throws Exception { + // Length is two octets as defined by RFC-7766 + // See https://tools.ietf.org/html/rfc7766#section-8 + out.writerIndex(out.writerIndex() + 2); + encoder.encode(msg, out); + + // Now fill in the correct length based on the amount of data that we wrote the ByteBuf. + out.setShort(0, out.readableBytes() - 2); + } + + @Override + protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, @SuppressWarnings("unused") DnsQuery msg, + boolean preferDirect) { + if (preferDirect) { + return ctx.alloc().ioBuffer(1024); + } else { + return ctx.alloc().heapBuffer(1024); + } + } +} diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsResponseDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsResponseDecoder.java new file mode 100644 index 0000000000..f8a92d9929 --- /dev/null +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/TcpDnsResponseDecoder.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.dns; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.internal.UnstableApi; + +import java.net.SocketAddress; + +@UnstableApi +public final class TcpDnsResponseDecoder extends LengthFieldBasedFrameDecoder { + + private final DnsResponseDecoder responseDecoder; + + /** + * Creates a new decoder with {@linkplain DnsRecordDecoder#DEFAULT the default record decoder}. + */ + public TcpDnsResponseDecoder() { + this(DnsRecordDecoder.DEFAULT, 64 * 1024); + } + + /** + * Creates a new decoder with the specified {@code recordDecoder} and {@code maxFrameLength} + */ + public TcpDnsResponseDecoder(DnsRecordDecoder recordDecoder, int maxFrameLength) { + // Length is two octets as defined by RFC-7766 + // See https://tools.ietf.org/html/rfc7766#section-8 + super(maxFrameLength, 0, 2, 0, 2); + + this.responseDecoder = new DnsResponseDecoder(recordDecoder) { + @Override + protected DnsResponse newResponse(SocketAddress sender, SocketAddress recipient, + int id, DnsOpCode opCode, DnsResponseCode responseCode) { + return new DefaultDnsResponse(id, opCode, responseCode); + } + }; + } + + @Override + protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + ByteBuf frame = (ByteBuf) super.decode(ctx, in); + if (frame == null) { + return null; + } + + try { + return responseDecoder.decode(ctx.channel().remoteAddress(), ctx.channel().localAddress(), frame.slice()); + } finally { + frame.release(); + } + } + + @Override + protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) { + return buffer.copy(index, length); + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java new file mode 100644 index 0000000000..0ac80b9cb0 --- /dev/null +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java @@ -0,0 +1,51 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver.dns; + +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.Channel; +import io.netty.handler.codec.dns.DatagramDnsQuery; +import io.netty.handler.codec.dns.DnsQuery; +import io.netty.handler.codec.dns.DnsQuestion; +import io.netty.handler.codec.dns.DnsRecord; +import io.netty.handler.codec.dns.DnsResponse; +import io.netty.util.concurrent.Promise; + +import java.net.InetSocketAddress; + +final class DatagramDnsQueryContext extends DnsQueryContext { + + DatagramDnsQueryContext(DnsNameResolver parent, InetSocketAddress nameServerAddr, DnsQuestion question, + DnsRecord[] additionals, + Promise> promise) { + super(parent, nameServerAddr, question, additionals, promise); + } + + @Override + protected DnsQuery newQuery(int id) { + return new DatagramDnsQuery(null, nameServerAddr(), id); + } + + @Override + protected Channel channel() { + return parent().ch; + } + + @Override + protected String protocol() { + return "UDP"; + } +} diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 020f593b4d..dcdac54019 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -32,6 +32,7 @@ import io.netty.channel.EventLoop; import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.dns.DatagramDnsQueryEncoder; import io.netty.handler.codec.dns.DatagramDnsResponse; import io.netty.handler.codec.dns.DatagramDnsResponseDecoder; @@ -41,6 +42,8 @@ import io.netty.handler.codec.dns.DnsRawRecord; import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsResponse; +import io.netty.handler.codec.dns.TcpDnsQueryEncoder; +import io.netty.handler.codec.dns.TcpDnsResponseDecoder; import io.netty.resolver.HostsFileEntries; import io.netty.resolver.HostsFileEntriesResolver; import io.netty.resolver.InetNameResolver; @@ -66,6 +69,7 @@ import java.net.Inet6Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; +import java.net.SocketAddress; import java.net.SocketException; import java.util.ArrayList; import java.util.Collection; @@ -182,8 +186,9 @@ public class DnsNameResolver extends InetNameResolver { return (List) nameservers.invoke(instance); } - private static final DatagramDnsResponseDecoder DECODER = new DatagramDnsResponseDecoder(); - private static final DatagramDnsQueryEncoder ENCODER = new DatagramDnsQueryEncoder(); + private static final DatagramDnsResponseDecoder DATAGRAM_DECODER = new DatagramDnsResponseDecoder(); + private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder(); + private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder(); final Future channelFuture; final Channel ch; @@ -228,6 +233,7 @@ public class DnsNameResolver extends InetNameResolver { private final boolean decodeIdn; private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory; private final boolean completeOncePreferredResolved; + private final ChannelFactory socketChannelFactory; /** * Creates a new DNS-based name resolver that communicates with the specified list of DNS servers. @@ -326,7 +332,7 @@ public class DnsNameResolver extends InetNameResolver { String[] searchDomains, int ndots, boolean decodeIdn) { - this(eventLoop, channelFactory, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, + this(eventLoop, channelFactory, null, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver, dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false); @@ -335,6 +341,7 @@ public class DnsNameResolver extends InetNameResolver { DnsNameResolver( EventLoop eventLoop, ChannelFactory channelFactory, + ChannelFactory socketChannelFactory, final DnsCache resolveCache, final DnsCnameCache cnameCache, final AuthoritativeDnsServerCache authoritativeDnsServerCache, @@ -374,7 +381,7 @@ public class DnsNameResolver extends InetNameResolver { this.ndots = ndots >= 0 ? ndots : DEFAULT_NDOTS; this.decodeIdn = decodeIdn; this.completeOncePreferredResolved = completeOncePreferredResolved; - + this.socketChannelFactory = socketChannelFactory; switch (this.resolvedAddressTypes) { case IPV4_ONLY: supportsAAAARecords = false; @@ -414,8 +421,8 @@ public class DnsNameResolver extends InetNameResolver { final DnsResponseHandler responseHandler = new DnsResponseHandler(executor().newPromise()); b.handler(new ChannelInitializer() { @Override - protected void initChannel(DatagramChannel ch) throws Exception { - ch.pipeline().addLast(DECODER, ENCODER, responseHandler); + protected void initChannel(DatagramChannel ch) { + ch.pipeline().addLast(DATAGRAM_ENCODER, DATAGRAM_DECODER, responseHandler); } }); @@ -1141,7 +1148,7 @@ public class DnsNameResolver extends InetNameResolver { final Promise> castPromise = cast( checkNotNull(promise, "promise")); try { - new DnsQueryContext(this, nameServerAddr, question, additionals, castPromise) + new DatagramDnsQueryContext(this, nameServerAddr, question, additionals, castPromise) .query(flush, writePromise); return castPromise; } catch (Exception e) { @@ -1173,7 +1180,7 @@ public class DnsNameResolver extends InetNameResolver { final int queryId = res.id(); if (logger.isDebugEnabled()) { - logger.debug("{} RECEIVED: [{}: {}], {}", ch, queryId, res.sender(), res); + logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res); } final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId); @@ -1182,7 +1189,112 @@ public class DnsNameResolver extends InetNameResolver { return; } - qCtx.finish(res); + // Check if the response was truncated and if we can fallback to TCP to retry. + if (res.isTruncated() && socketChannelFactory != null) { + // Let's retain as we may need it later on. + res.retain(); + + Bootstrap bs = new Bootstrap(); + bs.option(ChannelOption.SO_REUSEADDR, true) + .group(executor()) + .channelFactory(socketChannelFactory) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(TCP_ENCODER); + ch.pipeline().addLast(new TcpDnsResponseDecoder()); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private boolean finish; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + try { + Channel channel = ctx.channel(); + DnsResponse response = (DnsResponse) msg; + int queryId = response.id(); + + if (logger.isDebugEnabled()) { + logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId, + channel.remoteAddress(), response); + } + + DnsQueryContext tcpCtx = queryContextManager.get(res.sender(), queryId); + if (tcpCtx == null) { + logger.warn("{} Received a DNS response with an unknown ID: {}", + channel, queryId); + qCtx.finish(res); + return; + } + + // Release the original response as we will use the response that we + // received via TCP fallback. + res.release(); + + tcpCtx.finish(new AddressedEnvelopeAdapter( + (InetSocketAddress) ctx.channel().remoteAddress(), + (InetSocketAddress) ctx.channel().localAddress(), + response)); + + finish = true; + } finally { + ReferenceCountUtil.release(msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!finish) { + if (logger.isDebugEnabled()) { + logger.debug("{} Error during processing response: TCP [{}: {}]", + ctx.channel(), queryId, + ctx.channel().remoteAddress(), cause); + } + // TCP fallback failed, just use the truncated response as + qCtx.finish(res); + } + } + }); + } + }); + bs.connect(res.sender()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + final Channel channel = future.channel(); + + Promise> promise = + channel.eventLoop().newPromise(); + new TcpDnsQueryContext(DnsNameResolver.this, channel, + (InetSocketAddress) channel.remoteAddress(), qCtx.question(), + EMPTY_ADDITIONALS, promise).query(true, future.channel().newPromise()); + promise.addListener( + new FutureListener>() { + @Override + public void operationComplete( + Future> future) { + channel.close(); + + if (future.isSuccess()) { + qCtx.finish(future.getNow()); + } else { + // TCP fallback failed, just use the truncated response. + qCtx.finish(res); + } + } + }); + } else { + if (logger.isDebugEnabled()) { + logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause()); + } + + // TCP fallback failed, just use the truncated response. + qCtx.finish(res); + } + } + }); + } else { + qCtx.finish(res); + } } finally { ReferenceCountUtil.safeRelease(msg); } @@ -1196,7 +1308,116 @@ public class DnsNameResolver extends InetNameResolver { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - logger.warn("{} Unexpected exception: ", ch, cause); + logger.warn("{} Unexpected exception: ", ctx.channel(), cause); + } + } + + private final class AddressedEnvelopeAdapter implements AddressedEnvelope { + private final InetSocketAddress sender; + private final InetSocketAddress recipient; + private final DnsResponse response; + + AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) { + this.sender = sender; + this.recipient = recipient; + this.response = response; + } + + @Override + public DnsResponse content() { + return response; + } + + @Override + public InetSocketAddress sender() { + return sender; + } + + @Override + public InetSocketAddress recipient() { + return recipient; + } + + @Override + public AddressedEnvelope retain() { + response.retain(); + return this; + } + + @Override + public AddressedEnvelope retain(int increment) { + response.retain(increment); + return this; + } + + @Override + public AddressedEnvelope touch() { + response.touch(); + return this; + } + + @Override + public AddressedEnvelope touch(Object hint) { + response.touch(hint); + return this; + } + + @Override + public int refCnt() { + return response.refCnt(); + } + + @Override + public boolean release() { + return response.release(); + } + + @Override + public boolean release(int decrement) { + return response.release(decrement); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof AddressedEnvelope)) { + return false; + } + + @SuppressWarnings("unchecked") + final AddressedEnvelope that = (AddressedEnvelope) obj; + if (sender() == null) { + if (that.sender() != null) { + return false; + } + } else if (!sender().equals(that.sender())) { + return false; + } + + if (recipient() == null) { + if (that.recipient() != null) { + return false; + } + } else if (!recipient().equals(that.recipient())) { + return false; + } + + return response.equals(obj); + } + + @Override + public int hashCode() { + int hashCode = response.hashCode(); + if (sender() != null) { + hashCode = hashCode * 31 + sender().hashCode(); + } + if (recipient() != null) { + hashCode = hashCode * 31 + recipient().hashCode(); + } + return hashCode; } } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java index fce8469322..8c2bf95abb 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java @@ -20,6 +20,7 @@ import io.netty.channel.EventLoop; import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.channel.socket.SocketChannel; import io.netty.resolver.HostsFileEntriesResolver; import io.netty.resolver.ResolvedAddressTypes; import io.netty.util.concurrent.Future; @@ -40,6 +41,7 @@ import static io.netty.util.internal.ObjectUtil.intValue; public final class DnsNameResolverBuilder { private EventLoop eventLoop; private ChannelFactory channelFactory; + private ChannelFactory socketChannelFactory; private DnsCache resolveCache; private DnsCnameCache cnameCache; private AuthoritativeDnsServerCache authoritativeDnsServerCache; @@ -115,6 +117,35 @@ public final class DnsNameResolverBuilder { return channelFactory(new ReflectiveChannelFactory(channelType)); } + /** + * Sets the {@link ChannelFactory} that will create a {@link SocketChannel} for + * TCP fallback if needed. + * + * @param channelFactory the {@link ChannelFactory} or {@code null} + * if TCP fallback should not be supported. + * @return {@code this} + */ + public DnsNameResolverBuilder socketChannelFactory(ChannelFactory channelFactory) { + this.socketChannelFactory = channelFactory; + return this; + } + + /** + * Sets the {@link ChannelFactory} as a {@link ReflectiveChannelFactory} of this type for + * TCP fallback if needed. + * Use as an alternative to {@link #socketChannelFactory(ChannelFactory)}. + * + * @param channelType the type or {@code null} if TCP fallback + * should not be supported. + * @return {@code this} + */ + public DnsNameResolverBuilder socketChannelType(Class channelType) { + if (channelType == null) { + return socketChannelFactory(null); + } + return socketChannelFactory(new ReflectiveChannelFactory(channelType)); + } + /** * Sets the cache for resolution results. * @@ -444,6 +475,7 @@ public final class DnsNameResolverBuilder { return new DnsNameResolver( eventLoop, channelFactory, + socketChannelFactory, resolveCache, cnameCache, authoritativeDnsServerCache, @@ -479,6 +511,10 @@ public final class DnsNameResolverBuilder { copiedBuilder.channelFactory(channelFactory); } + if (socketChannelFactory != null) { + copiedBuilder.socketChannelFactory(socketChannelFactory); + } + if (resolveCache != null) { copiedBuilder.resolveCache(resolveCache); } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index 08bbcb6088..d5ba91e5e8 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -20,7 +20,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.dns.DatagramDnsQuery; import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord; import io.netty.handler.codec.dns.DnsQuery; import io.netty.handler.codec.dns.DnsQuestion; @@ -40,7 +39,7 @@ import java.util.concurrent.TimeUnit; import static io.netty.util.internal.ObjectUtil.checkNotNull; -final class DnsQueryContext implements FutureListener> { +abstract class DnsQueryContext implements FutureListener> { private static final InternalLogger logger = InternalLoggerFactory.getInstance(DnsQueryContext.class); @@ -89,10 +88,18 @@ final class DnsQueryContext implements FutureListener> promise) { + super(parent, nameServerAddr, question, additionals, promise); + this.channel = channel; + } + + @Override + protected DnsQuery newQuery(int id) { + return new DefaultDnsQuery(id); + } + + @Override + protected Channel channel() { + return channel; + } + + @Override + protected String protocol() { + return "TCP"; + } +} diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index 0090bd9c63..39968ebc85 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -27,6 +27,7 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.DatagramChannel; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.channel.socket.nio.NioDatagramChannel; +import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.dns.DefaultDnsQuestion; import io.netty.handler.codec.dns.DnsQuestion; import io.netty.handler.codec.dns.DnsRawRecord; @@ -47,7 +48,9 @@ import io.netty.util.internal.StringUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import org.apache.directory.server.dns.DnsException; +import org.apache.directory.server.dns.io.encoder.DnsMessageEncoder; import org.apache.directory.server.dns.messages.DnsMessage; +import org.apache.directory.server.dns.messages.DnsMessageModifier; import org.apache.directory.server.dns.messages.QuestionRecord; import org.apache.directory.server.dns.messages.RecordClass; import org.apache.directory.server.dns.messages.RecordType; @@ -56,6 +59,7 @@ import org.apache.directory.server.dns.messages.ResourceRecordModifier; import org.apache.directory.server.dns.messages.ResponseCode; import org.apache.directory.server.dns.store.DnsAttribute; import org.apache.directory.server.dns.store.RecordStore; +import org.apache.mina.core.buffer.IoBuffer; import org.hamcrest.Matchers; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -68,7 +72,10 @@ import java.net.DatagramSocket; import java.net.Inet4Address; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; import java.net.UnknownHostException; +import java.nio.ByteBuffer; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; @@ -2734,4 +2741,131 @@ public class DnsNameResolverTest { } } } + + @Test(timeout = 5000) + public void testTruncatedWithoutTcpFallback() throws IOException { + testTruncated0(false); + } + + @Test(timeout = 5000) + public void testTruncatedWithTcpFallback() throws IOException { + testTruncated0(true); + } + + private static void testTruncated0(boolean tcpFallback) throws IOException { + final String host = "somehost.netty.io"; + final String txt = "this is a txt record"; + final AtomicReference messageRef = new AtomicReference(); + + TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() { + @Override + public Set getRecords(QuestionRecord question) { + String name = question.getDomainName(); + if (name.equals(host)) { + return Collections.singleton( + new TestDnsServer.TestResourceRecord(name, RecordType.TXT, + Collections.singletonMap( + DnsAttribute.CHARACTER_STRING.toLowerCase(), txt))); + } + return null; + } + }) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + // Store a original message so we can replay it later on. + messageRef.set(message); + + // Create a copy of the message but set the truncated flag. + DnsMessageModifier modifier = new DnsMessageModifier(); + modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData()); + modifier.setAdditionalRecords(message.getAdditionalRecords()); + modifier.setAnswerRecords(message.getAnswerRecords()); + modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer()); + modifier.setAuthorityRecords(message.getAuthorityRecords()); + modifier.setMessageType(message.getMessageType()); + modifier.setOpCode(message.getOpCode()); + modifier.setQuestionRecords(message.getQuestionRecords()); + modifier.setRecursionAvailable(message.isRecursionAvailable()); + modifier.setRecursionDesired(message.isRecursionDesired()); + modifier.setReserved(message.isReserved()); + modifier.setResponseCode(message.getResponseCode()); + modifier.setTransactionId(message.getTransactionId()); + modifier.setTruncated(true); + return modifier.getDnsMessage(); + } + }; + dnsServer2.start(); + DnsNameResolver resolver = null; + ServerSocket serverSocket = null; + try { + DnsNameResolverBuilder builder = newResolver() + .queryTimeoutMillis(10000) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress())); + + if (tcpFallback) { + // If we are configured to use TCP as a fallback also bind a TCP socket + serverSocket = new ServerSocket(dnsServer2.localAddress().getPort()); + + builder.socketChannelType(NioSocketChannel.class); + } + resolver = builder.build(); + Future> envelopeFuture = resolver.query( + new DefaultDnsQuestion(host, DnsRecordType.TXT)); + + if (tcpFallback) { + // If we are configured to use TCP as a fallback lets replay the dns message over TCP + Socket socket = serverSocket.accept(); + + IoBuffer ioBuffer = IoBuffer.allocate(1024); + new DnsMessageEncoder().encode(ioBuffer, messageRef.get()); + ioBuffer.flip(); + + ByteBuffer lenBuffer = ByteBuffer.allocate(2); + lenBuffer.putShort((short) ioBuffer.remaining()); + lenBuffer.flip(); + + while (lenBuffer.hasRemaining()) { + socket.getOutputStream().write(lenBuffer.get()); + } + + while (ioBuffer.hasRemaining()) { + socket.getOutputStream().write(ioBuffer.get()); + } + socket.getOutputStream().flush(); + socket.getOutputStream().close(); + socket.close(); + } + + AddressedEnvelope envelope = envelopeFuture.syncUninterruptibly().getNow(); + assertNotNull(envelope.sender()); + + DnsResponse response = envelope.content(); + assertNotNull(response); + + assertEquals(DnsResponseCode.NOERROR, response.code()); + int count = response.count(DnsSection.ANSWER); + + assertEquals(1, count); + List texts = decodeTxt(response.recordAt(DnsSection.ANSWER, 0)); + assertEquals(1, texts.size()); + assertEquals(txt, texts.get(0)); + + if (tcpFallback) { + assertFalse(envelope.content().isTruncated()); + } else { + assertTrue(envelope.content().isTruncated()); + } + envelope.release(); + } finally { + dnsServer2.stop(); + if (serverSocket != null) { + serverSocket.close(); + } + if (resolver != null) { + resolver.close(); + } + } + } }