2022-06-14 21:58:26 +02:00
|
|
|
package it.cavallium.dbengine.lucene.searcher;
|
|
|
|
|
|
|
|
import java.io.IOException;
|
|
|
|
import java.util.ArrayList;
|
|
|
|
import java.util.Collection;
|
|
|
|
import java.util.HashMap;
|
|
|
|
import java.util.HashSet;
|
|
|
|
import java.util.List;
|
|
|
|
import java.util.Map;
|
|
|
|
import java.util.NoSuchElementException;
|
2022-06-15 18:36:22 +02:00
|
|
|
import java.util.Optional;
|
2022-06-14 21:58:26 +02:00
|
|
|
import java.util.Set;
|
|
|
|
import org.apache.lucene.index.Term;
|
|
|
|
import org.apache.lucene.index.TermStates;
|
|
|
|
import org.apache.lucene.search.CollectionStatistics;
|
|
|
|
import org.apache.lucene.search.IndexSearcher;
|
|
|
|
import org.apache.lucene.search.Query;
|
|
|
|
import org.apache.lucene.search.QueryVisitor;
|
|
|
|
import org.apache.lucene.search.TermStatistics;
|
2022-06-15 18:36:22 +02:00
|
|
|
import org.jetbrains.annotations.Nullable;
|
2022-06-14 21:58:26 +02:00
|
|
|
|
|
|
|
public class ShardIndexSearcher extends IndexSearcher {
|
|
|
|
public final int myNodeID;
|
|
|
|
private final IndexSearcher[] searchers;
|
|
|
|
|
2022-06-15 18:36:22 +02:00
|
|
|
private final Map<FieldAndShar, CachedCollectionStatistics> collectionStatsCache;
|
2022-06-14 21:58:26 +02:00
|
|
|
private final Map<TermAndShard, TermStatistics> termStatsCache;
|
|
|
|
|
2022-06-15 18:36:22 +02:00
|
|
|
public ShardIndexSearcher(SharedShardStatistics sharedShardStatistics, List<IndexSearcher> searchers, int nodeID) {
|
2022-06-14 21:58:26 +02:00
|
|
|
super(searchers.get(nodeID).getIndexReader(), searchers.get(nodeID).getExecutor());
|
|
|
|
this.collectionStatsCache = sharedShardStatistics.collectionStatsCache;
|
|
|
|
this.termStatsCache = sharedShardStatistics.termStatsCache;
|
|
|
|
this.searchers = searchers.toArray(IndexSearcher[]::new);
|
|
|
|
myNodeID = nodeID;
|
|
|
|
}
|
|
|
|
|
|
|
|
public static List<IndexSearcher> create(Iterable<IndexSearcher> indexSearchersIterable) {
|
|
|
|
var it = indexSearchersIterable.iterator();
|
|
|
|
if (it.hasNext()) {
|
|
|
|
var is = it.next();
|
|
|
|
if (!it.hasNext()) {
|
|
|
|
return List.of(is);
|
|
|
|
}
|
|
|
|
if (is instanceof ShardIndexSearcher) {
|
|
|
|
if (indexSearchersIterable instanceof List<IndexSearcher> list) {
|
|
|
|
return list;
|
|
|
|
} else {
|
|
|
|
var result = new ArrayList<IndexSearcher>();
|
|
|
|
result.add(is);
|
|
|
|
do {
|
|
|
|
result.add(it.next());
|
|
|
|
} while (it.hasNext());
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
List<IndexSearcher> indexSearchers;
|
|
|
|
if (indexSearchersIterable instanceof List<IndexSearcher> collection) {
|
|
|
|
indexSearchers = collection;
|
|
|
|
} else if (indexSearchersIterable instanceof Collection<IndexSearcher> collection) {
|
|
|
|
indexSearchers = List.copyOf(collection);
|
|
|
|
} else {
|
|
|
|
indexSearchers = new ArrayList<>();
|
|
|
|
for (IndexSearcher i : indexSearchersIterable) {
|
|
|
|
indexSearchers.add(i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (indexSearchers.size() == 0) {
|
|
|
|
return List.of();
|
|
|
|
} else {
|
|
|
|
var sharedShardStatistics = new SharedShardStatistics();
|
|
|
|
List<IndexSearcher> result = new ArrayList<>(indexSearchers.size());
|
|
|
|
for (int nodeId = 0; nodeId < indexSearchers.size(); nodeId++) {
|
|
|
|
result.add(new ShardIndexSearcher(sharedShardStatistics, indexSearchers, nodeId));
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public Query rewrite(Query original) throws IOException {
|
|
|
|
final IndexSearcher localSearcher = new IndexSearcher(getIndexReader());
|
|
|
|
original = localSearcher.rewrite(original);
|
|
|
|
final Set<Term> terms = new HashSet<>();
|
|
|
|
original.visit(QueryVisitor.termCollector(terms));
|
|
|
|
|
|
|
|
// Make a single request to remote nodes for term
|
|
|
|
// stats:
|
|
|
|
for (int nodeID = 0; nodeID < searchers.length; nodeID++) {
|
|
|
|
if (nodeID == myNodeID) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
final Set<Term> missing = new HashSet<>();
|
|
|
|
for (Term term : terms) {
|
2022-06-15 18:36:22 +02:00
|
|
|
final TermAndShard key = new TermAndShard(nodeID, term);
|
2022-06-14 21:58:26 +02:00
|
|
|
if (!termStatsCache.containsKey(key)) {
|
|
|
|
missing.add(term);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (missing.size() != 0) {
|
2022-06-15 18:36:22 +02:00
|
|
|
for (Map.Entry<Term, TermStatistics> ent : getNodeTermStats(missing, nodeID).entrySet()) {
|
2022-06-14 21:58:26 +02:00
|
|
|
if (ent.getValue() != null) {
|
|
|
|
final TermAndShard key = new TermAndShard(nodeID, ent.getKey());
|
|
|
|
termStatsCache.put(key, ent.getValue());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return original;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Mock: in a real env, this would hit the wire and get
|
|
|
|
// term stats from remote node
|
2022-06-15 18:36:22 +02:00
|
|
|
Map<Term, TermStatistics> getNodeTermStats(Set<Term> terms, int nodeID) throws IOException {
|
2022-06-14 21:58:26 +02:00
|
|
|
var s = searchers[nodeID];
|
|
|
|
final Map<Term, TermStatistics> stats = new HashMap<>();
|
|
|
|
if (s == null) {
|
|
|
|
throw new NoSuchElementException("node=" + nodeID);
|
|
|
|
}
|
|
|
|
for (Term term : terms) {
|
|
|
|
final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true);
|
|
|
|
if (ts.docFreq() > 0) {
|
|
|
|
stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return stats;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq)
|
|
|
|
throws IOException {
|
|
|
|
assert term != null;
|
|
|
|
long distributedDocFreq = 0;
|
|
|
|
long distributedTotalTermFreq = 0;
|
|
|
|
for (int nodeID = 0; nodeID < searchers.length; nodeID++) {
|
|
|
|
|
|
|
|
final TermStatistics subStats;
|
|
|
|
if (nodeID == myNodeID) {
|
|
|
|
subStats = super.termStatistics(term, docFreq, totalTermFreq);
|
|
|
|
} else {
|
2022-06-15 18:36:22 +02:00
|
|
|
final TermAndShard key = new TermAndShard(nodeID, term);
|
2022-06-14 21:58:26 +02:00
|
|
|
subStats = termStatsCache.get(key);
|
|
|
|
if (subStats == null) {
|
|
|
|
continue; // term not found
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
long nodeDocFreq = subStats.docFreq();
|
|
|
|
distributedDocFreq += nodeDocFreq;
|
|
|
|
|
|
|
|
long nodeTotalTermFreq = subStats.totalTermFreq();
|
|
|
|
distributedTotalTermFreq += nodeTotalTermFreq;
|
|
|
|
}
|
|
|
|
assert distributedDocFreq > 0;
|
|
|
|
return new TermStatistics(term.bytes(), distributedDocFreq, distributedTotalTermFreq);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public CollectionStatistics collectionStatistics(String field) throws IOException {
|
|
|
|
// TODO: we could compute this on init and cache,
|
|
|
|
// since we are re-inited whenever any nodes have a
|
|
|
|
// new reader
|
|
|
|
long docCount = 0;
|
|
|
|
long sumTotalTermFreq = 0;
|
|
|
|
long sumDocFreq = 0;
|
|
|
|
long maxDoc = 0;
|
|
|
|
|
|
|
|
for (int nodeID = 0; nodeID < searchers.length; nodeID++) {
|
|
|
|
final FieldAndShar key = new FieldAndShar(nodeID, field);
|
|
|
|
final CollectionStatistics nodeStats;
|
|
|
|
if (nodeID == myNodeID) {
|
|
|
|
nodeStats = super.collectionStatistics(field);
|
2022-06-15 18:36:22 +02:00
|
|
|
collectionStatsCache.put(key, new CachedCollectionStatistics(nodeStats));
|
2022-06-14 21:58:26 +02:00
|
|
|
} else {
|
2022-06-15 18:36:22 +02:00
|
|
|
var nodeStatsOptional = collectionStatsCache.get(key);
|
|
|
|
if (nodeStatsOptional == null) {
|
|
|
|
nodeStatsOptional = new CachedCollectionStatistics(computeNodeCollectionStatistics(key));
|
|
|
|
collectionStatsCache.put(key, nodeStatsOptional);
|
|
|
|
}
|
|
|
|
nodeStats = nodeStatsOptional.collectionStatistics();
|
2022-06-14 21:58:26 +02:00
|
|
|
}
|
|
|
|
if (nodeStats == null) {
|
|
|
|
continue; // field not in sub at all
|
|
|
|
}
|
|
|
|
|
|
|
|
long nodeDocCount = nodeStats.docCount();
|
|
|
|
docCount += nodeDocCount;
|
|
|
|
|
|
|
|
long nodeSumTotalTermFreq = nodeStats.sumTotalTermFreq();
|
|
|
|
sumTotalTermFreq += nodeSumTotalTermFreq;
|
|
|
|
|
|
|
|
long nodeSumDocFreq = nodeStats.sumDocFreq();
|
|
|
|
sumDocFreq += nodeSumDocFreq;
|
|
|
|
|
|
|
|
assert nodeStats.maxDoc() >= 0;
|
|
|
|
maxDoc += nodeStats.maxDoc();
|
|
|
|
}
|
|
|
|
|
|
|
|
if (maxDoc == 0) {
|
|
|
|
return null; // field not found across any node whatsoever
|
|
|
|
} else {
|
|
|
|
return new CollectionStatistics(field, maxDoc, docCount, sumTotalTermFreq, sumDocFreq);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-06-15 18:36:22 +02:00
|
|
|
private CollectionStatistics computeNodeCollectionStatistics(FieldAndShar fieldAndShard) throws IOException {
|
|
|
|
var searcher = searchers[fieldAndShard.nodeID];
|
|
|
|
return searcher.collectionStatistics(fieldAndShard.field);
|
|
|
|
}
|
|
|
|
|
|
|
|
public record CachedCollectionStatistics(@Nullable CollectionStatistics collectionStatistics) {}
|
|
|
|
|
2022-06-14 21:58:26 +02:00
|
|
|
public static class TermAndShard {
|
|
|
|
private final int nodeID;
|
|
|
|
private final Term term;
|
|
|
|
|
|
|
|
public TermAndShard(int nodeID, Term term) {
|
|
|
|
this.nodeID = nodeID;
|
|
|
|
this.term = term;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public int hashCode() {
|
|
|
|
return (nodeID + term.hashCode());
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public boolean equals(Object _other) {
|
|
|
|
if (!(_other instanceof final TermAndShard other)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
return term.equals(other.term) && nodeID == other.nodeID;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public static class FieldAndShar {
|
|
|
|
private final int nodeID;
|
|
|
|
private final String field;
|
|
|
|
|
|
|
|
public FieldAndShar(int nodeID, String field) {
|
|
|
|
this.nodeID = nodeID;
|
|
|
|
this.field = field;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public int hashCode() {
|
|
|
|
return (nodeID + field.hashCode());
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public boolean equals(Object _other) {
|
|
|
|
if (!(_other instanceof final FieldAndShar other)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
return field.equals(other.field) && nodeID == other.nodeID;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public String toString() {
|
|
|
|
return "FieldAndShardVersion(field="
|
|
|
|
+ field
|
|
|
|
+ " nodeID="
|
|
|
|
+ nodeID
|
|
|
|
+ ")";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|