Test lucene generator
This commit is contained in:
parent
2eb4a84afa
commit
f478ea97cd
12
pom.xml
12
pom.xml
@ -157,7 +157,7 @@
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.7.32</version>
|
||||
<version>1.7.35</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
@ -213,6 +213,11 @@
|
||||
<groupId>org.apache.lucene</groupId>
|
||||
<artifactId>lucene-facet</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.lucene</groupId>
|
||||
<artifactId>lucene-test-framework</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains</groupId>
|
||||
<artifactId>annotations</artifactId>
|
||||
@ -394,6 +399,11 @@
|
||||
<artifactId>lucene-facet</artifactId>
|
||||
<version>9.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.lucene</groupId>
|
||||
<artifactId>lucene-test-framework</artifactId>
|
||||
<version>9.0.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains</groupId>
|
||||
<artifactId>annotations</artifactId>
|
||||
|
@ -441,12 +441,12 @@ public class LuceneUtils {
|
||||
return result;
|
||||
}
|
||||
|
||||
public static int totalHitsThreshold(boolean complete) {
|
||||
return complete ? Integer.MAX_VALUE : 1;
|
||||
public static int totalHitsThreshold(@Nullable Boolean complete) {
|
||||
return complete == null || complete ? Integer.MAX_VALUE : 1;
|
||||
}
|
||||
|
||||
public static long totalHitsThresholdLong(boolean complete) {
|
||||
return complete ? Long.MAX_VALUE : 1;
|
||||
public static long totalHitsThresholdLong(@Nullable Boolean complete) {
|
||||
return complete == null || complete ? Long.MAX_VALUE : 1;
|
||||
}
|
||||
|
||||
public static TotalHitsCount convertTotalHitsCount(TotalHits totalHits) {
|
||||
|
@ -12,15 +12,28 @@ import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public record LocalQueryParams(@NotNull Query query, int offsetInt, long offsetLong, int limitInt, long limitLong,
|
||||
@NotNull PageLimits pageLimits, @Nullable Sort sort, boolean computePreciseHitsCount,
|
||||
Duration timeout) {
|
||||
@NotNull PageLimits pageLimits, @Nullable Sort sort,
|
||||
@Nullable Boolean computePreciseHitsCount, Duration timeout) {
|
||||
|
||||
public LocalQueryParams {
|
||||
if ((sort == null) != (computePreciseHitsCount == null)) {
|
||||
if (sort == null) {
|
||||
throw new IllegalArgumentException("If computePreciseHitsCount is set, sort must be set");
|
||||
} else {
|
||||
throw new IllegalArgumentException("If sort is set, computePreciseHitsCount must be set");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sorted params with long offsets
|
||||
*/
|
||||
public LocalQueryParams(@NotNull Query query,
|
||||
long offsetLong,
|
||||
long limitLong,
|
||||
@NotNull PageLimits pageLimits,
|
||||
@Nullable Sort sort,
|
||||
boolean computePreciseHitsCount,
|
||||
@Nullable Boolean computePreciseHitsCount,
|
||||
Duration timeout) {
|
||||
this(query,
|
||||
safeLongToInt(offsetLong),
|
||||
@ -34,6 +47,9 @@ public record LocalQueryParams(@NotNull Query query, int offsetInt, long offsetL
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sorted params with int offsets
|
||||
*/
|
||||
public LocalQueryParams(@NotNull Query query,
|
||||
int offsetInt,
|
||||
int limitInt,
|
||||
@ -44,6 +60,38 @@ public record LocalQueryParams(@NotNull Query query, int offsetInt, long offsetL
|
||||
this(query, offsetInt, offsetInt, limitInt, limitInt, pageLimits, sort, computePreciseHitsCount, timeout);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unsorted params with int offsets
|
||||
*/
|
||||
public LocalQueryParams(@NotNull Query query,
|
||||
int offsetInt,
|
||||
int limitInt,
|
||||
@NotNull PageLimits pageLimits,
|
||||
Duration timeout) {
|
||||
this(query, offsetInt, offsetInt, limitInt, limitInt, pageLimits, null, null, timeout);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Unsorted params with long offsets
|
||||
*/
|
||||
public LocalQueryParams(@NotNull Query query,
|
||||
long offsetLong,
|
||||
long limitLong,
|
||||
@NotNull PageLimits pageLimits,
|
||||
Duration timeout) {
|
||||
this(query,
|
||||
safeLongToInt(offsetLong),
|
||||
offsetLong,
|
||||
safeLongToInt(limitLong),
|
||||
limitLong,
|
||||
pageLimits,
|
||||
null,
|
||||
null,
|
||||
timeout
|
||||
);
|
||||
}
|
||||
|
||||
public boolean isSorted() {
|
||||
return sort != null;
|
||||
}
|
||||
|
@ -33,7 +33,6 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
private final Iterator<LeafReaderContext> leavesIterator;
|
||||
private final boolean computeScores;
|
||||
private final CustomHitsThresholdChecker hitsThresholdChecker;
|
||||
private final MaxScoreAccumulator minScoreAcc;
|
||||
private final @Nullable Long limit;
|
||||
|
||||
private long remainingOffset;
|
||||
@ -48,8 +47,7 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
LuceneGenerator(IndexSearcher shard,
|
||||
LocalQueryParams localQueryParams,
|
||||
int shardIndex,
|
||||
CustomHitsThresholdChecker hitsThresholdChecker,
|
||||
MaxScoreAccumulator minScoreAcc) {
|
||||
CustomHitsThresholdChecker hitsThresholdChecker) {
|
||||
this.shard = shard;
|
||||
this.shardIndex = shardIndex;
|
||||
this.query = localQueryParams.query();
|
||||
@ -59,20 +57,20 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
List<LeafReaderContext> leaves = shard.getTopReaderContext().leaves();
|
||||
this.leavesIterator = leaves.iterator();
|
||||
this.hitsThresholdChecker = hitsThresholdChecker;
|
||||
this.minScoreAcc = minScoreAcc;
|
||||
}
|
||||
|
||||
public static Flux<ScoreDoc> reactive(IndexSearcher shard,
|
||||
LocalQueryParams localQueryParams,
|
||||
int shardIndex,
|
||||
CustomHitsThresholdChecker hitsThresholdChecker,
|
||||
MaxScoreAccumulator minScoreAcc) {
|
||||
CustomHitsThresholdChecker hitsThresholdChecker) {
|
||||
if (localQueryParams.sort() != null) {
|
||||
return Flux.error(new IllegalArgumentException("Sorting is not allowed"));
|
||||
}
|
||||
return Flux
|
||||
.<ScoreDoc, LuceneGenerator>generate(() -> new LuceneGenerator(shard,
|
||||
localQueryParams,
|
||||
shardIndex,
|
||||
hitsThresholdChecker,
|
||||
minScoreAcc
|
||||
hitsThresholdChecker
|
||||
),
|
||||
(s, sink) -> {
|
||||
ScoreDoc val = s.get();
|
||||
@ -92,7 +90,7 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
while (remainingOffset > 0) {
|
||||
skipNext();
|
||||
}
|
||||
if ((limit != null && totalHits > limit) || hitsThresholdChecker.isThresholdReached(true)) {
|
||||
if ((limit != null && totalHits >= limit) || hitsThresholdChecker.isThresholdReached(true)) {
|
||||
return null;
|
||||
}
|
||||
return getNext();
|
||||
@ -142,10 +140,6 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
totalHits++;
|
||||
hitsThresholdChecker.incrementHitCount();
|
||||
|
||||
if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) {
|
||||
updateGlobalMinCompetitiveScore(scorer);
|
||||
}
|
||||
|
||||
return transformDoc(doc);
|
||||
}
|
||||
docIdSetIterator = null;
|
||||
@ -166,9 +160,6 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
this.scorer = scorer;
|
||||
this.leaf = leaf;
|
||||
this.docIdSetIterator = scorer.iterator();
|
||||
if (minScoreAcc != null) {
|
||||
updateGlobalMinCompetitiveScore(scorer);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -185,23 +176,6 @@ public class LuceneGenerator implements Supplier<ScoreDoc> {
|
||||
return !liveDocs.get(doc);
|
||||
}
|
||||
|
||||
protected void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOException {
|
||||
assert minScoreAcc != null;
|
||||
DocAndScore maxMinScore = minScoreAcc.get();
|
||||
if (maxMinScore != null) {
|
||||
// since we tie-break on doc id and collect in doc id order we can require
|
||||
// the next float if the global minimum score is set on a document id that is
|
||||
// smaller than the ids in the current leaf
|
||||
float score =
|
||||
leaf.docBase >= maxMinScore.docBase ? Math.nextUp(maxMinScore.score) : maxMinScore.score;
|
||||
if (score > minCompetitiveScore) {
|
||||
assert hitsThresholdChecker.isThresholdReached(true);
|
||||
scorer.setMinCompetitiveScore(score);
|
||||
minCompetitiveScore = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void clearState() {
|
||||
docIdSetIterator = null;
|
||||
scorer = null;
|
||||
|
@ -16,8 +16,7 @@ public class LuceneMultiGenerator implements Supplier<ScoreDoc> {
|
||||
|
||||
public LuceneMultiGenerator(List<IndexSearcher> shards,
|
||||
LocalQueryParams localQueryParams,
|
||||
CustomHitsThresholdChecker hitsThresholdChecker,
|
||||
MaxScoreAccumulator minScoreAcc) {
|
||||
CustomHitsThresholdChecker hitsThresholdChecker) {
|
||||
this.generators = IntStream
|
||||
.range(0, shards.size())
|
||||
.mapToObj(shardIndex -> {
|
||||
@ -25,8 +24,7 @@ public class LuceneMultiGenerator implements Supplier<ScoreDoc> {
|
||||
return (Supplier<ScoreDoc>) new LuceneGenerator(shard,
|
||||
localQueryParams,
|
||||
shardIndex,
|
||||
hitsThresholdChecker,
|
||||
minScoreAcc
|
||||
hitsThresholdChecker
|
||||
);
|
||||
})
|
||||
.iterator();
|
||||
|
@ -80,7 +80,8 @@ public class OfficialSearcher implements MultiSearcher {
|
||||
LLUtils.ensureBlocking();
|
||||
|
||||
var collector = sharedManager.newCollector();
|
||||
assert queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive();
|
||||
assert queryParams.computePreciseHitsCount() == null
|
||||
|| (queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive());
|
||||
|
||||
shard.search(queryParams.query(), LuceneUtils.withTimeout(collector, queryParams.timeout()));
|
||||
return collector;
|
||||
|
@ -85,7 +85,8 @@ public class SortedByScoreFullMultiSearcher implements MultiSearcher {
|
||||
|
||||
var collector = sharedManager.newCollector();
|
||||
try {
|
||||
assert queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive();
|
||||
assert queryParams.computePreciseHitsCount() == null ||
|
||||
queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive();
|
||||
|
||||
shard.search(queryParams.query(), collector);
|
||||
return collector;
|
||||
|
@ -79,7 +79,8 @@ public class SortedScoredFullMultiSearcher implements MultiSearcher {
|
||||
|
||||
var collector = sharedManager.newCollector();
|
||||
try {
|
||||
assert queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive();
|
||||
assert queryParams.computePreciseHitsCount() == null
|
||||
|| queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive();
|
||||
|
||||
shard.search(queryParams.query(), collector);
|
||||
return collector;
|
||||
|
@ -64,7 +64,7 @@ public class UnsortedStreamingMultiSearcher implements MultiSearcher {
|
||||
return Flux.fromIterable(shards).index().flatMap(tuple -> {
|
||||
var shardIndex = (int) (long) tuple.getT1();
|
||||
var shard = tuple.getT2();
|
||||
return LuceneGenerator.reactive(shard, localQueryParams, shardIndex, hitsThreshold, maxScoreAccumulator);
|
||||
return LuceneGenerator.reactive(shard, localQueryParams, shardIndex, hitsThreshold);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -0,0 +1,180 @@
|
||||
package it.cavallium.dbengine.lucene.searcher;
|
||||
|
||||
import it.cavallium.dbengine.lucene.ExponentialPageLimits;
|
||||
import it.cavallium.dbengine.lucene.MaxScoreAccumulator;
|
||||
import it.unimi.dsi.fastutil.longs.LongList;
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.lucene.document.Field.Store;
|
||||
import org.apache.lucene.document.LongPoint;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.CustomHitsThresholdChecker;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.store.ByteBuffersDirectory;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class LuceneGeneratorTest {
|
||||
|
||||
private static IndexSearcher is;
|
||||
private static ExponentialPageLimits pageLimits;
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeAll() throws IOException {
|
||||
ByteBuffersDirectory dir = new ByteBuffersDirectory();
|
||||
IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig());
|
||||
iw.addDocument(List.of(new LongPoint("position", 1, 1)));
|
||||
iw.addDocument(List.of(new LongPoint("position", 2, 3)));
|
||||
iw.addDocument(List.of(new LongPoint("position", 4, -1)));
|
||||
iw.addDocument(List.of(new LongPoint("position", 3, -54)));
|
||||
// Exactly 4 dummies
|
||||
iw.addDocument(List.<IndexableField>of(new StringField("dummy", "dummy", Store.NO)));
|
||||
iw.addDocument(List.<IndexableField>of(new StringField("dummy", "dummy", Store.NO)));
|
||||
iw.addDocument(List.<IndexableField>of(new StringField("dummy", "dummy", Store.NO)));
|
||||
iw.addDocument(List.<IndexableField>of(new StringField("dummy", "dummy", Store.NO)));
|
||||
iw.commit();
|
||||
|
||||
DirectoryReader ir = DirectoryReader.open(iw);
|
||||
is = new IndexSearcher(ir);
|
||||
pageLimits = new ExponentialPageLimits();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test() throws IOException {
|
||||
var query = LongPoint.newRangeQuery("position",
|
||||
LongList.of(1, -1).toLongArray(),
|
||||
LongList.of(2, 3).toLongArray()
|
||||
);
|
||||
int limit = Integer.MAX_VALUE;
|
||||
var limitThresholdChecker = CustomHitsThresholdChecker.create(limit);
|
||||
|
||||
var expectedResults = fixResults(fixResults(List.of(is.search(query, limit, Sort.RELEVANCE, true).scoreDocs)));
|
||||
|
||||
var reactiveGenerator = LuceneGenerator.reactive(is,
|
||||
new LocalQueryParams(query,
|
||||
0L,
|
||||
limit,
|
||||
pageLimits,
|
||||
null,
|
||||
true,
|
||||
Duration.ofDays(1)
|
||||
),
|
||||
-1,
|
||||
limitThresholdChecker
|
||||
);
|
||||
var results = fixResults(reactiveGenerator.collectList().block());
|
||||
|
||||
Assertions.assertNotEquals(0, results.size());
|
||||
|
||||
Assertions.assertEquals(expectedResults, results);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLimit0() {
|
||||
var query = LongPoint.newRangeQuery("position",
|
||||
LongList.of(1, -1).toLongArray(),
|
||||
LongList.of(2, 3).toLongArray()
|
||||
);
|
||||
var limitThresholdChecker = CustomHitsThresholdChecker.create(0);
|
||||
var reactiveGenerator = LuceneGenerator.reactive(is,
|
||||
new LocalQueryParams(query,
|
||||
0L,
|
||||
0,
|
||||
pageLimits,
|
||||
Sort.RELEVANCE,
|
||||
true,
|
||||
Duration.ofDays(1)
|
||||
),
|
||||
-1,
|
||||
limitThresholdChecker
|
||||
);
|
||||
var results = reactiveGenerator.collectList().block();
|
||||
|
||||
Assertions.assertNotNull(results);
|
||||
Assertions.assertEquals(0, results.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLimitRestrictingResults() {
|
||||
var query = new TermQuery(new Term("dummy", "dummy"));
|
||||
int limit = 3; // the number of dummies - 1
|
||||
var limitThresholdChecker = CustomHitsThresholdChecker.create(limit);
|
||||
var reactiveGenerator = LuceneGenerator.reactive(is,
|
||||
new LocalQueryParams(query,
|
||||
0L,
|
||||
limit,
|
||||
pageLimits,
|
||||
null,
|
||||
true,
|
||||
Duration.ofDays(1)
|
||||
),
|
||||
-1,
|
||||
limitThresholdChecker
|
||||
);
|
||||
var results = reactiveGenerator.collectList().block();
|
||||
|
||||
Assertions.assertNotNull(results);
|
||||
Assertions.assertEquals(limit, results.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDummies() throws IOException {
|
||||
var query = new TermQuery(new Term("dummy", "dummy"));
|
||||
int limit = Integer.MAX_VALUE;
|
||||
|
||||
var expectedResults = fixResults(List.of(is.search(query, limit).scoreDocs));
|
||||
|
||||
var limitThresholdChecker = CustomHitsThresholdChecker.create(limit);
|
||||
var reactiveGenerator = LuceneGenerator.reactive(is,
|
||||
new LocalQueryParams(query,
|
||||
0,
|
||||
limit,
|
||||
pageLimits,
|
||||
null,
|
||||
true,
|
||||
Duration.ofDays(1)
|
||||
),
|
||||
-1,
|
||||
limitThresholdChecker
|
||||
);
|
||||
var results = fixResults(reactiveGenerator.collectList().block());
|
||||
|
||||
Assertions.assertEquals(4, results.size());
|
||||
Assertions.assertEquals(expectedResults, fixResults(results));
|
||||
}
|
||||
|
||||
private List<ScoreDoc> fixResults(List<ScoreDoc> results) {
|
||||
Assertions.assertNotNull(results);
|
||||
return results.stream().map(MyScoreDoc::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static class MyScoreDoc extends ScoreDoc {
|
||||
|
||||
public MyScoreDoc(ScoreDoc scoreDoc) {
|
||||
super(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj == this) {
|
||||
return true;
|
||||
}
|
||||
if (obj instanceof ScoreDoc sd) {
|
||||
return this.score == sd.score && this.doc == sd.doc && this.shardIndex == sd.shardIndex;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user