Fixed a race condition in MemoryAwareThreadPoolExecutor

Replaced a semaphore with a custom concurrency construct to fix a
known race condition in MemoryAwareThreadPoolExecutor
This commit is contained in:
Trustin Lee 2011-01-13 14:56:38 +09:00
parent 9f55834823
commit 23f33629ca

View File

@ -21,7 +21,6 @@ import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
@ -148,9 +147,7 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
private final ConcurrentMap<Channel, AtomicLong> channelCounters =
new ConcurrentIdentityHashMap<Channel, AtomicLong>();
private final AtomicLong totalCounter = new AtomicLong();
private final Semaphore semaphore = new Semaphore(0);
private final Limiter totalLimiter;
/**
* Creates a new instance.
@ -250,7 +247,13 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
}
settings = new Settings(
objectSizeEstimator, maxChannelMemorySize, maxTotalMemorySize);
objectSizeEstimator, maxChannelMemorySize);
if (maxTotalMemorySize == 0) {
totalLimiter = null;
} else {
totalLimiter = new Limiter(maxTotalMemorySize);
}
// Misuse check
misuseDetector.increase();
@ -279,7 +282,7 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
settings = new Settings(
objectSizeEstimator,
settings.maxChannelMemorySize, settings.maxTotalMemorySize);
settings.maxChannelMemorySize);
}
/**
@ -306,20 +309,20 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
settings = new Settings(
settings.objectSizeEstimator,
maxChannelMemorySize, settings.maxTotalMemorySize);
maxChannelMemorySize);
}
/**
* Returns the maximum total size of the queued events for this pool.
*/
public long getMaxTotalMemorySize() {
return settings.maxTotalMemorySize;
return totalLimiter.limit;
}
/**
* Sets the maximum total size of the queued events for this pool.
* Specify {@code 0} to disable.
* @deprecated <tt>maxTotalMemorySize</tt> is not modifiable anymore.
*/
@Deprecated
public void setMaxTotalMemorySize(long maxTotalMemorySize) {
if (maxTotalMemorySize < 0) {
throw new IllegalArgumentException(
@ -330,10 +333,6 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
throw new IllegalStateException(
"can't be changed after a task is executed");
}
settings = new Settings(
settings.objectSizeEstimator,
settings.maxChannelMemorySize, maxTotalMemorySize);
}
@Override
@ -342,12 +341,8 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
command = new MemoryAwareRunnable(command);
}
boolean pause = increaseCounter(command);
increaseCounter(command);
doExecute(command);
if (pause) {
//System.out.println("ACQUIRE: " + command);
semaphore.acquireUninterruptibly();
}
}
/**
@ -380,17 +375,15 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
decreaseCounter(r);
}
protected boolean increaseCounter(Runnable task) {
protected void increaseCounter(Runnable task) {
if (!shouldCount(task)) {
return false;
return;
}
Settings settings = this.settings;
long maxTotalMemorySize = settings.maxTotalMemorySize;
long maxChannelMemorySize = settings.maxChannelMemorySize;
int increment = settings.objectSizeEstimator.estimateSize(task);
long totalCounter = this.totalCounter.addAndGet(increment);
if (task instanceof ChannelEventRunnable) {
ChannelEventRunnable eventTask = (ChannelEventRunnable) task;
@ -413,8 +406,9 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
((MemoryAwareRunnable) task).estimatedSize = increment;
}
//System.out.println("I: " + totalCounter + ", " + increment);
return maxTotalMemorySize != 0 && totalCounter >= maxTotalMemorySize;
if (totalLimiter != null) {
totalLimiter.increase(increment);
}
}
protected void decreaseCounter(Runnable task) {
@ -423,7 +417,6 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
}
Settings settings = this.settings;
long maxTotalMemorySize = settings.maxTotalMemorySize;
long maxChannelMemorySize = settings.maxChannelMemorySize;
int increment;
@ -433,14 +426,8 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
increment = ((MemoryAwareRunnable) task).estimatedSize;
}
long totalCounter = this.totalCounter.addAndGet(-increment);
//System.out.println("D: " + totalCounter + ", " + increment);
if (maxTotalMemorySize != 0 && totalCounter + increment >= maxTotalMemorySize) {
//System.out.println("RELEASE: " + task);
while (semaphore.hasQueuedThreads()) {
semaphore.release();
}
if (totalLimiter != null) {
totalLimiter.decrease(increment);
}
if (task instanceof ChannelEventRunnable) {
@ -503,13 +490,11 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
private static final class Settings {
final ObjectSizeEstimator objectSizeEstimator;
final long maxChannelMemorySize;
final long maxTotalMemorySize;
Settings(ObjectSizeEstimator objectSizeEstimator,
long maxChannelMemorySize, long maxTotalMemorySize) {
long maxChannelMemorySize) {
this.objectSizeEstimator = objectSizeEstimator;
this.maxChannelMemorySize = maxChannelMemorySize;
this.maxTotalMemorySize = maxTotalMemorySize;
}
}
@ -543,4 +528,38 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor {
task.run();
}
}
private static class Limiter {
final long limit;
private long counter;
private int waiters;
Limiter(long limit) {
super();
this.limit = limit;
}
synchronized void increase(long amount) {
while (counter >= limit) {
waiters ++;
try {
wait();
} catch (InterruptedException e) {
// Ignore
} finally {
waiters --;
}
}
counter += amount;
}
synchronized void decrease(long amount) {
counter -= amount;
if (counter < limit && waiters > 0) {
notifyAll();
}
}
}
}