Sum the old priority with the new priority if an element is offered a second time

This commit is contained in:
Andrea Cavalli 2020-09-22 19:17:48 +02:00
parent f99ef5906e
commit b754302914

View File

@ -1,18 +1,20 @@
package org.warp.commonutils.type;
import com.google.common.collect.Queues;
import it.unimi.dsi.fastutil.objects.Object2FloatMap;
import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap;
import java.lang.reflect.Array;
import java.util.Collection;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.function.IntFunction;
import java.util.stream.Stream;
import org.jetbrains.annotations.NotNull;
public class FloatPriorityQueue<T> implements Queue<T> {
private final Object2FloatMap<T> contentValues;
private final Queue<ScoredValue<T>> internalQueue;
public static <T> FloatPriorityQueue<T> of() {
@ -40,62 +42,75 @@ public class FloatPriorityQueue<T> implements Queue<T> {
}
public FloatPriorityQueue(PriorityQueue<ScoredValue<T>> internalQueue) {
this.internalQueue = internalQueue;
this(internalQueue.size());
for (ScoredValue<T> tScoredValue : internalQueue) {
add(tScoredValue.getValue(), tScoredValue.getScore());
}
}
private FloatPriorityQueue(Queue<ScoredValue<T>> internalQueue) {
private FloatPriorityQueue(Object2FloatMap<T> contentValues, Queue<ScoredValue<T>> internalQueue) {
this.contentValues = contentValues;
this.internalQueue = internalQueue;
}
public FloatPriorityQueue() {
this.contentValues = new Object2FloatOpenHashMap<>();
internalQueue = new PriorityQueue<>();
}
public FloatPriorityQueue(int initialCapacity) {
this.contentValues = new Object2FloatOpenHashMap<>(initialCapacity);
internalQueue = new PriorityQueue<>(Math.max(1, initialCapacity));
}
public static <T> FloatPriorityQueue<T> synchronize(FloatPriorityQueue<T> queue) {
return new FloatPriorityQueue<T>(Queues.synchronizedQueue(queue.internalQueue));
return new SynchronizedFloatPriorityQueue<T>(queue.contentValues, queue.internalQueue);
}
public static <T> FloatPriorityQueue<T> synchronizedPq(int initialCapacity) {
return new FloatPriorityQueue<T>(Queues.synchronizedQueue(new PriorityQueue<>(Math.max(1, initialCapacity))));
var pq = new FloatPriorityQueue<T>(initialCapacity);
return new SynchronizedFloatPriorityQueue<T>(pq.contentValues, pq.internalQueue);
}
@Override
public int size() {
assert contentValues.size() == internalQueue.size();
return internalQueue.size();
}
@Override
public boolean isEmpty() {
assert contentValues.isEmpty();
return internalQueue.isEmpty();
}
@Override
public boolean contains(Object o) {
assert contentValues.size() == internalQueue.size();
return internalQueue.contains(ScoredValue.of(0, o));
}
@NotNull
@Override
public Iterator<T> iterator() {
assert contentValues.size() == internalQueue.size();
var it = internalQueue.iterator();
return new Iterator<T>() {
@Override
public boolean hasNext() {
assert contentValues.size() == internalQueue.size();
return it.hasNext();
}
@Override
public T next() {
assert contentValues.size() == internalQueue.size();
return getValueOrNull(it.next());
}
};
}
private T getValueOrNull(ScoredValue<T> scoredValue) {
private static <T> T getValueOrNull(ScoredValue<T> scoredValue) {
if (scoredValue == null) {
return null;
} else {
@ -106,100 +121,313 @@ public class FloatPriorityQueue<T> implements Queue<T> {
@NotNull
@Override
public Object[] toArray() {
return internalQueue.stream().map(this::getValueOrNull).toArray(Object[]::new);
assert contentValues.size() == internalQueue.size();
return internalQueue.stream().map(FloatPriorityQueue::getValueOrNull).toArray(Object[]::new);
}
@NotNull
@Override
public <T1> T1[] toArray(@NotNull T1[] a) {
assert contentValues.size() == internalQueue.size();
return internalQueue
.stream()
.map(this::getValueOrNull)
.map(FloatPriorityQueue::getValueOrNull)
.toArray(i -> (T1[]) Array.newInstance(a.getClass().getComponentType(), i));
}
@Deprecated
@Override
public boolean add(T t) {
return internalQueue.add(ScoredValue.of(0, t));
assert contentValues.size() == internalQueue.size();
return offer(t, 0);
}
public boolean addTop(T t) {
return internalQueue.add(ScoredValue.of(Integer.MAX_VALUE, t));
assert contentValues.size() == internalQueue.size();
if (contentValues.getFloat(t) == Integer.MAX_VALUE) {
return false;
} else {
if (contentValues.containsKey(t)) {
internalQueue.remove(ScoredValue.of(0, t));
}
contentValues.put(t, Integer.MAX_VALUE);
return internalQueue.add(ScoredValue.of(Integer.MAX_VALUE, t));
}
}
public boolean add(T t, float score) {
return internalQueue.add(ScoredValue.of(score, t));
assert contentValues.size() == internalQueue.size();
return offer(t, score);
}
@Override
public boolean remove(Object o) {
assert contentValues.size() == internalQueue.size();
contentValues.removeFloat(o);
return internalQueue.remove(ScoredValue.of(0, o));
}
@Override
public boolean containsAll(@NotNull Collection<?> c) {
return internalQueue.containsAll(c.stream().map(val -> ScoredValue.of(0, val)).collect(Collectors.toList()));
assert contentValues.size() == internalQueue.size();
for (Object o : c) {
//noinspection SuspiciousMethodCalls
if (!contentValues.containsKey(o)) {
return false;
}
}
return true;
}
@Override
public boolean addAll(@NotNull Collection<? extends T> c) {
return internalQueue.addAll(c.stream().map(val -> ScoredValue.of(0, (T) val)).collect(Collectors.toList()));
assert contentValues.size() == internalQueue.size();
boolean added = false;
for (T t : c) {
added |= add(t);
}
return added;
}
@Override
public boolean removeAll(@NotNull Collection<?> c) {
return internalQueue.removeAll(c.stream().map(val -> ScoredValue.of(0, val)).collect(Collectors.toList()));
assert contentValues.size() == internalQueue.size();
boolean removed = false;
for (Object o : c) {
removed |= remove(o);
}
return removed;
}
@Override
public boolean retainAll(@NotNull Collection<?> c) {
return internalQueue.retainAll(c.stream().map(val -> ScoredValue.of(0, val)).collect(Collectors.toList()));
assert contentValues.size() == internalQueue.size();
return removeIf(item -> !c.contains(item));
}
@Override
public void clear() {
assert contentValues.size() == internalQueue.size();
contentValues.clear();
internalQueue.clear();
}
@Override
public boolean offer(T t) {
assert contentValues.size() == internalQueue.size();
return offer(ScoredValue.of(0, t));
}
public boolean offer(T t, float score) {
assert contentValues.size() == internalQueue.size();
return offer(ScoredValue.of(score, t));
}
public boolean offer(ScoredValue<T> value) {
return this.internalQueue.offer(value);
assert contentValues.size() == internalQueue.size();
boolean added = true;
float oldValue;
if (contentValues.containsKey(value.getValue())) {
internalQueue.remove(value);
oldValue = contentValues.getFloat(value.getValue());
added = false;
} else {
oldValue = 0f;
}
contentValues.put(value.getValue(), oldValue + value.getScore());
internalQueue.add(value);
return added;
}
@Override
public T remove() {
return getValueOrNull(internalQueue.remove());
assert contentValues.size() == internalQueue.size();
var val = internalQueue.remove();
if (val != null) {
contentValues.removeFloat(val.getValue());
}
return getValueOrNull(val);
}
@Override
public T poll() {
return getValueOrNull(internalQueue.poll());
assert contentValues.size() == internalQueue.size();
var val = internalQueue.poll();
if (val != null) {
contentValues.removeFloat(val.getValue());
}
return getValueOrNull(val);
}
@Override
public T element() {
assert contentValues.size() == internalQueue.size();
return getValueOrNull(internalQueue.element());
}
@Override
public T peek() {
assert contentValues.size() == internalQueue.size();
return getValueOrNull(internalQueue.peek());
}
public void forEachItem(Consumer<ScoredValue<T>> action) {
assert contentValues.size() == internalQueue.size();
internalQueue.forEach(action);
}
public Stream<ScoredValue<T>> streamItems() {
assert contentValues.size() == internalQueue.size();
return internalQueue.stream();
}
private static class SynchronizedFloatPriorityQueue<T> extends FloatPriorityQueue<T> {
public SynchronizedFloatPriorityQueue(Object2FloatMap<T> contentValues, Queue<ScoredValue<T>> internalQueue) {
super(contentValues, internalQueue);
}
@Override
public synchronized int size() {
return super.size();
}
@Override
public synchronized boolean isEmpty() {
return super.isEmpty();
}
@Override
public synchronized boolean contains(Object o) {
return super.contains(o);
}
@Override
public synchronized @NotNull Iterator<T> iterator() {
var it = super.iterator();
return new Iterator<T>() {
@Override
public boolean hasNext() {
synchronized (SynchronizedFloatPriorityQueue.this) {
return it.hasNext();
}
}
@Override
public T next() {
synchronized (SynchronizedFloatPriorityQueue.this) {
return it.next();
}
}
};
}
@Override
public synchronized @NotNull Object[] toArray() {
return super.toArray();
}
@Override
public synchronized <T1> @NotNull T1[] toArray(@NotNull T1[] a) {
return super.toArray(a);
}
@Override
public synchronized <T1> T1[] toArray(IntFunction<T1[]> generator) {
//noinspection SuspiciousToArrayCall
return super.toArray(generator);
}
@Override
public synchronized boolean add(T t, float score) {
return super.add(t, score);
}
@Override
public synchronized boolean addTop(T t) {
return super.addTop(t);
}
@Override
public synchronized boolean add(T t) {
return super.add(t);
}
@Override
public synchronized boolean addAll(@NotNull Collection<? extends T> c) {
return super.addAll(c);
}
@Override
public synchronized boolean remove(Object o) {
return super.remove(o);
}
@Override
public synchronized boolean removeAll(@NotNull Collection<?> c) {
return super.removeAll(c);
}
@Override
public synchronized boolean containsAll(@NotNull Collection<?> c) {
return super.containsAll(c);
}
@Override
public synchronized boolean retainAll(@NotNull Collection<?> c) {
return super.retainAll(c);
}
@Override
public synchronized void clear() {
super.clear();
}
@Override
public synchronized boolean offer(T t) {
return super.offer(t);
}
@Override
public synchronized boolean offer(T t, float score) {
return super.offer(t, score);
}
@Override
public synchronized boolean offer(ScoredValue<T> value) {
return super.offer(value);
}
@Override
public synchronized T remove() {
return super.remove();
}
@Override
public synchronized T poll() {
return super.poll();
}
@Override
public synchronized T element() {
return super.element();
}
@Override
public synchronized T peek() {
return super.peek();
}
@Override
public synchronized void forEachItem(Consumer<ScoredValue<T>> action) {
super.forEachItem(action);
}
@Override
public synchronized Stream<ScoredValue<T>> streamItems() {
return super.streamItems();
}
}
}