Fix DnsNameResolver TCP fallback test and message leaks (#9647)

Motivation

A memory leak related to DNS resolution was reported in #9634,
specifically linked to the TCP retry fallback functionality that was
introduced relatively recently. Upon inspection it's apparent that there
are some error paths where the original UDP response might not be fully
released, and more significantly the TCP response actually leaks every
time on the fallback success path.

It turns out that a bug in the unit test meant that the intended TCP
fallback path was not actually exercised, so it did not expose the main
leak in question.

Modifications

- Fix DnsNameResolverTest#testTruncated0 dummy server fallback logic to
first read transaction id of retried query and use it in replayed
response
- Adjust semantic of internal DnsQueryContext#finish method to always
take refcount ownership of passed in envelope
- Reorder some logic in DnsResponseHandler fallback handling to verify
the context of the response is expected, and ensure that the query
response are either released or propagated in all cases. This also
reduces a number of redundant retain/release pairings

Result

Fixes #9634
This commit is contained in:
Nick Hill 2019-10-14 22:10:15 +08:00 committed by Norman Maurer
parent 6c05d16967
commit 2e5dd28800
3 changed files with 132 additions and 144 deletions

View File

@ -50,7 +50,6 @@ import io.netty.resolver.HostsFileEntriesResolver;
import io.netty.resolver.InetNameResolver;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.util.NetUtil;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.FastThreadLocal;
import io.netty.util.concurrent.Future;
@ -1199,129 +1198,107 @@ public class DnsNameResolver extends InetNameResolver {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
final DatagramDnsResponse res = (DatagramDnsResponse) msg;
final int queryId = res.id();
final DatagramDnsResponse res = (DatagramDnsResponse) msg;
final int queryId = res.id();
if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res);
}
if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res);
}
final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId);
if (qCtx == null) {
logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId);
return;
}
final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId);
if (qCtx == null) {
logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId);
res.release();
return;
}
// Check if the response was truncated and if we can fallback to TCP to retry.
if (res.isTruncated() && socketChannelFactory != null) {
// Let's retain as we may need it later on.
res.retain();
// Check if the response was truncated and if we can fallback to TCP to retry.
if (!res.isTruncated() || socketChannelFactory == null) {
qCtx.finish(res);
return;
}
Bootstrap bs = new Bootstrap();
bs.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(TCP_ENCODER);
ch.pipeline().addLast(new TcpDnsResponseDecoder());
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private boolean finish;
Bootstrap bs = new Bootstrap();
bs.option(ChannelOption.SO_REUSEADDR, true)
.group(executor())
.channelFactory(socketChannelFactory)
.handler(TCP_ENCODER);
bs.connect(res.sender()).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
if (logger.isDebugEnabled()) {
logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
Channel channel = ctx.channel();
DnsResponse response = (DnsResponse) msg;
int queryId = response.id();
// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
return;
}
final Channel channel = future.channel();
if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId,
channel.remoteAddress(), response);
}
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
channel.eventLoop().newPromise();
final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(DnsNameResolver.this, channel,
(InetSocketAddress) channel.remoteAddress(), qCtx.question(),
EMPTY_ADDITIONALS, promise);
DnsQueryContext tcpCtx = queryContextManager.get(res.sender(), queryId);
if (tcpCtx == null) {
logger.warn("{} Received a DNS response with an unknown ID: {}",
channel, queryId);
qCtx.finish(res);
return;
}
// Release the original response as we will use the response that we
// received via TCP fallback.
res.release();
tcpCtx.finish(new AddressedEnvelopeAdapter(
(InetSocketAddress) ctx.channel().remoteAddress(),
(InetSocketAddress) ctx.channel().localAddress(),
response));
finish = true;
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (!finish) {
if (logger.isDebugEnabled()) {
logger.debug("{} Error during processing response: TCP [{}: {}]",
ctx.channel(), queryId,
ctx.channel().remoteAddress(), cause);
}
// TCP fallback failed, just use the truncated response as
qCtx.finish(res);
}
}
});
}
});
bs.connect(res.sender()).addListener(new ChannelFutureListener() {
channel.pipeline().addLast(new TcpDnsResponseDecoder());
channel.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void operationComplete(ChannelFuture future) {
if (future.isSuccess()) {
final Channel channel = future.channel();
public void channelRead(ChannelHandlerContext ctx, Object msg) {
Channel channel = ctx.channel();
DnsResponse response = (DnsResponse) msg;
int queryId = response.id();
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
channel.eventLoop().newPromise();
new TcpDnsQueryContext(DnsNameResolver.this, channel,
(InetSocketAddress) channel.remoteAddress(), qCtx.question(),
EMPTY_ADDITIONALS, promise).query(true, future.channel().newPromise());
promise.addListener(
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@Override
public void operationComplete(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
channel.close();
if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId,
channel.remoteAddress(), response);
}
if (future.isSuccess()) {
qCtx.finish(future.getNow());
} else {
// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
}
}
});
DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId);
if (foundCtx == tcpCtx) {
tcpCtx.finish(new AddressedEnvelopeAdapter(
(InetSocketAddress) ctx.channel().remoteAddress(),
(InetSocketAddress) ctx.channel().localAddress(),
response));
} else {
if (logger.isDebugEnabled()) {
logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause());
}
response.release();
tcpCtx.tryFailure("Received TCP response with unexpected ID", null, false);
logger.warn("{} Received a DNS response with an unexpected ID: {}",
channel, queryId);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (tcpCtx.tryFailure("TCP fallback error", cause, false) && logger.isDebugEnabled()) {
logger.debug("{} Error during processing response: TCP [{}: {}]",
ctx.channel(), queryId,
ctx.channel().remoteAddress(), cause);
}
}
});
promise.addListener(
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@Override
public void operationComplete(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
channel.close();
if (future.isSuccess()) {
qCtx.finish(future.getNow());
res.release();
} else {
// TCP fallback failed, just use the truncated response.
qCtx.finish(res);
}
}
});
} else {
qCtx.finish(res);
tcpCtx.query(true, future.channel().newPromise());
}
} finally {
ReferenceCountUtil.safeRelease(msg);
}
});
}
@Override

View File

@ -159,7 +159,7 @@ abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRe
private void onQueryWriteCompletion(ChannelFuture writeFuture) {
if (!writeFuture.isSuccess()) {
setFailure("failed to send a query via " + protocol(), writeFuture.cause());
tryFailure("failed to send a query via " + protocol(), writeFuture.cause(), false);
return;
}
@ -174,40 +174,37 @@ abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRe
return;
}
setFailure("query via " + protocol() + " timed out after " +
queryTimeoutMillis + " milliseconds", null);
tryFailure("query via " + protocol() + " timed out after " +
queryTimeoutMillis + " milliseconds", null, true);
}
}, queryTimeoutMillis, TimeUnit.MILLISECONDS);
}
}
/**
* Takes ownership of passed envelope
*/
void finish(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
final DnsResponse res = envelope.content();
if (res.count(DnsSection.QUESTION) != 1) {
logger.warn("Received a DNS response with invalid number of questions: {}", envelope);
return;
}
if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
} else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
logger.warn("Received a mismatching DNS response: {}", envelope);
return;
} else if (trySuccess(envelope)) {
return; // Ownership transferred, don't release
}
setSuccess(envelope);
envelope.release();
}
private void setSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise = this.promise;
@SuppressWarnings("unchecked")
AddressedEnvelope<DnsResponse, InetSocketAddress> castResponse =
(AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope.retain();
if (!promise.trySuccess(castResponse)) {
// We failed to notify the promise as it was failed before, thus we need to release the envelope
envelope.release();
}
@SuppressWarnings("unchecked")
private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
}
private void setFailure(String message, Throwable cause) {
boolean tryFailure(String message, Throwable cause, boolean timeout) {
if (promise.isDone()) {
return false;
}
final InetSocketAddress nameServerAddr = nameServerAddr();
final StringBuilder buf = new StringBuilder(message.length() + 64);
@ -218,14 +215,14 @@ abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRe
.append(" (no stack trace available)");
final DnsNameResolverException e;
if (cause == null) {
if (timeout) {
// This was caused by an timeout so use DnsNameResolverTimeoutException to allow the user to
// handle it special (like retry the query).
e = new DnsNameResolverTimeoutException(nameServerAddr, question(), buf.toString());
} else {
e = new DnsNameResolverException(nameServerAddr, question(), buf.toString(), cause);
}
promise.tryFailure(e);
return promise.tryFailure(e);
}
@Override

View File

@ -71,6 +71,7 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.IOException;
import java.io.InputStream;
import java.net.DatagramSocket;
import java.net.Inet4Address;
import java.net.InetAddress;
@ -2759,6 +2760,25 @@ public class DnsNameResolverTest {
testTruncated0(true, true);
}
private static DnsMessageModifier modifierFrom(DnsMessage message) {
DnsMessageModifier modifier = new DnsMessageModifier();
modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData());
modifier.setAdditionalRecords(message.getAdditionalRecords());
modifier.setAnswerRecords(message.getAnswerRecords());
modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer());
modifier.setAuthorityRecords(message.getAuthorityRecords());
modifier.setMessageType(message.getMessageType());
modifier.setOpCode(message.getOpCode());
modifier.setQuestionRecords(message.getQuestionRecords());
modifier.setRecursionAvailable(message.isRecursionAvailable());
modifier.setRecursionDesired(message.isRecursionDesired());
modifier.setReserved(message.isReserved());
modifier.setResponseCode(message.getResponseCode());
modifier.setTransactionId(message.getTransactionId());
modifier.setTruncated(message.isTruncated());
return modifier;
}
private static void testTruncated0(boolean tcpFallback, final boolean truncatedBecauseOfMtu) throws IOException {
final String host = "somehost.netty.io";
final String txt = "this is a txt record";
@ -2784,20 +2804,7 @@ public class DnsNameResolverTest {
if (!truncatedBecauseOfMtu) {
// Create a copy of the message but set the truncated flag.
DnsMessageModifier modifier = new DnsMessageModifier();
modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData());
modifier.setAdditionalRecords(message.getAdditionalRecords());
modifier.setAnswerRecords(message.getAnswerRecords());
modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer());
modifier.setAuthorityRecords(message.getAuthorityRecords());
modifier.setMessageType(message.getMessageType());
modifier.setOpCode(message.getOpCode());
modifier.setQuestionRecords(message.getQuestionRecords());
modifier.setRecursionAvailable(message.isRecursionAvailable());
modifier.setRecursionDesired(message.isRecursionDesired());
modifier.setReserved(message.isReserved());
modifier.setResponseCode(message.getResponseCode());
modifier.setTransactionId(message.getTransactionId());
DnsMessageModifier modifier = modifierFrom(message);
modifier.setTruncated(true);
return modifier.getDnsMessage();
}
@ -2842,8 +2849,15 @@ public class DnsNameResolverTest {
// If we are configured to use TCP as a fallback lets replay the dns message over TCP
Socket socket = serverSocket.accept();
InputStream in = socket.getInputStream();
assertTrue((in.read() << 8 | (in.read() & 0xff)) > 2); // skip length field
int txnId = in.read() << 8 | (in.read() & 0xff);
IoBuffer ioBuffer = IoBuffer.allocate(1024);
new DnsMessageEncoder().encode(ioBuffer, messageRef.get());
// Must replace the transactionId with the one from the TCP request
DnsMessageModifier modifier = modifierFrom(messageRef.get());
modifier.setTransactionId(txnId);
new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage());
ioBuffer.flip();
ByteBuffer lenBuffer = ByteBuffer.allocate(2);
@ -2884,7 +2898,7 @@ public class DnsNameResolverTest {
} else {
assertTrue(envelope.content().isTruncated());
}
envelope.release();
assertTrue(envelope.release());
} finally {
dnsServer2.stop();
if (resolver != null) {