diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java index da973263ac..b9c4a1cf81 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DatagramDnsResponseDecoder.java @@ -55,6 +55,10 @@ public class DatagramDnsResponseDecoder extends MessageToMessageDecoder out) throws Exception { - out.add(responseDecoder.decode(packet.sender(), packet.recipient(), packet.content())); + out.add(decodeResponse(ctx, packet)); + } + + protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket packet) throws Exception { + return responseDecoder.decode(packet.sender(), packet.recipient(), packet.content()); } } diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java index e61d46cc20..70df1c6f39 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoder.java @@ -47,7 +47,7 @@ public class DefaultDnsRecordDecoder implements DnsRecordDecoder { final String name = decodeName(in); final int endOffset = in.writerIndex(); - if (endOffset - startOffset < 10) { + if (endOffset - in.readerIndex() < 10) { // Not enough data in.readerIndex(startOffset); return null; diff --git a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java index fd5a1bed91..aadf3c25c4 100644 --- a/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java +++ b/codec-dns/src/main/java/io/netty/handler/codec/dns/DnsResponseDecoder.java @@ -61,8 +61,15 @@ abstract class DnsResponseDecoder { final int additionalRecordCount = buffer.readUnsignedShort(); decodeQuestions(response, buffer, questionCount); - decodeRecords(response, DnsSection.ANSWER, buffer, answerCount); - decodeRecords(response, DnsSection.AUTHORITY, buffer, authorityRecordCount); + if (!decodeRecords(response, DnsSection.ANSWER, buffer, answerCount)) { + success = true; + return response; + } + if (!decodeRecords(response, DnsSection.AUTHORITY, buffer, authorityRecordCount)) { + success = true; + return response; + } + decodeRecords(response, DnsSection.ADDITIONAL, buffer, additionalRecordCount); success = true; return response; @@ -82,16 +89,17 @@ abstract class DnsResponseDecoder { } } - private void decodeRecords( + private boolean decodeRecords( DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception { for (int i = count; i > 0; i --) { final DnsRecord r = recordDecoder.decodeRecord(buf); if (r == null) { // Truncated response - break; + return false; } response.addRecord(section, r); } + return true; } } diff --git a/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java b/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java index a90acaa283..b18c19ecbe 100644 --- a/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java +++ b/codec-dns/src/test/java/io/netty/handler/codec/dns/DefaultDnsRecordDecoderTest.java @@ -232,4 +232,24 @@ public class DefaultDnsRecordDecoderTest { buffer.release(); } } + + @Test + public void testTruncatedPacket() throws Exception { + ByteBuf buffer = Unpooled.buffer(); + buffer.writeByte(0); + buffer.writeShort(DnsRecordType.A.intValue()); + buffer.writeShort(1); + buffer.writeInt(32); + + // Write a truncated last value. + buffer.writeByte(0); + DefaultDnsRecordDecoder decoder = new DefaultDnsRecordDecoder(); + try { + int readerIndex = buffer.readerIndex(); + assertNull(decoder.decodeRecord(buffer)); + assertEquals(readerIndex, buffer.readerIndex()); + } finally { + buffer.release(); + } + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index dcdac54019..c98179e2a2 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -31,6 +31,7 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramPacket; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.channel.socket.SocketChannel; import io.netty.handler.codec.dns.DatagramDnsQueryEncoder; @@ -186,7 +187,25 @@ public class DnsNameResolver extends InetNameResolver { return (List) nameservers.invoke(instance); } - private static final DatagramDnsResponseDecoder DATAGRAM_DECODER = new DatagramDnsResponseDecoder(); + private static final DatagramDnsResponseDecoder DATAGRAM_DECODER = new DatagramDnsResponseDecoder() { + @Override + protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket packet) throws Exception { + DnsResponse response = super.decodeResponse(ctx, packet); + if (packet.content().isReadable()) { + // If there is still something to read we did stop parsing because of a truncated message. + // This can happen if we enabled EDNS0 but our MTU is not big enough to handle all the + // data. + response.setTruncated(true); + + if (logger.isDebugEnabled()) { + logger.debug( + "{} RECEIVED: UDP truncated packet received, consider adjusting maxPayloadSize for the {}.", + ctx.channel(), StringUtil.simpleClassName(DnsNameResolver.class)); + } + } + return response; + } + }; private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder(); private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder(); diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index 25484c0a5e..38a27ca4a1 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -20,11 +20,14 @@ import io.netty.buffer.ByteBufHolder; import io.netty.channel.AddressedEnvelope; import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramPacket; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.channel.socket.nio.NioSocketChannel; @@ -2743,15 +2746,20 @@ public class DnsNameResolverTest { @Test(timeout = 5000) public void testTruncatedWithoutTcpFallback() throws IOException { - testTruncated0(false); + testTruncated0(false, false); } @Test(timeout = 5000) public void testTruncatedWithTcpFallback() throws IOException { - testTruncated0(true); + testTruncated0(true, false); } - private static void testTruncated0(boolean tcpFallback) throws IOException { + @Test(timeout = 5000) + public void testTruncatedWithTcpFallbackBecauseOfMtu() throws IOException { + testTruncated0(true, true); + } + + private static void testTruncated0(boolean tcpFallback, final boolean truncatedBecauseOfMtu) throws IOException { final String host = "somehost.netty.io"; final String txt = "this is a txt record"; final AtomicReference messageRef = new AtomicReference(); @@ -2774,23 +2782,26 @@ public class DnsNameResolverTest { // Store a original message so we can replay it later on. messageRef.set(message); - // Create a copy of the message but set the truncated flag. - DnsMessageModifier modifier = new DnsMessageModifier(); - modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData()); - modifier.setAdditionalRecords(message.getAdditionalRecords()); - modifier.setAnswerRecords(message.getAnswerRecords()); - modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer()); - modifier.setAuthorityRecords(message.getAuthorityRecords()); - modifier.setMessageType(message.getMessageType()); - modifier.setOpCode(message.getOpCode()); - modifier.setQuestionRecords(message.getQuestionRecords()); - modifier.setRecursionAvailable(message.isRecursionAvailable()); - modifier.setRecursionDesired(message.isRecursionDesired()); - modifier.setReserved(message.isReserved()); - modifier.setResponseCode(message.getResponseCode()); - modifier.setTransactionId(message.getTransactionId()); - modifier.setTruncated(true); - return modifier.getDnsMessage(); + if (!truncatedBecauseOfMtu) { + // Create a copy of the message but set the truncated flag. + DnsMessageModifier modifier = new DnsMessageModifier(); + modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData()); + modifier.setAdditionalRecords(message.getAdditionalRecords()); + modifier.setAnswerRecords(message.getAnswerRecords()); + modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer()); + modifier.setAuthorityRecords(message.getAuthorityRecords()); + modifier.setMessageType(message.getMessageType()); + modifier.setOpCode(message.getOpCode()); + modifier.setQuestionRecords(message.getQuestionRecords()); + modifier.setRecursionAvailable(message.isRecursionAvailable()); + modifier.setRecursionDesired(message.isRecursionDesired()); + modifier.setReserved(message.isReserved()); + modifier.setResponseCode(message.getResponseCode()); + modifier.setTransactionId(message.getTransactionId()); + modifier.setTruncated(true); + return modifier.getDnsMessage(); + } + return message; } }; dnsServer2.start(); @@ -2811,6 +2822,19 @@ public class DnsNameResolverTest { builder.socketChannelType(NioSocketChannel.class); } resolver = builder.build(); + if (truncatedBecauseOfMtu) { + resolver.ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof DatagramPacket) { + // Truncate the packet by 1 byte. + DatagramPacket packet = (DatagramPacket) msg; + packet.content().writerIndex(packet.content().writerIndex() - 1); + } + ctx.fireChannelRead(msg); + } + }); + } Future> envelopeFuture = resolver.query( new DefaultDnsQuestion(host, DnsRecordType.TXT));