CavalliumDBEngine/src/main/java/it/cavallium/dbengine/lucene/collector/ScoringShardsCollectorMulti...

164 lines
5.3 KiB
Java

package it.cavallium.dbengine.lucene.collector;
import static it.cavallium.dbengine.lucene.searcher.CurrentPageInfo.TIE_BREAKER;
import it.cavallium.dbengine.lucene.LuceneUtils;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopFieldDocs;
import org.jetbrains.annotations.Nullable;
import reactor.core.scheduler.Schedulers;
public class ScoringShardsCollectorMultiManager implements CollectorMultiManager<TopDocs, TopDocs> {
private static final boolean USE_CLASSIC_REDUCE = false;
private final Query query;
@Nullable
private final Sort sort;
private final int numHits;
private final FieldDoc after;
private final int totalHitsThreshold;
private final @Nullable Integer startN;
private final @Nullable Integer topN;
private final CollectorManager<TopFieldCollector, TopFieldDocs> sharedCollectorManager;
public ScoringShardsCollectorMultiManager(Query query,
@Nullable final Sort sort,
final int numHits,
final FieldDoc after,
final int totalHitsThreshold,
int startN,
int topN) {
this(query, sort, numHits, after, totalHitsThreshold, (Integer) startN, (Integer) topN);
}
public ScoringShardsCollectorMultiManager(Query query,
@Nullable final Sort sort,
final int numHits,
final FieldDoc after,
final int totalHitsThreshold,
int startN) {
this(query, sort, numHits, after, totalHitsThreshold, (Integer) startN, (Integer) 2147483630);
}
public ScoringShardsCollectorMultiManager(Query query,
@Nullable final Sort sort,
final int numHits,
final FieldDoc after,
final int totalHitsThreshold) {
this(query, sort, numHits, after, totalHitsThreshold, null, null);
}
private ScoringShardsCollectorMultiManager(Query query,
@Nullable final Sort sort,
final int numHits,
final FieldDoc after,
final int totalHitsThreshold,
@Nullable Integer startN,
@Nullable Integer topN) {
this.query = query;
this.sort = sort;
this.numHits = numHits;
this.after = after;
this.totalHitsThreshold = totalHitsThreshold;
this.startN = startN;
if (topN != null && startN != null && (long) topN + (long) startN > 2147483630) {
this.topN = 2147483630 - startN;
} else if (topN != null && topN > 2147483630) {
this.topN = 2147483630;
} else {
this.topN = topN;
}
this.sharedCollectorManager = TopFieldCollector.createSharedManager(sort == null ? Sort.RELEVANCE : sort, numHits, after, totalHitsThreshold);
}
public CollectorManager<TopFieldCollector, TopDocs> get(IndexSearcher indexSearcher, int shardIndex) {
return new CollectorManager<>() {
@Override
public TopFieldCollector newCollector() throws IOException {
return sharedCollectorManager.newCollector();
}
@Override
public TopDocs reduce(Collection<TopFieldCollector> collectors) throws IOException {
if (Schedulers.isInNonBlockingThread()) {
throw new UnsupportedOperationException("Called reduce in a nonblocking thread");
}
if (USE_CLASSIC_REDUCE) {
final TopFieldDocs[] topDocs = new TopFieldDocs[collectors.size()];
int i = 0;
for (TopFieldCollector collector : collectors) {
topDocs[i++] = collector.topDocs();
}
var result = LuceneUtils.mergeTopDocs(sort, null, null, topDocs);
if (sort != null && sort.needsScores()) {
TopFieldCollector.populateScores(result.scoreDocs, indexSearcher, query);
}
return result;
} else {
TopDocs[] topDocs;
if (sort != null) {
topDocs = new TopFieldDocs[collectors.size()];
var i = 0;
for (TopFieldCollector collector : collectors) {
topDocs[i] = collector.topDocs();
// Populate scores of topfieldcollector. By default it doesn't popupate the scores
if (topDocs[i].scoreDocs.length > 0 && Float.isNaN(topDocs[i].scoreDocs[0].score) && sort.needsScores()) {
TopFieldCollector.populateScores(topDocs[i].scoreDocs, indexSearcher, query);
}
for (ScoreDoc scoreDoc : topDocs[i].scoreDocs) {
scoreDoc.shardIndex = shardIndex;
}
i++;
}
} else {
topDocs = new TopDocs[collectors.size()];
var i = 0;
for (TopFieldCollector collector : collectors) {
topDocs[i] = collector.topDocs();
for (ScoreDoc scoreDoc : topDocs[i].scoreDocs) {
scoreDoc.shardIndex = shardIndex;
}
i++;
}
}
return LuceneUtils.mergeTopDocs(sort, null, null, topDocs);
}
}
};
}
@Override
public ScoreMode scoreMode() {
throw new NotImplementedException();
}
@SuppressWarnings({"SuspiciousToArrayCall", "IfStatementWithIdenticalBranches"})
@Override
public TopDocs reduce(List<TopDocs> topDocs) {
TopDocs[] arr;
if (sort != null) {
arr = topDocs.toArray(TopFieldDocs[]::new);
} else {
arr = topDocs.toArray(TopDocs[]::new);
}
return LuceneUtils.mergeTopDocs(sort, startN, topN, arr);
}
}