Add TcpDnsQueryDecoder and TcpDnsResponseEncoder (#11415)

Motivation:
There is no decoder and encoder for TCP based DNS.

Result:
- Added decoder and encoder
- Added tests
- Added example

Result:

Be able to decode and encode TCP based dns 

Co-authored-by: Norman Maurer <norman_maurer@apple.com>
This commit is contained in:
RockyLOMO 2021-07-28 15:24:28 +08:00 committed by GitHub
parent c83c5b7d7c
commit ccef8feedd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 490 additions and 60 deletions

View File

@ -51,9 +51,16 @@ public class DatagramDnsQueryDecoder extends MessageToMessageDecoder<DatagramPac
}
@Override
protected void decode(ChannelHandlerContext ctx, DatagramPacket packet, List<Object> out) throws Exception {
protected void decode(ChannelHandlerContext ctx, final DatagramPacket packet, List<Object> out) throws Exception {
final ByteBuf buf = packet.content();
DnsMessageUtil.decodeDnsQuery(recordDecoder, buf, new DnsMessageUtil.DnsQueryFactory() {
@Override
public DnsQuery newQuery(int id, DnsOpCode dnsOpCode) {
return new DatagramDnsQuery(packet.sender(), packet.recipient(), id, dnsOpCode);
}
});
final DnsQuery query = newQuery(packet, buf);
boolean success = false;
try {

View File

@ -61,19 +61,7 @@ public class DatagramDnsResponseEncoder
final DnsResponse response = in.content();
final ByteBuf buf = allocateBuffer(ctx, in);
boolean success = false;
try {
encodeHeader(response, buf);
encodeQuestions(response, buf);
encodeRecords(response, DnsSection.ANSWER, buf);
encodeRecords(response, DnsSection.AUTHORITY, buf);
encodeRecords(response, DnsSection.ADDITIONAL, buf);
success = true;
} finally {
if (!success) {
buf.release();
}
}
DnsMessageUtil.encodeDnsResponse(recordEncoder, response, buf);
out.add(new DatagramPacket(buf, recipient, null));
}
@ -87,49 +75,4 @@ public class DatagramDnsResponseEncoder
@SuppressWarnings("unused") AddressedEnvelope<DnsResponse, InetSocketAddress> msg) throws Exception {
return ctx.alloc().ioBuffer(1024);
}
/**
* Encodes the header that is always 12 bytes long.
*
* @param response the response header being encoded
* @param buf the buffer the encoded data should be written to
*/
private static void encodeHeader(DnsResponse response, ByteBuf buf) {
buf.writeShort(response.id());
int flags = 32768;
flags |= (response.opCode().byteValue() & 0xFF) << 11;
if (response.isAuthoritativeAnswer()) {
flags |= 1 << 10;
}
if (response.isTruncated()) {
flags |= 1 << 9;
}
if (response.isRecursionDesired()) {
flags |= 1 << 8;
}
if (response.isRecursionAvailable()) {
flags |= 1 << 7;
}
flags |= response.z() << 4;
flags |= response.code().intValue();
buf.writeShort(flags);
buf.writeShort(response.count(DnsSection.QUESTION));
buf.writeShort(response.count(DnsSection.ANSWER));
buf.writeShort(response.count(DnsSection.AUTHORITY));
buf.writeShort(response.count(DnsSection.ADDITIONAL));
}
private void encodeQuestions(DnsResponse response, ByteBuf buf) throws Exception {
final int count = response.count(DnsSection.QUESTION);
for (int i = 0; i < count; i++) {
recordEncoder.encodeQuestion((DnsQuestion) response.recordAt(DnsSection.QUESTION, i), buf);
}
}
private void encodeRecords(DnsResponse response, DnsSection section, ByteBuf buf) throws Exception {
final int count = response.count(section);
for (int i = 0; i < count; i++) {
recordEncoder.encodeRecord(response.recordAt(section, i), buf);
}
}
}

View File

@ -15,7 +15,9 @@
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.util.internal.StringUtil;
import java.net.SocketAddress;
@ -177,5 +179,124 @@ final class DnsMessageUtil {
}
}
private DnsMessageUtil() { }
static DnsQuery decodeDnsQuery(DnsRecordDecoder decoder, ByteBuf buf, DnsQueryFactory supplier) throws Exception {
DnsQuery query = newQuery(buf, supplier);
boolean success = false;
try {
int questionCount = buf.readUnsignedShort();
int answerCount = buf.readUnsignedShort();
int authorityRecordCount = buf.readUnsignedShort();
int additionalRecordCount = buf.readUnsignedShort();
decodeQuestions(decoder, query, buf, questionCount);
decodeRecords(decoder, query, DnsSection.ANSWER, buf, answerCount);
decodeRecords(decoder, query, DnsSection.AUTHORITY, buf, authorityRecordCount);
decodeRecords(decoder, query, DnsSection.ADDITIONAL, buf, additionalRecordCount);
success = true;
return query;
} finally {
if (!success) {
query.release();
}
}
}
private static DnsQuery newQuery(ByteBuf buf, DnsQueryFactory supplier) {
int id = buf.readUnsignedShort();
int flags = buf.readUnsignedShort();
if (flags >> 15 == 1) {
throw new CorruptedFrameException("not a query");
}
DnsQuery query = supplier.newQuery(id, DnsOpCode.valueOf((byte) (flags >> 11 & 0xf)));
query.setRecursionDesired((flags >> 8 & 1) == 1);
query.setZ(flags >> 4 & 0x7);
return query;
}
private static void decodeQuestions(DnsRecordDecoder decoder,
DnsQuery query, ByteBuf buf, int questionCount) throws Exception {
for (int i = questionCount; i > 0; --i) {
query.addRecord(DnsSection.QUESTION, decoder.decodeQuestion(buf));
}
}
private static void decodeRecords(DnsRecordDecoder decoder,
DnsQuery query, DnsSection section, ByteBuf buf, int count) throws Exception {
for (int i = count; i > 0; --i) {
DnsRecord r = decoder.decodeRecord(buf);
if (r == null) {
break;
}
query.addRecord(section, r);
}
}
static void encodeDnsResponse(DnsRecordEncoder encoder, DnsResponse response, ByteBuf buf) throws Exception {
boolean success = false;
try {
encodeHeader(response, buf);
encodeQuestions(encoder, response, buf);
encodeRecords(encoder, response, DnsSection.ANSWER, buf);
encodeRecords(encoder, response, DnsSection.AUTHORITY, buf);
encodeRecords(encoder, response, DnsSection.ADDITIONAL, buf);
success = true;
} finally {
if (!success) {
buf.release();
}
}
}
/**
* Encodes the header that is always 12 bytes long.
*
* @param response the response header being encoded
* @param buf the buffer the encoded data should be written to
*/
private static void encodeHeader(DnsResponse response, ByteBuf buf) {
buf.writeShort(response.id());
int flags = 32768;
flags |= (response.opCode().byteValue() & 0xFF) << 11;
if (response.isAuthoritativeAnswer()) {
flags |= 1 << 10;
}
if (response.isTruncated()) {
flags |= 1 << 9;
}
if (response.isRecursionDesired()) {
flags |= 1 << 8;
}
if (response.isRecursionAvailable()) {
flags |= 1 << 7;
}
flags |= response.z() << 4;
flags |= response.code().intValue();
buf.writeShort(flags);
buf.writeShort(response.count(DnsSection.QUESTION));
buf.writeShort(response.count(DnsSection.ANSWER));
buf.writeShort(response.count(DnsSection.AUTHORITY));
buf.writeShort(response.count(DnsSection.ADDITIONAL));
}
private static void encodeQuestions(DnsRecordEncoder encoder, DnsResponse response, ByteBuf buf) throws Exception {
int count = response.count(DnsSection.QUESTION);
for (int i = 0; i < count; ++i) {
encoder.encodeQuestion(response.<DnsQuestion>recordAt(DnsSection.QUESTION, i), buf);
}
}
private static void encodeRecords(DnsRecordEncoder encoder,
DnsResponse response, DnsSection section, ByteBuf buf) throws Exception {
int count = response.count(section);
for (int i = 0; i < count; ++i) {
encoder.encodeRecord(response.recordAt(section, i), buf);
}
}
interface DnsQueryFactory {
DnsQuery newQuery(int id, DnsOpCode dnsOpCode);
}
private DnsMessageUtil() {
}
}

View File

@ -0,0 +1,55 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.internal.ObjectUtil;
public final class TcpDnsQueryDecoder extends LengthFieldBasedFrameDecoder {
private final DnsRecordDecoder decoder;
/**
* Creates a new decoder with {@linkplain DnsRecordDecoder#DEFAULT the default record decoder}.
*/
public TcpDnsQueryDecoder() {
this(DnsRecordDecoder.DEFAULT, 65535);
}
/**
* Creates a new decoder with the specified {@code decoder}.
*/
public TcpDnsQueryDecoder(DnsRecordDecoder decoder, int maxFrameLength) {
super(maxFrameLength, 0, 2, 0, 2);
this.decoder = ObjectUtil.checkNotNull(decoder, "decoder");
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf frame = (ByteBuf) super.decode(ctx, in);
if (frame == null) {
return null;
}
return DnsMessageUtil.decodeDnsQuery(decoder, frame.slice(), new DnsMessageUtil.DnsQueryFactory() {
@Override
public DnsQuery newQuery(int id, DnsOpCode dnsOpCode) {
return new DefaultDnsQuery(id, dnsOpCode);
}
});
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
import io.netty.util.internal.ObjectUtil;
import java.util.List;
@ChannelHandler.Sharable
public final class TcpDnsResponseEncoder extends MessageToMessageEncoder<DnsResponse> {
private final DnsRecordEncoder encoder;
/**
* Creates a new encoder with {@linkplain DnsRecordEncoder#DEFAULT the default record encoder}.
*/
public TcpDnsResponseEncoder() {
this(DnsRecordEncoder.DEFAULT);
}
/**
* Creates a new encoder with the specified {@code encoder}.
*/
public TcpDnsResponseEncoder(DnsRecordEncoder encoder) {
this.encoder = ObjectUtil.checkNotNull(encoder, "encoder");
}
@Override
protected void encode(ChannelHandlerContext ctx, DnsResponse response, List<Object> out) throws Exception {
ByteBuf buf = ctx.alloc().ioBuffer(1024);
buf.writerIndex(buf.writerIndex() + 2);
DnsMessageUtil.encodeDnsResponse(encoder, response, buf);
buf.setShort(0, buf.readableBytes() - 2);
out.add(buf);
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.dns;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.Test;
import java.util.Random;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TcpDnsTest {
private static final String QUERY_DOMAIN = "www.example.com";
private static final long TTL = 600;
private static final byte[] QUERY_RESULT = new byte[]{(byte) 192, (byte) 168, 1, 1};
@Test
public void testQueryDecode() {
EmbeddedChannel channel = new EmbeddedChannel(new TcpDnsQueryDecoder());
int randomID = new Random().nextInt(60000 - 1000) + 1000;
DnsQuery query = new DefaultDnsQuery(randomID, DnsOpCode.QUERY)
.setRecord(DnsSection.QUESTION, new DefaultDnsQuestion(QUERY_DOMAIN, DnsRecordType.A));
assertTrue(channel.writeInbound(query));
DnsQuery readQuery = channel.readInbound();
assertThat(readQuery, is(query));
assertThat(readQuery.recordAt(DnsSection.QUESTION), is(query.recordAt(DnsSection.QUESTION)));
assertThat(readQuery.recordAt(DnsSection.QUESTION).name(), is(query.recordAt(DnsSection.QUESTION).name()));
assertFalse(channel.finish());
}
@Test
public void testResponseEncode() {
EmbeddedChannel channel = new EmbeddedChannel(new TcpDnsResponseEncoder());
int randomID = new Random().nextInt(60000 - 1000) + 1000;
DnsQuery query = new DefaultDnsQuery(randomID, DnsOpCode.QUERY)
.setRecord(DnsSection.QUESTION, new DefaultDnsQuestion(QUERY_DOMAIN, DnsRecordType.A));
DnsQuestion question = query.recordAt(DnsSection.QUESTION);
channel.writeInbound(newResponse(query, question, QUERY_RESULT));
DnsResponse readResponse = channel.readInbound();
assertThat(readResponse.recordAt(DnsSection.QUESTION), is((DnsRecord) question));
DnsRawRecord record = new DefaultDnsRawRecord(question.name(),
DnsRecordType.A, TTL, Unpooled.wrappedBuffer(QUERY_RESULT));
assertThat(readResponse.recordAt(DnsSection.ANSWER), is((DnsRecord) record));
assertThat(readResponse.<DnsRawRecord>recordAt(DnsSection.ANSWER).content(), is(record.content()));
ReferenceCountUtil.release(readResponse);
ReferenceCountUtil.release(record);
assertFalse(channel.finish());
}
private static DefaultDnsResponse newResponse(DnsQuery query, DnsQuestion question, byte[]... addresses) {
DefaultDnsResponse response = new DefaultDnsResponse(query.id());
response.addRecord(DnsSection.QUESTION, question);
for (byte[] address : addresses) {
DefaultDnsRawRecord queryAnswer = new DefaultDnsRawRecord(question.name(),
DnsRecordType.A, TTL, Unpooled.wrappedBuffer(address));
response.addRecord(DnsSection.ANSWER, queryAnswer);
}
return response;
}
}

View File

@ -0,0 +1,166 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.example.dns.tcp;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.dns.DefaultDnsQuery;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsResponse;
import io.netty.handler.codec.dns.DnsOpCode;
import io.netty.handler.codec.dns.DnsQuery;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.dns.TcpDnsQueryDecoder;
import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
import io.netty.handler.codec.dns.TcpDnsResponseEncoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.NetUtil;
import java.util.Random;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public final class TcpDnsServer {
private static final String QUERY_DOMAIN = "www.example.com";
private static final int DNS_SERVER_PORT = 53;
private static final String DNS_SERVER_HOST = "127.0.0.1";
private static final byte[] QUERY_RESULT = new byte[]{(byte) 192, (byte) 168, 1, 1};
public static void main(String[] args) throws Exception {
ServerBootstrap bootstrap = new ServerBootstrap().group(new NioEventLoopGroup(1),
new NioEventLoopGroup())
.channel(NioServerSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new TcpDnsQueryDecoder(), new TcpDnsResponseEncoder(),
new SimpleChannelInboundHandler<DnsQuery>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx,
DnsQuery msg) throws Exception {
DnsQuestion question = msg.recordAt(DnsSection.QUESTION);
System.out.println("Query domain: " + question);
//always return 192.168.1.1
ctx.writeAndFlush(newResponse(msg, question, 600, QUERY_RESULT));
}
private DefaultDnsResponse newResponse(DnsQuery query,
DnsQuestion question,
long ttl, byte[]... addresses) {
DefaultDnsResponse response = new DefaultDnsResponse(query.id());
response.addRecord(DnsSection.QUESTION, question);
for (byte[] address : addresses) {
DefaultDnsRawRecord queryAnswer = new DefaultDnsRawRecord(
question.name(),
DnsRecordType.A, ttl, Unpooled.wrappedBuffer(address));
response.addRecord(DnsSection.ANSWER, queryAnswer);
}
return response;
}
});
}
});
final Channel channel = bootstrap.bind(DNS_SERVER_PORT).channel();
Executors.newSingleThreadScheduledExecutor().schedule(new Runnable() {
@Override
public void run() {
try {
clientQuery();
channel.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}, 1000, TimeUnit.MILLISECONDS);
channel.closeFuture().sync();
}
// copy from TcpDnsClient.java
private static void clientQuery() throws Exception {
NioEventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) {
ch.pipeline().addLast(new TcpDnsQueryEncoder())
.addLast(new TcpDnsResponseDecoder())
.addLast(new SimpleChannelInboundHandler<DefaultDnsResponse>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, DefaultDnsResponse msg) {
try {
handleQueryResp(msg);
} finally {
ctx.close();
}
}
});
}
});
final Channel ch = b.connect(DNS_SERVER_HOST, DNS_SERVER_PORT).sync().channel();
int randomID = new Random().nextInt(60000 - 1000) + 1000;
DnsQuery query = new DefaultDnsQuery(randomID, DnsOpCode.QUERY)
.setRecord(DnsSection.QUESTION, new DefaultDnsQuestion(QUERY_DOMAIN, DnsRecordType.A));
ch.writeAndFlush(query).sync();
boolean success = ch.closeFuture().await(10, TimeUnit.SECONDS);
if (!success) {
System.err.println("dns query timeout!");
ch.close().sync();
}
} finally {
group.shutdownGracefully();
}
}
private static void handleQueryResp(DefaultDnsResponse msg) {
if (msg.count(DnsSection.QUESTION) > 0) {
DnsQuestion question = msg.recordAt(DnsSection.QUESTION, 0);
System.out.printf("name: %s%n", question.name());
}
for (int i = 0, count = msg.count(DnsSection.ANSWER); i < count; i++) {
DnsRecord record = msg.recordAt(DnsSection.ANSWER, i);
if (record.type() == DnsRecordType.A) {
//just print the IP after query
DnsRawRecord raw = (DnsRawRecord) record;
System.out.println(NetUtil.bytesToIpAddress(ByteBufUtil.getBytes(raw.content())));
}
}
}
}