Fix IllegalReferenceCountException in DnsNameResolver

Related: #3797

Motivation:

There is a race condition where DnsNameResolver.query() can attempt to
increase the reference count of the DNS response which was released
already by other thread.

Modifications:

- Make DnsCacheEntry a top-level class for clear access control
- Use 'synchronized' to avoid the race condition
  - Add DnsCacheEntry.retainedResponse() to make sure that the response
    is never released while it is retained
  - Make retainedResponse() return null when the response has been
    released already, so that DnsNameResolver.query() knows that the
    cached entry has been released

Result:

The forementioned race condition has been fixed.
This commit is contained in:
Trustin Lee 2015-06-03 19:12:25 +09:00
parent 09ecc34924
commit 311532feb0
3 changed files with 128 additions and 51 deletions

View File

@ -0,0 +1,88 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.resolver.dns;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.OneTimeTask;
import io.netty.util.internal.PlatformDependent;
import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
final class DnsCacheEntry {
private enum State {
INIT,
SCHEDULED_EXPIRATION,
RELEASED
}
private final AddressedEnvelope<DnsResponse, InetSocketAddress> response;
private final Throwable cause;
private volatile ScheduledFuture<?> expirationFuture;
private boolean released;
@SuppressWarnings("unchecked")
DnsCacheEntry(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> response) {
this.response = (AddressedEnvelope<DnsResponse, InetSocketAddress>) response.retain();
cause = null;
}
DnsCacheEntry(Throwable cause) {
this.cause = cause;
response = null;
}
Throwable cause() {
return cause;
}
synchronized AddressedEnvelope<DnsResponse, InetSocketAddress> retainedResponse() {
if (released) {
// Released by other thread via either the expiration task or clearCache()
return null;
}
return response.retain();
}
void scheduleExpiration(EventLoop loop, Runnable task, long delay, TimeUnit unit) {
assert expirationFuture == null: "expiration task scheduled already";
expirationFuture = loop.schedule(task, delay, unit);
}
void release() {
synchronized (this) {
if (released) {
return;
}
released = true;
ReferenceCountUtil.safeRelease(response);
}
ScheduledFuture<?> expirationFuture = this.expirationFuture;
if (expirationFuture != null) {
expirationFuture.cancel(false);
}
}
}

View File

@ -685,16 +685,19 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
final EventLoop eventLoop = ch.eventLoop(); final EventLoop eventLoop = ch.eventLoop();
final DnsCacheEntry cachedResult = queryCache.get(question); final DnsCacheEntry cachedResult = queryCache.get(question);
if (cachedResult != null) { if (cachedResult != null) {
if (cachedResult.response != null) { AddressedEnvelope<DnsResponse, InetSocketAddress> response = cachedResult.retainedResponse();
return eventLoop.newSucceededFuture(cachedResult.response.retain()); if (response != null) {
return eventLoop.newSucceededFuture(response);
} else { } else {
return eventLoop.newFailedFuture(cachedResult.cause); Throwable cause = cachedResult.cause();
if (cause != null) {
return eventLoop.newFailedFuture(cause);
}
} }
} else {
return query0(
nameServerAddresses, question,
eventLoop.<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
} }
return query0(
nameServerAddresses, question,
eventLoop.<AddressedEnvelope<? extends DnsResponse, InetSocketAddress>>newPromise());
} }
/** /**
@ -716,14 +719,18 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
final DnsCacheEntry cachedResult = queryCache.get(question); final DnsCacheEntry cachedResult = queryCache.get(question);
if (cachedResult != null) { if (cachedResult != null) {
if (cachedResult.response != null) { AddressedEnvelope<DnsResponse, InetSocketAddress> response = cachedResult.retainedResponse();
return cast(promise).setSuccess(cachedResult.response.retain()); if (response != null) {
return cast(promise).setSuccess(response);
} else { } else {
return cast(promise).setFailure(cachedResult.cause); Throwable cause = cachedResult.cause();
if (cause != null) {
return cast(promise).setFailure(cause);
}
} }
} else {
return query0(nameServerAddresses, question, promise);
} }
return query0(nameServerAddresses, question, promise);
} }
private Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0( private Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
@ -739,7 +746,16 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
} }
} }
void cache(final DnsQuestion question, DnsCacheEntry entry, long delaySeconds) { void cacheSuccess(
DnsQuestion question, AddressedEnvelope<? extends DnsResponse, InetSocketAddress> res, long delaySeconds) {
cache(question, new DnsCacheEntry(res), delaySeconds);
}
void cacheFailure(DnsQuestion question, Throwable cause, long delaySeconds) {
cache(question, new DnsCacheEntry(cause), delaySeconds);
}
private void cache(final DnsQuestion question, DnsCacheEntry entry, long delaySeconds) {
DnsCacheEntry oldEntry = queryCache.put(question, entry); DnsCacheEntry oldEntry = queryCache.put(question, entry);
if (oldEntry != null) { if (oldEntry != null) {
oldEntry.release(); oldEntry.release();
@ -747,13 +763,15 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
boolean scheduled = false; boolean scheduled = false;
try { try {
entry.expirationFuture = ch.eventLoop().schedule(new OneTimeTask() { entry.scheduleExpiration(
@Override ch.eventLoop(),
public void run() { new OneTimeTask() {
clearCache(question); @Override
} public void run() {
}, delaySeconds, TimeUnit.SECONDS); clearCache(question);
}
},
delaySeconds, TimeUnit.SECONDS);
scheduled = true; scheduled = true;
} finally { } finally {
if (!scheduled) { if (!scheduled) {
@ -852,7 +870,7 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
// Ensure that the found TTL is between minTtl and maxTtl. // Ensure that the found TTL is between minTtl and maxTtl.
ttl = Math.max(minTtl(), Math.min(maxTtl, ttl)); ttl = Math.max(minTtl(), Math.min(maxTtl, ttl));
DnsNameResolver.this.cache(question, new DnsCacheEntry(res), ttl); DnsNameResolver.this.cacheSuccess(question, res, ttl);
} }
@Override @Override
@ -861,32 +879,4 @@ public class DnsNameResolver extends SimpleNameResolver<InetSocketAddress> {
} }
} }
static final class DnsCacheEntry {
final AddressedEnvelope<DnsResponse, InetSocketAddress> response;
final Throwable cause;
volatile ScheduledFuture<?> expirationFuture;
@SuppressWarnings("unchecked")
DnsCacheEntry(AddressedEnvelope<? extends DnsResponse, InetSocketAddress> response) {
this.response = (AddressedEnvelope<DnsResponse, InetSocketAddress>) response.retain();
cause = null;
}
DnsCacheEntry(Throwable cause) {
this.cause = cause;
response = null;
}
void release() {
AddressedEnvelope<DnsResponse, InetSocketAddress> response = this.response;
if (response != null) {
ReferenceCountUtil.safeRelease(response);
}
ScheduledFuture<?> expirationFuture = this.expirationFuture;
if (expirationFuture != null) {
expirationFuture.cancel(false);
}
}
}
} }

View File

@ -28,7 +28,6 @@ import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse; import io.netty.handler.codec.dns.DnsResponse;
import io.netty.resolver.dns.DnsNameResolver.DnsCacheEntry;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture; import io.netty.util.concurrent.ScheduledFuture;
import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.OneTimeTask;
@ -219,6 +218,6 @@ final class DnsQueryContext {
return; return;
} }
parent.cache(question, new DnsCacheEntry(cause), negativeTtl); parent.cacheFailure(question, cause, negativeTtl);
} }
} }