diff --git a/pom.xml b/pom.xml index 2741ad7..a6a58e6 100644 --- a/pom.xml +++ b/pom.xml @@ -157,7 +157,7 @@ org.slf4j slf4j-api - 1.7.32 + 1.7.35 org.apache.logging.log4j @@ -213,6 +213,11 @@ org.apache.lucene lucene-facet + + org.apache.lucene + lucene-test-framework + test + org.jetbrains annotations @@ -394,6 +399,11 @@ lucene-facet 9.0.0 + + org.apache.lucene + lucene-test-framework + 9.0.0 + org.jetbrains annotations diff --git a/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java b/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java index ef22ef3..bdcc5e7 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java +++ b/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java @@ -441,12 +441,12 @@ public class LuceneUtils { return result; } - public static int totalHitsThreshold(boolean complete) { - return complete ? Integer.MAX_VALUE : 1; + public static int totalHitsThreshold(@Nullable Boolean complete) { + return complete == null || complete ? Integer.MAX_VALUE : 1; } - public static long totalHitsThresholdLong(boolean complete) { - return complete ? Long.MAX_VALUE : 1; + public static long totalHitsThresholdLong(@Nullable Boolean complete) { + return complete == null || complete ? Long.MAX_VALUE : 1; } public static TotalHitsCount convertTotalHitsCount(TotalHits totalHits) { diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/LocalQueryParams.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/LocalQueryParams.java index fecee5c..e48f11f 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/LocalQueryParams.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/LocalQueryParams.java @@ -12,15 +12,28 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; public record LocalQueryParams(@NotNull Query query, int offsetInt, long offsetLong, int limitInt, long limitLong, - @NotNull PageLimits pageLimits, @Nullable Sort sort, boolean computePreciseHitsCount, - Duration timeout) { + @NotNull PageLimits pageLimits, @Nullable Sort sort, + @Nullable Boolean computePreciseHitsCount, Duration timeout) { + public LocalQueryParams { + if ((sort == null) != (computePreciseHitsCount == null)) { + if (sort == null) { + throw new IllegalArgumentException("If computePreciseHitsCount is set, sort must be set"); + } else { + throw new IllegalArgumentException("If sort is set, computePreciseHitsCount must be set"); + } + } + } + + /** + * Sorted params with long offsets + */ public LocalQueryParams(@NotNull Query query, long offsetLong, long limitLong, @NotNull PageLimits pageLimits, @Nullable Sort sort, - boolean computePreciseHitsCount, + @Nullable Boolean computePreciseHitsCount, Duration timeout) { this(query, safeLongToInt(offsetLong), @@ -34,6 +47,9 @@ public record LocalQueryParams(@NotNull Query query, int offsetInt, long offsetL ); } + /** + * Sorted params with int offsets + */ public LocalQueryParams(@NotNull Query query, int offsetInt, int limitInt, @@ -44,6 +60,38 @@ public record LocalQueryParams(@NotNull Query query, int offsetInt, long offsetL this(query, offsetInt, offsetInt, limitInt, limitInt, pageLimits, sort, computePreciseHitsCount, timeout); } + /** + * Unsorted params with int offsets + */ + public LocalQueryParams(@NotNull Query query, + int offsetInt, + int limitInt, + @NotNull PageLimits pageLimits, + Duration timeout) { + this(query, offsetInt, offsetInt, limitInt, limitInt, pageLimits, null, null, timeout); + } + + + /** + * Unsorted params with long offsets + */ + public LocalQueryParams(@NotNull Query query, + long offsetLong, + long limitLong, + @NotNull PageLimits pageLimits, + Duration timeout) { + this(query, + safeLongToInt(offsetLong), + offsetLong, + safeLongToInt(limitLong), + limitLong, + pageLimits, + null, + null, + timeout + ); + } + public boolean isSorted() { return sort != null; } diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneGenerator.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneGenerator.java index 570182a..ce2238e 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneGenerator.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneGenerator.java @@ -33,7 +33,6 @@ public class LuceneGenerator implements Supplier { private final Iterator leavesIterator; private final boolean computeScores; private final CustomHitsThresholdChecker hitsThresholdChecker; - private final MaxScoreAccumulator minScoreAcc; private final @Nullable Long limit; private long remainingOffset; @@ -48,8 +47,7 @@ public class LuceneGenerator implements Supplier { LuceneGenerator(IndexSearcher shard, LocalQueryParams localQueryParams, int shardIndex, - CustomHitsThresholdChecker hitsThresholdChecker, - MaxScoreAccumulator minScoreAcc) { + CustomHitsThresholdChecker hitsThresholdChecker) { this.shard = shard; this.shardIndex = shardIndex; this.query = localQueryParams.query(); @@ -59,20 +57,20 @@ public class LuceneGenerator implements Supplier { List leaves = shard.getTopReaderContext().leaves(); this.leavesIterator = leaves.iterator(); this.hitsThresholdChecker = hitsThresholdChecker; - this.minScoreAcc = minScoreAcc; } public static Flux reactive(IndexSearcher shard, LocalQueryParams localQueryParams, int shardIndex, - CustomHitsThresholdChecker hitsThresholdChecker, - MaxScoreAccumulator minScoreAcc) { + CustomHitsThresholdChecker hitsThresholdChecker) { + if (localQueryParams.sort() != null) { + return Flux.error(new IllegalArgumentException("Sorting is not allowed")); + } return Flux .generate(() -> new LuceneGenerator(shard, localQueryParams, shardIndex, - hitsThresholdChecker, - minScoreAcc + hitsThresholdChecker ), (s, sink) -> { ScoreDoc val = s.get(); @@ -92,7 +90,7 @@ public class LuceneGenerator implements Supplier { while (remainingOffset > 0) { skipNext(); } - if ((limit != null && totalHits > limit) || hitsThresholdChecker.isThresholdReached(true)) { + if ((limit != null && totalHits >= limit) || hitsThresholdChecker.isThresholdReached(true)) { return null; } return getNext(); @@ -142,10 +140,6 @@ public class LuceneGenerator implements Supplier { totalHits++; hitsThresholdChecker.incrementHitCount(); - if (minScoreAcc != null && (totalHits & minScoreAcc.modInterval) == 0) { - updateGlobalMinCompetitiveScore(scorer); - } - return transformDoc(doc); } docIdSetIterator = null; @@ -166,9 +160,6 @@ public class LuceneGenerator implements Supplier { this.scorer = scorer; this.leaf = leaf; this.docIdSetIterator = scorer.iterator(); - if (minScoreAcc != null) { - updateGlobalMinCompetitiveScore(scorer); - } return true; } return false; @@ -185,23 +176,6 @@ public class LuceneGenerator implements Supplier { return !liveDocs.get(doc); } - protected void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOException { - assert minScoreAcc != null; - DocAndScore maxMinScore = minScoreAcc.get(); - if (maxMinScore != null) { - // since we tie-break on doc id and collect in doc id order we can require - // the next float if the global minimum score is set on a document id that is - // smaller than the ids in the current leaf - float score = - leaf.docBase >= maxMinScore.docBase ? Math.nextUp(maxMinScore.score) : maxMinScore.score; - if (score > minCompetitiveScore) { - assert hitsThresholdChecker.isThresholdReached(true); - scorer.setMinCompetitiveScore(score); - minCompetitiveScore = score; - } - } - } - private void clearState() { docIdSetIterator = null; scorer = null; diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneMultiGenerator.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneMultiGenerator.java index 5de4fd3..b0ae9e2 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneMultiGenerator.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/LuceneMultiGenerator.java @@ -16,8 +16,7 @@ public class LuceneMultiGenerator implements Supplier { public LuceneMultiGenerator(List shards, LocalQueryParams localQueryParams, - CustomHitsThresholdChecker hitsThresholdChecker, - MaxScoreAccumulator minScoreAcc) { + CustomHitsThresholdChecker hitsThresholdChecker) { this.generators = IntStream .range(0, shards.size()) .mapToObj(shardIndex -> { @@ -25,8 +24,7 @@ public class LuceneMultiGenerator implements Supplier { return (Supplier) new LuceneGenerator(shard, localQueryParams, shardIndex, - hitsThresholdChecker, - minScoreAcc + hitsThresholdChecker ); }) .iterator(); diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/OfficialSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/OfficialSearcher.java index bb85c82..02bdddc 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/OfficialSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/OfficialSearcher.java @@ -80,7 +80,8 @@ public class OfficialSearcher implements MultiSearcher { LLUtils.ensureBlocking(); var collector = sharedManager.newCollector(); - assert queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive(); + assert queryParams.computePreciseHitsCount() == null + || (queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive()); shard.search(queryParams.query(), LuceneUtils.withTimeout(collector, queryParams.timeout())); return collector; diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedByScoreFullMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedByScoreFullMultiSearcher.java index 1112722..08e1279 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedByScoreFullMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedByScoreFullMultiSearcher.java @@ -85,7 +85,8 @@ public class SortedByScoreFullMultiSearcher implements MultiSearcher { var collector = sharedManager.newCollector(); try { - assert queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive(); + assert queryParams.computePreciseHitsCount() == null || + queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive(); shard.search(queryParams.query(), collector); return collector; diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedScoredFullMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedScoredFullMultiSearcher.java index 7bbe75b..ebd65e7 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedScoredFullMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/SortedScoredFullMultiSearcher.java @@ -79,7 +79,8 @@ public class SortedScoredFullMultiSearcher implements MultiSearcher { var collector = sharedManager.newCollector(); try { - assert queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive(); + assert queryParams.computePreciseHitsCount() == null + || queryParams.computePreciseHitsCount() == collector.scoreMode().isExhaustive(); shard.search(queryParams.query(), collector); return collector; diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java index a24436a..c1d3ef8 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java @@ -64,7 +64,7 @@ public class UnsortedStreamingMultiSearcher implements MultiSearcher { return Flux.fromIterable(shards).index().flatMap(tuple -> { var shardIndex = (int) (long) tuple.getT1(); var shard = tuple.getT2(); - return LuceneGenerator.reactive(shard, localQueryParams, shardIndex, hitsThreshold, maxScoreAccumulator); + return LuceneGenerator.reactive(shard, localQueryParams, shardIndex, hitsThreshold); }); }); } diff --git a/src/test/java/it/cavallium/dbengine/lucene/searcher/LuceneGeneratorTest.java b/src/test/java/it/cavallium/dbengine/lucene/searcher/LuceneGeneratorTest.java new file mode 100644 index 0000000..a6f7c00 --- /dev/null +++ b/src/test/java/it/cavallium/dbengine/lucene/searcher/LuceneGeneratorTest.java @@ -0,0 +1,180 @@ +package it.cavallium.dbengine.lucene.searcher; + +import it.cavallium.dbengine.lucene.ExponentialPageLimits; +import it.cavallium.dbengine.lucene.MaxScoreAccumulator; +import it.unimi.dsi.fastutil.longs.LongList; +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.CustomHitsThresholdChecker; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +public class LuceneGeneratorTest { + + private static IndexSearcher is; + private static ExponentialPageLimits pageLimits; + + @BeforeAll + public static void beforeAll() throws IOException { + ByteBuffersDirectory dir = new ByteBuffersDirectory(); + IndexWriter iw = new IndexWriter(dir, new IndexWriterConfig()); + iw.addDocument(List.of(new LongPoint("position", 1, 1))); + iw.addDocument(List.of(new LongPoint("position", 2, 3))); + iw.addDocument(List.of(new LongPoint("position", 4, -1))); + iw.addDocument(List.of(new LongPoint("position", 3, -54))); + // Exactly 4 dummies + iw.addDocument(List.of(new StringField("dummy", "dummy", Store.NO))); + iw.addDocument(List.of(new StringField("dummy", "dummy", Store.NO))); + iw.addDocument(List.of(new StringField("dummy", "dummy", Store.NO))); + iw.addDocument(List.of(new StringField("dummy", "dummy", Store.NO))); + iw.commit(); + + DirectoryReader ir = DirectoryReader.open(iw); + is = new IndexSearcher(ir); + pageLimits = new ExponentialPageLimits(); + } + + @Test + public void test() throws IOException { + var query = LongPoint.newRangeQuery("position", + LongList.of(1, -1).toLongArray(), + LongList.of(2, 3).toLongArray() + ); + int limit = Integer.MAX_VALUE; + var limitThresholdChecker = CustomHitsThresholdChecker.create(limit); + + var expectedResults = fixResults(fixResults(List.of(is.search(query, limit, Sort.RELEVANCE, true).scoreDocs))); + + var reactiveGenerator = LuceneGenerator.reactive(is, + new LocalQueryParams(query, + 0L, + limit, + pageLimits, + null, + true, + Duration.ofDays(1) + ), + -1, + limitThresholdChecker + ); + var results = fixResults(reactiveGenerator.collectList().block()); + + Assertions.assertNotEquals(0, results.size()); + + Assertions.assertEquals(expectedResults, results); + } + + @Test + public void testLimit0() { + var query = LongPoint.newRangeQuery("position", + LongList.of(1, -1).toLongArray(), + LongList.of(2, 3).toLongArray() + ); + var limitThresholdChecker = CustomHitsThresholdChecker.create(0); + var reactiveGenerator = LuceneGenerator.reactive(is, + new LocalQueryParams(query, + 0L, + 0, + pageLimits, + Sort.RELEVANCE, + true, + Duration.ofDays(1) + ), + -1, + limitThresholdChecker + ); + var results = reactiveGenerator.collectList().block(); + + Assertions.assertNotNull(results); + Assertions.assertEquals(0, results.size()); + } + + @Test + public void testLimitRestrictingResults() { + var query = new TermQuery(new Term("dummy", "dummy")); + int limit = 3; // the number of dummies - 1 + var limitThresholdChecker = CustomHitsThresholdChecker.create(limit); + var reactiveGenerator = LuceneGenerator.reactive(is, + new LocalQueryParams(query, + 0L, + limit, + pageLimits, + null, + true, + Duration.ofDays(1) + ), + -1, + limitThresholdChecker + ); + var results = reactiveGenerator.collectList().block(); + + Assertions.assertNotNull(results); + Assertions.assertEquals(limit, results.size()); + } + + @Test + public void testDummies() throws IOException { + var query = new TermQuery(new Term("dummy", "dummy")); + int limit = Integer.MAX_VALUE; + + var expectedResults = fixResults(List.of(is.search(query, limit).scoreDocs)); + + var limitThresholdChecker = CustomHitsThresholdChecker.create(limit); + var reactiveGenerator = LuceneGenerator.reactive(is, + new LocalQueryParams(query, + 0, + limit, + pageLimits, + null, + true, + Duration.ofDays(1) + ), + -1, + limitThresholdChecker + ); + var results = fixResults(reactiveGenerator.collectList().block()); + + Assertions.assertEquals(4, results.size()); + Assertions.assertEquals(expectedResults, fixResults(results)); + } + + private List fixResults(List results) { + Assertions.assertNotNull(results); + return results.stream().map(MyScoreDoc::new).collect(Collectors.toList()); + } + + private static class MyScoreDoc extends ScoreDoc { + + public MyScoreDoc(ScoreDoc scoreDoc) { + super(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof ScoreDoc sd) { + return this.score == sd.score && this.doc == sd.doc && this.shardIndex == sd.shardIndex; + } + return false; + } + } + +}