diff --git a/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java b/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java index 9375b62..328c9d3 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java +++ b/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java @@ -403,8 +403,7 @@ public class LuceneUtils { @Nullable Sort sort, @Nullable Integer startN, @Nullable Integer topN, - TopDocs[] topDocs, - Comparator tieBreaker) { + TopDocs[] topDocs) { if ((startN == null) != (topN == null)) { throw new IllegalArgumentException("You must pass startN and topN together or nothing"); } @@ -420,14 +419,12 @@ public class LuceneUtils { defaultTopN += length; } result = TopDocs.merge(sort, 0, defaultTopN, - (TopFieldDocs[]) topDocs, - tieBreaker + (TopFieldDocs[]) topDocs ); } else { result = TopDocs.merge(sort, startN, topN, - (TopFieldDocs[]) topDocs, - tieBreaker + (TopFieldDocs[]) topDocs ); } } else { @@ -439,14 +436,12 @@ public class LuceneUtils { } result = TopDocs.merge(0, defaultTopN, - topDocs, - tieBreaker + topDocs ); } else { result = TopDocs.merge(startN, topN, - topDocs, - tieBreaker + topDocs ); } } diff --git a/src/main/java/it/cavallium/dbengine/lucene/collector/CollectorMultiManager.java b/src/main/java/it/cavallium/dbengine/lucene/collector/CollectorMultiManager.java index 7b76943..5107b1a 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/collector/CollectorMultiManager.java +++ b/src/main/java/it/cavallium/dbengine/lucene/collector/CollectorMultiManager.java @@ -1,8 +1,12 @@ package it.cavallium.dbengine.lucene.collector; +import java.util.List; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; -public interface CollectorMultiManager { +public interface CollectorMultiManager { ScoreMode scoreMode(); + + U reduce(List results); } diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/TopDocsCollectorUtils.java b/src/main/java/it/cavallium/dbengine/lucene/collector/OptimizedTopDocsCollector.java similarity index 82% rename from src/main/java/it/cavallium/dbengine/lucene/searcher/TopDocsCollectorUtils.java rename to src/main/java/it/cavallium/dbengine/lucene/collector/OptimizedTopDocsCollector.java index 1216e3f..79ffca3 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/TopDocsCollectorUtils.java +++ b/src/main/java/it/cavallium/dbengine/lucene/collector/OptimizedTopDocsCollector.java @@ -1,19 +1,24 @@ -package it.cavallium.dbengine.lucene.searcher; +package it.cavallium.dbengine.lucene.collector; import static it.cavallium.dbengine.lucene.searcher.PaginationInfo.ALLOW_UNSCORED_PAGINATION_MODE; import it.cavallium.dbengine.lucene.collector.UnscoredCollector; +import java.io.IOException; +import java.util.Collection; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.search.TopScoreDocCollector; -class TopDocsCollectorUtils { +public class OptimizedTopDocsCollector { @SuppressWarnings({"unchecked", "rawtypes"}) - public static TopDocsCollector getTopDocsCollector(Sort luceneSort, + public static TopDocsCollector create(Sort luceneSort, int limit, ScoreDoc after, int totalHitsThreshold, diff --git a/src/main/java/it/cavallium/dbengine/lucene/collector/ReactiveCollectorMultiManager.java b/src/main/java/it/cavallium/dbengine/lucene/collector/ReactiveCollectorMultiManager.java index 809ceee..afd13e9 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/collector/ReactiveCollectorMultiManager.java +++ b/src/main/java/it/cavallium/dbengine/lucene/collector/ReactiveCollectorMultiManager.java @@ -2,6 +2,7 @@ package it.cavallium.dbengine.lucene.collector; import java.io.IOException; import java.util.Collection; +import java.util.List; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; @@ -11,7 +12,7 @@ import org.apache.lucene.search.ScoreMode; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Sinks.Many; -public class ReactiveCollectorMultiManager implements CollectorMultiManager { +public class ReactiveCollectorMultiManager implements CollectorMultiManager { private final FluxSink scoreDocsSink; @@ -49,4 +50,9 @@ public class ReactiveCollectorMultiManager implements CollectorMultiManager { public ScoreMode scoreMode() { return ScoreMode.COMPLETE_NO_SCORES; } + + @Override + public Void reduce(List results) { + return null; + } } diff --git a/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorManager.java b/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorManager.java deleted file mode 100644 index e8cefb2..0000000 --- a/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorManager.java +++ /dev/null @@ -1,132 +0,0 @@ -package it.cavallium.dbengine.lucene.collector; - -import static it.cavallium.dbengine.lucene.searcher.CurrentPageInfo.TIE_BREAKER; - -import it.cavallium.dbengine.lucene.LuceneUtils; -import java.io.IOException; -import java.util.Collection; -import java.util.List; -import java.util.Objects; -import org.apache.lucene.search.CollectorManager; -import org.apache.lucene.search.FieldDoc; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.Sort; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopFieldCollector; -import org.apache.lucene.search.TopFieldDocs; -import org.jetbrains.annotations.Nullable; -import reactor.core.scheduler.Schedulers; - -public class ScoringShardsCollectorManager implements CollectorManager { - - private final Query query; - @Nullable - private final Sort sort; - private final int numHits; - private final FieldDoc after; - private final int totalHitsThreshold; - private final @Nullable Integer startN; - private final @Nullable Integer topN; - private final CollectorManager sharedCollectorManager; - private List indexSearchers; - - public ScoringShardsCollectorManager(Query query, - @Nullable final Sort sort, - final int numHits, - final FieldDoc after, - final int totalHitsThreshold, - int startN, - int topN) { - this(query, sort, numHits, after, totalHitsThreshold, (Integer) startN, (Integer) topN); - } - - public ScoringShardsCollectorManager(Query query, - @Nullable final Sort sort, - final int numHits, - final FieldDoc after, - final int totalHitsThreshold, - int startN) { - this(query, sort, numHits, after, totalHitsThreshold, (Integer) startN, (Integer) 2147483630); - } - - public ScoringShardsCollectorManager(Query query, - @Nullable final Sort sort, - final int numHits, - final FieldDoc after, - final int totalHitsThreshold) { - this(query, sort, numHits, after, totalHitsThreshold, null, null); - } - - private ScoringShardsCollectorManager(Query query, - @Nullable final Sort sort, - final int numHits, - final FieldDoc after, - final int totalHitsThreshold, - @Nullable Integer startN, - @Nullable Integer topN) { - this.query = query; - this.sort = sort; - this.numHits = numHits; - this.after = after; - this.totalHitsThreshold = totalHitsThreshold; - this.startN = startN; - if (topN != null && startN != null && (long) topN + (long) startN > 2147483630) { - this.topN = 2147483630 - startN; - } else if (topN != null && topN > 2147483630) { - this.topN = 2147483630; - } else { - this.topN = topN; - } - this.sharedCollectorManager = TopFieldCollector.createSharedManager(sort == null ? Sort.RELEVANCE : sort, numHits, after, totalHitsThreshold); - } - - @Override - public TopFieldCollector newCollector() throws IOException { - return sharedCollectorManager.newCollector(); - } - - public void setIndexSearchers(List indexSearcher) { - this.indexSearchers = indexSearcher; - } - - @Override - public TopDocs reduce(Collection collectors) throws IOException { - if (Schedulers.isInNonBlockingThread()) { - throw new UnsupportedOperationException("Called reduce in a nonblocking thread"); - } - TopDocs result; - if (sort != null) { - TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()]; - var i = 0; - for (TopFieldCollector collector : collectors) { - topDocs[i] = collector.topDocs(); - - // Populate scores of topfieldcollector. By default it doesn't popupate the scores - if (topDocs[i].scoreDocs.length > 0 && Float.isNaN(topDocs[i].scoreDocs[0].score) && sort.needsScores()) { - Objects.requireNonNull(indexSearchers, "You must call setIndexSearchers before calling reduce!"); - TopFieldCollector.populateScores(topDocs[i].scoreDocs, indexSearchers.get(i), query); - } - - for (ScoreDoc scoreDoc : topDocs[i].scoreDocs) { - scoreDoc.shardIndex = i; - } - i++; - } - result = LuceneUtils.mergeTopDocs(sort, startN, topN, topDocs, TIE_BREAKER); - } else { - TopDocs[] topDocs = new TopDocs[collectors.size()]; - var i = 0; - for (TopFieldCollector collector : collectors) { - topDocs[i] = collector.topDocs(); - for (ScoreDoc scoreDoc : topDocs[i].scoreDocs) { - scoreDoc.shardIndex = i; - } - i++; - } - result = LuceneUtils.mergeTopDocs(null, startN, topN, topDocs, TIE_BREAKER); - } - return result; - } -} diff --git a/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorMultiManager.java b/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorMultiManager.java new file mode 100644 index 0000000..bd40051 --- /dev/null +++ b/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorMultiManager.java @@ -0,0 +1,183 @@ +package it.cavallium.dbengine.lucene.collector; + +import static it.cavallium.dbengine.lucene.searcher.CurrentPageInfo.TIE_BREAKER; + +import it.cavallium.dbengine.lucene.LuceneUtils; +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopFieldCollector; +import org.apache.lucene.search.TopFieldDocs; +import org.jetbrains.annotations.Nullable; +import reactor.core.scheduler.Schedulers; + +public class ScoringShardsCollectorMultiManager implements CollectorMultiManager { + + private static final boolean USE_CLASSIC_REDUCE = false; + private final Query query; + @Nullable + private final Sort sort; + private final int numHits; + private final FieldDoc after; + private final int totalHitsThreshold; + private final @Nullable Integer startN; + private final @Nullable Integer topN; + private final @Nullable Integer internalStartN; + private final @Nullable Integer internalTopN; + private final CollectorManager sharedCollectorManager; + + public ScoringShardsCollectorMultiManager(Query query, + @Nullable final Sort sort, + final int numHits, + final FieldDoc after, + final int totalHitsThreshold, + int startN, + int topN) { + this(query, sort, numHits, after, totalHitsThreshold, (Integer) startN, (Integer) topN); + } + + public ScoringShardsCollectorMultiManager(Query query, + @Nullable final Sort sort, + final int numHits, + final FieldDoc after, + final int totalHitsThreshold, + int startN) { + this(query, sort, numHits, after, totalHitsThreshold, (Integer) startN, (Integer) 2147483630); + } + + public ScoringShardsCollectorMultiManager(Query query, + @Nullable final Sort sort, + final int numHits, + final FieldDoc after, + final int totalHitsThreshold) { + this(query, sort, numHits, after, totalHitsThreshold, null, null); + } + + private ScoringShardsCollectorMultiManager(Query query, + @Nullable final Sort sort, + final int numHits, + final FieldDoc after, + final int totalHitsThreshold, + @Nullable Integer startN, + @Nullable Integer topN) { + this.query = query; + this.sort = sort; + this.numHits = numHits; + this.after = after; + this.totalHitsThreshold = totalHitsThreshold; + this.startN = startN; + if (topN != null && startN != null && (long) topN + (long) startN > 2147483630) { + this.topN = 2147483630 - startN; + } else if (topN != null && topN > 2147483630) { + this.topN = 2147483630; + } else { + this.topN = topN; + } + if (this.topN != null && this.startN != null) { + if (this.topN >= 2147483630) { + this.internalTopN = this.topN; + } else { + this.internalTopN = this.startN + this.topN; + } + } else if (this.topN == null && this.startN != null) { + this.internalTopN = null; + } else { + this.internalTopN = this.topN; + } + if (this.internalTopN != null) { + this.internalStartN = 0; + } else { + this.internalStartN = null; + } + this.sharedCollectorManager = TopFieldCollector.createSharedManager(sort == null ? Sort.RELEVANCE : sort, numHits, after, totalHitsThreshold); + } + + public CollectorManager get(IndexSearcher indexSearcher, int shardIndex) { + return new CollectorManager<>() { + @Override + public TopFieldCollector newCollector() throws IOException { + return sharedCollectorManager.newCollector(); + } + + @Override + public TopDocs reduce(Collection collectors) throws IOException { + if (Schedulers.isInNonBlockingThread()) { + throw new UnsupportedOperationException("Called reduce in a nonblocking thread"); + } + if (USE_CLASSIC_REDUCE) { + final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()]; + int i = 0; + for (TopFieldCollector collector : collectors) { + topDocs[i++] = collector.topDocs(); + } + var result = LuceneUtils.mergeTopDocs(sort, 0, numHits, topDocs); + + if (sort != null && sort.needsScores()) { + TopFieldCollector.populateScores(result.scoreDocs, indexSearcher, query); + } + + return result; + } else { + TopDocs result; + if (sort != null) { + TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()]; + var i = 0; + for (TopFieldCollector collector : collectors) { + topDocs[i] = collector.topDocs(); + + // Populate scores of topfieldcollector. By default it doesn't popupate the scores + if (topDocs[i].scoreDocs.length > 0 && Float.isNaN(topDocs[i].scoreDocs[0].score) && sort.needsScores()) { + TopFieldCollector.populateScores(topDocs[i].scoreDocs, indexSearcher, query); + } + + for (ScoreDoc scoreDoc : topDocs[i].scoreDocs) { + scoreDoc.shardIndex = shardIndex; + } + i++; + } + result = LuceneUtils.mergeTopDocs(sort, internalStartN, internalTopN, topDocs); + } else { + TopDocs[] topDocs = new TopDocs[collectors.size()]; + var i = 0; + for (TopFieldCollector collector : collectors) { + topDocs[i] = collector.topDocs(); + for (ScoreDoc scoreDoc : topDocs[i].scoreDocs) { + scoreDoc.shardIndex = shardIndex; + } + i++; + } + result = LuceneUtils.mergeTopDocs(null, internalStartN, internalTopN, topDocs); + } + return result; + } + } + }; + } + + @Override + public ScoreMode scoreMode() { + throw new NotImplementedException(); + } + + @SuppressWarnings({"SuspiciousToArrayCall", "IfStatementWithIdenticalBranches"}) + @Override + public TopDocs reduce(List topDocs) { + TopDocs[] arr; + if (sort != null) { + arr = topDocs.toArray(TopFieldDocs[]::new); + } else { + arr = topDocs.toArray(TopDocs[]::new); + } + return LuceneUtils.mergeTopDocs(sort, startN, topN, arr); + } +} diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/PagedLocalSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/PagedLocalSearcher.java index 1628b9d..715bbc3 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/PagedLocalSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/PagedLocalSearcher.java @@ -10,15 +10,14 @@ import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.disk.LLIndexSearcher; import it.cavallium.dbengine.database.disk.LLIndexSearchers; import it.cavallium.dbengine.lucene.LuceneUtils; +import it.cavallium.dbengine.lucene.collector.OptimizedTopDocsCollector; import it.cavallium.dbengine.lucene.searcher.LLSearchTransformer.TransformerInput; import java.io.IOException; import java.util.Arrays; import java.util.List; -import java.util.Objects; import org.apache.lucene.search.Collector; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopFieldCollector; @@ -185,7 +184,7 @@ public class PagedLocalSearcher implements LocalSearcher { } else if (s.pageIndex() == 0 || (s.last() != null && s.remainingLimit() > 0)) { TopDocs pageTopDocs; try { - TopDocsCollector collector = TopDocsCollectorUtils.getTopDocsCollector(queryParams.sort(), + TopDocsCollector collector = OptimizedTopDocsCollector.create(queryParams.sort(), currentPageLimit, s.last(), queryParams.getTotalHitsThresholdInt(), allowPagination, queryParams.needsScores()); assert queryParams.complete() == collector.scoreMode().isExhaustive(); diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java index 4307668..397b795 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java @@ -8,7 +8,7 @@ import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.disk.LLIndexSearchers; import it.cavallium.dbengine.lucene.LuceneUtils; import it.cavallium.dbengine.lucene.PageLimits; -import it.cavallium.dbengine.lucene.collector.ScoringShardsCollectorManager; +import it.cavallium.dbengine.lucene.collector.ScoringShardsCollectorMultiManager; import it.cavallium.dbengine.lucene.searcher.LLSearchTransformer.TransformerInput; import java.util.Arrays; import java.util.List; @@ -176,32 +176,29 @@ public class ScoredPagedMultiSearcher implements MultiSearcher { var pageLimit = pageLimits.getPageLimit(s.pageIndex()); var after = (FieldDoc) s.last(); var totalHitsThreshold = queryParams.getTotalHitsThresholdInt(); - return new ScoringShardsCollectorManager(query, sort, pageLimit, after, totalHitsThreshold, + return new ScoringShardsCollectorMultiManager(query, sort, pageLimit, after, totalHitsThreshold, resultsOffset, pageLimit); } else { return null; } }) - .flatMap(sharedManager -> Flux + .flatMap(cmm -> Flux .fromIterable(indexSearchers) - .flatMap(shard -> Mono.fromCallable(() -> { + .index() + .flatMap(shardWithIndex -> Mono.fromCallable(() -> { LLUtils.ensureBlocking(); - var collector = sharedManager.newCollector(); - assert queryParams.complete() == collector.scoreMode().isExhaustive(); - assert pageLimits.getPageLimit(s.pageIndex()) < Integer.MAX_VALUE || queryParams - .getScoreModeOptional() - .map(scoreMode -> scoreMode == collector.scoreMode()) - .orElse(true); + var index = (int) (long) shardWithIndex.getT1(); + var shard = shardWithIndex.getT2(); - shard.search(queryParams.query(), collector); - return collector; + var cm = cmm.get(shard, index); + + return shard.search(queryParams.query(), cm); })) .collectList() - .flatMap(collectors -> Mono.fromCallable(() -> { + .flatMap(results -> Mono.fromCallable(() -> { LLUtils.ensureBlocking(); - sharedManager.setIndexSearchers(indexSearchers); - var pageTopDocs = sharedManager.reduce(collectors); + var pageTopDocs = cmm.reduce(results); var pageLastDoc = LuceneUtils.getLastScoreDoc(pageTopDocs.scoreDocs); long nextRemainingLimit; diff --git a/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java b/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java index bda736b..53a8359 100644 --- a/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java +++ b/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java @@ -38,8 +38,8 @@ public class UnsortedUnscoredSimpleMultiSearcher implements MultiSearcher { if (transformer == LLSearchTransformer.NO_TRANSFORMATION) { queryParamsMono = Mono.just(queryParams); } else { - queryParamsMono = transformer.transform(Mono - .fromSupplier(() -> new TransformerInput(indexSearchers, queryParams))); + var transformerInput = Mono.just(new TransformerInput(indexSearchers, queryParams)); + queryParamsMono = transformer.transform(transformerInput); } return queryParamsMono.flatMap(queryParams2 -> {