Detect truncated responses caused by EDNS0 and MTU miss-match (#9468)

Motivation:

It is possible that the user uses a too big EDNS0 setting for the MTU and so we may receive a truncated datagram packet. In this case we should try to detect this and retry via TCP if possible

Modifications:

- Fix detecting of incomplete records
- Mark response as truncated if we did not consume the whole packet
- Add unit test

Result:

Fixes https://github.com/netty/netty/issues/9365
This commit is contained in:
Norman Maurer 2019-08-17 09:58:22 +02:00
parent 188c927576
commit 1a53df1031
6 changed files with 102 additions and 27 deletions

View File

@ -55,6 +55,10 @@ public class DatagramDnsResponseDecoder extends MessageToMessageDecoder<Datagram
@Override
protected void decode(ChannelHandlerContext ctx, DatagramPacket packet, List<Object> out) throws Exception {
out.add(responseDecoder.decode(packet.sender(), packet.recipient(), packet.content()));
out.add(decodeResponse(ctx, packet));
}
protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket packet) throws Exception {
return responseDecoder.decode(packet.sender(), packet.recipient(), packet.content());
}
}

View File

@ -47,7 +47,7 @@ public class DefaultDnsRecordDecoder implements DnsRecordDecoder {
final String name = decodeName(in);
final int endOffset = in.writerIndex();
if (endOffset - startOffset < 10) {
if (endOffset - in.readerIndex() < 10) {
// Not enough data
in.readerIndex(startOffset);
return null;

View File

@ -60,8 +60,15 @@ abstract class DnsResponseDecoder<A extends SocketAddress> {
final int additionalRecordCount = buffer.readUnsignedShort();
decodeQuestions(response, buffer, questionCount);
decodeRecords(response, DnsSection.ANSWER, buffer, answerCount);
decodeRecords(response, DnsSection.AUTHORITY, buffer, authorityRecordCount);
if (!decodeRecords(response, DnsSection.ANSWER, buffer, answerCount)) {
success = true;
return response;
}
if (!decodeRecords(response, DnsSection.AUTHORITY, buffer, authorityRecordCount)) {
success = true;
return response;
}
decodeRecords(response, DnsSection.ADDITIONAL, buffer, additionalRecordCount);
success = true;
return response;
@ -81,16 +88,17 @@ abstract class DnsResponseDecoder<A extends SocketAddress> {
}
}
private void decodeRecords(
private boolean decodeRecords(
DnsResponse response, DnsSection section, ByteBuf buf, int count) throws Exception {
for (int i = count; i > 0; i --) {
final DnsRecord r = recordDecoder.decodeRecord(buf);
if (r == null) {
// Truncated response
break;
return false;
}
response.addRecord(section, r);
}
return true;
}
}

View File

@ -232,4 +232,24 @@ public class DefaultDnsRecordDecoderTest {
buffer.release();
}
}
@Test
public void testTruncatedPacket() throws Exception {
ByteBuf buffer = Unpooled.buffer();
buffer.writeByte(0);
buffer.writeShort(DnsRecordType.A.intValue());
buffer.writeShort(1);
buffer.writeInt(32);
// Write a truncated last value.
buffer.writeByte(0);
DefaultDnsRecordDecoder decoder = new DefaultDnsRecordDecoder();
try {
int readerIndex = buffer.readerIndex();
assertNull(decoder.decodeRecord(buffer));
assertEquals(readerIndex, buffer.readerIndex());
} finally {
buffer.release();
}
}
}

View File

@ -32,6 +32,7 @@ import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.dns.DatagramDnsQueryEncoder;
@ -187,7 +188,25 @@ public class DnsNameResolver extends InetNameResolver {
return (List<String>) nameservers.invoke(instance);
}
private static final DatagramDnsResponseDecoder DATAGRAM_DECODER = new DatagramDnsResponseDecoder();
private static final DatagramDnsResponseDecoder DATAGRAM_DECODER = new DatagramDnsResponseDecoder() {
@Override
protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket packet) throws Exception {
DnsResponse response = super.decodeResponse(ctx, packet);
if (packet.content().isReadable()) {
// If there is still something to read we did stop parsing because of a truncated message.
// This can happen if we enabled EDNS0 but our MTU is not big enough to handle all the
// data.
response.setTruncated(true);
if (logger.isDebugEnabled()) {
logger.debug(
"{} RECEIVED: UDP truncated packet received, consider adjusting maxPayloadSize for the {}.",
ctx.channel(), StringUtil.simpleClassName(DnsNameResolver.class));
}
}
return response;
}
};
private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder();
private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();

View File

@ -20,12 +20,15 @@ import io.netty.buffer.ByteBufHolder;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultithreadEventLoopGroup;
import io.netty.channel.ReflectiveChannelFactory;
import io.netty.channel.nio.NioHandler;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
@ -2715,15 +2718,20 @@ public class DnsNameResolverTest {
@Test(timeout = 5000)
public void testTruncatedWithoutTcpFallback() throws IOException {
testTruncated0(false);
testTruncated0(false, false);
}
@Test(timeout = 5000)
public void testTruncatedWithTcpFallback() throws IOException {
testTruncated0(true);
testTruncated0(true, false);
}
private static void testTruncated0(boolean tcpFallback) throws IOException {
@Test(timeout = 5000)
public void testTruncatedWithTcpFallbackBecauseOfMtu() throws IOException {
testTruncated0(true, true);
}
private static void testTruncated0(boolean tcpFallback, final boolean truncatedBecauseOfMtu) throws IOException {
final String host = "somehost.netty.io";
final String txt = "this is a txt record";
final AtomicReference<DnsMessage> messageRef = new AtomicReference<DnsMessage>();
@ -2746,6 +2754,7 @@ public class DnsNameResolverTest {
// Store a original message so we can replay it later on.
messageRef.set(message);
if (!truncatedBecauseOfMtu) {
// Create a copy of the message but set the truncated flag.
DnsMessageModifier modifier = new DnsMessageModifier();
modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData());
@ -2764,6 +2773,8 @@ public class DnsNameResolverTest {
modifier.setTruncated(true);
return modifier.getDnsMessage();
}
return message;
}
};
dnsServer2.start();
DnsNameResolver resolver = null;
@ -2783,6 +2794,19 @@ public class DnsNameResolverTest {
builder.socketChannelType(NioSocketChannel.class);
}
resolver = builder.build();
if (truncatedBecauseOfMtu) {
resolver.ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof DatagramPacket) {
// Truncate the packet by 1 byte.
DatagramPacket packet = (DatagramPacket) msg;
packet.content().writerIndex(packet.content().writerIndex() - 1);
}
ctx.fireChannelRead(msg);
}
});
}
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> envelopeFuture = resolver.query(
new DefaultDnsQuestion(host, DnsRecordType.TXT));