diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 7dbc148f80..444d92827c 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -51,7 +51,6 @@ import io.netty.resolver.HostsFileEntriesResolver; import io.netty.resolver.InetNameResolver; import io.netty.resolver.ResolvedAddressTypes; import io.netty.util.NetUtil; -import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.Future; @@ -1190,129 +1189,107 @@ public class DnsNameResolver extends InetNameResolver { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { - try { - final DatagramDnsResponse res = (DatagramDnsResponse) msg; - final int queryId = res.id(); + final DatagramDnsResponse res = (DatagramDnsResponse) msg; + final int queryId = res.id(); - if (logger.isDebugEnabled()) { - logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res); - } + if (logger.isDebugEnabled()) { + logger.debug("{} RECEIVED: UDP [{}: {}], {}", ch, queryId, res.sender(), res); + } - final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId); - if (qCtx == null) { - logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId); - return; - } + final DnsQueryContext qCtx = queryContextManager.get(res.sender(), queryId); + if (qCtx == null) { + logger.warn("{} Received a DNS response with an unknown ID: {}", ch, queryId); + res.release(); + return; + } - // Check if the response was truncated and if we can fallback to TCP to retry. - if (res.isTruncated() && socketChannelFactory != null) { - // Let's retain as we may need it later on. - res.retain(); + // Check if the response was truncated and if we can fallback to TCP to retry. + if (!res.isTruncated() || socketChannelFactory == null) { + qCtx.finish(res); + return; + } - Bootstrap bs = new Bootstrap(); - bs.option(ChannelOption.SO_REUSEADDR, true) - .group(executor()) - .channelFactory(socketChannelFactory) - .handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) { - ch.pipeline().addLast(TCP_ENCODER); - ch.pipeline().addLast(new TcpDnsResponseDecoder()); - ch.pipeline().addLast(new ChannelHandler() { - private boolean finish; + Bootstrap bs = new Bootstrap(); + bs.option(ChannelOption.SO_REUSEADDR, true) + .group(executor()) + .channelFactory(socketChannelFactory) + .handler(TCP_ENCODER); + bs.connect(res.sender()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + if (logger.isDebugEnabled()) { + logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause()); + } - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - try { - Channel channel = ctx.channel(); - DnsResponse response = (DnsResponse) msg; - int queryId = response.id(); + // TCP fallback failed, just use the truncated response. + qCtx.finish(res); + return; + } + final Channel channel = future.channel(); - if (logger.isDebugEnabled()) { - logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId, - channel.remoteAddress(), response); - } + Promise> promise = + channel.eventLoop().newPromise(); + final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(DnsNameResolver.this, channel, + (InetSocketAddress) channel.remoteAddress(), qCtx.question(), + EMPTY_ADDITIONALS, promise); - DnsQueryContext tcpCtx = queryContextManager.get(res.sender(), queryId); - if (tcpCtx == null) { - logger.warn("{} Received a DNS response with an unknown ID: {}", - channel, queryId); - qCtx.finish(res); - return; - } - - // Release the original response as we will use the response that we - // received via TCP fallback. - res.release(); - - tcpCtx.finish(new AddressedEnvelopeAdapter( - (InetSocketAddress) ctx.channel().remoteAddress(), - (InetSocketAddress) ctx.channel().localAddress(), - response)); - - finish = true; - } finally { - ReferenceCountUtil.release(msg); - } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - if (!finish) { - if (logger.isDebugEnabled()) { - logger.debug("{} Error during processing response: TCP [{}: {}]", - ctx.channel(), queryId, - ctx.channel().remoteAddress(), cause); - } - // TCP fallback failed, just use the truncated response as - qCtx.finish(res); - } - } - }); - } - }); - bs.connect(res.sender()).addListener(new ChannelFutureListener() { + channel.pipeline().addLast(new TcpDnsResponseDecoder()); + channel.pipeline().addLast(new ChannelHandler() { @Override - public void operationComplete(ChannelFuture future) { - if (future.isSuccess()) { - final Channel channel = future.channel(); + public void channelRead(ChannelHandlerContext ctx, Object msg) { + Channel channel = ctx.channel(); + DnsResponse response = (DnsResponse) msg; + int queryId = response.id(); - Promise> promise = - channel.eventLoop().newPromise(); - new TcpDnsQueryContext(DnsNameResolver.this, channel, - (InetSocketAddress) channel.remoteAddress(), qCtx.question(), - EMPTY_ADDITIONALS, promise).query(true, future.channel().newPromise()); - promise.addListener( - new FutureListener>() { - @Override - public void operationComplete( - Future> future) { - channel.close(); + if (logger.isDebugEnabled()) { + logger.debug("{} RECEIVED: TCP [{}: {}], {}", channel, queryId, + channel.remoteAddress(), response); + } - if (future.isSuccess()) { - qCtx.finish(future.getNow()); - } else { - // TCP fallback failed, just use the truncated response. - qCtx.finish(res); - } - } - }); + DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId); + if (foundCtx == tcpCtx) { + tcpCtx.finish(new AddressedEnvelopeAdapter( + (InetSocketAddress) ctx.channel().remoteAddress(), + (InetSocketAddress) ctx.channel().localAddress(), + response)); } else { - if (logger.isDebugEnabled()) { - logger.debug("{} Unable to fallback to TCP [{}]", queryId, future.cause()); - } + response.release(); + tcpCtx.tryFailure("Received TCP response with unexpected ID", null, false); + logger.warn("{} Received a DNS response with an unexpected ID: {}", + channel, queryId); + } + } + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (tcpCtx.tryFailure("TCP fallback error", cause, false) && logger.isDebugEnabled()) { + logger.debug("{} Error during processing response: TCP [{}: {}]", + ctx.channel(), queryId, + ctx.channel().remoteAddress(), cause); + } + } + }); + + promise.addListener( + new FutureListener>() { + @Override + public void operationComplete( + Future> future) { + channel.close(); + + if (future.isSuccess()) { + qCtx.finish(future.getNow()); + res.release(); + } else { // TCP fallback failed, just use the truncated response. qCtx.finish(res); } } }); - } else { - qCtx.finish(res); + tcpCtx.query(true, future.channel().newPromise()); } - } finally { - ReferenceCountUtil.safeRelease(msg); - } + }); } @Override diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index 8fa497f2fc..d2f709d1fd 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -150,7 +150,7 @@ abstract class DnsQueryContext implements FutureListener envelope) { final DnsResponse res = envelope.content(); if (res.count(DnsSection.QUESTION) != 1) { logger.warn("Received a DNS response with invalid number of questions: {}", envelope); - return; - } - - if (!question().equals(res.recordAt(DnsSection.QUESTION))) { + } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) { logger.warn("Received a mismatching DNS response: {}", envelope); - return; + } else if (trySuccess(envelope)) { + return; // Ownership transferred, don't release } - - setSuccess(envelope); + envelope.release(); } - private void setSuccess(AddressedEnvelope envelope) { - Promise> promise = this.promise; - @SuppressWarnings("unchecked") - AddressedEnvelope castResponse = - (AddressedEnvelope) envelope.retain(); - if (!promise.trySuccess(castResponse)) { - // We failed to notify the promise as it was failed before, thus we need to release the envelope - envelope.release(); - } + @SuppressWarnings("unchecked") + private boolean trySuccess(AddressedEnvelope envelope) { + return promise.trySuccess((AddressedEnvelope) envelope); } - private void setFailure(String message, Throwable cause) { + boolean tryFailure(String message, Throwable cause, boolean timeout) { + if (promise.isDone()) { + return false; + } final InetSocketAddress nameServerAddr = nameServerAddr(); final StringBuilder buf = new StringBuilder(message.length() + 64); @@ -206,14 +203,14 @@ abstract class DnsQueryContext implements FutureListener 2); // skip length field + int txnId = in.read() << 8 | (in.read() & 0xff); + IoBuffer ioBuffer = IoBuffer.allocate(1024); - new DnsMessageEncoder().encode(ioBuffer, messageRef.get()); + // Must replace the transactionId with the one from the TCP request + DnsMessageModifier modifier = modifierFrom(messageRef.get()); + modifier.setTransactionId(txnId); + new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage()); ioBuffer.flip(); ByteBuffer lenBuffer = ByteBuffer.allocate(2); @@ -2856,7 +2870,7 @@ public class DnsNameResolverTest { } else { assertTrue(envelope.content().isTruncated()); } - envelope.release(); + assertTrue(envelope.release()); } finally { dnsServer2.stop(); if (resolver != null) {