Share term statistics across shards

This commit is contained in:
Andrea Cavalli 2022-06-14 21:58:26 +02:00
parent 0d830fbd21
commit 17c40757ba
7 changed files with 365 additions and 50 deletions

View File

@ -13,7 +13,7 @@
<revision>0-SNAPSHOT</revision> <revision>0-SNAPSHOT</revision>
<dbengine.ci>false</dbengine.ci> <dbengine.ci>false</dbengine.ci>
<micrometer.version>1.8.5</micrometer.version> <micrometer.version>1.8.5</micrometer.version>
<lucene.version>9.1.0</lucene.version> <lucene.version>9.2.0</lucene.version>
<junit.jupiter.version>5.8.2</junit.jupiter.version> <junit.jupiter.version>5.8.2</junit.jupiter.version>
</properties> </properties>
<repositories> <repositories>
@ -205,7 +205,7 @@
<dependency> <dependency>
<groupId>org.rocksdb</groupId> <groupId>org.rocksdb</groupId>
<artifactId>rocksdbjni</artifactId> <artifactId>rocksdbjni</artifactId>
<version>7.2.2</version> <version>7.3.1</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.lucene</groupId> <groupId>org.apache.lucene</groupId>

View File

@ -5,6 +5,7 @@ import io.netty5.buffer.api.Owned;
import io.netty5.buffer.api.Resource; import io.netty5.buffer.api.Resource;
import io.netty5.buffer.api.Send; import io.netty5.buffer.api.Send;
import io.netty5.buffer.api.internal.ResourceSupport; import io.netty5.buffer.api.internal.ResourceSupport;
import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
@ -13,6 +14,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
@ -98,12 +100,19 @@ public interface LLIndexSearchers extends Closeable {
private final List<IndexSearcher> indexSearchersVals; private final List<IndexSearcher> indexSearchersVals;
public ShardedIndexSearchers(List<LLIndexSearcher> indexSearchers) { public ShardedIndexSearchers(List<LLIndexSearcher> indexSearchers) {
this.indexSearchers = new ArrayList<>(indexSearchers.size()); var shardedIndexSearchers = new ArrayList<LLIndexSearcher>(indexSearchers.size());
this.indexSearchersVals = new ArrayList<>(indexSearchers.size()); List<IndexSearcher> shardedIndexSearchersVals = new ArrayList<>(indexSearchers.size());
for (LLIndexSearcher indexSearcher : indexSearchers) { for (LLIndexSearcher indexSearcher : indexSearchers) {
this.indexSearchers.add(indexSearcher); shardedIndexSearchersVals.add(indexSearcher.getIndexSearcher());
this.indexSearchersVals.add(indexSearcher.getIndexSearcher());
} }
shardedIndexSearchersVals = ShardIndexSearcher.create(shardedIndexSearchersVals);
int i = 0;
for (IndexSearcher shardedIndexSearcher : shardedIndexSearchersVals) {
shardedIndexSearchers.add(new WrappedLLIndexSearcher(shardedIndexSearcher, indexSearchers.get(i)));
i++;
}
this.indexSearchers = shardedIndexSearchers;
this.indexSearchersVals = shardedIndexSearchersVals;
} }
@Override @Override
@ -156,5 +165,30 @@ public interface LLIndexSearchers extends Closeable {
indexSearcher.close(); indexSearcher.close();
} }
} }
private static class WrappedLLIndexSearcher extends LLIndexSearcher {
private final LLIndexSearcher parent;
public WrappedLLIndexSearcher(IndexSearcher indexSearcher, LLIndexSearcher parent) {
super(indexSearcher, parent.getClosed());
this.parent = parent;
}
@Override
public IndexSearcher getIndexSearcher() {
return indexSearcher;
}
@Override
public IndexReader getIndexReader() {
return indexSearcher.getIndexReader();
}
@Override
protected void onClose() throws IOException {
parent.onClose();
}
}
} }
} }

View File

@ -183,7 +183,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
(DoubleRange[]) bucketRanges (DoubleRange[]) bucketRanges
); );
} }
FacetResult children = facets.getTopChildren(0, bucketField); FacetResult children = facets.getTopChildren(1, bucketField);
if (AMORTIZE && randomSamplingEnabled) { if (AMORTIZE && randomSamplingEnabled) {
var cfg = new FacetsConfig(); var cfg = new FacetsConfig();
for (Range bucketRange : bucketRanges) { for (Range bucketRange : bucketRanges) {

View File

@ -0,0 +1,263 @@
package it.cavallium.dbengine.lucene.searcher;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermStatistics;
public class ShardIndexSearcher extends IndexSearcher {
public final int myNodeID;
private final IndexSearcher[] searchers;
private final Map<FieldAndShar, CollectionStatistics> collectionStatsCache;
private final Map<TermAndShard, TermStatistics> termStatsCache;
public ShardIndexSearcher(SharedShardStatistics sharedShardStatistics,
List<IndexSearcher> searchers,
int nodeID) {
super(searchers.get(nodeID).getIndexReader(), searchers.get(nodeID).getExecutor());
this.collectionStatsCache = sharedShardStatistics.collectionStatsCache;
this.termStatsCache = sharedShardStatistics.termStatsCache;
this.searchers = searchers.toArray(IndexSearcher[]::new);
myNodeID = nodeID;
}
public static List<IndexSearcher> create(Iterable<IndexSearcher> indexSearchersIterable) {
var it = indexSearchersIterable.iterator();
if (it.hasNext()) {
var is = it.next();
if (!it.hasNext()) {
return List.of(is);
}
if (is instanceof ShardIndexSearcher) {
if (indexSearchersIterable instanceof List<IndexSearcher> list) {
return list;
} else {
var result = new ArrayList<IndexSearcher>();
result.add(is);
do {
result.add(it.next());
} while (it.hasNext());
return result;
}
}
}
List<IndexSearcher> indexSearchers;
if (indexSearchersIterable instanceof List<IndexSearcher> collection) {
indexSearchers = collection;
} else if (indexSearchersIterable instanceof Collection<IndexSearcher> collection) {
indexSearchers = List.copyOf(collection);
} else {
indexSearchers = new ArrayList<>();
for (IndexSearcher i : indexSearchersIterable) {
indexSearchers.add(i);
}
}
if (indexSearchers.size() == 0) {
return List.of();
} else {
var sharedShardStatistics = new SharedShardStatistics();
List<IndexSearcher> result = new ArrayList<>(indexSearchers.size());
for (int nodeId = 0; nodeId < indexSearchers.size(); nodeId++) {
result.add(new ShardIndexSearcher(sharedShardStatistics, indexSearchers, nodeId));
}
return result;
}
}
@Override
public Query rewrite(Query original) throws IOException {
final IndexSearcher localSearcher = new IndexSearcher(getIndexReader());
original = localSearcher.rewrite(original);
final Set<Term> terms = new HashSet<>();
original.visit(QueryVisitor.termCollector(terms));
// Make a single request to remote nodes for term
// stats:
for (int nodeID = 0; nodeID < searchers.length; nodeID++) {
if (nodeID == myNodeID) {
continue;
}
final Set<Term> missing = new HashSet<>();
for (Term term : terms) {
final TermAndShard key =
new TermAndShard(nodeID, term);
if (!termStatsCache.containsKey(key)) {
missing.add(term);
}
}
if (missing.size() != 0) {
for (Map.Entry<Term, TermStatistics> ent :
getNodeTermStats(missing, nodeID).entrySet()) {
if (ent.getValue() != null) {
final TermAndShard key = new TermAndShard(nodeID, ent.getKey());
termStatsCache.put(key, ent.getValue());
}
}
}
}
return original;
}
// Mock: in a real env, this would hit the wire and get
// term stats from remote node
Map<Term, TermStatistics> getNodeTermStats(Set<Term> terms, int nodeID)
throws IOException {
var s = searchers[nodeID];
final Map<Term, TermStatistics> stats = new HashMap<>();
if (s == null) {
throw new NoSuchElementException("node=" + nodeID);
}
for (Term term : terms) {
final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true);
if (ts.docFreq() > 0) {
stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq()));
}
}
return stats;
}
@Override
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq)
throws IOException {
assert term != null;
long distributedDocFreq = 0;
long distributedTotalTermFreq = 0;
for (int nodeID = 0; nodeID < searchers.length; nodeID++) {
final TermStatistics subStats;
if (nodeID == myNodeID) {
subStats = super.termStatistics(term, docFreq, totalTermFreq);
} else {
final TermAndShard key =
new TermAndShard(nodeID, term);
subStats = termStatsCache.get(key);
if (subStats == null) {
continue; // term not found
}
}
long nodeDocFreq = subStats.docFreq();
distributedDocFreq += nodeDocFreq;
long nodeTotalTermFreq = subStats.totalTermFreq();
distributedTotalTermFreq += nodeTotalTermFreq;
}
assert distributedDocFreq > 0;
return new TermStatistics(term.bytes(), distributedDocFreq, distributedTotalTermFreq);
}
@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
// TODO: we could compute this on init and cache,
// since we are re-inited whenever any nodes have a
// new reader
long docCount = 0;
long sumTotalTermFreq = 0;
long sumDocFreq = 0;
long maxDoc = 0;
for (int nodeID = 0; nodeID < searchers.length; nodeID++) {
final FieldAndShar key = new FieldAndShar(nodeID, field);
final CollectionStatistics nodeStats;
if (nodeID == myNodeID) {
nodeStats = super.collectionStatistics(field);
} else {
nodeStats = collectionStatsCache.get(key);
}
if (nodeStats == null) {
continue; // field not in sub at all
}
long nodeDocCount = nodeStats.docCount();
docCount += nodeDocCount;
long nodeSumTotalTermFreq = nodeStats.sumTotalTermFreq();
sumTotalTermFreq += nodeSumTotalTermFreq;
long nodeSumDocFreq = nodeStats.sumDocFreq();
sumDocFreq += nodeSumDocFreq;
assert nodeStats.maxDoc() >= 0;
maxDoc += nodeStats.maxDoc();
}
if (maxDoc == 0) {
return null; // field not found across any node whatsoever
} else {
return new CollectionStatistics(field, maxDoc, docCount, sumTotalTermFreq, sumDocFreq);
}
}
public static class TermAndShard {
private final int nodeID;
private final Term term;
public TermAndShard(int nodeID, Term term) {
this.nodeID = nodeID;
this.term = term;
}
@Override
public int hashCode() {
return (nodeID + term.hashCode());
}
@Override
public boolean equals(Object _other) {
if (!(_other instanceof final TermAndShard other)) {
return false;
}
return term.equals(other.term) && nodeID == other.nodeID;
}
}
public static class FieldAndShar {
private final int nodeID;
private final String field;
public FieldAndShar(int nodeID, String field) {
this.nodeID = nodeID;
this.field = field;
}
@Override
public int hashCode() {
return (nodeID + field.hashCode());
}
@Override
public boolean equals(Object _other) {
if (!(_other instanceof final FieldAndShar other)) {
return false;
}
return field.equals(other.field) && nodeID == other.nodeID;
}
@Override
public String toString() {
return "FieldAndShardVersion(field="
+ field
+ " nodeID="
+ nodeID
+ ")";
}
}
}

View File

@ -0,0 +1,14 @@
package it.cavallium.dbengine.lucene.searcher;
import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher.FieldAndShar;
import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher.TermAndShard;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
public class SharedShardStatistics {
public final Map<FieldAndShar, CollectionStatistics> collectionStatsCache = new ConcurrentHashMap<>();
public final Map<TermAndShard, TermStatistics> termStatsCache = new ConcurrentHashMap<>();
}

View File

@ -4,7 +4,6 @@ import static it.cavallium.dbengine.client.UninterruptibleScheduler.uninterrupti
import static it.cavallium.dbengine.database.LLUtils.singleOrClose; import static it.cavallium.dbengine.database.LLUtils.singleOrClose;
import static java.util.Objects.requireNonNull; import static java.util.Objects.requireNonNull;
import io.netty5.buffer.api.Send;
import it.cavallium.dbengine.database.LLKeyScore; import it.cavallium.dbengine.database.LLKeyScore;
import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.LLUtils;
import it.cavallium.dbengine.database.disk.LLIndexSearchers; import it.cavallium.dbengine.database.disk.LLIndexSearchers;
@ -75,51 +74,54 @@ public class StandardSearcher implements MultiSearcher {
} }
}) })
.subscribeOn(uninterruptibleScheduler(Schedulers.boundedElastic())) .subscribeOn(uninterruptibleScheduler(Schedulers.boundedElastic()))
.flatMap(sharedManager -> Flux.fromIterable(indexSearchers).<TopDocsCollector<?>>handle((shard, sink) -> { .flatMap(sharedManager -> Flux
LLUtils.ensureBlocking(); .fromIterable(indexSearchers)
try { .<TopDocsCollector<?>>handle((shard, sink) -> {
var collector = sharedManager.newCollector(); LLUtils.ensureBlocking();
assert queryParams.computePreciseHitsCount() == null || (queryParams.computePreciseHitsCount() == collector try {
.scoreMode() var collector = sharedManager.newCollector();
.isExhaustive()); assert queryParams.computePreciseHitsCount() == null || (queryParams.computePreciseHitsCount()
== collector.scoreMode().isExhaustive());
shard.search(queryParams.query(), LuceneUtils.withTimeout(collector, queryParams.timeout())); shard.search(queryParams.query(), LuceneUtils.withTimeout(collector, queryParams.timeout()));
sink.next(collector); sink.next(collector);
} catch (IOException e) { } catch (IOException e) {
sink.error(e); sink.error(e);
}
}).collectList().handle((collectors, sink) -> {
LLUtils.ensureBlocking();
try {
if (collectors.size() <= 1) {
sink.next(sharedManager.reduce((List) collectors));
} else if (queryParams.isSorted() && !queryParams.isSortedByScore()) {
final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()];
int i = 0;
for (var collector : collectors) {
var topFieldDocs = ((TopFieldCollector) collector).topDocs();
for (ScoreDoc scoreDoc : topFieldDocs.scoreDocs) {
scoreDoc.shardIndex = i;
}
topDocs[i++] = topFieldDocs;
} }
sink.next(TopDocs.merge(requireNonNull(queryParams.sort()), 0, queryParams.limitInt(), topDocs)); })
} else { .collectList()
final TopDocs[] topDocs = new TopDocs[collectors.size()]; .handle((collectors, sink) -> {
int i = 0; LLUtils.ensureBlocking();
for (var collector : collectors) { try {
var topScoreDocs = collector.topDocs(); if (collectors.size() <= 1) {
for (ScoreDoc scoreDoc : topScoreDocs.scoreDocs) { sink.next(sharedManager.reduce((List) collectors));
scoreDoc.shardIndex = i; } else if (queryParams.isSorted() && !queryParams.isSortedByScore()) {
final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()];
int i = 0;
for (var collector : collectors) {
var topFieldDocs = ((TopFieldCollector) collector).topDocs();
for (ScoreDoc scoreDoc : topFieldDocs.scoreDocs) {
scoreDoc.shardIndex = i;
}
topDocs[i++] = topFieldDocs;
}
sink.next(TopDocs.merge(requireNonNull(queryParams.sort()), 0, queryParams.limitInt(), topDocs));
} else {
final TopDocs[] topDocs = new TopDocs[collectors.size()];
int i = 0;
for (var collector : collectors) {
var topScoreDocs = collector.topDocs();
for (ScoreDoc scoreDoc : topScoreDocs.scoreDocs) {
scoreDoc.shardIndex = i;
}
topDocs[i++] = topScoreDocs;
}
sink.next(TopDocs.merge(0, queryParams.limitInt(), topDocs));
} }
topDocs[i++] = topScoreDocs; } catch (IOException ex) {
sink.error(ex);
} }
sink.next(TopDocs.merge(0, queryParams.limitInt(), topDocs)); }));
}
} catch (IOException ex) {
sink.error(ex);
}
}));
} }
/** /**

View File

@ -1,6 +1,7 @@
package it.cavallium.dbengine; package it.cavallium.dbengine;
import static it.cavallium.dbengine.client.UninterruptibleScheduler.uninterruptibleScheduler; import static it.cavallium.dbengine.client.UninterruptibleScheduler.uninterruptibleScheduler;
import static it.cavallium.dbengine.database.LLUtils.singleOrClose;
import static it.cavallium.dbengine.lucene.searcher.GlobalQueryRewrite.NO_REWRITE; import static it.cavallium.dbengine.lucene.searcher.GlobalQueryRewrite.NO_REWRITE;
import io.netty5.buffer.api.Send; import io.netty5.buffer.api.Send;
@ -14,6 +15,7 @@ import it.cavallium.dbengine.lucene.searcher.LocalQueryParams;
import it.cavallium.dbengine.lucene.searcher.LocalSearcher; import it.cavallium.dbengine.lucene.searcher.LocalSearcher;
import it.cavallium.dbengine.lucene.searcher.LuceneSearchResult; import it.cavallium.dbengine.lucene.searcher.LuceneSearchResult;
import it.cavallium.dbengine.lucene.searcher.MultiSearcher; import it.cavallium.dbengine.lucene.searcher.MultiSearcher;
import it.cavallium.dbengine.lucene.searcher.ShardIndexSearcher;
import it.cavallium.dbengine.utils.SimpleResource; import it.cavallium.dbengine.utils.SimpleResource;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -40,7 +42,7 @@ public class UnsortedUnscoredSimpleMultiSearcher implements MultiSearcher {
String keyFieldName, String keyFieldName,
GlobalQueryRewrite transformer) { GlobalQueryRewrite transformer) {
return indexSearchersMono.flatMap(indexSearchers -> { return singleOrClose(indexSearchersMono, indexSearchers -> {
Mono<LocalQueryParams> queryParamsMono; Mono<LocalQueryParams> queryParamsMono;
if (transformer == NO_REWRITE) { if (transformer == NO_REWRITE) {
queryParamsMono = Mono.just(queryParams); queryParamsMono = Mono.just(queryParams);