Fix IndexOutOfBoundsException caused by consuming the buffer multiple times in DatagramDnsQueryDecoder (#11592)

Motivation:

ccef8feedd introduced some changes to share code but did introduce a regression when decoding queries.

Modifications:

- Correctly only decode one time.
- Adjust unit test

Result:

Fixes https://github.com/netty/netty/issues/11591
This commit is contained in:
Norman Maurer 2021-08-18 14:57:57 +02:00
parent 89866da252
commit 8e4465d14c
2 changed files with 18 additions and 69 deletions

View File

@ -15,11 +15,9 @@
*/ */
package io.netty.handler.codec.dns; package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.DatagramPacket; import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.internal.UnstableApi; import io.netty.util.internal.UnstableApi;
@ -51,71 +49,13 @@ public class DatagramDnsQueryDecoder extends MessageToMessageDecoder<DatagramPac
@Override @Override
protected void decode(ChannelHandlerContext ctx, DatagramPacket packet) throws Exception { protected void decode(ChannelHandlerContext ctx, DatagramPacket packet) throws Exception {
final ByteBuf buf = packet.content(); DnsQuery query = DnsMessageUtil.decodeDnsQuery(recordDecoder, packet.content(),
new DnsMessageUtil.DnsQueryFactory() {
DnsMessageUtil.decodeDnsQuery(recordDecoder, buf, new DnsMessageUtil.DnsQueryFactory() {
@Override @Override
public DnsQuery newQuery(int id, DnsOpCode dnsOpCode) { public DnsQuery newQuery(int id, DnsOpCode dnsOpCode) {
return new DatagramDnsQuery(packet.sender(), packet.recipient(), id, dnsOpCode); return new DatagramDnsQuery(packet.sender(), packet.recipient(), id, dnsOpCode);
} }
}); });
ctx.fireChannelRead(query);
DnsQuery query = newQuery(packet, buf);
try {
final int questionCount = buf.readUnsignedShort();
final int answerCount = buf.readUnsignedShort();
final int authorityRecordCount = buf.readUnsignedShort();
final int additionalRecordCount = buf.readUnsignedShort();
decodeQuestions(query, buf, questionCount);
decodeRecords(query, DnsSection.ANSWER, buf, answerCount);
decodeRecords(query, DnsSection.AUTHORITY, buf, authorityRecordCount);
decodeRecords(query, DnsSection.ADDITIONAL, buf, additionalRecordCount);
DnsQuery q = query;
query = null;
ctx.fireChannelRead(q);
} finally {
if (query != null) {
query.release();
}
}
}
private static DnsQuery newQuery(DatagramPacket packet, ByteBuf buf) {
final int id = buf.readUnsignedShort();
final int flags = buf.readUnsignedShort();
if (flags >> 15 == 1) {
throw new CorruptedFrameException("not a query");
}
final DnsQuery query =
new DatagramDnsQuery(
packet.sender(),
packet.recipient(),
id,
DnsOpCode.valueOf((byte) (flags >> 11 & 0xf)));
query.setRecursionDesired((flags >> 8 & 1) == 1);
query.setZ(flags >> 4 & 0x7);
return query;
}
private void decodeQuestions(DnsQuery query, ByteBuf buf, int questionCount) throws Exception {
for (int i = questionCount; i > 0; i--) {
query.addRecord(DnsSection.QUESTION, recordDecoder.decodeQuestion(buf));
}
}
private void decodeRecords(
DnsQuery query, DnsSection section, ByteBuf buf, int count) throws Exception {
for (int i = count; i > 0; i--) {
final DnsRecord r = recordDecoder.decodeRecord(buf);
if (r == null) {
// Truncated response
break;
}
query.addRecord(section, r);
}
} }
} }

View File

@ -27,15 +27,19 @@ import java.util.List;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class DnsQueryTest { public class DnsQueryTest {
@Test @Test
public void writeQueryTest() throws Exception { public void testEncodeAndDecodeQuery() {
InetSocketAddress addr = SocketUtils.socketAddress("8.8.8.8", 53); InetSocketAddress addr = SocketUtils.socketAddress("8.8.8.8", 53);
EmbeddedChannel embedder = new EmbeddedChannel(new DatagramDnsQueryEncoder()); EmbeddedChannel writeChannel = new EmbeddedChannel(new DatagramDnsQueryEncoder());
EmbeddedChannel readChannel = new EmbeddedChannel(new DatagramDnsQueryDecoder());
List<DnsQuery> queries = new ArrayList<>(5); List<DnsQuery> queries = new ArrayList<>(5);
queries.add(new DatagramDnsQuery(null, addr, 1).setRecord( queries.add(new DatagramDnsQuery(null, addr, 1).setRecord(
DnsSection.QUESTION, DnsSection.QUESTION,
@ -59,12 +63,17 @@ public class DnsQueryTest {
assertThat(query.count(DnsSection.AUTHORITY), is(0)); assertThat(query.count(DnsSection.AUTHORITY), is(0));
assertThat(query.count(DnsSection.ADDITIONAL), is(0)); assertThat(query.count(DnsSection.ADDITIONAL), is(0));
embedder.writeOutbound(query); assertTrue(writeChannel.writeOutbound(query));
DatagramPacket packet = embedder.readOutbound(); DatagramPacket packet = writeChannel.readOutbound();
assertTrue(packet.content().isReadable()); assertTrue(packet.content().isReadable());
packet.release(); assertTrue(readChannel.writeInbound(packet));
assertNull(embedder.readOutbound()); assertEquals(query, readChannel.readInbound());
assertNull(writeChannel.readOutbound());
assertNull(readChannel.readInbound());
} }
assertFalse(writeChannel.finish());
assertFalse(readChannel.finish());
} }
} }