package org.warp.commonutils.type; import it.unimi.dsi.fastutil.objects.Object2FloatMap; import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap; import java.lang.reflect.Array; import java.util.Collection; import java.util.Iterator; import java.util.PriorityQueue; import java.util.Queue; import java.util.function.Consumer; import java.util.function.IntFunction; import java.util.stream.Stream; import org.jetbrains.annotations.NotNull; public class FloatPriorityQueue implements Queue { private final Object2FloatMap contentValues; private final Queue> internalQueue; public static FloatPriorityQueue of() { return new FloatPriorityQueue<>(0); } public static FloatPriorityQueue of(T value, float score) { var pq = new FloatPriorityQueue(1); pq.offer(value); return pq; } public static FloatPriorityQueue of(ScoredValue value) { var pq = new FloatPriorityQueue(1); pq.offer(value); return pq; } @SafeVarargs public static FloatPriorityQueue of(ScoredValue... values) { var pq = new FloatPriorityQueue(values.length); for (ScoredValue value : values) { pq.offer(value); } return pq; } public FloatPriorityQueue(PriorityQueue> internalQueue) { this(internalQueue.size()); for (ScoredValue tScoredValue : internalQueue) { add(tScoredValue.getValue(), tScoredValue.getScore()); } } private FloatPriorityQueue(Object2FloatMap contentValues, Queue> internalQueue) { this.contentValues = contentValues; this.internalQueue = internalQueue; } public FloatPriorityQueue() { this.contentValues = new Object2FloatOpenHashMap<>(); internalQueue = new PriorityQueue<>(); } public FloatPriorityQueue(int initialCapacity) { this.contentValues = new Object2FloatOpenHashMap<>(initialCapacity); internalQueue = new PriorityQueue<>(Math.max(1, initialCapacity)); } public static FloatPriorityQueue synchronize(FloatPriorityQueue queue) { return new SynchronizedFloatPriorityQueue<>(queue.contentValues, queue.internalQueue); } public static FloatPriorityQueue synchronizedPq(int initialCapacity) { var pq = new FloatPriorityQueue(initialCapacity); return new SynchronizedFloatPriorityQueue<>(pq.contentValues, pq.internalQueue); } @Override public int size() { assert contentValues.size() == internalQueue.size(); return internalQueue.size(); } @Override public boolean isEmpty() { assert contentValues.size() == internalQueue.size(); return internalQueue.isEmpty(); } @Override public boolean contains(Object o) { assert contentValues.size() == internalQueue.size(); return internalQueue.contains(ScoredValue.of(0, o)); } @NotNull @Override public Iterator iterator() { assert contentValues.size() == internalQueue.size(); var it = internalQueue.iterator(); return new Iterator<>() { @Override public boolean hasNext() { assert contentValues.size() == internalQueue.size(); return it.hasNext(); } @Override public T next() { assert contentValues.size() == internalQueue.size(); return getValueOrNull(it.next()); } }; } private static T getValueOrNull(ScoredValue scoredValue) { if (scoredValue == null) { return null; } else { return scoredValue.getValue(); } } @NotNull @Override public Object @NotNull [] toArray() { assert contentValues.size() == internalQueue.size(); return internalQueue.stream().map(FloatPriorityQueue::getValueOrNull).toArray(Object[]::new); } @SuppressWarnings({"SuspiciousToArrayCall", "unchecked"}) @NotNull @Override public T1 @NotNull [] toArray(@NotNull T1 @NotNull [] a) { assert contentValues.size() == internalQueue.size(); return internalQueue .stream() .map(FloatPriorityQueue::getValueOrNull) .toArray(i -> (T1[]) Array.newInstance(a.getClass().getComponentType(), i)); } @Deprecated @Override public boolean add(T t) { assert contentValues.size() == internalQueue.size(); return offer(t, 0); } public boolean addTop(T t) { assert contentValues.size() == internalQueue.size(); if (contentValues.getFloat(t) == Integer.MAX_VALUE) { return false; } else { if (contentValues.containsKey(t)) { internalQueue.remove(ScoredValue.of(0, t)); } contentValues.put(t, Integer.MAX_VALUE); return internalQueue.add(ScoredValue.of(Integer.MAX_VALUE, t)); } } public boolean add(T t, float score) { assert contentValues.size() == internalQueue.size(); return offer(t, score); } @Override public boolean remove(Object o) { assert contentValues.size() == internalQueue.size(); contentValues.removeFloat(o); return internalQueue.remove(ScoredValue.of(0, o)); } @Override public boolean containsAll(@NotNull Collection c) { assert contentValues.size() == internalQueue.size(); for (Object o : c) { //noinspection SuspiciousMethodCalls if (!contentValues.containsKey(o)) { return false; } } return true; } @Override public boolean addAll(@NotNull Collection c) { assert contentValues.size() == internalQueue.size(); boolean added = false; for (T t : c) { added |= add(t); } return added; } @Override public boolean removeAll(@NotNull Collection c) { assert contentValues.size() == internalQueue.size(); boolean removed = false; for (Object o : c) { removed |= remove(o); } return removed; } @Override public boolean retainAll(@NotNull Collection c) { assert contentValues.size() == internalQueue.size(); return removeIf(item -> !c.contains(item)); } @Override public void clear() { assert contentValues.size() == internalQueue.size(); contentValues.clear(); internalQueue.clear(); } @Override public boolean offer(T t) { assert contentValues.size() == internalQueue.size(); return offer(ScoredValue.of(0, t)); } public boolean offer(T t, float score) { assert contentValues.size() == internalQueue.size(); return offer(ScoredValue.of(score, t)); } public boolean offer(ScoredValue value) { int contentValuesSize = contentValues.size(); int internalQueueSize = internalQueue.size(); assert contentValuesSize == internalQueueSize; boolean added = true; float oldValue; if (contentValues.containsKey(value.getValue())) { internalQueue.remove(value); oldValue = contentValues.getFloat(value.getValue()); added = false; } else { oldValue = 0f; } contentValues.put(value.getValue(), oldValue + value.getScore()); internalQueue.add(value); return added; } @Override public T remove() { assert contentValues.size() == internalQueue.size(); var val = internalQueue.remove(); if (val != null) { contentValues.removeFloat(val.getValue()); } return getValueOrNull(val); } @Override public T poll() { assert contentValues.size() == internalQueue.size(); var val = internalQueue.poll(); if (val != null) { contentValues.removeFloat(val.getValue()); } return getValueOrNull(val); } @Override public T element() { assert contentValues.size() == internalQueue.size(); return getValueOrNull(internalQueue.element()); } @Override public T peek() { assert contentValues.size() == internalQueue.size(); return getValueOrNull(internalQueue.peek()); } public void forEachItem(Consumer> action) { assert contentValues.size() == internalQueue.size(); internalQueue.forEach(action); } public Stream> streamItems() { assert contentValues.size() == internalQueue.size(); return internalQueue.stream(); } public FloatPriorityQueueView view() { return new FloatPriorityQueueView<>(this); } private static class SynchronizedFloatPriorityQueue extends FloatPriorityQueue { public SynchronizedFloatPriorityQueue(Object2FloatMap contentValues, Queue> internalQueue) { super(contentValues, internalQueue); } @Override public synchronized int size() { return super.size(); } @Override public synchronized boolean isEmpty() { return super.isEmpty(); } @Override public synchronized boolean contains(Object o) { return super.contains(o); } @Override public synchronized @NotNull Iterator iterator() { var it = super.iterator(); return new Iterator<>() { @Override public boolean hasNext() { synchronized (SynchronizedFloatPriorityQueue.this) { return it.hasNext(); } } @Override public T next() { synchronized (SynchronizedFloatPriorityQueue.this) { return it.next(); } } }; } @Override public synchronized @NotNull Object @NotNull [] toArray() { return super.toArray(); } @SuppressWarnings("SuspiciousToArrayCall") @Override public synchronized @NotNull T1 @NotNull [] toArray(@NotNull T1 @NotNull [] a) { return super.toArray(a); } @Override public synchronized T1[] toArray(IntFunction generator) { //noinspection SuspiciousToArrayCall return super.toArray(generator); } @Override public synchronized boolean add(T t, float score) { return super.add(t, score); } @Override public synchronized boolean addTop(T t) { return super.addTop(t); } @Deprecated @Override public synchronized boolean add(T t) { return super.add(t); } @Override public synchronized boolean addAll(@NotNull Collection c) { return super.addAll(c); } @Override public synchronized boolean remove(Object o) { return super.remove(o); } @Override public synchronized boolean removeAll(@NotNull Collection c) { return super.removeAll(c); } @Override public synchronized boolean containsAll(@NotNull Collection c) { return super.containsAll(c); } @Override public synchronized boolean retainAll(@NotNull Collection c) { return super.retainAll(c); } @Override public synchronized void clear() { super.clear(); } @Override public synchronized boolean offer(T t) { return super.offer(t); } @Override public synchronized boolean offer(T t, float score) { return super.offer(t, score); } @Override public synchronized boolean offer(ScoredValue value) { return super.offer(value); } @Override public synchronized T remove() { return super.remove(); } @Override public synchronized T poll() { return super.poll(); } @Override public synchronized T element() { return super.element(); } @Override public synchronized T peek() { return super.peek(); } @Override public synchronized void forEachItem(Consumer> action) { super.forEachItem(action); } @Override public synchronized Stream> streamItems() { return super.streamItems(); } } }