Fix DnsNameResolver TCP fallback test and message leaks (#9647)
Motivation A memory leak related to DNS resolution was reported in #9634, specifically linked to the TCP retry fallback functionality that was introduced relatively recently. Upon inspection it's apparent that there are some error paths where the original UDP response might not be fully released, and more significantly the TCP response actually leaks every time on the fallback success path. It turns out that a bug in the unit test meant that the intended TCP fallback path was not actually exercised, so it did not expose the main leak in question. Modifications - Fix DnsNameResolverTest#testTruncated0 dummy server fallback logic to first read transaction id of retried query and use it in replayed response - Adjust semantic of internal DnsQueryContext#finish method to always take refcount ownership of passed in envelope - Reorder some logic in DnsResponseHandler fallback handling to verify the context of the response is expected, and ensure that the query response are either released or propagated in all cases. This also reduces a number of redundant retain/release pairings Result Fixes #9634
This commit is contained in:
parent
6c05d16967
commit
2e5dd28800
@ -50,7 +50,6 @@ import io.netty.resolver.HostsFileEntriesResolver;
|
||||
import io.netty.resolver.InetNameResolver;
|
||||
import io.netty.resolver.ResolvedAddressTypes;
|
||||
import io.netty.util.NetUtil;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
import io.netty.util.concurrent.EventExecutor;
|
||||
import io.netty.util.concurrent.FastThreadLocal;
|
||||
import io.netty.util.concurrent.Future;
|
||||
@ -1199,129 +1198,107 @@ public class DnsNameResolver extends InetNameResolver {
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) {
|
||||
try {
|
||||
final DatagramDnsResponse res = (DatagramDnsResponse) msg;
|
||||
final int queryId = res.id();
|
||||
final DatagramDnsResponse res = (DatagramDnsResponse) msg;
|
||||
final int queryId = res.id();
|
||||
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res);
|
||||
}
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res);
|
||||
}
|
||||
|
||||
final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId);
|
||||
if (qCtx == null) {
|
||||
logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId);
|
||||
return;
|
||||
}
|
||||
final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId);
|
||||
if (qCtx == null) {
|
||||
logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId);
|
||||
res.release();
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if the response was truncated and if we can fallback to TCP to retry.
|
||||
if (res.isTruncated() && socketChannelFactory != null) {
|
||||
// Let's retain as we may need it later on.
|
||||
res.retain();
|
||||
// Check if the response was truncated and if we can fallback to TCP to retry.
|
||||
if (!res.isTruncated() || socketChannelFactory == null) {
|
||||
qCtx.finish(res);
|
||||
return;
|
||||
}
|
||||
|
||||
Bootstrap bs = new Bootstrap();
|
||||
bs.option(ChannelOption.SO_REUSEADDR, true)
|
||||
.group(executor())
|
||||
.channelFactory(socketChannelFactory)
|
||||
.handler(new ChannelInitializer<Channel>() {
|
||||
@Override
|
||||
protected void initChannel(Channel ch) {
|
||||
ch.pipeline().addLast(TCP_ENCODER);
|
||||
ch.pipeline().addLast(new TcpDnsResponseDecoder());
|
||||
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
|
||||
private boolean finish;
|
||||
Bootstrap bs = new Bootstrap();
|
||||
bs.option(ChannelOption.SO_REUSEADDR, true)
|
||||
.group(executor())
|
||||
.channelFactory(socketChannelFactory)
|
||||
.handler(TCP_ENCODER);
|
||||
bs.connect(res.sender()).addListener(new ChannelFutureListener() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) {
|
||||
if (!future.isSuccess()) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) {
|
||||
try {
|
||||
Channel channel = ctx.channel();
|
||||
DnsResponse response = (DnsResponse) msg;
|
||||
int queryId = response.id();
|
||||
// TCP fallback failed, just use the truncated response.
|
||||
qCtx.finish(res);
|
||||
return;
|
||||
}
|
||||
final Channel channel = future.channel();
|
||||
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId,
|
||||
channel.remoteAddress(), response);
|
||||
}
|
||||
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
|
||||
channel.eventLoop().newPromise();
|
||||
final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(DnsNameResolver.this, channel,
|
||||
(InetSocketAddress) channel.remoteAddress(), qCtx.question(),
|
||||
EMPTY_ADDITIONALS, promise);
|
||||
|
||||
DnsQueryContext tcpCtx = queryContextManager.get(res.sender(), queryId);
|
||||
if (tcpCtx == null) {
|
||||
logger.warn("{} Received a DNS response with an unknown ID: {}",
|
||||
channel, queryId);
|
||||
qCtx.finish(res);
|
||||
return;
|
||||
}
|
||||
|
||||
// Release the original response as we will use the response that we
|
||||
// received via TCP fallback.
|
||||
res.release();
|
||||
|
||||
tcpCtx.finish(new AddressedEnvelopeAdapter(
|
||||
(InetSocketAddress) ctx.channel().remoteAddress(),
|
||||
(InetSocketAddress) ctx.channel().localAddress(),
|
||||
response));
|
||||
|
||||
finish = true;
|
||||
} finally {
|
||||
ReferenceCountUtil.release(msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
|
||||
if (!finish) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} Error during processing response: TCP [{}: {}]",
|
||||
ctx.channel(), queryId,
|
||||
ctx.channel().remoteAddress(), cause);
|
||||
}
|
||||
// TCP fallback failed, just use the truncated response as
|
||||
qCtx.finish(res);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
bs.connect(res.sender()).addListener(new ChannelFutureListener() {
|
||||
channel.pipeline().addLast(new TcpDnsResponseDecoder());
|
||||
channel.pipeline().addLast(new ChannelInboundHandlerAdapter() {
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) {
|
||||
if (future.isSuccess()) {
|
||||
final Channel channel = future.channel();
|
||||
public void channelRead(ChannelHandlerContext ctx, Object msg) {
|
||||
Channel channel = ctx.channel();
|
||||
DnsResponse response = (DnsResponse) msg;
|
||||
int queryId = response.id();
|
||||
|
||||
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
|
||||
channel.eventLoop().newPromise();
|
||||
new TcpDnsQueryContext(DnsNameResolver.this, channel,
|
||||
(InetSocketAddress) channel.remoteAddress(), qCtx.question(),
|
||||
EMPTY_ADDITIONALS, promise).query(true, future.channel().newPromise());
|
||||
promise.addListener(
|
||||
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
|
||||
@Override
|
||||
public void operationComplete(
|
||||
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
|
||||
channel.close();
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId,
|
||||
channel.remoteAddress(), response);
|
||||
}
|
||||
|
||||
if (future.isSuccess()) {
|
||||
qCtx.finish(future.getNow());
|
||||
} else {
|
||||
// TCP fallback failed, just use the truncated response.
|
||||
qCtx.finish(res);
|
||||
}
|
||||
}
|
||||
});
|
||||
DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId);
|
||||
if (foundCtx == tcpCtx) {
|
||||
tcpCtx.finish(new AddressedEnvelopeAdapter(
|
||||
(InetSocketAddress) ctx.channel().remoteAddress(),
|
||||
(InetSocketAddress) ctx.channel().localAddress(),
|
||||
response));
|
||||
} else {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause());
|
||||
}
|
||||
response.release();
|
||||
tcpCtx.tryFailure("Received TCP response with unexpected ID", null, false);
|
||||
logger.warn("{} Received a DNS response with an unexpected ID: {}",
|
||||
channel, queryId);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
|
||||
if (tcpCtx.tryFailure("TCP fallback error", cause, false) && logger.isDebugEnabled()) {
|
||||
logger.debug("{} Error during processing response: TCP [{}: {}]",
|
||||
ctx.channel(), queryId,
|
||||
ctx.channel().remoteAddress(), cause);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
promise.addListener(
|
||||
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
|
||||
@Override
|
||||
public void operationComplete(
|
||||
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
|
||||
channel.close();
|
||||
|
||||
if (future.isSuccess()) {
|
||||
qCtx.finish(future.getNow());
|
||||
res.release();
|
||||
} else {
|
||||
// TCP fallback failed, just use the truncated response.
|
||||
qCtx.finish(res);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
qCtx.finish(res);
|
||||
tcpCtx.query(true, future.channel().newPromise());
|
||||
}
|
||||
} finally {
|
||||
ReferenceCountUtil.safeRelease(msg);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -159,7 +159,7 @@ abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRe
|
||||
|
||||
private void onQueryWriteCompletion(ChannelFuture writeFuture) {
|
||||
if (!writeFuture.isSuccess()) {
|
||||
setFailure("failed to send a query via " + protocol(), writeFuture.cause());
|
||||
tryFailure("failed to send a query via " + protocol(), writeFuture.cause(), false);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -174,40 +174,37 @@ abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRe
|
||||
return;
|
||||
}
|
||||
|
||||
setFailure("query via " + protocol() + " timed out after " +
|
||||
queryTimeoutMillis + " milliseconds", null);
|
||||
tryFailure("query via " + protocol() + " timed out after " +
|
||||
queryTimeoutMillis + " milliseconds", null, true);
|
||||
}
|
||||
}, queryTimeoutMillis, TimeUnit.MILLISECONDS);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes ownership of passed envelope
|
||||
*/
|
||||
void finish(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
|
||||
final DnsResponse res = envelope.content();
|
||||
if (res.count(DnsSection.QUESTION) != 1) {
|
||||
logger.warn("Received a DNS response with invalid number of questions: {}", envelope);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
|
||||
} else if (!question().equals(res.recordAt(DnsSection.QUESTION))) {
|
||||
logger.warn("Received a mismatching DNS response: {}", envelope);
|
||||
return;
|
||||
} else if (trySuccess(envelope)) {
|
||||
return; // Ownership transferred, don't release
|
||||
}
|
||||
|
||||
setSuccess(envelope);
|
||||
envelope.release();
|
||||
}
|
||||
|
||||
private void setSuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
|
||||
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise = this.promise;
|
||||
@SuppressWarnings("unchecked")
|
||||
AddressedEnvelope<DnsResponse, InetSocketAddress> castResponse =
|
||||
(AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope.retain();
|
||||
if (!promise.trySuccess(castResponse)) {
|
||||
// We failed to notify the promise as it was failed before, thus we need to release the envelope
|
||||
envelope.release();
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
private boolean trySuccess(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> envelope) {
|
||||
return promise.trySuccess((AddressedEnvelope<DnsResponse, InetSocketAddress>) envelope);
|
||||
}
|
||||
|
||||
private void setFailure(String message, Throwable cause) {
|
||||
boolean tryFailure(String message, Throwable cause, boolean timeout) {
|
||||
if (promise.isDone()) {
|
||||
return false;
|
||||
}
|
||||
final InetSocketAddress nameServerAddr = nameServerAddr();
|
||||
|
||||
final StringBuilder buf = new StringBuilder(message.length() + 64);
|
||||
@ -218,14 +215,14 @@ abstract class DnsQueryContext implements FutureListener<AddressedEnvelope<DnsRe
|
||||
.append(" (no stack trace available)");
|
||||
|
||||
final DnsNameResolverException e;
|
||||
if (cause == null) {
|
||||
if (timeout) {
|
||||
// This was caused by an timeout so use DnsNameResolverTimeoutException to allow the user to
|
||||
// handle it special (like retry the query).
|
||||
e = new DnsNameResolverTimeoutException(nameServerAddr, question(), buf.toString());
|
||||
} else {
|
||||
e = new DnsNameResolverException(nameServerAddr, question(), buf.toString(), cause);
|
||||
}
|
||||
promise.tryFailure(e);
|
||||
return promise.tryFailure(e);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -71,6 +71,7 @@ import org.junit.Test;
|
||||
import org.junit.rules.ExpectedException;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.DatagramSocket;
|
||||
import java.net.Inet4Address;
|
||||
import java.net.InetAddress;
|
||||
@ -2759,6 +2760,25 @@ public class DnsNameResolverTest {
|
||||
testTruncated0(true, true);
|
||||
}
|
||||
|
||||
private static DnsMessageModifier modifierFrom(DnsMessage message) {
|
||||
DnsMessageModifier modifier = new DnsMessageModifier();
|
||||
modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData());
|
||||
modifier.setAdditionalRecords(message.getAdditionalRecords());
|
||||
modifier.setAnswerRecords(message.getAnswerRecords());
|
||||
modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer());
|
||||
modifier.setAuthorityRecords(message.getAuthorityRecords());
|
||||
modifier.setMessageType(message.getMessageType());
|
||||
modifier.setOpCode(message.getOpCode());
|
||||
modifier.setQuestionRecords(message.getQuestionRecords());
|
||||
modifier.setRecursionAvailable(message.isRecursionAvailable());
|
||||
modifier.setRecursionDesired(message.isRecursionDesired());
|
||||
modifier.setReserved(message.isReserved());
|
||||
modifier.setResponseCode(message.getResponseCode());
|
||||
modifier.setTransactionId(message.getTransactionId());
|
||||
modifier.setTruncated(message.isTruncated());
|
||||
return modifier;
|
||||
}
|
||||
|
||||
private static void testTruncated0(boolean tcpFallback, final boolean truncatedBecauseOfMtu) throws IOException {
|
||||
final String host = "somehost.netty.io";
|
||||
final String txt = "this is a txt record";
|
||||
@ -2784,20 +2804,7 @@ public class DnsNameResolverTest {
|
||||
|
||||
if (!truncatedBecauseOfMtu) {
|
||||
// Create a copy of the message but set the truncated flag.
|
||||
DnsMessageModifier modifier = new DnsMessageModifier();
|
||||
modifier.setAcceptNonAuthenticatedData(message.isAcceptNonAuthenticatedData());
|
||||
modifier.setAdditionalRecords(message.getAdditionalRecords());
|
||||
modifier.setAnswerRecords(message.getAnswerRecords());
|
||||
modifier.setAuthoritativeAnswer(message.isAuthoritativeAnswer());
|
||||
modifier.setAuthorityRecords(message.getAuthorityRecords());
|
||||
modifier.setMessageType(message.getMessageType());
|
||||
modifier.setOpCode(message.getOpCode());
|
||||
modifier.setQuestionRecords(message.getQuestionRecords());
|
||||
modifier.setRecursionAvailable(message.isRecursionAvailable());
|
||||
modifier.setRecursionDesired(message.isRecursionDesired());
|
||||
modifier.setReserved(message.isReserved());
|
||||
modifier.setResponseCode(message.getResponseCode());
|
||||
modifier.setTransactionId(message.getTransactionId());
|
||||
DnsMessageModifier modifier = modifierFrom(message);
|
||||
modifier.setTruncated(true);
|
||||
return modifier.getDnsMessage();
|
||||
}
|
||||
@ -2842,8 +2849,15 @@ public class DnsNameResolverTest {
|
||||
// If we are configured to use TCP as a fallback lets replay the dns message over TCP
|
||||
Socket socket = serverSocket.accept();
|
||||
|
||||
InputStream in = socket.getInputStream();
|
||||
assertTrue((in.read() << 8 | (in.read() & 0xff)) > 2); // skip length field
|
||||
int txnId = in.read() << 8 | (in.read() & 0xff);
|
||||
|
||||
IoBuffer ioBuffer = IoBuffer.allocate(1024);
|
||||
new DnsMessageEncoder().encode(ioBuffer, messageRef.get());
|
||||
// Must replace the transactionId with the one from the TCP request
|
||||
DnsMessageModifier modifier = modifierFrom(messageRef.get());
|
||||
modifier.setTransactionId(txnId);
|
||||
new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage());
|
||||
ioBuffer.flip();
|
||||
|
||||
ByteBuffer lenBuffer = ByteBuffer.allocate(2);
|
||||
@ -2884,7 +2898,7 @@ public class DnsNameResolverTest {
|
||||
} else {
|
||||
assertTrue(envelope.content().isTruncated());
|
||||
}
|
||||
envelope.release();
|
||||
assertTrue(envelope.release());
|
||||
} finally {
|
||||
dnsServer2.stop();
|
||||
if (resolver != null) {
|
||||
|
Loading…
Reference in New Issue
Block a user