diff --git a/src/main/java/org/warp/commonutils/concurrency/executor/BoundedExecutorServiceImpl.java b/src/main/java/org/warp/commonutils/concurrency/executor/BoundedExecutorServiceImpl.java index aa49413..10e10cb 100644 --- a/src/main/java/org/warp/commonutils/concurrency/executor/BoundedExecutorServiceImpl.java +++ b/src/main/java/org/warp/commonutils/concurrency/executor/BoundedExecutorServiceImpl.java @@ -51,6 +51,24 @@ class BoundedExecutorServiceImpl extends ThreadPoolExecutor implements BoundedEx return this.submit(task); } + @Override + public Future submit(Runnable task) { + preChecks(); + return super.submit(task); + } + + @Override + public Future submit(Callable task) { + preChecks(); + return super.submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + preChecks(); + return super.submit(task, result); + } + /** * Submits task to execution pool, but blocks while number of running threads * has reached the bound limit @@ -60,6 +78,32 @@ class BoundedExecutorServiceImpl extends ThreadPoolExecutor implements BoundedEx this.execute(task); } + @Override + public void execute(Runnable command) { + preChecks(); + super.execute(command); + } + + private void preChecks() { + try { + blockIfFull(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + protected void beforeExecute(Thread t, Runnable r) { + var queueSize = getQueue().size(); + synchronized (queueSizeStatusLock) { + if (queueSizeStatus != null) queueSizeStatus.accept(queueSize >= maxQueueSize, queueSize); + } + + semaphore.release(); + + super.beforeExecute(t, r); + } + private void blockIfFull() throws InterruptedException { if (semaphore.availablePermits() == 0) { synchronized (queueSizeStatusLock) { @@ -68,22 +112,4 @@ class BoundedExecutorServiceImpl extends ThreadPoolExecutor implements BoundedEx } semaphore.acquire(); } - - @Override - public void beforeExecute(Thread t, Runnable r) { - try { - blockIfFull(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - - var queueSize = getQueue().size(); - synchronized (queueSizeStatusLock) { - if (queueSizeStatus != null) queueSizeStatus.accept(queueSize >= maxQueueSize, queueSize); - } - - semaphore.release(); - - super.beforeExecute(t, r); - } } \ No newline at end of file diff --git a/src/test/java/org/warp/commonutils/BoundedQueueTest.java b/src/test/java/org/warp/commonutils/BoundedQueueTest.java new file mode 100644 index 0000000..6ca49a6 --- /dev/null +++ b/src/test/java/org/warp/commonutils/BoundedQueueTest.java @@ -0,0 +1,51 @@ +package org.warp.commonutils; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.opentest4j.AssertionFailedError; +import org.warp.commonutils.concurrency.executor.BoundedExecutorService; +import org.warp.commonutils.type.ShortNamedThreadFactory; + +public class BoundedQueueTest { + + @Test + public void testBoundedQueue() throws InterruptedException { + int maxQueueSize = 2; + AtomicInteger queueSize = new AtomicInteger(); + AtomicReference failedError = new AtomicReference<>(); + var executor = BoundedExecutorService.create(maxQueueSize, + 1, + 1, + 0L, + TimeUnit.MILLISECONDS, + new ShortNamedThreadFactory("test"), + (isQueueFull, currentQueueSize) -> { + try { + if (currentQueueSize >= maxQueueSize) { + Assertions.assertTrue(isQueueFull); + } else { + Assertions.assertFalse(isQueueFull); + } + } catch (AssertionFailedError ex) { + if (failedError.get() == null) { + failedError.set(ex); + } + ex.printStackTrace(); + } + } + ); + + for (int i = 0; i < 10000; i++) { + queueSize.incrementAndGet(); + executor.executeButBlockIfFull(queueSize::decrementAndGet); + } + + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + + Assertions.assertNull(failedError.get()); + } +}