From 3810c49fa1c872ed6a583a978be866fa6fad2488 Mon Sep 17 00:00:00 2001 From: Andrea Cavalli Date: Fri, 19 Nov 2021 22:15:31 +0100 Subject: [PATCH] Allow random sampling --- .../DecimalBucketMultiCollectorManager.java | 33 +++++++++++++++---- .../lucene/searcher/BucketParams.java | 3 +- .../searcher/DecimalBucketMultiSearcher.java | 3 +- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java b/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java index 8a765d0..7a71505 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java +++ b/src/main/java/it/cavallium/dbengine/lucene/collector/DecimalBucketMultiCollectorManager.java @@ -13,11 +13,13 @@ import org.apache.lucene.facet.Facets; import org.apache.lucene.facet.FacetsCollector; import org.apache.lucene.facet.FacetsCollectorManager; import org.apache.lucene.facet.LabelAndValue; +import org.apache.lucene.facet.RandomSamplingFacetsCollector; import org.apache.lucene.facet.range.DoubleRange; import org.apache.lucene.facet.range.DoubleRangeFacetCounts; import org.apache.lucene.facet.range.LongRange; import org.apache.lucene.facet.range.LongRangeFacetCounts; import org.apache.lucene.facet.range.Range; +import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.IndexSearcher; @@ -31,8 +33,12 @@ import org.jetbrains.annotations.Nullable; public class DecimalBucketMultiCollectorManager implements CollectorMultiManager { + private final FacetsCollectorManager facetsCollectorManager; + private final Range[] bucketRanges; + private final List queries; private final Query normalizationQuery; + private final @Nullable Integer sampleSize; private final String bucketField; private final BucketValueSource bucketValueSource; @@ -43,8 +49,6 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager private final double maximum; private final int buckets; - private final Range[] bucketRanges; - // todo: replace with an argument private static final boolean USE_LONGS = true; @@ -54,7 +58,8 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager String bucketField, BucketValueSource bucketValueSource, List queries, - Query normalizationQuery) { + Query normalizationQuery, + @Nullable Integer sampleSize) { this.queries = queries; this.normalizationQuery = normalizationQuery; var bucketsInt = (int) Math.ceil(buckets); @@ -65,6 +70,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager this.totalLength = bucketLength * bucketsInt; this.bucketField = bucketField; this.bucketValueSource = bucketValueSource; + this.sampleSize = sampleSize; if (USE_LONGS) { this.bucketRanges = new LongRange[bucketsInt]; @@ -90,6 +96,17 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager ); } } + + this.facetsCollectorManager = new FacetsCollectorManager() { + @Override + public FacetsCollector newCollector() throws IOException { + if (sampleSize != null) { + return new RandomSamplingFacetsCollector(sampleSize); + } else { + return super.newCollector(); + } + } + }; } public double[] newBuckets() { @@ -97,7 +114,8 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager } public Buckets search(IndexSearcher indexSearcher) throws IOException { - var facetsCollectorManager = new FacetsCollectorManager(); + //var facetsCollector = facetsCollectorManager.newCollector(); + //FacetsCollector.search(indexSearcher, normalizationQuery, 10, facetsCollector); var facetsCollector = indexSearcher.search(normalizationQuery, facetsCollectorManager); double[] reducedNormalizationBuckets = newBuckets(); List seriesReducedBuckets = new ArrayList<>(queries.size()); @@ -112,6 +130,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager if (USE_LONGS) { LongValuesSource valuesSource; if (bucketValueSource instanceof NullValueSource) { + valuesSource = null; } else if (bucketValueSource instanceof ConstantValueSource constantValueSource) { valuesSource = LongValuesSource.constant(constantValueSource.constant().longValue()); @@ -144,7 +163,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager (DoubleRange[]) bucketRanges ); } - FacetResult children = facets.getTopChildren(100, bucketField); + FacetResult children = facets.getTopChildren(0, bucketField); for (LabelAndValue labelAndValue : children.labelValues) { var index = Integer.parseInt(labelAndValue.label); reducedBuckets.set(index, reducedBuckets.getDouble(index) + labelAndValue.value.doubleValue()); @@ -167,7 +186,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager normalizationFacets = new LongRangeFacetCounts(bucketField, valuesSource, facetsCollector, - null, + normalizationQuery, (LongRange[]) bucketRanges ); } else { @@ -184,7 +203,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager normalizationFacets = new DoubleRangeFacetCounts(bucketField, valuesSource, facetsCollector, - null, + normalizationQuery, (DoubleRange[]) bucketRanges ); } diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/BucketParams.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/BucketParams.java index e501000..bc3d217 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/BucketParams.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/BucketParams.java @@ -2,6 +2,7 @@ package it.cavallium.dbengine.lucene.searcher; import it.cavallium.dbengine.lucene.collector.BucketValueSource; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public record BucketParams(double min, double max, int buckets, String bucketFieldName, - @NotNull BucketValueSource valueSource) {} + @NotNull BucketValueSource valueSource, @Nullable Integer sampleSize) {} diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java index 7131245..d3a6fd0 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java @@ -45,7 +45,8 @@ public class DecimalBucketMultiSearcher { bucketParams.bucketFieldName(), bucketParams.valueSource(), queries, - normalizationQuery + normalizationQuery, + bucketParams.sampleSize() ); }) .flatMap(cmm -> Flux