From d1f05ea4e7b480b072a8f528542fae4a9f017661 Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Thu, 13 Jan 2011 14:56:38 +0900 Subject: [PATCH] Fixed a race condition in MemoryAwareThreadPoolExecutor Replaced a semaphore with a custom concurrency construct to fix a known race condition in MemoryAwareThreadPoolExecutor --- .../MemoryAwareThreadPoolExecutor.java | 93 +++++++++++-------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/jboss/netty/handler/execution/MemoryAwareThreadPoolExecutor.java b/src/main/java/org/jboss/netty/handler/execution/MemoryAwareThreadPoolExecutor.java index 4376a983aa..af9287f752 100644 --- a/src/main/java/org/jboss/netty/handler/execution/MemoryAwareThreadPoolExecutor.java +++ b/src/main/java/org/jboss/netty/handler/execution/MemoryAwareThreadPoolExecutor.java @@ -21,7 +21,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.RejectedExecutionHandler; -import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -148,9 +147,7 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { private final ConcurrentMap channelCounters = new ConcurrentIdentityHashMap(); - private final AtomicLong totalCounter = new AtomicLong(); - - private final Semaphore semaphore = new Semaphore(0); + private final Limiter totalLimiter; /** * Creates a new instance. @@ -250,7 +247,13 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { } settings = new Settings( - objectSizeEstimator, maxChannelMemorySize, maxTotalMemorySize); + objectSizeEstimator, maxChannelMemorySize); + + if (maxTotalMemorySize == 0) { + totalLimiter = null; + } else { + totalLimiter = new Limiter(maxTotalMemorySize); + } // Misuse check misuseDetector.increase(); @@ -279,7 +282,7 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { settings = new Settings( objectSizeEstimator, - settings.maxChannelMemorySize, settings.maxTotalMemorySize); + settings.maxChannelMemorySize); } /** @@ -306,20 +309,20 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { settings = new Settings( settings.objectSizeEstimator, - maxChannelMemorySize, settings.maxTotalMemorySize); + maxChannelMemorySize); } /** * Returns the maximum total size of the queued events for this pool. */ public long getMaxTotalMemorySize() { - return settings.maxTotalMemorySize; + return totalLimiter.limit; } /** - * Sets the maximum total size of the queued events for this pool. - * Specify {@code 0} to disable. + * @deprecated maxTotalMemorySize is not modifiable anymore. */ + @Deprecated public void setMaxTotalMemorySize(long maxTotalMemorySize) { if (maxTotalMemorySize < 0) { throw new IllegalArgumentException( @@ -330,10 +333,6 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { throw new IllegalStateException( "can't be changed after a task is executed"); } - - settings = new Settings( - settings.objectSizeEstimator, - settings.maxChannelMemorySize, maxTotalMemorySize); } @Override @@ -342,12 +341,8 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { command = new MemoryAwareRunnable(command); } - boolean pause = increaseCounter(command); + increaseCounter(command); doExecute(command); - if (pause) { - //System.out.println("ACQUIRE: " + command); - semaphore.acquireUninterruptibly(); - } } /** @@ -380,17 +375,15 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { decreaseCounter(r); } - protected boolean increaseCounter(Runnable task) { + protected void increaseCounter(Runnable task) { if (!shouldCount(task)) { - return false; + return; } Settings settings = this.settings; - long maxTotalMemorySize = settings.maxTotalMemorySize; long maxChannelMemorySize = settings.maxChannelMemorySize; int increment = settings.objectSizeEstimator.estimateSize(task); - long totalCounter = this.totalCounter.addAndGet(increment); if (task instanceof ChannelEventRunnable) { ChannelEventRunnable eventTask = (ChannelEventRunnable) task; @@ -413,8 +406,9 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { ((MemoryAwareRunnable) task).estimatedSize = increment; } - //System.out.println("I: " + totalCounter + ", " + increment); - return maxTotalMemorySize != 0 && totalCounter >= maxTotalMemorySize; + if (totalLimiter != null) { + totalLimiter.increase(increment); + } } protected void decreaseCounter(Runnable task) { @@ -423,7 +417,6 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { } Settings settings = this.settings; - long maxTotalMemorySize = settings.maxTotalMemorySize; long maxChannelMemorySize = settings.maxChannelMemorySize; int increment; @@ -433,14 +426,8 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { increment = ((MemoryAwareRunnable) task).estimatedSize; } - long totalCounter = this.totalCounter.addAndGet(-increment); - - //System.out.println("D: " + totalCounter + ", " + increment); - if (maxTotalMemorySize != 0 && totalCounter + increment >= maxTotalMemorySize) { - //System.out.println("RELEASE: " + task); - while (semaphore.hasQueuedThreads()) { - semaphore.release(); - } + if (totalLimiter != null) { + totalLimiter.decrease(increment); } if (task instanceof ChannelEventRunnable) { @@ -503,13 +490,11 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { private static final class Settings { final ObjectSizeEstimator objectSizeEstimator; final long maxChannelMemorySize; - final long maxTotalMemorySize; Settings(ObjectSizeEstimator objectSizeEstimator, - long maxChannelMemorySize, long maxTotalMemorySize) { + long maxChannelMemorySize) { this.objectSizeEstimator = objectSizeEstimator; this.maxChannelMemorySize = maxChannelMemorySize; - this.maxTotalMemorySize = maxTotalMemorySize; } } @@ -541,4 +526,38 @@ public class MemoryAwareThreadPoolExecutor extends ThreadPoolExecutor { task.run(); } } + + + private static class Limiter { + + final long limit; + private long counter; + private int waiters; + + Limiter(long limit) { + super(); + this.limit = limit; + } + + synchronized void increase(long amount) { + while (counter >= limit) { + waiters ++; + try { + wait(); + } catch (InterruptedException e) { + // Ignore + } finally { + waiters --; + } + } + counter += amount; + } + + synchronized void decrease(long amount) { + counter -= amount; + if (counter < limit && waiters > 0) { + notifyAll(); + } + } + } }