common-utils/src/main/java/org/warp/commonutils/type/FloatPriorityQueue.java

444 lines
11 KiB
Java

package org.warp.commonutils.type;
import it.unimi.dsi.fastutil.objects.Object2FloatMap;
import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap;
import java.lang.reflect.Array;
import java.util.Collection;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.function.Consumer;
import java.util.function.IntFunction;
import java.util.stream.Stream;
import org.jetbrains.annotations.NotNull;
public class FloatPriorityQueue<T> implements Queue<T> {
private final Object2FloatMap<T> contentValues;
private final Queue<ScoredValue<T>> internalQueue;
public static <T> FloatPriorityQueue<T> of() {
return new FloatPriorityQueue<>(0);
}
public static <T> FloatPriorityQueue<T> of(T value, float score) {
var pq = new FloatPriorityQueue<T>(1);
pq.offer(value);
return pq;
}
public static <T> FloatPriorityQueue<T> of(ScoredValue<T> value) {
var pq = new FloatPriorityQueue<T>(1);
pq.offer(value);
return pq;
}
@SafeVarargs
public static <T> FloatPriorityQueue<T> of(ScoredValue<T>... values) {
var pq = new FloatPriorityQueue<T>(values.length);
for (ScoredValue<T> value : values) {
pq.offer(value);
}
return pq;
}
public FloatPriorityQueue(PriorityQueue<ScoredValue<T>> internalQueue) {
this(internalQueue.size());
for (ScoredValue<T> tScoredValue : internalQueue) {
add(tScoredValue.getValue(), tScoredValue.getScore());
}
}
private FloatPriorityQueue(Object2FloatMap<T> contentValues, Queue<ScoredValue<T>> internalQueue) {
this.contentValues = contentValues;
this.internalQueue = internalQueue;
}
public FloatPriorityQueue() {
this.contentValues = new Object2FloatOpenHashMap<>();
internalQueue = new PriorityQueue<>();
}
public FloatPriorityQueue(int initialCapacity) {
this.contentValues = new Object2FloatOpenHashMap<>(initialCapacity);
internalQueue = new PriorityQueue<>(Math.max(1, initialCapacity));
}
public static <T> FloatPriorityQueue<T> synchronize(FloatPriorityQueue<T> queue) {
return new SynchronizedFloatPriorityQueue<>(queue.contentValues, queue.internalQueue);
}
public static <T> FloatPriorityQueue<T> synchronizedPq(int initialCapacity) {
var pq = new FloatPriorityQueue<T>(initialCapacity);
return new SynchronizedFloatPriorityQueue<>(pq.contentValues, pq.internalQueue);
}
@Override
public int size() {
assert contentValues.size() == internalQueue.size();
return internalQueue.size();
}
@Override
public boolean isEmpty() {
assert contentValues.size() == internalQueue.size();
return internalQueue.isEmpty();
}
@Override
public boolean contains(Object o) {
assert contentValues.size() == internalQueue.size();
return internalQueue.contains(ScoredValue.of(0, o));
}
@NotNull
@Override
public Iterator<T> iterator() {
assert contentValues.size() == internalQueue.size();
var it = internalQueue.iterator();
return new Iterator<>() {
@Override
public boolean hasNext() {
assert contentValues.size() == internalQueue.size();
return it.hasNext();
}
@Override
public T next() {
assert contentValues.size() == internalQueue.size();
return getValueOrNull(it.next());
}
};
}
private static <T> T getValueOrNull(ScoredValue<T> scoredValue) {
if (scoredValue == null) {
return null;
} else {
return scoredValue.getValue();
}
}
@NotNull
@Override
public Object @NotNull [] toArray() {
assert contentValues.size() == internalQueue.size();
return internalQueue.stream().map(FloatPriorityQueue::getValueOrNull).toArray(Object[]::new);
}
@SuppressWarnings({"SuspiciousToArrayCall", "unchecked"})
@NotNull
@Override
public <T1> T1 @NotNull [] toArray(@NotNull T1 @NotNull [] a) {
assert contentValues.size() == internalQueue.size();
return internalQueue
.stream()
.map(FloatPriorityQueue::getValueOrNull)
.toArray(i -> (T1[]) Array.newInstance(a.getClass().getComponentType(), i));
}
@Deprecated
@Override
public boolean add(T t) {
assert contentValues.size() == internalQueue.size();
return offer(t, 0);
}
public boolean addTop(T t) {
assert contentValues.size() == internalQueue.size();
if (contentValues.getFloat(t) == Integer.MAX_VALUE) {
return false;
} else {
if (contentValues.containsKey(t)) {
internalQueue.remove(ScoredValue.of(0, t));
}
contentValues.put(t, Integer.MAX_VALUE);
return internalQueue.add(ScoredValue.of(Integer.MAX_VALUE, t));
}
}
public boolean add(T t, float score) {
assert contentValues.size() == internalQueue.size();
return offer(t, score);
}
@Override
public boolean remove(Object o) {
assert contentValues.size() == internalQueue.size();
contentValues.removeFloat(o);
return internalQueue.remove(ScoredValue.of(0, o));
}
@Override
public boolean containsAll(@NotNull Collection<?> c) {
assert contentValues.size() == internalQueue.size();
for (Object o : c) {
//noinspection SuspiciousMethodCalls
if (!contentValues.containsKey(o)) {
return false;
}
}
return true;
}
@Override
public boolean addAll(@NotNull Collection<? extends T> c) {
assert contentValues.size() == internalQueue.size();
boolean added = false;
for (T t : c) {
added |= add(t);
}
return added;
}
@Override
public boolean removeAll(@NotNull Collection<?> c) {
assert contentValues.size() == internalQueue.size();
boolean removed = false;
for (Object o : c) {
removed |= remove(o);
}
return removed;
}
@Override
public boolean retainAll(@NotNull Collection<?> c) {
assert contentValues.size() == internalQueue.size();
return removeIf(item -> !c.contains(item));
}
@Override
public void clear() {
assert contentValues.size() == internalQueue.size();
contentValues.clear();
internalQueue.clear();
}
@Override
public boolean offer(T t) {
assert contentValues.size() == internalQueue.size();
return offer(ScoredValue.of(0, t));
}
public boolean offer(T t, float score) {
assert contentValues.size() == internalQueue.size();
return offer(ScoredValue.of(score, t));
}
public boolean offer(ScoredValue<T> value) {
int contentValuesSize = contentValues.size();
int internalQueueSize = internalQueue.size();
assert contentValuesSize == internalQueueSize;
boolean added = true;
float oldValue;
if (contentValues.containsKey(value.getValue())) {
internalQueue.remove(value);
oldValue = contentValues.getFloat(value.getValue());
added = false;
} else {
oldValue = 0f;
}
contentValues.put(value.getValue(), oldValue + value.getScore());
internalQueue.add(value);
return added;
}
@Override
public T remove() {
assert contentValues.size() == internalQueue.size();
var val = internalQueue.remove();
if (val != null) {
contentValues.removeFloat(val.getValue());
}
return getValueOrNull(val);
}
@Override
public T poll() {
assert contentValues.size() == internalQueue.size();
var val = internalQueue.poll();
if (val != null) {
contentValues.removeFloat(val.getValue());
}
return getValueOrNull(val);
}
@Override
public T element() {
assert contentValues.size() == internalQueue.size();
return getValueOrNull(internalQueue.element());
}
@Override
public T peek() {
assert contentValues.size() == internalQueue.size();
return getValueOrNull(internalQueue.peek());
}
public void forEachItem(Consumer<ScoredValue<T>> action) {
assert contentValues.size() == internalQueue.size();
internalQueue.forEach(action);
}
public Stream<ScoredValue<T>> streamItems() {
assert contentValues.size() == internalQueue.size();
return internalQueue.stream();
}
public <U extends T> FloatPriorityQueueView<U> view() {
return new FloatPriorityQueueView<>(this);
}
private static class SynchronizedFloatPriorityQueue<T> extends FloatPriorityQueue<T> {
public SynchronizedFloatPriorityQueue(Object2FloatMap<T> contentValues, Queue<ScoredValue<T>> internalQueue) {
super(contentValues, internalQueue);
}
@Override
public synchronized int size() {
return super.size();
}
@Override
public synchronized boolean isEmpty() {
return super.isEmpty();
}
@Override
public synchronized boolean contains(Object o) {
return super.contains(o);
}
@Override
public synchronized @NotNull Iterator<T> iterator() {
var it = super.iterator();
return new Iterator<>() {
@Override
public boolean hasNext() {
synchronized (SynchronizedFloatPriorityQueue.this) {
return it.hasNext();
}
}
@Override
public T next() {
synchronized (SynchronizedFloatPriorityQueue.this) {
return it.next();
}
}
};
}
@Override
public synchronized @NotNull Object @NotNull [] toArray() {
return super.toArray();
}
@SuppressWarnings("SuspiciousToArrayCall")
@Override
public synchronized <T1> @NotNull T1 @NotNull [] toArray(@NotNull T1 @NotNull [] a) {
return super.toArray(a);
}
@Override
public synchronized <T1> T1[] toArray(IntFunction<T1[]> generator) {
//noinspection SuspiciousToArrayCall
return super.toArray(generator);
}
@Override
public synchronized boolean add(T t, float score) {
return super.add(t, score);
}
@Override
public synchronized boolean addTop(T t) {
return super.addTop(t);
}
@Deprecated
@Override
public synchronized boolean add(T t) {
return super.add(t);
}
@Override
public synchronized boolean addAll(@NotNull Collection<? extends T> c) {
return super.addAll(c);
}
@Override
public synchronized boolean remove(Object o) {
return super.remove(o);
}
@Override
public synchronized boolean removeAll(@NotNull Collection<?> c) {
return super.removeAll(c);
}
@Override
public synchronized boolean containsAll(@NotNull Collection<?> c) {
return super.containsAll(c);
}
@Override
public synchronized boolean retainAll(@NotNull Collection<?> c) {
return super.retainAll(c);
}
@Override
public synchronized void clear() {
super.clear();
}
@Override
public synchronized boolean offer(T t) {
return super.offer(t);
}
@Override
public synchronized boolean offer(T t, float score) {
return super.offer(t, score);
}
@Override
public synchronized boolean offer(ScoredValue<T> value) {
return super.offer(value);
}
@Override
public synchronized T remove() {
return super.remove();
}
@Override
public synchronized T poll() {
return super.poll();
}
@Override
public synchronized T element() {
return super.element();
}
@Override
public synchronized T peek() {
return super.peek();
}
@Override
public synchronized void forEachItem(Consumer<ScoredValue<T>> action) {
super.forEachItem(action);
}
@Override
public synchronized Stream<ScoredValue<T>> streamItems() {
return super.streamItems();
}
}
}