Allow to use absolute values

This commit is contained in:
Andrea Cavalli 2021-11-20 01:12:17 +01:00
parent 3810c49fa1
commit 06d98040b1

View File

@ -3,6 +3,7 @@ package it.cavallium.dbengine.lucene.collector;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -19,6 +20,8 @@ import org.apache.lucene.facet.range.DoubleRangeFacetCounts;
import org.apache.lucene.facet.range.LongRange; import org.apache.lucene.facet.range.LongRange;
import org.apache.lucene.facet.range.LongRangeFacetCounts; import org.apache.lucene.facet.range.LongRangeFacetCounts;
import org.apache.lucene.facet.range.Range; import org.apache.lucene.facet.range.Range;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.DoubleValuesSource; import org.apache.lucene.search.DoubleValuesSource;
@ -37,7 +40,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
private final Range[] bucketRanges; private final Range[] bucketRanges;
private final List<Query> queries; private final List<Query> queries;
private final Query normalizationQuery; private final @Nullable Query normalizationQuery;
private final @Nullable Integer sampleSize; private final @Nullable Integer sampleSize;
private final String bucketField; private final String bucketField;
@ -58,7 +61,7 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
String bucketField, String bucketField,
BucketValueSource bucketValueSource, BucketValueSource bucketValueSource,
List<Query> queries, List<Query> queries,
Query normalizationQuery, @Nullable Query normalizationQuery,
@Nullable Integer sampleSize) { @Nullable Integer sampleSize) {
this.queries = queries; this.queries = queries;
this.normalizationQuery = normalizationQuery; this.normalizationQuery = normalizationQuery;
@ -114,9 +117,18 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
} }
public Buckets search(IndexSearcher indexSearcher) throws IOException { public Buckets search(IndexSearcher indexSearcher) throws IOException {
//var facetsCollector = facetsCollectorManager.newCollector(); Query globalQuery;
//FacetsCollector.search(indexSearcher, normalizationQuery, 10, facetsCollector); if (normalizationQuery != null) {
var facetsCollector = indexSearcher.search(normalizationQuery, facetsCollectorManager); globalQuery = normalizationQuery;
} else {
var booleanQueryBuilder = new BooleanQuery.Builder();
for (Query query : queries) {
booleanQueryBuilder.add(query, Occur.SHOULD);
}
booleanQueryBuilder.setMinimumNumberShouldMatch(1);
globalQuery = booleanQueryBuilder.build();
}
var facetsCollector = indexSearcher.search(globalQuery, facetsCollectorManager);
double[] reducedNormalizationBuckets = newBuckets(); double[] reducedNormalizationBuckets = newBuckets();
List<DoubleArrayList> seriesReducedBuckets = new ArrayList<>(queries.size()); List<DoubleArrayList> seriesReducedBuckets = new ArrayList<>(queries.size());
for (int i = 0; i < queries.size(); i++) { for (int i = 0; i < queries.size(); i++) {
@ -172,45 +184,49 @@ public class DecimalBucketMultiCollectorManager implements CollectorMultiManager
} }
Facets normalizationFacets; Facets normalizationFacets;
if (USE_LONGS) { if (normalizationQuery != null) {
LongValuesSource valuesSource; if (USE_LONGS) {
if (bucketValueSource instanceof NullValueSource) { LongValuesSource valuesSource;
valuesSource = null; if (bucketValueSource instanceof NullValueSource) {
} else if (bucketValueSource instanceof ConstantValueSource constantValueSource) { valuesSource = null;
valuesSource = LongValuesSource.constant(constantValueSource.constant().longValue()); } else if (bucketValueSource instanceof ConstantValueSource constantValueSource) {
} else if (bucketValueSource instanceof LongBucketValueSource longBucketValueSource) { valuesSource = LongValuesSource.constant(constantValueSource.constant().longValue());
valuesSource = longBucketValueSource.source(); } else if (bucketValueSource instanceof LongBucketValueSource longBucketValueSource) {
valuesSource = longBucketValueSource.source();
} else {
throw new IllegalArgumentException("Wrong value source type: " + bucketValueSource);
}
normalizationFacets = new LongRangeFacetCounts(bucketField,
valuesSource,
facetsCollector,
null,
(LongRange[]) bucketRanges
);
} else { } else {
throw new IllegalArgumentException("Wrong value source type: " + bucketValueSource); DoubleValuesSource valuesSource;
if (bucketValueSource instanceof NullValueSource) {
valuesSource = null;
} else if (bucketValueSource instanceof ConstantValueSource constantValueSource) {
valuesSource = DoubleValuesSource.constant(constantValueSource.constant().longValue());
} else if (bucketValueSource instanceof DoubleBucketValueSource doubleBucketValueSource) {
valuesSource = doubleBucketValueSource.source();
} else {
throw new IllegalArgumentException("Wrong value source type: " + bucketValueSource);
}
normalizationFacets = new DoubleRangeFacetCounts(bucketField,
valuesSource,
facetsCollector,
null,
(DoubleRange[]) bucketRanges
);
}
var normalizationChildren = normalizationFacets.getTopChildren(0, bucketField);
for (LabelAndValue labelAndValue : normalizationChildren.labelValues) {
var index = Integer.parseInt(labelAndValue.label);
reducedNormalizationBuckets[index] += labelAndValue.value.doubleValue();
} }
normalizationFacets = new LongRangeFacetCounts(bucketField,
valuesSource,
facetsCollector,
normalizationQuery,
(LongRange[]) bucketRanges
);
} else { } else {
DoubleValuesSource valuesSource; Arrays.fill(reducedNormalizationBuckets, 1);
if (bucketValueSource instanceof NullValueSource) {
valuesSource = null;
} else if (bucketValueSource instanceof ConstantValueSource constantValueSource) {
valuesSource = DoubleValuesSource.constant(constantValueSource.constant().longValue());
} else if (bucketValueSource instanceof DoubleBucketValueSource doubleBucketValueSource) {
valuesSource = doubleBucketValueSource.source();
} else {
throw new IllegalArgumentException("Wrong value source type: " + bucketValueSource);
}
normalizationFacets = new DoubleRangeFacetCounts(bucketField,
valuesSource,
facetsCollector,
normalizationQuery,
(DoubleRange[]) bucketRanges
);
}
var normalizationChildren = normalizationFacets.getTopChildren(0, bucketField);
for (LabelAndValue labelAndValue : normalizationChildren.labelValues) {
var index = Integer.parseInt(labelAndValue.label);
reducedNormalizationBuckets[index] += labelAndValue.value.doubleValue();
} }
return new Buckets(seriesReducedBuckets, DoubleArrayList.wrap(reducedNormalizationBuckets)); return new Buckets(seriesReducedBuckets, DoubleArrayList.wrap(reducedNormalizationBuckets));
} }