diff --git a/pom.xml b/pom.xml index adac45e..777c6da 100644 --- a/pom.xml +++ b/pom.xml @@ -13,7 +13,7 @@ 0-SNAPSHOT false 1.8.5 - 9.1.0 + 9.2.0 5.8.2 @@ -205,7 +205,7 @@ org.rocksdb rocksdbjni - 7.2.2 + 7.3.1 org.apache.lucene diff --git a/src/main/java/it/cavallium/dbengine/database/disk/LLIndexSearchers.java b/src/main/java/it/cavallium/dbengine/database/disk/LLIndexSearchers.java index 19579f3..5fb6c72 100644 --- a/src/main/java/it/cavallium/dbengine/database/disk/LLIndexSearchers.java +++ b/src/main/java/it/cavallium/dbengine/database/disk/LLIndexSearchers.java @@ -5,6 +5,7 @@ import io.netty5.buffer.api.Owned; import io.netty5.buffer.api.Resource; import io.netty5.buffer.api.Send; import io.netty5.buffer.api.internal.ResourceSupport; +import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import java.io.Closeable; import java.io.IOException; @@ -13,6 +14,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.IndexReader; @@ -98,12 +100,19 @@ public interface LLIndexSearchers extends Closeable { private final List indexSearchersVals; public ShardedIndexSearchers(List indexSearchers) { - this.indexSearchers = new ArrayList<>(indexSearchers.size()); - this.indexSearchersVals = new ArrayList<>(indexSearchers.size()); + var shardedIndexSearchers = new ArrayList(indexSearchers.size()); + List shardedIndexSearchersVals = new ArrayList<>(indexSearchers.size()); for (LLIndexSearcher indexSearcher : indexSearchers) { - this.indexSearchers.add(indexSearcher); - this.indexSearchersVals.add(indexSearcher.getIndexSearcher()); + shardedIndexSearchersVals.add(indexSearcher.getIndexSearcher()); } + shardedIndexSearchersVals = ShardIndexSearcher.create(shardedIndexSearchersVals); + int i = 0; + for (IndexSearcher shardedIndexSearcher : shardedIndexSearchersVals) { + shardedIndexSearchers.add(new WrappedLLIndexSearcher(shardedIndexSearcher, indexSearchers.get(i))); + i++; + } + this.indexSearchers = shardedIndexSearchers; + this.indexSearchersVals = shardedIndexSearchersVals; } @Override @@ -156,5 +165,30 @@ public interface LLIndexSearchers extends Closeable { indexSearcher.close(); } } + + private static class WrappedLLIndexSearcher extends LLIndexSearcher { + + private final LLIndexSearcher parent; + + public WrappedLLIndexSearcher(IndexSearcher indexSearcher, LLIndexSearcher parent) { + super(indexSearcher, parent.getClosed()); + this.parent = parent; + } + + @Override + public IndexSearcher getIndexSearcher() { + return indexSearcher; + } + + @Override + public IndexReader getIndexReader() { + return indexSearcher.getIndexReader(); + } + + @Override + protected void onClose() throws IOException { + parent.onClose(); + } + } } } diff --git a/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java b/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java index 372992a..5a3e042 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java +++ b/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java @@ -183,7 +183,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager (DoubleRange[]) bucketRanges ); } - FacetResult children = facets.getTopChildren(0, bucketField); + FacetResult children = facets.getTopChildren(1, bucketField); if (AMORTIZE && randomSamplingEnabled) { var cfg = new FacetsConfig(); for (Range bucketRange : bucketRanges) { diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/ShardIndexSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/ShardIndexSearcher.java new file mode 100644 index 0000000..70fa473 --- /dev/null +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/ShardIndexSearcher.java @@ -0,0 +1,263 @@ +package it.cavallium.dbengine.lucene.searcher; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermStates; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.TermStatistics; + +public class ShardIndexSearcher extends IndexSearcher { + public final int myNodeID; + private final IndexSearcher[] searchers; + + private final Map collectionStatsCache; + private final Map termStatsCache; + + public ShardIndexSearcher(SharedShardStatistics sharedShardStatistics, + List searchers, + int nodeID) { + super(searchers.get(nodeID).getIndexReader(), searchers.get(nodeID).getExecutor()); + this.collectionStatsCache = sharedShardStatistics.collectionStatsCache; + this.termStatsCache = sharedShardStatistics.termStatsCache; + this.searchers = searchers.toArray(IndexSearcher[]::new); + myNodeID = nodeID; + } + + public static List create(Iterable indexSearchersIterable) { + var it = indexSearchersIterable.iterator(); + if (it.hasNext()) { + var is = it.next(); + if (!it.hasNext()) { + return List.of(is); + } + if (is instanceof ShardIndexSearcher) { + if (indexSearchersIterable instanceof List list) { + return list; + } else { + var result = new ArrayList(); + result.add(is); + do { + result.add(it.next()); + } while (it.hasNext()); + return result; + } + } + } + List indexSearchers; + if (indexSearchersIterable instanceof List collection) { + indexSearchers = collection; + } else if (indexSearchersIterable instanceof Collection collection) { + indexSearchers = List.copyOf(collection); + } else { + indexSearchers = new ArrayList<>(); + for (IndexSearcher i : indexSearchersIterable) { + indexSearchers.add(i); + } + } + if (indexSearchers.size() == 0) { + return List.of(); + } else { + var sharedShardStatistics = new SharedShardStatistics(); + List result = new ArrayList<>(indexSearchers.size()); + for (int nodeId = 0; nodeId < indexSearchers.size(); nodeId++) { + result.add(new ShardIndexSearcher(sharedShardStatistics, indexSearchers, nodeId)); + } + return result; + } + } + + @Override + public Query rewrite(Query original) throws IOException { + final IndexSearcher localSearcher = new IndexSearcher(getIndexReader()); + original = localSearcher.rewrite(original); + final Set terms = new HashSet<>(); + original.visit(QueryVisitor.termCollector(terms)); + + // Make a single request to remote nodes for term + // stats: + for (int nodeID = 0; nodeID < searchers.length; nodeID++) { + if (nodeID == myNodeID) { + continue; + } + + final Set missing = new HashSet<>(); + for (Term term : terms) { + final TermAndShard key = + new TermAndShard(nodeID, term); + if (!termStatsCache.containsKey(key)) { + missing.add(term); + } + } + if (missing.size() != 0) { + for (Map.Entry ent : + getNodeTermStats(missing, nodeID).entrySet()) { + if (ent.getValue() != null) { + final TermAndShard key = new TermAndShard(nodeID, ent.getKey()); + termStatsCache.put(key, ent.getValue()); + } + } + } + } + + return original; + } + + // Mock: in a real env, this would hit the wire and get + // term stats from remote node + Map getNodeTermStats(Set terms, int nodeID) + throws IOException { + var s = searchers[nodeID]; + final Map stats = new HashMap<>(); + if (s == null) { + throw new NoSuchElementException("node=" + nodeID); + } + for (Term term : terms) { + final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true); + if (ts.docFreq() > 0) { + stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq())); + } + } + return stats; + } + + @Override + public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) + throws IOException { + assert term != null; + long distributedDocFreq = 0; + long distributedTotalTermFreq = 0; + for (int nodeID = 0; nodeID < searchers.length; nodeID++) { + + final TermStatistics subStats; + if (nodeID == myNodeID) { + subStats = super.termStatistics(term, docFreq, totalTermFreq); + } else { + final TermAndShard key = + new TermAndShard(nodeID, term); + subStats = termStatsCache.get(key); + if (subStats == null) { + continue; // term not found + } + } + + long nodeDocFreq = subStats.docFreq(); + distributedDocFreq += nodeDocFreq; + + long nodeTotalTermFreq = subStats.totalTermFreq(); + distributedTotalTermFreq += nodeTotalTermFreq; + } + assert distributedDocFreq > 0; + return new TermStatistics(term.bytes(), distributedDocFreq, distributedTotalTermFreq); + } + + @Override + public CollectionStatistics collectionStatistics(String field) throws IOException { + // TODO: we could compute this on init and cache, + // since we are re-inited whenever any nodes have a + // new reader + long docCount = 0; + long sumTotalTermFreq = 0; + long sumDocFreq = 0; + long maxDoc = 0; + + for (int nodeID = 0; nodeID < searchers.length; nodeID++) { + final FieldAndShar key = new FieldAndShar(nodeID, field); + final CollectionStatistics nodeStats; + if (nodeID == myNodeID) { + nodeStats = super.collectionStatistics(field); + } else { + nodeStats = collectionStatsCache.get(key); + } + if (nodeStats == null) { + continue; // field not in sub at all + } + + long nodeDocCount = nodeStats.docCount(); + docCount += nodeDocCount; + + long nodeSumTotalTermFreq = nodeStats.sumTotalTermFreq(); + sumTotalTermFreq += nodeSumTotalTermFreq; + + long nodeSumDocFreq = nodeStats.sumDocFreq(); + sumDocFreq += nodeSumDocFreq; + + assert nodeStats.maxDoc() >= 0; + maxDoc += nodeStats.maxDoc(); + } + + if (maxDoc == 0) { + return null; // field not found across any node whatsoever + } else { + return new CollectionStatistics(field, maxDoc, docCount, sumTotalTermFreq, sumDocFreq); + } + } + + public static class TermAndShard { + private final int nodeID; + private final Term term; + + public TermAndShard(int nodeID, Term term) { + this.nodeID = nodeID; + this.term = term; + } + + @Override + public int hashCode() { + return (nodeID + term.hashCode()); + } + + @Override + public boolean equals(Object _other) { + if (!(_other instanceof final TermAndShard other)) { + return false; + } + + return term.equals(other.term) && nodeID == other.nodeID; + } + } + + + public static class FieldAndShar { + private final int nodeID; + private final String field; + + public FieldAndShar(int nodeID, String field) { + this.nodeID = nodeID; + this.field = field; + } + + @Override + public int hashCode() { + return (nodeID + field.hashCode()); + } + + @Override + public boolean equals(Object _other) { + if (!(_other instanceof final FieldAndShar other)) { + return false; + } + + return field.equals(other.field) && nodeID == other.nodeID; + } + + @Override + public String toString() { + return "FieldAndShardVersion(field=" + + field + + " nodeID=" + + nodeID + + ")"; + } + } +} \ No newline at end of file diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/SharedShardStatistics.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/SharedShardStatistics.java new file mode 100644 index 0000000..51280db --- /dev/null +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/SharedShardStatistics.java @@ -0,0 +1,14 @@ +package it.cavallium.dbengine.lucene.searcher; + +import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher.FieldAndShar; +import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher.TermAndShard; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.TermStatistics; + +public class SharedShardStatistics { + + public final Map collectionStatsCache = new ConcurrentHashMap<>(); + public final Map termStatsCache = new ConcurrentHashMap<>(); +} diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/StandardSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/StandardSearcher.java index c5c7dda..bec8356 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/StandardSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/StandardSearcher.java @@ -4,7 +4,6 @@ import static it.cavallium.dbengine.client.UninterruptibleScheduler.uninterrupti import static it.cavallium.dbengine.database.LLUtils.singleOrClose; import static java.util.Objects.requireNonNull; -import io.netty5.buffer.api.Send; import it.cavallium.dbengine.database.LLKeyScore; import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.disk.LLIndexSearchers; @@ -75,51 +74,54 @@ public class StandardSearcher implements MultiSearcher { } }) .subscribeOn(uninterruptibleScheduler(Schedulers.boundedElastic())) - .flatMap(sharedManager -> Flux.fromIterable(indexSearchers).>handle((shard, sink) -> { - LLUtils.ensureBlocking(); - try { - var collector = sharedManager.newCollector(); - assert queryParams.computePreciseHitsCount() == null || (queryParams.computePreciseHitsCount() == collector - .scoreMode() - .isExhaustive()); + .flatMap(sharedManager -> Flux + .fromIterable(indexSearchers) + .>handle((shard, sink) -> { + LLUtils.ensureBlocking(); + try { + var collector = sharedManager.newCollector(); + assert queryParams.computePreciseHitsCount() == null || (queryParams.computePreciseHitsCount() + == collector.scoreMode().isExhaustive()); - shard.search(queryParams.query(), LuceneUtils.withTimeout(collector, queryParams.timeout())); - sink.next(collector); - } catch (IOException e) { - sink.error(e); - } - }).collectList().handle((collectors, sink) -> { - LLUtils.ensureBlocking(); - try { - if (collectors.size() <= 1) { - sink.next(sharedManager.reduce((List) collectors)); - } else if (queryParams.isSorted() && !queryParams.isSortedByScore()) { - final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()]; - int i = 0; - for (var collector : collectors) { - var topFieldDocs = ((TopFieldCollector) collector).topDocs(); - for (ScoreDoc scoreDoc : topFieldDocs.scoreDocs) { - scoreDoc.shardIndex = i; - } - topDocs[i++] = topFieldDocs; + shard.search(queryParams.query(), LuceneUtils.withTimeout(collector, queryParams.timeout())); + sink.next(collector); + } catch (IOException e) { + sink.error(e); } - sink.next(TopDocs.merge(requireNonNull(queryParams.sort()), 0, queryParams.limitInt(), topDocs)); - } else { - final TopDocs[] topDocs = new TopDocs[collectors.size()]; - int i = 0; - for (var collector : collectors) { - var topScoreDocs = collector.topDocs(); - for (ScoreDoc scoreDoc : topScoreDocs.scoreDocs) { - scoreDoc.shardIndex = i; + }) + .collectList() + .handle((collectors, sink) -> { + LLUtils.ensureBlocking(); + try { + if (collectors.size() <= 1) { + sink.next(sharedManager.reduce((List) collectors)); + } else if (queryParams.isSorted() && !queryParams.isSortedByScore()) { + final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()]; + int i = 0; + for (var collector : collectors) { + var topFieldDocs = ((TopFieldCollector) collector).topDocs(); + for (ScoreDoc scoreDoc : topFieldDocs.scoreDocs) { + scoreDoc.shardIndex = i; + } + topDocs[i++] = topFieldDocs; + } + sink.next(TopDocs.merge(requireNonNull(queryParams.sort()), 0, queryParams.limitInt(), topDocs)); + } else { + final TopDocs[] topDocs = new TopDocs[collectors.size()]; + int i = 0; + for (var collector : collectors) { + var topScoreDocs = collector.topDocs(); + for (ScoreDoc scoreDoc : topScoreDocs.scoreDocs) { + scoreDoc.shardIndex = i; + } + topDocs[i++] = topScoreDocs; + } + sink.next(TopDocs.merge(0, queryParams.limitInt(), topDocs)); } - topDocs[i++] = topScoreDocs; + } catch (IOException ex) { + sink.error(ex); } - sink.next(TopDocs.merge(0, queryParams.limitInt(), topDocs)); - } - } catch (IOException ex) { - sink.error(ex); - } - })); + })); } /** diff --git a/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java b/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java index 11e26f3..711ffbd 100644 --- a/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java +++ b/src/test/java/it/cavallium/dbengine/UnsortedUnscoredSimpleMultiSearcher.java @@ -1,6 +1,7 @@ package it.cavallium.dbengine; import static it.cavallium.dbengine.client.UninterruptibleScheduler.uninterruptibleScheduler; +import static it.cavallium.dbengine.database.LLUtils.singleOrClose; import static it.cavallium.dbengine.lucene.searcher.GlobalQueryRewrite.NO_REWRITE; import io.netty5.buffer.api.Send; @@ -14,6 +15,7 @@ import it.cavallium.dbengine.lucene.searcher.LocalQueryParams; import it.cavallium.dbengine.lucene.searcher.LocalSearcher; import it.cavallium.dbengine.lucene.searcher.LuceneSearchResult; import it.cavallium.dbengine.lucene.searcher.MultiSearcher; +import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher; import it.cavallium.dbengine.utils.SimpleResource; import java.io.IOException; import java.util.ArrayList; @@ -40,7 +42,7 @@ public class UnsortedUnscoredSimpleMultiSearcher implements MultiSearcher { String keyFieldName, GlobalQueryRewrite transformer) { - return indexSearchersMono.flatMap(indexSearchers -> { + return singleOrClose(indexSearchersMono, indexSearchers -> { Mono queryParamsMono; if (transformer == NO_REWRITE) { queryParamsMono = Mono.just(queryParams);