Allow random sampling

This commit is contained in:
Andrea Cavalli 2021-11-19 22:15:31 +01:00
parent 29d9aad8bf
commit 3810c49fa1
3 changed files with 30 additions and 9 deletions

View File

@ -13,11 +13,13 @@ import org.apache.lucene.facet.Facets;
import org.apache.lucene.facet.FacetsCollector; import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollectorManager; import org.apache.lucene.facet.FacetsCollectorManager;
import org.apache.lucene.facet.LabelAndValue; import org.apache.lucene.facet.LabelAndValue;
import org.apache.lucene.facet.RandomSamplingFacetsCollector;
import org.apache.lucene.facet.range.DoubleRange; import org.apache.lucene.facet.range.DoubleRange;
import org.apache.lucene.facet.range.DoubleRangeFacetCounts; import org.apache.lucene.facet.range.DoubleRangeFacetCounts;
import org.apache.lucene.facet.range.LongRange; import org.apache.lucene.facet.range.LongRange;
import org.apache.lucene.facet.range.LongRangeFacetCounts; import org.apache.lucene.facet.range.LongRangeFacetCounts;
import org.apache.lucene.facet.range.Range; import org.apache.lucene.facet.range.Range;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
@ -31,8 +33,12 @@ import org.jetbrains.annotations.Nullable;
public class DecimalBucketMultiCollectorManager implements CollectorMultiManager<Buckets, Buckets> { public class DecimalBucketMultiCollectorManager implements CollectorMultiManager<Buckets, Buckets> {
private final FacetsCollectorManager facetsCollectorManager;
private final Range[] bucketRanges;
private final List<Query> queries; private final List<Query> queries;
private final Query normalizationQuery; private final Query normalizationQuery;
private final @Nullable Integer sampleSize;
private final String bucketField; private final String bucketField;
private final BucketValueSource bucketValueSource; private final BucketValueSource bucketValueSource;
@ -43,8 +49,6 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
private final double maximum; private final double maximum;
private final int buckets; private final int buckets;
private final Range[] bucketRanges;
// todo: replace with an argument // todo: replace with an argument
private static final boolean USE_LONGS = true; private static final boolean USE_LONGS = true;
@ -54,7 +58,8 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
String bucketField, String bucketField,
BucketValueSource bucketValueSource, BucketValueSource bucketValueSource,
List<Query> queries, List<Query> queries,
Query normalizationQuery) { Query normalizationQuery,
@Nullable Integer sampleSize) {
this.queries = queries; this.queries = queries;
this.normalizationQuery = normalizationQuery; this.normalizationQuery = normalizationQuery;
var bucketsInt = (int) Math.ceil(buckets); var bucketsInt = (int) Math.ceil(buckets);
@ -65,6 +70,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
this.totalLength = bucketLength * bucketsInt; this.totalLength = bucketLength * bucketsInt;
this.bucketField = bucketField; this.bucketField = bucketField;
this.bucketValueSource = bucketValueSource; this.bucketValueSource = bucketValueSource;
this.sampleSize = sampleSize;
if (USE_LONGS) { if (USE_LONGS) {
this.bucketRanges = new LongRange[bucketsInt]; this.bucketRanges = new LongRange[bucketsInt];
@ -90,6 +96,17 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
); );
} }
} }
this.facetsCollectorManager = new FacetsCollectorManager() {
@Override
public FacetsCollector newCollector() throws IOException {
if (sampleSize != null) {
return new RandomSamplingFacetsCollector(sampleSize);
} else {
return super.newCollector();
}
}
};
} }
public double[] newBuckets() { public double[] newBuckets() {
@ -97,7 +114,8 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
} }
public Buckets search(IndexSearcher indexSearcher) throws IOException { public Buckets search(IndexSearcher indexSearcher) throws IOException {
var facetsCollectorManager = new FacetsCollectorManager(); //var facetsCollector = facetsCollectorManager.newCollector();
//FacetsCollector.search(indexSearcher, normalizationQuery, 10, facetsCollector);
var facetsCollector = indexSearcher.search(normalizationQuery, facetsCollectorManager); var facetsCollector = indexSearcher.search(normalizationQuery, facetsCollectorManager);
double[] reducedNormalizationBuckets = newBuckets(); double[] reducedNormalizationBuckets = newBuckets();
List<DoubleArrayList> seriesReducedBuckets = new ArrayList<>(queries.size()); List<DoubleArrayList> seriesReducedBuckets = new ArrayList<>(queries.size());
@ -112,6 +130,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
if (USE_LONGS) { if (USE_LONGS) {
LongValuesSource valuesSource; LongValuesSource valuesSource;
if (bucketValueSource instanceof NullValueSource) { if (bucketValueSource instanceof NullValueSource) {
valuesSource = null; valuesSource = null;
} else if (bucketValueSource instanceof ConstantValueSource constantValueSource) { } else if (bucketValueSource instanceof ConstantValueSource constantValueSource) {
valuesSource = LongValuesSource.constant(constantValueSource.constant().longValue()); valuesSource = LongValuesSource.constant(constantValueSource.constant().longValue());
@ -144,7 +163,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
(DoubleRange[]) bucketRanges (DoubleRange[]) bucketRanges
); );
} }
FacetResult children = facets.getTopChildren(100, bucketField); FacetResult children = facets.getTopChildren(0, bucketField);
for (LabelAndValue labelAndValue : children.labelValues) { for (LabelAndValue labelAndValue : children.labelValues) {
var index = Integer.parseInt(labelAndValue.label); var index = Integer.parseInt(labelAndValue.label);
reducedBuckets.set(index, reducedBuckets.getDouble(index) + labelAndValue.value.doubleValue()); reducedBuckets.set(index, reducedBuckets.getDouble(index) + labelAndValue.value.doubleValue());
@ -167,7 +186,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
normalizationFacets = new LongRangeFacetCounts(bucketField, normalizationFacets = new LongRangeFacetCounts(bucketField,
valuesSource, valuesSource,
facetsCollector, facetsCollector,
null, normalizationQuery,
(LongRange[]) bucketRanges (LongRange[]) bucketRanges
); );
} else { } else {
@ -184,7 +203,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
normalizationFacets = new DoubleRangeFacetCounts(bucketField, normalizationFacets = new DoubleRangeFacetCounts(bucketField,
valuesSource, valuesSource,
facetsCollector, facetsCollector,
null, normalizationQuery,
(DoubleRange[]) bucketRanges (DoubleRange[]) bucketRanges
); );
} }

View File

@ -2,6 +2,7 @@ package it.cavallium.dbengine.lucene.searcher;
import it.cavallium.dbengine.lucene.collector.BucketValueSource; import it.cavallium.dbengine.lucene.collector.BucketValueSource;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public record BucketParams(double min, double max, int buckets, String bucketFieldName, public record BucketParams(double min, double max, int buckets, String bucketFieldName,
@NotNull BucketValueSource valueSource) {} @NotNull BucketValueSource valueSource, @Nullable Integer sampleSize) {}

View File

@ -45,7 +45,8 @@ public class DecimalBucketMultiSearcher {
bucketParams.bucketFieldName(), bucketParams.bucketFieldName(),
bucketParams.valueSource(), bucketParams.valueSource(),
queries, queries,
normalizationQuery normalizationQuery,
bucketParams.sampleSize()
); );
}) })
.flatMap(cmm -> Flux .flatMap(cmm -> Flux