Fix LMDB PriorityQueue

This commit is contained in:
Andrea Cavalli 2021-10-14 15:55:58 +02:00
parent d6b6a211a8
commit e1d1e1fb05
13 changed files with 765 additions and 88 deletions

View File

@ -24,12 +24,7 @@ public class EmptyPriorityQueue<T> implements PriorityQueue<T> {
}
@Override
public void updateTop() {
}
@Override
public void updateTop(T newTop) {
public void replaceTop(T newTop) {
assert newTop == null;
}

View File

@ -4,11 +4,12 @@ import static org.lmdbjava.DbiFlags.*;
import io.net5.buffer.ByteBuf;
import io.net5.buffer.PooledByteBufAllocator;
import io.net5.buffer.Unpooled;
import it.cavallium.dbengine.database.LLUtils;
import it.cavallium.dbengine.database.disk.LLTempLMDBEnv;
import java.io.IOException;
import java.util.Iterator;
import java.util.Objects;
import java.util.StringJoiner;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.jetbrains.annotations.NotNull;
@ -17,7 +18,7 @@ import org.lmdbjava.CursorIterable;
import org.lmdbjava.CursorIterable.KeyVal;
import org.lmdbjava.Dbi;
import org.lmdbjava.Env;
import org.lmdbjava.PutFlags;
import org.lmdbjava.GetOp;
import org.lmdbjava.Txn;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Scheduler;
@ -32,7 +33,7 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
private static final boolean FORCE_THREAD_LOCAL = true;
private static final AtomicLong NEXT_LMDB_QUEUE_ID = new AtomicLong(0);
private static final ByteBuf EMPTY = Unpooled.directBuffer(1, 1).writeByte(1).asReadOnly();
private static final AtomicLong NEXT_ITEM_UID = new AtomicLong(0);
private final AtomicBoolean closed = new AtomicBoolean();
private final Runnable onClose;
@ -57,7 +58,7 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
var name = "$queue_" + NEXT_LMDB_QUEUE_ID.getAndIncrement();
this.codec = codec;
this.env = env.getEnvAndIncrementRef();
this.lmdb = this.env.openDbi(name, codec::compareDirect, MDB_CREATE);
this.lmdb = this.env.openDbi(name, codec::compareDirect, MDB_CREATE, MDB_DUPSORT, MDB_DUPFIXED);
this.writing = true;
this.iterating = false;
@ -153,12 +154,14 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
}
private static void ensureThread() {
LLUtils.ensureBlocking();
}
private static void ensureItThread() {
if (!(Thread.currentThread() instanceof LMDBThread)) {
throw new IllegalStateException("Must run in LMDB scheduler");
}
ensureThread();
//if (!(Thread.currentThread() instanceof LMDBThread)) {
// throw new IllegalStateException("Must run in LMDB scheduler");
//}
}
@Override
@ -166,8 +169,10 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
ensureThread();
switchToMode(true, false);
var buf = codec.serialize(this::allocate, element);
var uid = allocate(Long.BYTES);
uid.writeLong(NEXT_ITEM_UID.getAndIncrement());
try {
if (lmdb.put(rwTxn, buf, EMPTY, PutFlags.MDB_NOOVERWRITE)) {
if (lmdb.put(rwTxn, buf, uid)) {
if (++size == 1) {
topValid = true;
top = element;
@ -230,10 +235,13 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
top = null;
} else {
topValid = false;
top = null;
}
cur.delete();
return data;
} else {
topValid = true;
top = null;
return null;
}
} finally {
@ -242,14 +250,10 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
}
@Override
public void updateTop() {
// do nothing
}
@Override
public void updateTop(T newTop) {
public void replaceTop(T newTop) {
ensureThread();
assert codec.compare(newTop, databaseTop()) == 0;
this.pop();
this.add(newTop);
}
@Override
@ -276,11 +280,12 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
public boolean remove(@NotNull T element) {
ensureThread();
Objects.requireNonNull(element);
switchToMode(true, false);
switchToMode(true, true);
var buf = codec.serialize(this::allocate, element);
try {
var deleted = lmdb.delete(rwTxn, buf);
if (deleted) {
var deletable = cur.get(buf, GetOp.MDB_SET);
if (deletable) {
cur.delete();
if (topValid && codec.compare(top, element) == 0) {
if (--size == 0) {
top = null;
@ -294,7 +299,7 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
}
}
}
return deleted;
return deletable;
} finally {
endMode();
}
@ -315,6 +320,7 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
var it = cit.iterator();
return Tuples.of(cit, it);
}, (t, sink) -> {
try {
ensureItThread();
var it = t.getT2();
if (it.hasNext()) {
@ -323,14 +329,17 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
sink.complete();
}
return t;
} catch (Throwable ex) {
sink.error(ex);
return t;
}
}, t -> {
ensureItThread();
var cit = t.getT1();
cit.close();
iterating = false;
endMode();
})
.subscribeOn(scheduler, false);
});
}
@Override
@ -400,4 +409,10 @@ public class LMDBPriorityQueue<T> implements PriorityQueue<T> {
return scheduler;
}
@Override
public String toString() {
return new StringJoiner(", ", LMDBPriorityQueue.class.getSimpleName() + "[", "]")
.add("size=" + size)
.toString();
}
}

View File

@ -1,7 +1,6 @@
package it.cavallium.dbengine.lucene;
import java.io.Closeable;
import java.util.Iterator;
public interface PriorityQueue<T> extends ResourceIterable<T>, Closeable {
@ -22,28 +21,9 @@ public interface PriorityQueue<T> extends ResourceIterable<T>, Closeable {
T pop();
/**
* Should be called when the Object at top changes values. Still log(n) worst case, but it's at least twice as fast
* to
*
* <pre class="prettyprint">
* pq.top().change();
* pq.updateTop();
* </pre>
* <p>
* instead of
*
* <pre class="prettyprint">
* o = pq.pop();
* o.change();
* pq.push(o);
* </pre>
* Replace the top of the pq with {@code newTop}
*/
void updateTop();
/**
* Replace the top of the pq with {@code newTop} and run {@link #updateTop()}.
*/
void updateTop(T newTop);
void replaceTop(T newTop);
/**
* Returns the number of elements currently stored in the PriorityQueue.

View File

@ -16,11 +16,13 @@
*/
package it.cavallium.dbengine.lucene.collector;
import it.cavallium.dbengine.database.LLUtils;
import it.cavallium.dbengine.database.disk.LLTempLMDBEnv;
import it.cavallium.dbengine.lucene.FullDocs;
import it.cavallium.dbengine.lucene.LLScoreDoc;
import it.cavallium.dbengine.lucene.LLScoreDocCodec;
import it.cavallium.dbengine.lucene.LMDBPriorityQueue;
import it.cavallium.dbengine.lucene.LuceneUtils;
import it.cavallium.dbengine.lucene.MaxScoreAccumulator;
import java.io.IOException;
import java.util.Collection;
@ -33,6 +35,8 @@ import it.cavallium.dbengine.lucene.MaxScoreAccumulator.DocAndScore;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TotalHits;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
/**
* A {@link Collector} implementation that collects the top-scoring hits, returning them as a {@link
@ -59,9 +63,9 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
private static class SimpleLMDBFullScoreDocCollector extends LMDBFullScoreDocCollector {
SimpleLMDBFullScoreDocCollector(LLTempLMDBEnv env,
SimpleLMDBFullScoreDocCollector(LLTempLMDBEnv env, @Nullable Long limit,
HitsThresholdChecker hitsThresholdChecker, MaxScoreAccumulator minScoreAcc) {
super(env, hitsThresholdChecker, minScoreAcc);
super(env, limit, hitsThresholdChecker, minScoreAcc);
}
@Override
@ -94,9 +98,10 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
updateGlobalMinCompetitiveScore(scorer);
}
// If there is a limit, and it's reached, use the replacement logic
if (limit != null && pq.size() >= limit) {
var pqTop = pq.top();
if (pqTop != null) {
if (score <= pqTop.score()) {
if (pqTop != null && score <= pqTop.score()) {
if (totalHitsRelation == TotalHits.Relation.EQUAL_TO) {
// we just reached totalHitsThreshold, we can start setting the min
// competitive score now
@ -106,10 +111,17 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
// with equal score to pqTop.score cannot compete since HitQueue favors
// documents with lower doc Ids. Therefore reject those docs too.
return;
} else {
// Remove the top element, then add the following element
pq.replaceTop(new LLScoreDoc(doc + docBase, score, -1));
// The minimum competitive score will be updated later
}
}
} else {
// There is no limit or the limit has not been reached. Add the document to the queue
pq.add(new LLScoreDoc(doc + docBase, score, -1));
pq.updateTop();
// The minimum competitive score will be updated later
}
// Update the minimum competitive score
updateMinCompetitiveScore(scorer);
}
};
@ -129,6 +141,16 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
* <p><b>NOTE</b>: The instances returned by this method pre-allocate a full array of length
* <code>numHits</code>, and fill the array with sentinel objects.
*/
public static LMDBFullScoreDocCollector create(LLTempLMDBEnv env, long numHits, int totalHitsThreshold) {
return create(env, numHits, HitsThresholdChecker.create(totalHitsThreshold), null);
}
/**
* Creates a new {@link LMDBFullScoreDocCollector} given the number of hits to count accurately.
*
* <p><b>NOTE</b>: A value of {@link Integer#MAX_VALUE} will make the hit count accurate
* but will also likely make query processing slower.
*/
public static LMDBFullScoreDocCollector create(LLTempLMDBEnv env, int totalHitsThreshold) {
return create(env, HitsThresholdChecker.create(totalHitsThreshold), null);
}
@ -142,13 +164,56 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
throw new IllegalArgumentException("hitsThresholdChecker must be non null");
}
return new SimpleLMDBFullScoreDocCollector(env, hitsThresholdChecker, minScoreAcc);
return new SimpleLMDBFullScoreDocCollector(env, null, hitsThresholdChecker, minScoreAcc);
}
static LMDBFullScoreDocCollector create(
LLTempLMDBEnv env,
@NotNull Long numHits,
HitsThresholdChecker hitsThresholdChecker,
MaxScoreAccumulator minScoreAcc) {
if (hitsThresholdChecker == null) {
throw new IllegalArgumentException("hitsThresholdChecker must be non null");
}
return new SimpleLMDBFullScoreDocCollector(env,
(numHits < 0 || numHits >= 2147483630L) ? null : numHits,
hitsThresholdChecker,
minScoreAcc
);
}
/**
* Create a CollectorManager which uses a shared hit counter to maintain number of hits and a
* shared {@link MaxScoreAccumulator} to propagate the minimum score accross segments
*/
public static CollectorManager<LMDBFullScoreDocCollector, FullDocs<LLScoreDoc>> createSharedManager(
LLTempLMDBEnv env,
long numHits,
int totalHitsThreshold) {
return new CollectorManager<>() {
private final HitsThresholdChecker hitsThresholdChecker =
HitsThresholdChecker.createShared(totalHitsThreshold);
private final MaxScoreAccumulator minScoreAcc = new MaxScoreAccumulator();
@Override
public LMDBFullScoreDocCollector newCollector() {
return LMDBFullScoreDocCollector.create(env, numHits, hitsThresholdChecker, minScoreAcc);
}
@Override
public FullDocs<LLScoreDoc> reduce(Collection<LMDBFullScoreDocCollector> collectors) {
return reduceShared(collectors);
}
};
}
/**
* Create a CollectorManager which uses a shared {@link MaxScoreAccumulator} to propagate
* the minimum score accross segments
*/
public static CollectorManager<LMDBFullScoreDocCollector, FullDocs<LLScoreDoc>> createSharedManager(
LLTempLMDBEnv env,
int totalHitsThreshold) {
@ -165,6 +230,12 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
@Override
public FullDocs<LLScoreDoc> reduce(Collection<LMDBFullScoreDocCollector> collectors) {
return reduceShared(collectors);
}
};
}
private static FullDocs<LLScoreDoc> reduceShared(Collection<LMDBFullScoreDocCollector> collectors) {
@SuppressWarnings("unchecked")
final FullDocs<LLScoreDoc>[] fullDocs = new FullDocs[collectors.size()];
int i = 0;
@ -173,20 +244,19 @@ public abstract class LMDBFullScoreDocCollector extends FullDocsCollector<LLScor
}
return FullDocs.merge(null, fullDocs);
}
};
}
int docBase;
final @Nullable Long limit;
final HitsThresholdChecker hitsThresholdChecker;
final MaxScoreAccumulator minScoreAcc;
float minCompetitiveScore;
// prevents instantiation
LMDBFullScoreDocCollector(LLTempLMDBEnv env,
LMDBFullScoreDocCollector(LLTempLMDBEnv env, @Nullable Long limit,
HitsThresholdChecker hitsThresholdChecker, MaxScoreAccumulator minScoreAcc) {
super(new LMDBPriorityQueue<>(env, new LLScoreDocCodec()));
assert hitsThresholdChecker != null;
this.limit = limit;
this.hitsThresholdChecker = hitsThresholdChecker;
this.minScoreAcc = minScoreAcc;
}

View File

@ -21,10 +21,10 @@ public class AdaptiveMultiSearcher implements MultiSearcher, Closeable {
private static final MultiSearcher unsortedUnscoredContinuous
= new UnsortedUnscoredStreamingMultiSearcher();
private final UnsortedScoredFullMultiSearcher scoredFull;
private final UnsortedScoredFullMultiSearcher unsortedScoredFull;
public AdaptiveMultiSearcher() throws IOException {
scoredFull = new UnsortedScoredFullMultiSearcher();
unsortedScoredFull = new UnsortedScoredFullMultiSearcher();
}
@Override
@ -54,10 +54,11 @@ public class AdaptiveMultiSearcher implements MultiSearcher, Closeable {
if (queryParams.limit() == 0) {
return count.collectMulti(indexSearchersMono, queryParams, keyFieldName, transformer);
} else if (queryParams.isSorted() || queryParams.needsScores()) {
if (queryParams.isSorted() || realLimit <= (long) queryParams.pageLimits().getPageLimit(0)) {
if ((queryParams.isSorted() && !queryParams.isSortedByScore())
|| realLimit <= (long) queryParams.pageLimits().getPageLimit(0)) {
return scoredSimple.collectMulti(indexSearchersMono, queryParams, keyFieldName, transformer);
} else {
return scoredFull.collectMulti(indexSearchersMono, queryParams, keyFieldName, transformer);
return unsortedScoredFull.collectMulti(indexSearchersMono, queryParams, keyFieldName, transformer);
}
} else if (realLimit <= (long) queryParams.pageLimits().getPageLimit(0)) {
// Run single-page searches using the paged multi searcher
@ -71,7 +72,7 @@ public class AdaptiveMultiSearcher implements MultiSearcher, Closeable {
@Override
public void close() throws IOException {
scoredFull.close();
unsortedScoredFull.close();
}
@Override

View File

@ -2,6 +2,7 @@ package it.cavallium.dbengine.lucene.searcher;
import it.cavallium.dbengine.lucene.LuceneUtils;
import it.cavallium.dbengine.lucene.PageLimits;
import java.util.Objects;
import java.util.Optional;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
@ -16,6 +17,10 @@ public record LocalQueryParams(@NotNull Query query, int offset, int limit, @Not
return sort != null;
}
public boolean isSortedByScore() {
return Objects.equals(sort, Sort.RELEVANCE);
}
public boolean needsScores() {
return sort != null && sort.needsScores();
}

View File

@ -140,10 +140,10 @@ public class ScoredPagedMultiSearcher implements MultiSearcher {
AtomicReference<CurrentPageInfo> currentPageInfoRef = new AtomicReference<>(secondPageInfo);
return Mono
.fromSupplier(currentPageInfoRef::get)
.doOnNext(s -> logger.debug("Current page info: {}", s))
.doOnNext(s -> logger.trace("Current page info: {}", s))
.flatMap(currentPageInfo -> this.searchPage(queryParams, indexSearchers, true,
queryParams.pageLimits(), 0, currentPageInfo))
.doOnNext(s -> logger.debug("Next page info: {}", s.nextPageInfo()))
.doOnNext(s -> logger.trace("Next page info: {}", s.nextPageInfo()))
.doOnNext(s -> currentPageInfoRef.set(s.nextPageInfo()))
.repeatWhen(s -> s.takeWhile(n -> n > 0));
})

View File

@ -44,7 +44,7 @@ public class UnsortedScoredFullMultiSearcher implements MultiSearcher, Closeable
}
return queryParamsMono.flatMap(queryParams2 -> {
if (queryParams2.isSorted()) {
if (queryParams2.isSorted() && !queryParams2.isSortedByScore()) {
throw new IllegalArgumentException(UnsortedScoredFullMultiSearcher.this.getClass().getSimpleName()
+ " doesn't support sorted queries");
}
@ -70,7 +70,7 @@ public class UnsortedScoredFullMultiSearcher implements MultiSearcher, Closeable
.fromCallable(() -> {
LLUtils.ensureBlocking();
var totalHitsThreshold = queryParams.getTotalHitsThreshold();
return LMDBFullScoreDocCollector.createSharedManager(env, totalHitsThreshold);
return LMDBFullScoreDocCollector.createSharedManager(env, queryParams.limit(), totalHitsThreshold);
})
.flatMap(sharedManager -> Flux
.fromIterable(indexSearchers)

View File

@ -1,3 +1,3 @@
package it.cavallium.dbengine;
record ExpectedQueryType(boolean shard, boolean sorted, boolean complete, boolean onlyCount) {}
record ExpectedQueryType(boolean shard, boolean sorted, boolean sortedByScore, boolean complete, boolean onlyCount) {}

View File

@ -0,0 +1,77 @@
package it.cavallium.dbengine;
import it.cavallium.dbengine.lucene.PriorityQueue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.search.HitQueue;
import reactor.core.publisher.Flux;
public class PriorityQueueAdaptor<T> implements PriorityQueue<T> {
private final org.apache.lucene.util.PriorityQueue<T> hitQueue;
public PriorityQueueAdaptor(org.apache.lucene.util.PriorityQueue<T> hitQueue) {
this.hitQueue = hitQueue;
}
@Override
public void add(T element) {
hitQueue.add(element);
hitQueue.updateTop();
}
@Override
public T top() {
hitQueue.updateTop();
return hitQueue.top();
}
@Override
public T pop() {
var popped = hitQueue.pop();
hitQueue.updateTop();
return popped;
}
@Override
public void replaceTop(T newTop) {
hitQueue.updateTop(newTop);
}
@Override
public long size() {
return hitQueue.size();
}
@Override
public void clear() {
hitQueue.clear();
}
@Override
public boolean remove(T element) {
var removed = hitQueue.remove(element);
hitQueue.updateTop();
return removed;
}
@Override
public Flux<T> iterate() {
List<T> items = new ArrayList<>(hitQueue.size());
T item;
while ((item = hitQueue.pop()) != null) {
items.add(item);
}
for (T t : items) {
hitQueue.insertWithOverflow(t);
}
return Flux.fromIterable(items);
}
@Override
public void close() throws IOException {
hitQueue.clear();
hitQueue.updateTop();
}
}

View File

@ -0,0 +1,147 @@
package it.cavallium.dbengine;
import io.net5.buffer.ByteBuf;
import it.cavallium.dbengine.database.disk.LLTempLMDBEnv;
import it.cavallium.dbengine.lucene.LMDBCodec;
import it.cavallium.dbengine.lucene.LMDBPriorityQueue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.function.Function;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
public class TestLMDB {
private LLTempLMDBEnv env;
private LMDBPriorityQueue<Integer> queue;
@BeforeEach
public void beforeEach() throws IOException {
this.env = new LLTempLMDBEnv();
this.queue = new LMDBPriorityQueue<>(env, new LMDBCodec<Integer>() {
@Override
public ByteBuf serialize(Function<Integer, ByteBuf> allocator, Integer data) {
return allocator.apply(Integer.BYTES).writeInt(data).asReadOnly();
}
@Override
public Integer deserialize(ByteBuf b) {
return b.getInt(0);
}
@Override
public int compare(Integer o1, Integer o2) {
return Integer.compare(o1, o2);
}
@Override
public int compareDirect(ByteBuf o1, ByteBuf o2) {
return Integer.compare(o1.getInt(0), o2.getInt(0));
}
});
}
@Test
public void testNoOp() {
}
@Test
public void testEmptyTop() {
Assertions.assertNull(queue.top());
}
@Test
public void testAddSingle() {
queue.add(2);
Assertions.assertEquals(2, queue.top());
}
@Test
public void testAddSame() {
queue.add(2);
queue.add(2);
Assertions.assertEquals(2, queue.top());
Assertions.assertEquals(2, queue.size());
}
@Test
public void testAddMulti() {
for (int i = 0; i < 1000; i++) {
queue.add(i);
}
Assertions.assertEquals(0, queue.top());
}
@Test
public void testAddMultiClear() {
for (int i = 0; i < 1000; i++) {
queue.add(i);
}
queue.clear();
Assertions.assertNull(queue.top());
}
@Test
public void testAddRemove() {
queue.add(0);
queue.remove(0);
Assertions.assertNull(queue.top());
}
@Test
public void testAddRemoveNonexistent() {
queue.add(0);
queue.remove(1);
Assertions.assertEquals(0, queue.top());
}
@Test
public void testAddMultiSameRemove() {
queue.add(0);
queue.add(0);
queue.add(1);
queue.remove(0);
Assertions.assertEquals(2, queue.size());
Assertions.assertEquals(0, queue.top());
}
@Test
public void testAddMultiRemove() {
for (int i = 0; i < 1000; i++) {
queue.add(i);
}
queue.remove(0);
Assertions.assertEquals(1, queue.top());
}
@Test
public void testSort() {
var sortedNumbers = new ArrayList<Integer>();
for (int i = 0; i < 1000; i++) {
sortedNumbers.add(i);
}
var shuffledNumbers = new ArrayList<>(sortedNumbers);
Collections.shuffle(shuffledNumbers);
for (Integer number : shuffledNumbers) {
queue.add(number);
}
var newSortedNumbers = new ArrayList<>();
Integer popped;
while ((popped = queue.pop()) != null) {
newSortedNumbers.add(popped);
}
Assertions.assertEquals(sortedNumbers, newSortedNumbers);
}
@AfterEach
public void afterEach() throws IOException {
queue.close();
env.close();
}
}

View File

@ -0,0 +1,371 @@
package it.cavallium.dbengine;
import com.google.common.collect.Lists;
import io.net5.buffer.ByteBuf;
import it.cavallium.dbengine.database.disk.LLTempLMDBEnv;
import it.cavallium.dbengine.lucene.LLScoreDoc;
import it.cavallium.dbengine.lucene.LMDBCodec;
import it.cavallium.dbengine.lucene.LMDBPriorityQueue;
import it.cavallium.dbengine.lucene.PriorityQueue;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.function.Function;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
public class TestLMDBHitQueue {
public static final int NUM_HITS = 1024;
private LLTempLMDBEnv env;
private Closeable lmdbQueue;
private TestingPriorityQueue testingPriorityQueue;
protected static boolean lessThan(ScoreDoc hitA, ScoreDoc hitB) {
if (hitA.score == hitB.score) {
return hitA.doc > hitB.doc;
} else {
return hitA.score < hitB.score;
}
}
private static int compareScoreDoc(ScoreDoc hitA, ScoreDoc hitB) {
if (hitA.score == hitB.score) {
if (hitA.doc == hitB.doc) {
return Integer.compare(hitA.shardIndex, hitB.shardIndex);
} else {
return Integer.compare(hitB.doc, hitA.doc);
}
} else {
return Float.compare(hitA.score, hitB.score);
}
}
private static void assertEqualsScoreDoc(ScoreDoc expected, ScoreDoc actual) {
Assertions.assertEquals(toLLScoreDoc(expected), toLLScoreDoc(actual));
}
private static void assertEqualsScoreDoc(List<ScoreDoc> expected, List<ScoreDoc> actual) {
Assertions.assertEquals(expected.size(), actual.size());
var list1 = expected.iterator();
var list2 = actual.iterator();
while (list1.hasNext() && list2.hasNext()) {
Assertions.assertFalse(lessThan(list1.next(), list2.next()));
}
}
@BeforeEach
public void beforeEach() throws IOException {
this.env = new LLTempLMDBEnv();
var lmdbQueue = new LMDBPriorityQueue<ScoreDoc>(env, new LMDBCodec<>() {
@Override
public ByteBuf serialize(Function<Integer, ByteBuf> allocator, ScoreDoc data) {
var buf = allocator.apply(Float.BYTES + Integer.BYTES + Integer.BYTES);
setScore(buf, data.score);
setDoc(buf, data.doc);
setShardIndex(buf, data.shardIndex);
buf.writerIndex(Float.BYTES + Integer.BYTES + Integer.BYTES);
return buf.asReadOnly();
}
@Override
public ScoreDoc deserialize(ByteBuf buf) {
return new ScoreDoc(getDoc(buf), getScore(buf), getShardIndex(buf));
}
@Override
public int compare(ScoreDoc hitA, ScoreDoc hitB) {
return compareScoreDoc(hitA, hitB);
}
@Override
public int compareDirect(ByteBuf hitA, ByteBuf hitB) {
var scoreA = getScore(hitA);
var scoreB = getScore(hitB);
if (scoreA == scoreB) {
var docA = getDoc(hitA);
var docB = getDoc(hitB);
if (docA == docB) {
return Integer.compare(getShardIndex(hitA), getShardIndex(hitB));
} else {
return Integer.compare(docB, docA);
}
} else {
return Float.compare(scoreA, scoreB);
}
}
private static float getScore(ByteBuf hit) {
return hit.getFloat(0);
}
private static int getDoc(ByteBuf hit) {
return hit.getInt(Float.BYTES);
}
private static int getShardIndex(ByteBuf hit) {
return hit.getInt(Float.BYTES + Integer.BYTES);
}
private static void setScore(ByteBuf hit, float score) {
hit.setFloat(0, score);
}
private static void setDoc(ByteBuf hit, int doc) {
hit.setInt(Float.BYTES, doc);
}
private static void setShardIndex(ByteBuf hit, int shardIndex) {
hit.setInt(Float.BYTES + Integer.BYTES, shardIndex);
}
});
this.lmdbQueue = lmdbQueue;
PriorityQueueAdaptor<ScoreDoc> hitQueue = new PriorityQueueAdaptor<>(new HitQueue(NUM_HITS, false));
Assertions.assertEquals(0, lmdbQueue.size());
Assertions.assertEquals(0, hitQueue.size());
this.testingPriorityQueue = new TestingPriorityQueue(hitQueue, lmdbQueue);
}
@Test
public void testNoOp() {
}
@Test
public void testEmptyTop() {
Assertions.assertNull(testingPriorityQueue.top());
}
@Test
public void testAddSingle() {
var item = new ScoreDoc(0, 0, 0);
testingPriorityQueue.add(item);
assertEqualsScoreDoc(item, testingPriorityQueue.top());
}
@Test
public void testAddMulti() {
for (int i = 0; i < 1000; i++) {
var item = new ScoreDoc(i, i >> 1, -1);
testingPriorityQueue.addUnsafe(item);
}
assertEqualsScoreDoc(new ScoreDoc(1, 0, -1), testingPriorityQueue.top());
}
@Test
public void testAddMultiClear() {
for (int i = 0; i < 1000; i++) {
var item = new ScoreDoc(i, i >> 1, -1);
testingPriorityQueue.addUnsafe(item);
}
testingPriorityQueue.clear();
Assertions.assertNull(testingPriorityQueue.top());
}
@Test
public void testAddRemove() {
var item = new ScoreDoc(0, 0, -1);
testingPriorityQueue.add(item);
testingPriorityQueue.remove(item);
Assertions.assertNull(testingPriorityQueue.top());
}
@Test
public void testAddRemoveNonexistent() {
var item = new ScoreDoc(0, 0, 0);
testingPriorityQueue.addUnsafe(item);
testingPriorityQueue.remove(new ScoreDoc(2, 0, 0));
assertEqualsScoreDoc(item, testingPriorityQueue.top());
}
@Test
public void testAddMultiRemove1() {
ScoreDoc toRemove = null;
ScoreDoc top = null;
for (int i = 0; i < 1000; i++) {
var item = new ScoreDoc(i, i >> 1, -1);
if (i == 1) {
toRemove = item;
} else if (i == 0) {
top = item;
}
testingPriorityQueue.addUnsafe(item);
}
testingPriorityQueue.removeUnsafe(toRemove);
assertEqualsScoreDoc(top, testingPriorityQueue.top());
}
@Test
public void testAddMultiRemove2() {
ScoreDoc toRemove = null;
ScoreDoc top = null;
for (int i = 0; i < 1000; i++) {
var item = new ScoreDoc(i, i >> 1, -1);
if (i == 0) {
toRemove = item;
} else if (i == 1) {
top = item;
}
testingPriorityQueue.addUnsafe(item);
}
testingPriorityQueue.removeUnsafe(new ScoreDoc(0, 0, -1));
assertEqualsScoreDoc(top, testingPriorityQueue.top());
}
@Test
public void testSort() {
var sortedNumbers = new ArrayList<ScoreDoc>();
for (int i = 0; i < 1000; i++) {
sortedNumbers.add(new ScoreDoc(i, i >> 1, -1));
}
sortedNumbers.sort(TestLMDBHitQueue::compareScoreDoc);
var shuffledNumbers = new ArrayList<>(sortedNumbers);
Collections.shuffle(shuffledNumbers, new Random(1000));
for (ScoreDoc scoreDoc : shuffledNumbers) {
testingPriorityQueue.addUnsafe(scoreDoc);
}
var newSortedNumbers = new ArrayList<ScoreDoc>();
ScoreDoc popped;
while ((popped = testingPriorityQueue.popUnsafe()) != null) {
newSortedNumbers.add(popped);
}
assertEqualsScoreDoc(sortedNumbers, newSortedNumbers);
}
@AfterEach
public void afterEach() throws IOException {
lmdbQueue.close();
env.close();
}
private static class TestingPriorityQueue implements PriorityQueue<ScoreDoc> {
private final PriorityQueue<ScoreDoc> referenceQueue;
private final PriorityQueue<ScoreDoc> testQueue;
public TestingPriorityQueue(PriorityQueue<ScoreDoc> referenceQueue, PriorityQueue<ScoreDoc> testQueue) {
this.referenceQueue = referenceQueue;
this.testQueue = testQueue;
}
@Override
public void add(ScoreDoc element) {
referenceQueue.add(element);
testQueue.add(element);
ensureEquality();
}
public void addUnsafe(ScoreDoc element) {
referenceQueue.add(element);
testQueue.add(element);
}
@Override
public ScoreDoc top() {
var top1 = referenceQueue.top();
var top2 = testQueue.top();
assertEqualsScoreDoc(top1, top2);
return top2;
}
public ScoreDoc topUnsafe() {
var top1 = referenceQueue.top();
var top2 = testQueue.top();
return top2;
}
@Override
public ScoreDoc pop() {
var top1 = referenceQueue.pop();
var top2 = testQueue.pop();
assertEqualsScoreDoc(top1, top2);
return top2;
}
public ScoreDoc popUnsafe() {
var top1 = referenceQueue.pop();
var top2 = testQueue.pop();
return top2;
}
@Override
public void replaceTop(ScoreDoc newTop) {
referenceQueue.replaceTop(newTop);
testQueue.replaceTop(newTop);
}
@Override
public long size() {
var size1 = referenceQueue.size();
var size2 = testQueue.size();
Assertions.assertEquals(size1, size2);
return size2;
}
@Override
public void clear() {
referenceQueue.clear();
testQueue.clear();
}
@Override
public boolean remove(ScoreDoc element) {
var removed1 = referenceQueue.remove(element);
var removed2 = testQueue.remove(element);
Assertions.assertEquals(removed1, removed2);
return removed2;
}
public boolean removeUnsafe(ScoreDoc element) {
var removed1 = referenceQueue.remove(element);
var removed2 = testQueue.remove(element);
return removed2;
}
@Override
public Flux<ScoreDoc> iterate() {
//noinspection BlockingMethodInNonBlockingContext
var it1 = referenceQueue.iterate().collectList().blockOptional().orElseThrow();
//noinspection BlockingMethodInNonBlockingContext
var it2 = testQueue.iterate().collectList().blockOptional().orElseThrow();
assertEqualsScoreDoc(it1, it2);
return Flux.fromIterable(it2);
}
@Override
public void close() throws IOException {
}
private void ensureEquality() {
Assertions.assertEquals(referenceQueue.size(), testQueue.size());
var referenceQueueElements = Lists.newArrayList(referenceQueue
.iterate()
.map(TestLMDBHitQueue::toLLScoreDoc)
.toIterable());
var testQueueElements = Lists.newArrayList(testQueue
.iterate()
.map(TestLMDBHitQueue::toLLScoreDoc)
.toIterable());
Assertions.assertEquals(referenceQueueElements, testQueueElements);
}
}
public static LLScoreDoc toLLScoreDoc(ScoreDoc scoreDoc) {
if (scoreDoc == null) return null;
return new LLScoreDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
}
}

View File

@ -13,12 +13,14 @@ import it.cavallium.dbengine.client.LuceneIndex;
import it.cavallium.dbengine.client.MultiSort;
import it.cavallium.dbengine.client.SearchResultKey;
import it.cavallium.dbengine.client.SearchResultKeys;
import it.cavallium.dbengine.client.query.BasicType;
import it.cavallium.dbengine.client.query.ClientQueryParams;
import it.cavallium.dbengine.client.query.ClientQueryParamsBuilder;
import it.cavallium.dbengine.client.query.QueryParser;
import it.cavallium.dbengine.client.query.current.data.MatchAllDocsQuery;
import it.cavallium.dbengine.client.query.current.data.MatchNoDocsQuery;
import it.cavallium.dbengine.client.query.current.data.NoSort;
import it.cavallium.dbengine.client.query.current.data.ScoreSort;
import it.cavallium.dbengine.client.query.current.data.TotalHitsCount;
import it.cavallium.dbengine.database.LLLuceneIndex;
import it.cavallium.dbengine.database.LLScoreMode;
@ -146,7 +148,7 @@ public class TestLuceneSearches {
sink.next(new UnsortedUnscoredSimpleMultiSearcher(new CountLocalSearcher()));
} else {
sink.next(new ScoredPagedMultiSearcher());
if (!info.sorted()) {
if (!info.sorted() || info.sortedByScore()) {
sink.next(new UnsortedScoredFullMultiSearcher());
}
if (!info.sorted()) {
@ -247,7 +249,9 @@ public class TestLuceneSearches {
@ParameterizedTest
@MethodSource("provideQueryArgumentsScoreModeAndSort")
public void testSearchNoDocs(boolean shards, MultiSort<SearchResultKey<String>> multiSort) {
runSearchers(new ExpectedQueryType(shards, multiSort.isSorted(), true, false), searcher -> {
var sorted = multiSort.isSorted();
var sortedByScore = multiSort.getQuerySort().getBasicType$() == BasicType.ScoreSort;
runSearchers(new ExpectedQueryType(shards, sorted, sortedByScore, true, false), searcher -> {
var luceneIndex = getLuceneIndex(shards, searcher);
ClientQueryParamsBuilder<SearchResultKey<String>> queryBuilder = ClientQueryParams.builder();
queryBuilder.query(new MatchNoDocsQuery());
@ -269,7 +273,8 @@ public class TestLuceneSearches {
@MethodSource("provideQueryArgumentsScoreModeAndSort")
public void testSearchAllDocs(boolean shards, MultiSort<SearchResultKey<String>> multiSort) {
var sorted = multiSort.isSorted();
runSearchers(new ExpectedQueryType(shards, sorted, true, false), (LocalSearcher searcher) -> {
var sortedByScore = multiSort.getQuerySort().getBasicType$() == BasicType.ScoreSort;
runSearchers(new ExpectedQueryType(shards, sorted, sortedByScore, true, false), (LocalSearcher searcher) -> {
var luceneIndex = getLuceneIndex(shards, searcher);
ClientQueryParamsBuilder<SearchResultKey<String>> queryBuilder = ClientQueryParams.builder();
queryBuilder.query(new MatchAllDocsQuery());
@ -282,12 +287,23 @@ public class TestLuceneSearches {
assertHitsIfPossible(ELEMENTS.size(), hits);
var keys = getResults(results);
assertResults(ELEMENTS.keySet().stream().toList(), keys, false);
assertResults(ELEMENTS.keySet().stream().toList(), keys, false, sortedByScore);
}
});
}
private void assertResults(List<String> expectedKeys, List<Scored> resultKeys, boolean sorted) {
private void assertResults(List<String> expectedKeys, List<Scored> resultKeys, boolean sorted, boolean sortedByScore) {
if (sortedByScore) {
float lastScore = Float.NEGATIVE_INFINITY;
for (Scored resultKey : resultKeys) {
if (!Float.isNaN(resultKey.score())) {
Assertions.assertTrue(resultKey.score() >= lastScore);
lastScore = resultKey.score();
}
}
}
if (!sorted) {
var results = resultKeys.stream().map(Scored::key).collect(Collectors.toSet());
Assertions.assertEquals(new HashSet<>(expectedKeys), results);