Added optional pending timeouts counter parameter to HashedWheelTimer constructor and ensured that pending timeouts don't exceed provided max pending timeouts.

Motivation:
If the rate at which new timeouts are created is very high and the created timeouts are not cancelled, then the JVM can crash because of out of heap space. There should be a guard in the implementation to prevent this.

Modifications:
The constructor of HashedWheelTimer now takes an optional max pending timeouts parameter beyond which it will reject new timeouts by throwing RejectedExecutionException.

Result:
After this change, if the max pending timeouts parameter is passed as constructor argument to HashedWheelTimer, then it keeps a track of pending timeouts that aren't yet expired or cancelled. When a new timeout is being created, it checks for current pending timeouts and if it's equal to or greater than provided max pending timeouts, then it throws RejectedExecutionException.
This commit is contained in:
Aniket Bhatnagar 2016-09-22 14:46:47 +01:00 committed by Scott Mitchell
parent 68100cdd7f
commit 7d0f9e80e2
3 changed files with 163 additions and 26 deletions

View File

@ -26,9 +26,11 @@ import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLong;
/**
* A {@link Timer} optimized for approximated I/O timeout scheduling.
@ -105,6 +107,8 @@ public class HashedWheelTimer implements Timer {
private final CountDownLatch startTimeInitialized = new CountDownLatch(1);
private final Queue<HashedWheelTimeout> timeouts = PlatformDependent.newMpscQueue();
private final Queue<HashedWheelTimeout> cancelledTimeouts = PlatformDependent.newMpscQueue();
private final AtomicLong pendingTimeouts = new AtomicLong(0);
private final long maxPendingTimeouts;
private volatile long startTime;
@ -195,20 +199,48 @@ public class HashedWheelTimer implements Timer {
/**
* Creates a new timer.
*
* @param threadFactory a {@link ThreadFactory} that creates a
* background {@link Thread} which is dedicated to
* {@link TimerTask} execution.
* @param tickDuration the duration between tick
* @param unit the time unit of the {@code tickDuration}
* @param ticksPerWheel the size of the wheel
* @param leakDetection {@code true} if leak detection should be enabled always, if false it will only be enabled
* if the worker thread is not a daemon thread.
* @param threadFactory a {@link ThreadFactory} that creates a
* background {@link Thread} which is dedicated to
* {@link TimerTask} execution.
* @param tickDuration the duration between tick
* @param unit the time unit of the {@code tickDuration}
* @param ticksPerWheel the size of the wheel
* @param leakDetection {@code true} if leak detection should be enabled always,
* if false it will only be enabled if the worker thread is not
* a daemon thread.
* @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null}
* @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is &lt;= 0
*/
public HashedWheelTimer(
ThreadFactory threadFactory,
long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection) {
this(threadFactory, tickDuration, unit, ticksPerWheel, leakDetection, -1);
}
/**
* Creates a new timer.
*
* @param threadFactory a {@link ThreadFactory} that creates a
* background {@link Thread} which is dedicated to
* {@link TimerTask} execution.
* @param tickDuration the duration between tick
* @param unit the time unit of the {@code tickDuration}
* @param ticksPerWheel the size of the wheel
* @param leakDetection {@code true} if leak detection should be enabled always,
* if false it will only be enabled if the worker thread is not
* a daemon thread.
* @param maxPendingTimeouts The maximum number of pending timeouts after which call to
* {@code newTimeout} will result in
* {@link java.util.concurrent.RejectedExecutionException}
* being thrown. No maximum pending timeouts limit is assumed if
* this value is 0 or negative.
* @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null}
* @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is &lt;= 0
*/
public HashedWheelTimer(
ThreadFactory threadFactory,
long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection) {
long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection,
long maxPendingTimeouts) {
if (threadFactory == null) {
throw new NullPointerException("threadFactory");
@ -239,6 +271,8 @@ public class HashedWheelTimer implements Timer {
workerThread = threadFactory.newThread(worker);
leak = leakDetection || !workerThread.isDaemon() ? leakDetector.open(this) : null;
this.maxPendingTimeouts = maxPendingTimeouts;
}
private static HashedWheelBucket[] createWheel(int ticksPerWheel) {
@ -347,6 +381,16 @@ public class HashedWheelTimer implements Timer {
if (unit == null) {
throw new NullPointerException("unit");
}
if (shouldLimitTimeouts()) {
long pendingTimeoutsCount = pendingTimeouts.incrementAndGet();
if (pendingTimeoutsCount > maxPendingTimeouts) {
pendingTimeouts.decrementAndGet();
throw new RejectedExecutionException("Number of pending timeouts ("
+ pendingTimeoutsCount + ") is greater than or equal to maximum allowed pending "
+ "timeouts (" + maxPendingTimeouts + ")");
}
}
start();
// Add the timeout to the timeout queue which will be processed on the next tick.
@ -357,6 +401,10 @@ public class HashedWheelTimer implements Timer {
return timeout;
}
private boolean shouldLimitTimeouts() {
return maxPendingTimeouts > 0;
}
private final class Worker implements Runnable {
private final Set<Timeout> unprocessedTimeouts = new HashSet<Timeout>();
@ -558,6 +606,8 @@ public class HashedWheelTimer implements Timer {
HashedWheelBucket bucket = this.bucket;
if (bucket != null) {
bucket.remove(this);
} else if (timer.shouldLimitTimeouts()) {
timer.pendingTimeouts.decrementAndGet();
}
}
@ -706,6 +756,9 @@ public class HashedWheelTimer implements Timer {
timeout.prev = null;
timeout.next = null;
timeout.bucket = null;
if (timeout.timer.shouldLimitTimeouts()) {
timeout.timer.pendingTimeouts.decrementAndGet();
}
}
/**

View File

@ -16,6 +16,7 @@
package io.netty.util;
import java.util.Set;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
/**
@ -30,8 +31,9 @@ public interface Timer {
*
* @return a handle which is associated with the specified task
*
* @throws IllegalStateException if this timer has been
* {@linkplain #stop() stopped} already
* @throws IllegalStateException if this timer has been {@linkplain #stop() stopped} already
* @throws RejectedExecutionException if the pending timeouts are too many and creating new timeout
* can cause instability in the system.
*/
Timeout newTimeout(TimerTask task, long delay, TimeUnit unit);

View File

@ -15,12 +15,14 @@
*/
package io.netty.util;
import com.google.common.base.Supplier;
import org.junit.Test;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@ -66,11 +68,7 @@ public class HashedWheelTimerTest {
public void testStopTimer() throws InterruptedException {
final Timer timerProcessed = new HashedWheelTimer();
for (int i = 0; i < 3; i ++) {
timerProcessed.newTimeout(new TimerTask() {
@Override
public void run(Timeout timeout) throws Exception {
}
}, 1, TimeUnit.MILLISECONDS);
timerProcessed.newTimeout(createNoOpTimerTask(), 1, TimeUnit.MILLISECONDS);
}
Thread.sleep(1000L); // sleep for a second
assertEquals("Number of unprocessed timeouts should be 0", 0, timerProcessed.stop().size());
@ -87,7 +85,7 @@ public class HashedWheelTimerTest {
assertFalse("Number of unprocessed timeouts should be greater than 0", timerUnprocessed.stop().isEmpty());
}
@Test(expected = IllegalStateException.class)
@Test
public void testTimerShouldThrowExceptionAfterShutdownForNewTimeouts() throws InterruptedException {
final Timer timer = new HashedWheelTimer();
for (int i = 0; i < 3; i ++) {
@ -99,20 +97,29 @@ public class HashedWheelTimerTest {
}
timer.stop();
Thread.sleep(1000L); // sleep for a second
timer.newTimeout(new TimerTask() {
waitUntilTrue(1000, new Supplier<Boolean>() {
@Override
public void run(Timeout timeout) throws Exception {
fail("This should not run");
public Boolean get() {
try {
timer.newTimeout(new TimerTask() {
@Override
public void run(Timeout timeout) throws Exception {
fail("This should not run");
}
}, 1, TimeUnit.SECONDS);
return false;
} catch (IllegalStateException e) {
return true;
}
}
}, 1, TimeUnit.SECONDS);
});
}
@Test
public void testTimerOverflowWheelLength() throws InterruptedException {
final HashedWheelTimer timer = new HashedWheelTimer(
Executors.defaultThreadFactory(), 100, TimeUnit.MILLISECONDS, 32);
Executors.defaultThreadFactory(), 100, TimeUnit.MILLISECONDS, 32);
final AtomicInteger counter = new AtomicInteger();
timer.newTimeout(new TimerTask() {
@ -122,8 +129,12 @@ public class HashedWheelTimerTest {
timer.newTimeout(this, 1, TimeUnit.SECONDS);
}
}, 1, TimeUnit.SECONDS);
Thread.sleep(3500);
assertEquals(3, counter.get());
waitUntilTrue(3500, new Supplier<Boolean>() {
@Override
public Boolean get() {
return 3 == counter.get();
}
});
timer.stop();
}
@ -149,9 +160,80 @@ public class HashedWheelTimerTest {
for (int i = 0; i < scheduledTasks; i++) {
long delay = queue.take();
assertTrue("Timeout + " + scheduledTasks + " delay " + delay + " must be " + timeout + " < " + maxTimeout,
delay >= timeout && delay < maxTimeout);
delay >= timeout && delay < maxTimeout);
}
timer.stop();
}
@Test
public void testRejectedExecutionExceptionWhenTooManyTimeoutsAreAddedBackToBack() {
HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), 100,
TimeUnit.MILLISECONDS, 32, true, 2);
timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.SECONDS);
timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.SECONDS);
try {
timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.SECONDS);
fail("Timer allowed adding 3 timeouts when maxPendingTimeouts was 2");
} catch (RejectedExecutionException e) {
// Expected
} finally {
timer.stop();
}
}
@Test
public void testNewTimeoutShouldStopThrowingRejectedExecutionExceptionWhenExistingTimeoutIsCancelled()
throws InterruptedException {
final HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), 100,
TimeUnit.MILLISECONDS, 32, true, 2);
timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.SECONDS);
Timeout timeoutToCancel = timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.SECONDS);
timeoutToCancel.cancel();
waitUntilNewTimeoutCanBeAdded(200, timer, 1000);
}
@Test
public void testNewTimeoutShouldStopThrowingRejectedExecutionExceptionWhenExistingTimeoutIsExecuted()
throws InterruptedException {
final HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), 25,
TimeUnit.MILLISECONDS, 4, true, 2);
timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.SECONDS);
timer.newTimeout(createNoOpTimerTask(), 90, TimeUnit.MILLISECONDS);
waitUntilNewTimeoutCanBeAdded(200, timer, 200);
timer.stop();
}
private static void waitUntilNewTimeoutCanBeAdded(final long timeoutMilli,
final HashedWheelTimer timer, final long newTimeoutDelay) {
waitUntilTrue(timeoutMilli, new Supplier<Boolean>() {
@Override
public Boolean get() {
try {
timer.newTimeout(createNoOpTimerTask(), newTimeoutDelay, TimeUnit.MILLISECONDS);
return true;
} catch (RejectedExecutionException e) {
return false;
}
}
});
}
private static void waitUntilTrue(long timeoutMilli, Supplier<Boolean> condition) {
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < timeoutMilli) {
if (condition.get()) {
return;
}
}
throw new AssertionError("Timed out waiting for condition to be true.");
}
private static TimerTask createNoOpTimerTask() {
return new TimerTask() {
@Override
public void run(final Timeout timeout) throws Exception {
}
};
}
}