Allow random sampling

This commit is contained in:
Andrea Cavalli 2021-11-19 22:15:31 +01:00
parent 29d9aad8bf
commit 3810c49fa1
3 changed files with 30 additions and 9 deletions

View File

@ -13,11 +13,13 @@ import org.apache.lucene.facet.Facets;
import org.apache.lucene.facet.FacetsCollector;
import org.apache.lucene.facet.FacetsCollectorManager;
import org.apache.lucene.facet.LabelAndValue;
import org.apache.lucene.facet.RandomSamplingFacetsCollector;
import org.apache.lucene.facet.range.DoubleRange;
import org.apache.lucene.facet.range.DoubleRangeFacetCounts;
import org.apache.lucene.facet.range.LongRange;
import org.apache.lucene.facet.range.LongRangeFacetCounts;
import org.apache.lucene.facet.range.Range;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
@ -31,8 +33,12 @@ import org.jetbrains.annotations.Nullable;
public class DecimalBucketMultiCollectorManager implements CollectorMultiManager<Buckets, Buckets> {
private final FacetsCollectorManager facetsCollectorManager;
private final Range[] bucketRanges;
private final List<Query> queries;
private final Query normalizationQuery;
private final @Nullable Integer sampleSize;
private final String bucketField;
private final BucketValueSource bucketValueSource;
@ -43,8 +49,6 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
private final double maximum;
private final int buckets;
private final Range[] bucketRanges;
// todo: replace with an argument
private static final boolean USE_LONGS = true;
@ -54,7 +58,8 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
String bucketField,
BucketValueSource bucketValueSource,
List<Query> queries,
Query normalizationQuery) {
Query normalizationQuery,
@Nullable Integer sampleSize) {
this.queries = queries;
this.normalizationQuery = normalizationQuery;
var bucketsInt = (int) Math.ceil(buckets);
@ -65,6 +70,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
this.totalLength = bucketLength * bucketsInt;
this.bucketField = bucketField;
this.bucketValueSource = bucketValueSource;
this.sampleSize = sampleSize;
if (USE_LONGS) {
this.bucketRanges = new LongRange[bucketsInt];
@ -90,6 +96,17 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
);
}
}
this.facetsCollectorManager = new FacetsCollectorManager() {
@Override
public FacetsCollector newCollector() throws IOException {
if (sampleSize != null) {
return new RandomSamplingFacetsCollector(sampleSize);
} else {
return super.newCollector();
}
}
};
}
public double[] newBuckets() {
@ -97,7 +114,8 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
}
public Buckets search(IndexSearcher indexSearcher) throws IOException {
var facetsCollectorManager = new FacetsCollectorManager();
//var facetsCollector = facetsCollectorManager.newCollector();
//FacetsCollector.search(indexSearcher, normalizationQuery, 10, facetsCollector);
var facetsCollector = indexSearcher.search(normalizationQuery, facetsCollectorManager);
double[] reducedNormalizationBuckets = newBuckets();
List<DoubleArrayList> seriesReducedBuckets = new ArrayList<>(queries.size());
@ -112,6 +130,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
if (USE_LONGS) {
LongValuesSource valuesSource;
if (bucketValueSource instanceof NullValueSource) {
valuesSource = null;
} else if (bucketValueSource instanceof ConstantValueSource constantValueSource) {
valuesSource = LongValuesSource.constant(constantValueSource.constant().longValue());
@ -144,7 +163,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
(DoubleRange[]) bucketRanges
);
}
FacetResult children = facets.getTopChildren(100, bucketField);
FacetResult children = facets.getTopChildren(0, bucketField);
for (LabelAndValue labelAndValue : children.labelValues) {
var index = Integer.parseInt(labelAndValue.label);
reducedBuckets.set(index, reducedBuckets.getDouble(index) + labelAndValue.value.doubleValue());
@ -167,7 +186,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
normalizationFacets = new LongRangeFacetCounts(bucketField,
valuesSource,
facetsCollector,
null,
normalizationQuery,
(LongRange[]) bucketRanges
);
} else {
@ -184,7 +203,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
normalizationFacets = new DoubleRangeFacetCounts(bucketField,
valuesSource,
facetsCollector,
null,
normalizationQuery,
(DoubleRange[]) bucketRanges
);
}

View File

@ -2,6 +2,7 @@ package it.cavallium.dbengine.lucene.searcher;
import it.cavallium.dbengine.lucene.collector.BucketValueSource;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public record BucketParams(double min, double max, int buckets, String bucketFieldName,
@NotNull BucketValueSource valueSource) {}
@NotNull BucketValueSource valueSource, @Nullable Integer sampleSize) {}

View File

@ -45,7 +45,8 @@ public class DecimalBucketMultiSearcher {
bucketParams.bucketFieldName(),
bucketParams.valueSource(),
queries,
normalizationQuery
normalizationQuery,
bucketParams.sampleSize()
);
})
.flatMap(cmm -> Flux