From a84bd61631efd0ca8b0f08bebc4d9a380fd6d7b1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 14 Oct 2019 22:10:15 +0800 Subject: [PATCH] Fix DnsNameResolver TCP fallback test and message leaks (#9647) Motivation A memory leak related to DNS resolution was reported in #9634, specifically linked to the TCP retry fallback functionality that was introduced relatively recently. Upon inspection it's apparent that there are some error paths where the original UDP response might not be fully released, and more significantly the TCP response actually leaks every time on the fallback success path. It turns out that a bug in the unit test meant that the intended TCP fallback path was not actually exercised, so it did not expose the main leak in question. Modifications - Fix DnsNameResolverTest#testTruncated0 dummy server fallback logic to first read transaction id of retried query and use it in replayed response - Adjust semantic of internal DnsQueryContext#finish method to always take refcount ownership of passed in envelope - Reorder some logic in DnsResponseHandler fallback handling to verify the context of the response is expected, and ensure that the query response are either released or propagated in all cases. This also reduces a number of redundant retain/release pairings Result Fixes #9634 --- .../netty/resolver/dns/DnsNameResolver.java | 189 ++++++++---------- .../netty/resolver/dns/DnsQueryContext.java | 41 ++-- .../resolver/dns/DnsNameResolverTest.java | 46 +++-- 3 files changed, 132 insertions(+), 144 deletions(-) diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 7dbc148f80..444d92827c 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -51,7 +51,6 @@ import io.netty.resolver.HostsFileEntriesResolver; import io.netty.resolver.InetNameResolver; import io.netty.resolver.ResolvedAddressTypes; import io.netty.util.NetUtil; -import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.Future; @@ -1190,129 +1189,107 @@ public class DnsNameResolver extends InetNameResolver { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { - try { - final DatagramDnsResponse res = (DatagramDnsResponse) msg; - final int queryId = res.id(); + final DatagramDnsResponse res = (DatagramDnsResponse) msg; + final int queryId = res.id(); - if (logger.isDebugEnabled()) { - logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res); - } + if (logger.isDebugEnabled()) { + logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res); + } - final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId); - if (qCtx == null) { - logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId); - return; - } + final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId); + if (qCtx == null) { + logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId); + res.release(); + return; + } - // Check if the response was truncated and if we can fallback to TCP to retry. - if (res.isTruncated() && socketChannelFactory != null) { - // Let's retain as we may need it later on. - res.retain(); + // Check if the response was truncated and if we can fallback to TCP to retry. + if (!res.isTruncated() || socketChannelFactory == null) { + qCtx.finish(res); + return; + } - Bootstrap bs = new Bootstrap(); - bs.option(ChannelOption.SO_REUSEADDR, true) - .group(executor()) - .channelFactory(socketChannelFactory) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) { - ch.pipeline().addLast(TCP_ENCODER); - ch.pipeline().addLast(new TcpDnsResponseDecoder()); - ch.pipeline().addLast(new ChannelHandler() { - private boolean finish; + Bootstrap bs = new Bootstrap(); + bs.option(ChannelOption.SO_REUSEADDR, true) + .group(executor()) + .channelFactory(socketChannelFactory) + .handler(TCP_ENCODER); + bs.connect(res.sender()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + if (logger.isDebugEnabled()) { + logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause()); + } - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - try { - Channel channel = ctx.channel(); - DnsResponse response = (DnsResponse) msg; - int queryId = response.id(); + // TCP fallback failed, just use the truncated response. + qCtx.finish(res); + return; + } + final Channel channel = future.channel(); - if (logger.isDebugEnabled()) { - logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId, - channel.remoteAddress(), response); - } + Promise> promise = + channel.eventLoop().newPromise(); + final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(DnsNameResolver.this, channel, + (InetSocketAddress) channel.remoteAddress(), qCtx.question(), + EMPTY_ADDITIONALS, promise); - DnsQueryContext tcpCtx = queryContextManager.get(res.sender(), queryId); - if (tcpCtx == null) { - logger.warn("{} Received a DNS response with an unknown ID: {}", - channel, queryId); - qCtx.finish(res); - return; - } - - // Release the original response as we will use the response that we - // received via TCP fallback. - res.release(); - - tcpCtx.finish(new AddressedEnvelopeAdapter( - (InetSocketAddress) ctx.channel().remoteAddress(), - (InetSocketAddress) ctx.channel().localAddress(), - response)); - - finish = true; - } finally { - ReferenceCountUtil.release(msg); - } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - if (!finish) { - if (logger.isDebugEnabled()) { - logger.debug("{} Error during processing response: TCP [{}: {}]", - ctx.channel(), queryId, - ctx.channel().remoteAddress(), cause); - } - // TCP fallback failed, just use the truncated response as - qCtx.finish(res); - } - } - }); - } - }); - bs.connect(res.sender()).addListener(new ChannelFutureListener() { + channel.pipeline().addLast(new TcpDnsResponseDecoder()); + channel.pipeline().addLast(new ChannelHandler() { @Override - public void operationComplete(ChannelFuture future) { - if (future.isSuccess()) { - final Channel channel = future.channel(); + public void channelRead(ChannelHandlerContext ctx, Object msg) { + Channel channel = ctx.channel(); + DnsResponse response = (DnsResponse) msg; + int queryId = response.id(); - Promise> promise = - channel.eventLoop().newPromise(); - new TcpDnsQueryContext(DnsNameResolver.this, channel, - (InetSocketAddress) channel.remoteAddress(), qCtx.question(), - EMPTY_ADDITIONALS, promise).query(true, future.channel().newPromise()); - promise.addListener( - new FutureListener>() { - @Override - public void operationComplete( - Future> future) { - channel.close(); + if (logger.isDebugEnabled()) { + logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId, + channel.remoteAddress(), response); + } - if (future.isSuccess()) { - qCtx.finish(future.getNow()); - } else { - // TCP fallback failed, just use the truncated response. - qCtx.finish(res); - } - } - }); + DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId); + if (foundCtx == tcpCtx) { + tcpCtx.finish(new AddressedEnvelopeAdapter( + (InetSocketAddress) ctx.channel().remoteAddress(), + (InetSocketAddress) ctx.channel().localAddress(), + response)); } else { - if (logger.isDebugEnabled()) { - logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause()); - } + response.release(); + tcpCtx.tryFailure("Received TCP response with unexpected ID", null, false); + logger.warn("{} Received a DNS response with an unexpected ID: {}", + channel, queryId); + } + } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (tcpCtx.tryFailure("TCP fallback error", cause, false) && logger.isDebugEnabled()) { + logger.debug("{} Error during processing response: TCP [{}: {}]", + ctx.channel(), queryId, + ctx.channel().remoteAddress(), cause); + } + } + }); + + promise.addListener( + new FutureListener>() { + @Override + public void operationComplete( + Future> future) { + channel.close(); + + if (future.isSuccess()) { + qCtx.finish(future.getNow()); + res.release(); + } else { // TCP fallback failed, just use the truncated response. qCtx.finish(res); } } }); - } else { - qCtx.finish(res); + tcpCtx.query(true, future.channel().newPromise()); } - } finally { - ReferenceCountUtil.safeRelease(msg); - } + }); } @Override diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index 8fa497f2fc..d2f709d1fd 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -150,7 +150,7 @@ abstract class DnsQueryContext implements FutureListener envelope) { final DnsResponse res = envelope.content(); if (res.count(DnsSection.QUESTION) != 1) { logger.warn("Received a DNS response with invalid number of questions: {}", envelope); - return; - } - - if (!question().equals(res.recordAt(DnsSection.QUESTION))) { + } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) { logger.warn("Received a mismatching DNS response: {}", envelope); - return; + } else if (trySuccess(envelope)) { + return; // Ownership transferred, don't release } - - setSuccess(envelope); + envelope.release(); } - private void setSuccess(AddressedEnvelope envelope) { - Promise> promise = this.promise; - @SuppressWarnings("unchecked") - AddressedEnvelope castResponse = - (AddressedEnvelope) envelope.retain(); - if (!promise.trySuccess(castResponse)) { - // We failed to notify the promise as it was failed before, thus we need to release the envelope - envelope.release(); - } + @SuppressWarnings("unchecked") + private boolean trySuccess(AddressedEnvelope envelope) { + return promise.trySuccess((AddressedEnvelope) envelope); } - private void setFailure(String message, Throwable cause) { + boolean tryFailure(String message, Throwable cause, boolean timeout) { + if (promise.isDone()) { + return false; + } final InetSocketAddress nameServerAddr = nameServerAddr(); final StringBuilder buf = new StringBuilder(message.length() + 64); @@ -206,14 +203,14 @@ abstract class DnsQueryContext implements FutureListener 2); // skip length field + int txnId = in.read() << 8 | (in.read() & 0xff); + IoBuffer ioBuffer = IoBuffer.allocate(1024); - new DnsMessageEncoder().encode(ioBuffer, messageRef.get()); + // Must replace the transactionId with the one from the TCP request + DnsMessageModifier modifier = modifierFrom(messageRef.get()); + modifier.setTransactionId(txnId); + new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage()); ioBuffer.flip(); ByteBuffer lenBuffer = ByteBuffer.allocate(2); @@ -2856,7 +2870,7 @@ public class DnsNameResolverTest { } else { assertTrue(envelope.content().isTruncated()); } - envelope.release(); + assertTrue(envelope.release()); } finally { dnsServer2.stop(); if (resolver != null) {