From b754302914af02dd97c802a4d647b9071307cd99 Mon Sep 17 00:00:00 2001 From: Andrea Cavalli Date: Tue, 22 Sep 2020 19:17:48 +0200 Subject: [PATCH] Sum the old priority with the new priority if an element is offered a second time --- .../commonutils/type/FloatPriorityQueue.java | 266 ++++++++++++++++-- 1 file changed, 247 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/warp/commonutils/type/FloatPriorityQueue.java b/src/main/java/org/warp/commonutils/type/FloatPriorityQueue.java index fe5a817..fa1fa0d 100644 --- a/src/main/java/org/warp/commonutils/type/FloatPriorityQueue.java +++ b/src/main/java/org/warp/commonutils/type/FloatPriorityQueue.java @@ -1,18 +1,20 @@ package org.warp.commonutils.type; -import com.google.common.collect.Queues; +import it.unimi.dsi.fastutil.objects.Object2FloatMap; +import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap; import java.lang.reflect.Array; import java.util.Collection; import java.util.Iterator; import java.util.PriorityQueue; import java.util.Queue; import java.util.function.Consumer; -import java.util.stream.Collectors; +import java.util.function.IntFunction; import java.util.stream.Stream; import org.jetbrains.annotations.NotNull; public class FloatPriorityQueue implements Queue { + private final Object2FloatMap contentValues; private final Queue> internalQueue; public static FloatPriorityQueue of() { @@ -40,62 +42,75 @@ public class FloatPriorityQueue implements Queue { } public FloatPriorityQueue(PriorityQueue> internalQueue) { - this.internalQueue = internalQueue; + this(internalQueue.size()); + for (ScoredValue tScoredValue : internalQueue) { + add(tScoredValue.getValue(), tScoredValue.getScore()); + } } - private FloatPriorityQueue(Queue> internalQueue) { + private FloatPriorityQueue(Object2FloatMap contentValues, Queue> internalQueue) { + this.contentValues = contentValues; this.internalQueue = internalQueue; } public FloatPriorityQueue() { + this.contentValues = new Object2FloatOpenHashMap<>(); internalQueue = new PriorityQueue<>(); } public FloatPriorityQueue(int initialCapacity) { + this.contentValues = new Object2FloatOpenHashMap<>(initialCapacity); internalQueue = new PriorityQueue<>(Math.max(1, initialCapacity)); } public static FloatPriorityQueue synchronize(FloatPriorityQueue queue) { - return new FloatPriorityQueue(Queues.synchronizedQueue(queue.internalQueue)); + return new SynchronizedFloatPriorityQueue(queue.contentValues, queue.internalQueue); } public static FloatPriorityQueue synchronizedPq(int initialCapacity) { - return new FloatPriorityQueue(Queues.synchronizedQueue(new PriorityQueue<>(Math.max(1, initialCapacity)))); + var pq = new FloatPriorityQueue(initialCapacity); + return new SynchronizedFloatPriorityQueue(pq.contentValues, pq.internalQueue); } @Override public int size() { + assert contentValues.size() == internalQueue.size(); return internalQueue.size(); } @Override public boolean isEmpty() { + assert contentValues.isEmpty(); return internalQueue.isEmpty(); } @Override public boolean contains(Object o) { + assert contentValues.size() == internalQueue.size(); return internalQueue.contains(ScoredValue.of(0, o)); } @NotNull @Override public Iterator iterator() { + assert contentValues.size() == internalQueue.size(); var it = internalQueue.iterator(); return new Iterator() { @Override public boolean hasNext() { + assert contentValues.size() == internalQueue.size(); return it.hasNext(); } @Override public T next() { + assert contentValues.size() == internalQueue.size(); return getValueOrNull(it.next()); } }; } - private T getValueOrNull(ScoredValue scoredValue) { + private static T getValueOrNull(ScoredValue scoredValue) { if (scoredValue == null) { return null; } else { @@ -106,100 +121,313 @@ public class FloatPriorityQueue implements Queue { @NotNull @Override public Object[] toArray() { - return internalQueue.stream().map(this::getValueOrNull).toArray(Object[]::new); + assert contentValues.size() == internalQueue.size(); + return internalQueue.stream().map(FloatPriorityQueue::getValueOrNull).toArray(Object[]::new); } @NotNull @Override public T1[] toArray(@NotNull T1[] a) { + assert contentValues.size() == internalQueue.size(); return internalQueue .stream() - .map(this::getValueOrNull) + .map(FloatPriorityQueue::getValueOrNull) .toArray(i -> (T1[]) Array.newInstance(a.getClass().getComponentType(), i)); } @Deprecated @Override public boolean add(T t) { - return internalQueue.add(ScoredValue.of(0, t)); + assert contentValues.size() == internalQueue.size(); + return offer(t, 0); } public boolean addTop(T t) { - return internalQueue.add(ScoredValue.of(Integer.MAX_VALUE, t)); + assert contentValues.size() == internalQueue.size(); + if (contentValues.getFloat(t) == Integer.MAX_VALUE) { + return false; + } else { + if (contentValues.containsKey(t)) { + internalQueue.remove(ScoredValue.of(0, t)); + } + contentValues.put(t, Integer.MAX_VALUE); + return internalQueue.add(ScoredValue.of(Integer.MAX_VALUE, t)); + } } public boolean add(T t, float score) { - return internalQueue.add(ScoredValue.of(score, t)); + assert contentValues.size() == internalQueue.size(); + return offer(t, score); } @Override public boolean remove(Object o) { + assert contentValues.size() == internalQueue.size(); + contentValues.removeFloat(o); return internalQueue.remove(ScoredValue.of(0, o)); } @Override public boolean containsAll(@NotNull Collection c) { - return internalQueue.containsAll(c.stream().map(val -> ScoredValue.of(0, val)).collect(Collectors.toList())); + assert contentValues.size() == internalQueue.size(); + for (Object o : c) { + //noinspection SuspiciousMethodCalls + if (!contentValues.containsKey(o)) { + return false; + } + } + return true; } @Override public boolean addAll(@NotNull Collection c) { - return internalQueue.addAll(c.stream().map(val -> ScoredValue.of(0, (T) val)).collect(Collectors.toList())); + assert contentValues.size() == internalQueue.size(); + boolean added = false; + for (T t : c) { + added |= add(t); + } + return added; } @Override public boolean removeAll(@NotNull Collection c) { - return internalQueue.removeAll(c.stream().map(val -> ScoredValue.of(0, val)).collect(Collectors.toList())); + assert contentValues.size() == internalQueue.size(); + boolean removed = false; + for (Object o : c) { + removed |= remove(o); + } + return removed; } @Override public boolean retainAll(@NotNull Collection c) { - return internalQueue.retainAll(c.stream().map(val -> ScoredValue.of(0, val)).collect(Collectors.toList())); + assert contentValues.size() == internalQueue.size(); + return removeIf(item -> !c.contains(item)); } @Override public void clear() { + assert contentValues.size() == internalQueue.size(); + contentValues.clear(); internalQueue.clear(); } @Override public boolean offer(T t) { + assert contentValues.size() == internalQueue.size(); return offer(ScoredValue.of(0, t)); } public boolean offer(T t, float score) { + assert contentValues.size() == internalQueue.size(); return offer(ScoredValue.of(score, t)); } public boolean offer(ScoredValue value) { - return this.internalQueue.offer(value); + assert contentValues.size() == internalQueue.size(); + + boolean added = true; + float oldValue; + if (contentValues.containsKey(value.getValue())) { + internalQueue.remove(value); + oldValue = contentValues.getFloat(value.getValue()); + added = false; + } else { + oldValue = 0f; + } + contentValues.put(value.getValue(), oldValue + value.getScore()); + internalQueue.add(value); + return added; } @Override public T remove() { - return getValueOrNull(internalQueue.remove()); + assert contentValues.size() == internalQueue.size(); + var val = internalQueue.remove(); + if (val != null) { + contentValues.removeFloat(val.getValue()); + } + return getValueOrNull(val); } @Override public T poll() { - return getValueOrNull(internalQueue.poll()); + assert contentValues.size() == internalQueue.size(); + var val = internalQueue.poll(); + + if (val != null) { + contentValues.removeFloat(val.getValue()); + } + return getValueOrNull(val); } @Override public T element() { + assert contentValues.size() == internalQueue.size(); return getValueOrNull(internalQueue.element()); } @Override public T peek() { + assert contentValues.size() == internalQueue.size(); return getValueOrNull(internalQueue.peek()); } public void forEachItem(Consumer> action) { + assert contentValues.size() == internalQueue.size(); internalQueue.forEach(action); } public Stream> streamItems() { + assert contentValues.size() == internalQueue.size(); return internalQueue.stream(); } + + private static class SynchronizedFloatPriorityQueue extends FloatPriorityQueue { + + public SynchronizedFloatPriorityQueue(Object2FloatMap contentValues, Queue> internalQueue) { + super(contentValues, internalQueue); + } + + @Override + public synchronized int size() { + return super.size(); + } + + @Override + public synchronized boolean isEmpty() { + return super.isEmpty(); + } + + @Override + public synchronized boolean contains(Object o) { + return super.contains(o); + } + + @Override + public synchronized @NotNull Iterator iterator() { + var it = super.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + synchronized (SynchronizedFloatPriorityQueue.this) { + return it.hasNext(); + } + } + + @Override + public T next() { + synchronized (SynchronizedFloatPriorityQueue.this) { + return it.next(); + } + } + }; + } + + @Override + public synchronized @NotNull Object[] toArray() { + return super.toArray(); + } + + @Override + public synchronized @NotNull T1[] toArray(@NotNull T1[] a) { + return super.toArray(a); + } + + @Override + public synchronized T1[] toArray(IntFunction generator) { + //noinspection SuspiciousToArrayCall + return super.toArray(generator); + } + + @Override + public synchronized boolean add(T t, float score) { + return super.add(t, score); + } + + @Override + public synchronized boolean addTop(T t) { + return super.addTop(t); + } + + @Override + public synchronized boolean add(T t) { + return super.add(t); + } + + @Override + public synchronized boolean addAll(@NotNull Collection c) { + return super.addAll(c); + } + + @Override + public synchronized boolean remove(Object o) { + return super.remove(o); + } + + @Override + public synchronized boolean removeAll(@NotNull Collection c) { + return super.removeAll(c); + } + + @Override + public synchronized boolean containsAll(@NotNull Collection c) { + return super.containsAll(c); + } + + @Override + public synchronized boolean retainAll(@NotNull Collection c) { + return super.retainAll(c); + } + + @Override + public synchronized void clear() { + super.clear(); + } + + @Override + public synchronized boolean offer(T t) { + return super.offer(t); + } + + @Override + public synchronized boolean offer(T t, float score) { + return super.offer(t, score); + } + + @Override + public synchronized boolean offer(ScoredValue value) { + return super.offer(value); + } + + @Override + public synchronized T remove() { + return super.remove(); + } + + @Override + public synchronized T poll() { + return super.poll(); + } + + @Override + public synchronized T element() { + return super.element(); + } + + @Override + public synchronized T peek() { + return super.peek(); + } + + @Override + public synchronized void forEachItem(Consumer> action) { + super.forEachItem(action); + } + + @Override + public synchronized Stream> streamItems() { + return super.streamItems(); + } + } }