From daa70476145938f04890b982b2ffe9f0d94dfb7e Mon Sep 17 00:00:00 2001 From: Andrea Cavalli Date: Fri, 24 Feb 2023 17:19:25 +0100 Subject: [PATCH] Partially replace foreach with collecton --- .../dbengine/client/LuceneIndexImpl.java | 47 +-- .../dbengine/database/LLDictionary.java | 2 +- .../dbengine/database/LLLuceneIndex.java | 16 +- .../database/LLMultiDatabaseConnection.java | 23 +- .../dbengine/database/LLMultiLuceneIndex.java | 63 ++-- .../database/LLSearchResultShard.java | 1 + .../cavallium/dbengine/database/LLUtils.java | 6 - .../dbengine/database/SerializedKey.java | 5 + .../collections/DatabaseMapDictionary.java | 25 +- .../DatabaseMapDictionaryDeep.java | 1 - .../collections/DatabaseStageMap.java | 38 +- .../database/disk/LLLocalDictionary.java | 168 ++++----- .../disk/LLLocalKeyValueDatabase.java | 119 ++++--- .../database/disk/LLLocalLuceneIndex.java | 5 +- .../disk/LLLocalMultiLuceneIndex.java | 70 ++-- .../database/memory/LLMemoryDictionary.java | 7 +- .../dbengine/lucene/LuceneUtils.java | 2 +- .../searcher/DecimalBucketMultiSearcher.java | 11 +- .../searcher/ScoredPagedMultiSearcher.java | 9 +- .../UnsortedStreamingMultiSearcher.java | 1 - .../dbengine/utils/MoshiPolymorphic.java | 30 +- .../utils/PartitionByIntSpliterator.java | 89 +++++ .../utils/PartitionBySpliterator.java | 88 +++++ .../cavallium/dbengine/utils/StreamUtils.java | 330 ++++++++++++++++-- .../dbengine/tests/TestDictionaryMap.java | 17 +- .../dbengine/tests/TestLLDictionary.java | 11 +- 26 files changed, 828 insertions(+), 356 deletions(-) create mode 100644 src/main/java/it/cavallium/dbengine/database/SerializedKey.java create mode 100644 src/main/java/it/cavallium/dbengine/utils/PartitionByIntSpliterator.java create mode 100644 src/main/java/it/cavallium/dbengine/utils/PartitionBySpliterator.java diff --git a/src/main/java/it/cavallium/dbengine/client/LuceneIndexImpl.java b/src/main/java/it/cavallium/dbengine/client/LuceneIndexImpl.java index 70b986b..74236d2 100644 --- a/src/main/java/it/cavallium/dbengine/client/LuceneIndexImpl.java +++ b/src/main/java/it/cavallium/dbengine/client/LuceneIndexImpl.java @@ -1,6 +1,10 @@ package it.cavallium.dbengine.client; -import static it.cavallium.dbengine.utils.StreamUtils.toListClose; +import static it.cavallium.dbengine.utils.StreamUtils.LUCENE_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.toListOn; +import static java.util.stream.Collectors.collectingAndThen; +import static java.util.stream.Collectors.toList; import it.cavallium.dbengine.client.Hits.CloseableHits; import it.cavallium.dbengine.client.Hits.LuceneHits; @@ -92,34 +96,19 @@ public class LuceneIndexImpl implements LuceneIndex { var mltDocumentFields = indicizer.getMoreLikeThisDocumentFields(key, mltDocumentValue); - var results = toListClose(luceneIndex - .moreLikeThis(resolveSnapshot(queryParams.snapshot()), + return collectOn(LUCENE_SCHEDULER, luceneIndex.moreLikeThis(resolveSnapshot(queryParams.snapshot()), queryParams.toQueryParams(), indicizer.getKeyFieldName(), - mltDocumentFields - )); - LLSearchResultShard mergedResults = mergeResults(queryParams, results); - if (mergedResults != null) { - return mapResults(mergedResults); - } else { - return Hits.empty(); - } + mltDocumentFields), + collectingAndThen(toList(), toHitsCollector(queryParams))); } @Override public Hits> search(ClientQueryParams queryParams) { - var results = toListClose(luceneIndex - .search(resolveSnapshot(queryParams.snapshot()), + return collectOn(LUCENE_SCHEDULER, luceneIndex.search(resolveSnapshot(queryParams.snapshot()), queryParams.toQueryParams(), - indicizer.getKeyFieldName() - )); - - var mergedResults = mergeResults(queryParams, results); - if (mergedResults != null) { - return mapResults(mergedResults); - } else { - return Hits.empty(); - } + indicizer.getKeyFieldName()), + collectingAndThen(toList(), toHitsCollector(queryParams))); } @Override @@ -192,6 +181,18 @@ public class LuceneIndexImpl implements LuceneIndex { luceneIndex.releaseSnapshot(snapshot); } + private Function, Hits>> toHitsCollector(ClientQueryParams queryParams) { + return (List results) -> resultsToHits(mergeResults(queryParams, results)); + } + + private Hits> resultsToHits(LLSearchResultShard resultShard) { + if (resultShard != null) { + return mapResults(resultShard); + } else { + return Hits.empty(); + } + } + @SuppressWarnings({"unchecked", "rawtypes"}) @Nullable private static LLSearchResultShard mergeResults(ClientQueryParams queryParams, List shards) { @@ -224,7 +225,7 @@ public class LuceneIndexImpl implements LuceneIndex { } else if (results.size() == 1) { resultsFlux = results.get(0); } else { - resultsFlux = results.parallelStream().flatMap(Function.identity()); + resultsFlux = results.stream().flatMap(Function.identity()); } if (luceneResources) { return new LuceneLLSearchResultShard(resultsFlux, count, (List) resources); diff --git a/src/main/java/it/cavallium/dbengine/database/LLDictionary.java b/src/main/java/it/cavallium/dbengine/database/LLDictionary.java index aae8d86..90433e1 100644 --- a/src/main/java/it/cavallium/dbengine/database/LLDictionary.java +++ b/src/main/java/it/cavallium/dbengine/database/LLDictionary.java @@ -38,7 +38,7 @@ public interface LLDictionary extends LLKeyValueDatabaseStructure { void putMulti(Stream entries); - Stream updateMulti(Stream keys, Stream serializedKeys, + Stream updateMulti(Stream> keys, KVSerializationFunction updateFunction); Stream getRange(@Nullable LLSnapshot snapshot, diff --git a/src/main/java/it/cavallium/dbengine/database/LLLuceneIndex.java b/src/main/java/it/cavallium/dbengine/database/LLLuceneIndex.java index 09b4200..1f9c61f 100644 --- a/src/main/java/it/cavallium/dbengine/database/LLLuceneIndex.java +++ b/src/main/java/it/cavallium/dbengine/database/LLLuceneIndex.java @@ -1,5 +1,8 @@ package it.cavallium.dbengine.database; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.fastReducing; + import com.google.common.collect.Multimap; import it.cavallium.dbengine.client.IBackuppable; import it.cavallium.dbengine.client.query.current.data.NoSort; @@ -8,10 +11,12 @@ import it.cavallium.dbengine.client.query.current.data.QueryParams; import it.cavallium.dbengine.client.query.current.data.TotalHitsCount; import it.cavallium.dbengine.lucene.collector.Buckets; import it.cavallium.dbengine.lucene.searcher.BucketParams; +import it.cavallium.dbengine.utils.StreamUtils; import java.io.IOException; import java.time.Duration; import java.util.List; import java.util.Map.Entry; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -69,11 +74,12 @@ public interface LLLuceneIndex extends LLSnapshottable, IBackuppable, SafeClosea false, timeout == null ? Long.MAX_VALUE : timeout.toMillis() ); - try (var stream = this.search(snapshot, params, null)) { - return stream.parallel().map(LLSearchResultShard::totalHitsCount).reduce(TotalHitsCount.of(0, true), - (a, b) -> TotalHitsCount.of(a.value() + b.value(), a.exact() && b.exact()) - ); - } + return collectOn(StreamUtils.LUCENE_SCHEDULER, + this.search(snapshot, params, null).map(LLSearchResultShard::totalHitsCount), + fastReducing(TotalHitsCount.of(0, true), + (a, b) -> TotalHitsCount.of(a.value() + b.value(), a.exact() && b.exact()) + ) + ); } boolean isLowMemoryMode(); diff --git a/src/main/java/it/cavallium/dbengine/database/LLMultiDatabaseConnection.java b/src/main/java/it/cavallium/dbengine/database/LLMultiDatabaseConnection.java index cdd9d20..9b44614 100644 --- a/src/main/java/it/cavallium/dbengine/database/LLMultiDatabaseConnection.java +++ b/src/main/java/it/cavallium/dbengine/database/LLMultiDatabaseConnection.java @@ -1,5 +1,9 @@ package it.cavallium.dbengine.database; +import static it.cavallium.dbengine.utils.StreamUtils.ROCKSDB_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.executing; + import com.google.common.collect.Multimap; import io.micrometer.core.instrument.MeterRegistry; import it.cavallium.dbengine.client.ConnectionSettings.ConnectionPart; @@ -85,14 +89,13 @@ public class LLMultiDatabaseConnection implements LLDatabaseConnection { @Override public LLDatabaseConnection connect() { - // todo: parallelize? - for (LLDatabaseConnection connection : allConnections) { + collectOn(ROCKSDB_SCHEDULER, allConnections.stream(), executing(connection -> { try { connection.connect(); } catch (Exception ex) { LOG.error("Failed to open connection", ex); } - } + })); return this; } @@ -142,18 +145,15 @@ public class LLMultiDatabaseConnection implements LLDatabaseConnection { ); } else { record ShardToIndex(int shard, LLLuceneIndex connIndex) {} - var indices = connectionToShardMap.entrySet().stream().flatMap(entry -> { + var luceneIndices = new LLLuceneIndex[indexStructure.totalShards()]; + connectionToShardMap.entrySet().stream().flatMap(entry -> { var connectionIndexStructure = indexStructure.setActiveShards(new IntArrayList(entry.getValue())); LLLuceneIndex connIndex = entry.getKey().getLuceneIndex(clusterName, connectionIndexStructure, indicizerAnalyzers, indicizerSimilarities, luceneOptions, luceneHacks); return entry.getValue().intStream().mapToObj(shard -> new ShardToIndex(shard, connIndex)); - }).toList(); - var luceneIndices = new LLLuceneIndex[indexStructure.totalShards()]; - for (var index : indices) { - luceneIndices[index.shard] = index.connIndex; - } + }).forEach(index -> luceneIndices[index.shard] = index.connIndex); return new LLMultiLuceneIndex(clusterName, indexStructure, indicizerAnalyzers, @@ -167,13 +167,12 @@ public class LLMultiDatabaseConnection implements LLDatabaseConnection { @Override public void disconnect() { - // todo: parallelize? - for (LLDatabaseConnection connection : allConnections) { + collectOn(ROCKSDB_SCHEDULER, allConnections.stream(), executing(connection -> { try { connection.disconnect(); } catch (Exception ex) { LOG.error("Failed to close connection", ex); } - } + })); } } diff --git a/src/main/java/it/cavallium/dbengine/database/LLMultiLuceneIndex.java b/src/main/java/it/cavallium/dbengine/database/LLMultiLuceneIndex.java index 0669498..db83960 100644 --- a/src/main/java/it/cavallium/dbengine/database/LLMultiLuceneIndex.java +++ b/src/main/java/it/cavallium/dbengine/database/LLMultiLuceneIndex.java @@ -1,6 +1,14 @@ package it.cavallium.dbengine.database; import static it.cavallium.dbengine.lucene.LuceneUtils.getLuceneIndexId; +import static it.cavallium.dbengine.utils.StreamUtils.LUCENE_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collect; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.executing; +import static it.cavallium.dbengine.utils.StreamUtils.fastListing; +import static it.cavallium.dbengine.utils.StreamUtils.fastReducing; +import static it.cavallium.dbengine.utils.StreamUtils.fastSummingLong; +import static it.cavallium.dbengine.utils.StreamUtils.partitionByInt; import static java.util.stream.Collectors.groupingBy; import com.google.common.collect.Multimap; @@ -14,6 +22,7 @@ import it.cavallium.dbengine.rpc.current.data.IndicizerAnalyzers; import it.cavallium.dbengine.rpc.current.data.IndicizerSimilarities; import it.cavallium.dbengine.rpc.current.data.LuceneIndexStructure; import it.cavallium.dbengine.rpc.current.data.LuceneOptions; +import it.cavallium.dbengine.utils.StreamUtils; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import java.util.ArrayList; @@ -83,17 +92,11 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { @Override public long addDocuments(boolean atomic, Stream> documents) { - var groupedRequests = documents - .collect(groupingBy(term -> getLuceneIndexId(term.getKey(), totalShards), - Int2ObjectOpenHashMap::new, - Collectors.toList() - )); - - return groupedRequests - .int2ObjectEntrySet() - .stream() - .map(entry -> luceneIndicesById[entry.getIntKey()].addDocuments(atomic, entry.getValue().stream())) - .reduce(0L, Long::sum); + return collectOn(LUCENE_SCHEDULER, + partitionByInt(term -> getLuceneIndexId(term.getKey(), totalShards), documents) + .map(entry -> luceneIndicesById[entry.key()].addDocuments(atomic, entry.values().stream())), + fastSummingLong() + ); } @Override @@ -108,17 +111,11 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { @Override public long updateDocuments(Stream> documents) { - var groupedRequests = documents - .collect(groupingBy(term -> getLuceneIndexId(term.getKey(), totalShards), - Int2ObjectOpenHashMap::new, - Collectors.toList() - )); - - return groupedRequests - .int2ObjectEntrySet() - .stream() - .map(entry -> luceneIndicesById[entry.getIntKey()].updateDocuments(entry.getValue().stream())) - .reduce(0L, Long::sum); + return collectOn(LUCENE_SCHEDULER, + partitionByInt(term -> getLuceneIndexId(term.getKey(), totalShards), documents) + .map(entry -> luceneIndicesById[entry.key()].updateDocuments(entry.values().stream())), + fastSummingLong() + ); } @Override @@ -131,7 +128,7 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { QueryParams queryParams, @Nullable String keyFieldName, Multimap mltDocumentFields) { - return luceneIndicesSet.parallelStream().flatMap(luceneIndex -> luceneIndex.moreLikeThis(snapshot, + return luceneIndicesSet.stream().flatMap(luceneIndex -> luceneIndex.moreLikeThis(snapshot, queryParams, keyFieldName, mltDocumentFields @@ -166,7 +163,7 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { public Stream search(@Nullable LLSnapshot snapshot, QueryParams queryParams, @Nullable String keyFieldName) { - return luceneIndicesSet.parallelStream().flatMap(luceneIndex -> luceneIndex.search(snapshot, + return luceneIndicesSet.stream().flatMap(luceneIndex -> luceneIndex.search(snapshot, queryParams, keyFieldName )); @@ -177,7 +174,7 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { @NotNull List queries, @Nullable Query normalizationQuery, BucketParams bucketParams) { - return mergeShards(luceneIndicesSet.parallelStream().map(luceneIndex -> luceneIndex.computeBuckets(snapshot, + return mergeShards(luceneIndicesSet.stream().map(luceneIndex -> luceneIndex.computeBuckets(snapshot, queries, normalizationQuery, bucketParams @@ -191,34 +188,34 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { @Override public void close() { - luceneIndicesSet.parallelStream().forEach(SafeCloseable::close); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::close)); } @Override public void flush() { - luceneIndicesSet.parallelStream().forEach(LLLuceneIndex::flush); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::flush)); } @Override public void waitForMerges() { - luceneIndicesSet.parallelStream().forEach(LLLuceneIndex::waitForMerges); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::waitForMerges)); } @Override public void waitForLastMerges() { - luceneIndicesSet.parallelStream().forEach(LLLuceneIndex::waitForLastMerges); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::waitForLastMerges)); } @Override public void refresh(boolean force) { - luceneIndicesSet.parallelStream().forEach(index -> index.refresh(force)); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(index -> index.refresh(force))); } @Override public LLSnapshot takeSnapshot() { // Generate next snapshot index var snapshotIndex = nextSnapshotNumber.getAndIncrement(); - var snapshot = luceneIndicesSet.parallelStream().map(LLSnapshottable::takeSnapshot).toList(); + var snapshot = collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream().map(LLSnapshottable::takeSnapshot), fastListing()); registeredSnapshots.put(snapshotIndex, snapshot); return new LLSnapshot(snapshotIndex); } @@ -235,12 +232,12 @@ public class LLMultiLuceneIndex implements LLLuceneIndex { @Override public void pauseForBackup() { - this.luceneIndicesSet.forEach(IBackuppable::pauseForBackup); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::pauseForBackup)); } @Override public void resumeAfterBackup() { - this.luceneIndicesSet.forEach(IBackuppable::resumeAfterBackup); + collectOn(LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::resumeAfterBackup)); } @Override diff --git a/src/main/java/it/cavallium/dbengine/database/LLSearchResultShard.java b/src/main/java/it/cavallium/dbengine/database/LLSearchResultShard.java index 3a9d083..21a87b7 100644 --- a/src/main/java/it/cavallium/dbengine/database/LLSearchResultShard.java +++ b/src/main/java/it/cavallium/dbengine/database/LLSearchResultShard.java @@ -63,6 +63,7 @@ public class LLSearchResultShard extends SimpleResource implements DiscardingClo @Override public void onClose() { + results.close(); } public static class ResourcesLLSearchResultShard extends LLSearchResultShard { diff --git a/src/main/java/it/cavallium/dbengine/database/LLUtils.java b/src/main/java/it/cavallium/dbengine/database/LLUtils.java index 95a7eb1..12be464 100644 --- a/src/main/java/it/cavallium/dbengine/database/LLUtils.java +++ b/src/main/java/it/cavallium/dbengine/database/LLUtils.java @@ -689,10 +689,4 @@ public class LLUtils { return term.getValueBytesRef(); } } - - public static void consume(Stream stream) { - try (stream) { - stream.forEach(NULL_CONSUMER); - } - } } diff --git a/src/main/java/it/cavallium/dbengine/database/SerializedKey.java b/src/main/java/it/cavallium/dbengine/database/SerializedKey.java new file mode 100644 index 0000000..b2ed6e4 --- /dev/null +++ b/src/main/java/it/cavallium/dbengine/database/SerializedKey.java @@ -0,0 +1,5 @@ +package it.cavallium.dbengine.database; + +import it.cavallium.dbengine.buffers.Buf; + +public record SerializedKey(T key, Buf serialized) {} diff --git a/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionary.java b/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionary.java index 61c28d8..3bfd2dd 100644 --- a/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionary.java +++ b/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionary.java @@ -1,6 +1,8 @@ package it.cavallium.dbengine.database.collections; -import static it.cavallium.dbengine.utils.StreamUtils.collectClose; +import static it.cavallium.dbengine.utils.StreamUtils.ROCKSDB_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.fastListing; import it.cavallium.dbengine.buffers.Buf; import it.cavallium.dbengine.buffers.BufDataInput; @@ -13,6 +15,7 @@ import it.cavallium.dbengine.database.LLDictionaryResultType; import it.cavallium.dbengine.database.LLEntry; import it.cavallium.dbengine.database.LLRange; import it.cavallium.dbengine.database.LLUtils; +import it.cavallium.dbengine.database.SerializedKey; import it.cavallium.dbengine.database.SubStageEntry; import it.cavallium.dbengine.database.UpdateMode; import it.cavallium.dbengine.database.UpdateReturnMode; @@ -22,6 +25,7 @@ import it.cavallium.dbengine.database.serialization.SerializationException; import it.cavallium.dbengine.database.serialization.SerializationFunction; import it.cavallium.dbengine.database.serialization.Serializer; import it.cavallium.dbengine.database.serialization.SerializerFixedBinaryLength; +import it.cavallium.dbengine.utils.StreamUtils; import it.unimi.dsi.fastutil.objects.Object2ObjectLinkedOpenHashMap; import it.unimi.dsi.fastutil.objects.Object2ObjectSortedMap; import it.unimi.dsi.fastutil.objects.Object2ObjectSortedMaps; @@ -166,7 +170,7 @@ public class DatabaseMapDictionary extends DatabaseMapDictionaryDeep get(@Nullable CompositeSnapshot snapshot) { - var map = collectClose(dictionary + Stream> stream = dictionary .getRange(resolveSnapshot(snapshot), range, false, true) .map(entry -> { Entry deserializedEntry; @@ -182,7 +186,12 @@ public class DatabaseMapDictionary extends DatabaseMapDictionaryDeep a, Object2ObjectLinkedOpenHashMap::new)); + }); + // serializedKey + // after this, it becomes serializedSuffixAndExt + var map = StreamUtils.collect(stream, + Collectors.toMap(Entry::getKey, Entry::getValue, (a, b) -> a, Object2ObjectLinkedOpenHashMap::new) + ); return map.isEmpty() ? null : map; } @@ -368,10 +377,9 @@ public class DatabaseMapDictionary extends DatabaseMapDictionaryDeep updateMulti(Stream keys, KVSerializationFunction updater) { - List sharedKeys = keys.toList(); - var serializedKeys = sharedKeys.stream().map(keySuffix -> serializeKeySuffixToKey(keySuffix)); + var serializedKeys = keys.map(keySuffix -> new SerializedKey<>(keySuffix, serializeKeySuffixToKey(keySuffix))); var serializedUpdater = getSerializedUpdater(updater); - return dictionary.updateMulti(sharedKeys.stream(), serializedKeys, serializedUpdater); + return dictionary.updateMulti(serializedKeys, serializedUpdater); } @Override @@ -487,9 +495,8 @@ public class DatabaseMapDictionary extends DatabaseMapDictionaryDeep> setAllValuesAndGetPrevious(Stream> entries) { - var previous = this.getAllValues(null, false); - dictionary.setRange(range, entries.map(entry -> serializeEntry(entry)), false); - return previous; + return getAllValues(null, false) + .onClose(() -> dictionary.setRange(range, entries.map(entry -> serializeEntry(entry)), false)); } @Override diff --git a/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionaryDeep.java b/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionaryDeep.java index 7c216d4..fa321ee 100644 --- a/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionaryDeep.java +++ b/src/main/java/it/cavallium/dbengine/database/collections/DatabaseMapDictionaryDeep.java @@ -276,7 +276,6 @@ public class DatabaseMapDictionaryDeep> implem public Stream> getAllStages(@Nullable CompositeSnapshot snapshot, boolean smallRange) { return dictionary .getRangeKeyPrefixes(resolveSnapshot(snapshot), range, keyPrefixLength + keySuffixLength, smallRange) - .parallel() .map(groupKeyWithoutExt -> { T deserializedSuffix; var splittedGroupSuffix = suffixSubList(groupKeyWithoutExt); diff --git a/src/main/java/it/cavallium/dbengine/database/collections/DatabaseStageMap.java b/src/main/java/it/cavallium/dbengine/database/collections/DatabaseStageMap.java index 4df4976..08cc843 100644 --- a/src/main/java/it/cavallium/dbengine/database/collections/DatabaseStageMap.java +++ b/src/main/java/it/cavallium/dbengine/database/collections/DatabaseStageMap.java @@ -1,9 +1,14 @@ package it.cavallium.dbengine.database.collections; -import static it.cavallium.dbengine.database.LLUtils.consume; +import static it.cavallium.dbengine.utils.StreamUtils.ROCKSDB_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.count; +import static it.cavallium.dbengine.utils.StreamUtils.executing; +import static it.cavallium.dbengine.utils.StreamUtils.iterating; import it.cavallium.dbengine.client.CompositeSnapshot; import it.cavallium.dbengine.database.Delta; +import it.cavallium.dbengine.database.LLLuceneIndex; import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.SubStageEntry; import it.cavallium.dbengine.database.UpdateMode; @@ -60,7 +65,7 @@ public interface DatabaseStageMap> extends Dat } default Stream updateMulti(Stream keys, KVSerializationFunction updater) { - return keys.parallel().map(key -> this.updateValue(key, prevValue -> updater.apply(key, prevValue))); + return keys.map(key -> this.updateValue(key, prevValue -> updater.apply(key, prevValue))); } default boolean updateValue(T key, SerializationFunction<@Nullable U, @Nullable U> updater) { @@ -98,28 +103,24 @@ public interface DatabaseStageMap> extends Dat * GetMulti must return the elements in sequence! */ default Stream> getMulti(@Nullable CompositeSnapshot snapshot, Stream keys) { - return keys.parallel().map(key -> Optional.ofNullable(this.getValue(snapshot, key))); + return keys.map(key -> Optional.ofNullable(this.getValue(snapshot, key))); } default void putMulti(Stream> entries) { - try (var stream = entries.parallel()) { - stream.forEach(entry -> this.putValue(entry.getKey(), entry.getValue())); - } + collectOn(ROCKSDB_SCHEDULER, entries, executing(entry -> this.putValue(entry.getKey(), entry.getValue()))); } Stream> getAllStages(@Nullable CompositeSnapshot snapshot, boolean smallRange); default Stream> getAllValues(@Nullable CompositeSnapshot snapshot, boolean smallRange) { - return this.getAllStages(snapshot, smallRange).parallel().mapMulti((stage, mapper) -> { + return this.getAllStages(snapshot, smallRange).map(stage -> { var val = stage.getValue().get(snapshot); - if (val != null) { - mapper.accept(Map.entry(stage.getKey(), val)); - } - }); + return val != null ? Map.entry(stage.getKey(), val) : null; + }).filter(Objects::nonNull); } default void setAllValues(Stream> entries) { - consume(setAllValuesAndGetPrevious(entries)); + setAllValuesAndGetPrevious(entries).close(); } Stream> setAllValuesAndGetPrevious(Stream> entries); @@ -136,16 +137,15 @@ public interface DatabaseStageMap> extends Dat this.setAllValues(values.map(entriesReplacer)); } } else { - try (var values = this.getAllValues(null, smallRange).map(entriesReplacer)) { - values.forEach(replacedEntry -> this.at(null, replacedEntry.getKey()).set(replacedEntry.getValue())); - } + collectOn(ROCKSDB_SCHEDULER, + this.getAllValues(null, smallRange).map(entriesReplacer), + executing(replacedEntry -> this.at(null, replacedEntry.getKey()).set(replacedEntry.getValue())) + ); } } default void replaceAll(Consumer> entriesReplacer) { - try (var stream = this.getAllStages(null, false)) { - stream.forEach(entriesReplacer); - } + collectOn(ROCKSDB_SCHEDULER, this.getAllStages(null, false), executing(entriesReplacer)); } @Override @@ -225,7 +225,7 @@ public interface DatabaseStageMap> extends Dat @Override default long leavesCount(@Nullable CompositeSnapshot snapshot, boolean fast) { - return StreamUtils.countClose(this.getAllStages(snapshot, false)); + return count(this.getAllStages(snapshot, false)); } /** diff --git a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalDictionary.java b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalDictionary.java index a72add5..f9a9f60 100644 --- a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalDictionary.java +++ b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalDictionary.java @@ -5,6 +5,13 @@ import static it.cavallium.dbengine.database.LLUtils.MARKER_ROCKSDB; import static it.cavallium.dbengine.database.LLUtils.isBoundedRange; import static it.cavallium.dbengine.database.LLUtils.toStringSafe; import static it.cavallium.dbengine.database.disk.UpdateAtomicResultMode.DELTA; +import static it.cavallium.dbengine.utils.StreamUtils.LUCENE_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.ROCKSDB_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collect; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.executing; +import static it.cavallium.dbengine.utils.StreamUtils.fastListing; +import static it.cavallium.dbengine.utils.StreamUtils.fastSummingLong; import static it.cavallium.dbengine.utils.StreamUtils.streamWhileNonNull; import static java.util.Objects.requireNonNull; import static it.cavallium.dbengine.utils.StreamUtils.batches; @@ -24,15 +31,18 @@ import it.cavallium.dbengine.database.LLRange; import it.cavallium.dbengine.database.LLSnapshot; import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.OptionalBuf; +import it.cavallium.dbengine.database.SerializedKey; import it.cavallium.dbengine.database.UpdateMode; import it.cavallium.dbengine.database.UpdateReturnMode; import it.cavallium.dbengine.database.serialization.KVSerializationFunction; import it.cavallium.dbengine.rpc.current.data.DatabaseOptions; import it.cavallium.dbengine.utils.DBException; +import it.cavallium.dbengine.utils.StreamUtils; import java.io.IOException; import java.nio.ByteBuffer; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.Callable; @@ -40,6 +50,7 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinTask; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Pair; @@ -484,47 +495,48 @@ public class LLLocalDictionary implements LLDictionary { @Override public void putMulti(Stream entries) { - batches(entries, Math.min(MULTI_GET_WINDOW, CAPPED_WRITE_BATCH_CAP)).parallel().forEach(entriesWindow -> { - try (var writeOptions = new WriteOptions()) { - assert !LLUtils.isInNonBlockingThread() : "Called putMulti in a nonblocking thread"; - if (USE_WRITE_BATCHES_IN_PUT_MULTI) { - try (var batch = new CappedWriteBatch(db, - CAPPED_WRITE_BATCH_CAP, - RESERVED_WRITE_BATCH_SIZE, - MAX_WRITE_BATCH_SIZE, - writeOptions - )) { - for (LLEntry entry : entriesWindow) { - batch.put(cfh, entry.getKey(), entry.getValue()); + collectOn(ROCKSDB_SCHEDULER, + batches(entries, Math.min(MULTI_GET_WINDOW, CAPPED_WRITE_BATCH_CAP)), + executing(entriesWindow -> { + try (var writeOptions = new WriteOptions()) { + assert !LLUtils.isInNonBlockingThread() : "Called putMulti in a nonblocking thread"; + if (USE_WRITE_BATCHES_IN_PUT_MULTI) { + try (var batch = new CappedWriteBatch(db, + CAPPED_WRITE_BATCH_CAP, + RESERVED_WRITE_BATCH_SIZE, + MAX_WRITE_BATCH_SIZE, + writeOptions + )) { + for (LLEntry entry : entriesWindow) { + batch.put(cfh, entry.getKey(), entry.getValue()); + } + batch.flush(); + } + } else { + for (LLEntry entry : entriesWindow) { + db.put(writeOptions, entry.getKey(), entry.getValue()); + } } - batch.flush(); + } catch (RocksDBException ex) { + throw new CompletionException(new DBException("Failed to write: " + ex.getMessage())); } - } else { - for (LLEntry entry : entriesWindow) { - db.put(writeOptions, entry.getKey(), entry.getValue()); - } - } - } catch (RocksDBException ex) { - throw new CompletionException(new DBException("Failed to write: " + ex.getMessage())); - } - }); + }) + ); } @Override - public Stream updateMulti(Stream keys, Stream serializedKeys, + public Stream updateMulti(Stream> keys, KVSerializationFunction updateFunction) { - record Key(K key, Buf serializedKey) {} record MappedInput(K key, Buf serializedKey, OptionalBuf mapped) {} - return batches(Streams.zip(keys, serializedKeys, Key::new), Math.min(MULTI_GET_WINDOW, CAPPED_WRITE_BATCH_CAP)) - .parallel() + return batches(keys, Math.min(MULTI_GET_WINDOW, CAPPED_WRITE_BATCH_CAP)) .flatMap(entriesWindow -> { try (var writeOptions = new WriteOptions()) { if (LLUtils.isInNonBlockingThread()) { throw new UnsupportedOperationException("Called updateMulti in a nonblocking thread"); } List keyBufsWindow = new ArrayList<>(entriesWindow.size()); - for (Key objects : entriesWindow) { - keyBufsWindow.add(objects.serializedKey()); + for (var objects : entriesWindow) { + keyBufsWindow.add(objects.serialized()); } ArrayList> mappedInputs; { @@ -572,12 +584,12 @@ public class LLLocalDictionary implements LLDictionary { writeOptions )) { int i = 0; - for (Key entry : entriesWindow) { + for (var entry : entriesWindow) { var valueToWrite = updatedValuesToWrite.get(i); if (valueToWrite == null) { - batch.delete(cfh, entry.serializedKey()); + batch.delete(cfh, entry.serialized()); } else { - batch.put(cfh, entry.serializedKey(), valueToWrite); + batch.put(cfh, entry.serialized(), valueToWrite); } i++; } @@ -585,8 +597,8 @@ public class LLLocalDictionary implements LLDictionary { } } else { int i = 0; - for (Key entry : entriesWindow) { - db.put(writeOptions, entry.serializedKey(), updatedValuesToWrite.get(i)); + for (var entry : entriesWindow) { + db.put(writeOptions, entry.serialized(), updatedValuesToWrite.get(i)); i++; } } @@ -618,7 +630,7 @@ public class LLLocalDictionary implements LLDictionary { if (range.isSingle()) { var rangeSingle = range.getSingle(); - return Stream.of(getRangeSingle(snapshot, rangeSingle).toList()); + return getRangeSingle(snapshot, rangeSingle).map(Collections::singletonList); } else { return getRangeMultiGrouped(snapshot, range, prefixLength, smallRange); } @@ -803,56 +815,55 @@ public class LLLocalDictionary implements LLDictionary { throw new DBException("Failed to set a range: " + ex.getMessage()); } - batches(entries, MULTI_GET_WINDOW) - .forEach(entriesList -> { - try (var writeOptions = new WriteOptions()) { - if (!USE_WRITE_BATCHES_IN_SET_RANGE) { - for (LLEntry entry : entriesList) { - db.put(writeOptions, entry.getKey(), entry.getValue()); - } - } else if (USE_CAPPED_WRITE_BATCH_IN_SET_RANGE) { - - try (var batch = new CappedWriteBatch(db, - CAPPED_WRITE_BATCH_CAP, - RESERVED_WRITE_BATCH_SIZE, - MAX_WRITE_BATCH_SIZE, - writeOptions - )) { - for (LLEntry entry : entriesList) { - batch.put(cfh, entry.getKey(), entry.getValue()); - } - batch.flush(); - } - } else { - try (var batch = new WriteBatch(RESERVED_WRITE_BATCH_SIZE)) { - for (LLEntry entry : entriesList) { - batch.put(cfh, LLUtils.asArray(entry.getKey()), LLUtils.asArray(entry.getValue())); - } - db.write(writeOptions, batch); - batch.clear(); - } - } - } catch (RocksDBException ex) { - throw new CompletionException(new DBException("Failed to write range", ex)); + collectOn(ROCKSDB_SCHEDULER, batches(entries, MULTI_GET_WINDOW), executing(entriesList -> { + try (var writeOptions = new WriteOptions()) { + if (!USE_WRITE_BATCHES_IN_SET_RANGE) { + for (LLEntry entry : entriesList) { + db.put(writeOptions, entry.getKey(), entry.getValue()); } - }); + } else if (USE_CAPPED_WRITE_BATCH_IN_SET_RANGE) { + + try (var batch = new CappedWriteBatch(db, + CAPPED_WRITE_BATCH_CAP, + RESERVED_WRITE_BATCH_SIZE, + MAX_WRITE_BATCH_SIZE, + writeOptions + )) { + for (LLEntry entry : entriesList) { + batch.put(cfh, entry.getKey(), entry.getValue()); + } + batch.flush(); + } + } else { + try (var batch = new WriteBatch(RESERVED_WRITE_BATCH_SIZE)) { + for (LLEntry entry : entriesList) { + batch.put(cfh, LLUtils.asArray(entry.getKey()), LLUtils.asArray(entry.getValue())); + } + db.write(writeOptions, batch); + batch.clear(); + } + } + } catch (RocksDBException ex) { + throw new CompletionException(new DBException("Failed to write range", ex)); + } + })); } else { if (USE_WRITE_BATCHES_IN_SET_RANGE) { throw new UnsupportedOperationException("Can't use write batches in setRange without window. Please fix the parameters"); } - this.getRange(null, range, false, smallRange).forEach(oldValue -> { + collectOn(ROCKSDB_SCHEDULER, this.getRange(null, range, false, smallRange), executing(oldValue -> { try (var writeOptions = new WriteOptions()) { db.delete(writeOptions, oldValue.getKey()); } catch (RocksDBException ex) { throw new CompletionException(new DBException("Failed to write range", ex)); } - }); + })); - entries.forEach(entry -> { + collectOn(ROCKSDB_SCHEDULER, entries, executing(entry -> { if (entry.getKey() != null && entry.getValue() != null) { this.putInternal(entry.getKey(), entry.getValue()); } - }); + })); } } @@ -1079,14 +1090,12 @@ public class LLLocalDictionary implements LLDictionary { readOpts.setVerifyChecksums(VERIFY_CHECKSUMS_WHEN_NOT_NEEDED); if (PARALLEL_EXACT_SIZE) { - var commonPool = ForkJoinPool.commonPool(); - var futures = IntStream + return collectOn(ROCKSDB_SCHEDULER, IntStream .range(-1, LLUtils.LEXICONOGRAPHIC_ITERATION_SEEKS.length) .mapToObj(idx -> Pair.of(idx == -1 ? new byte[0] : LLUtils.LEXICONOGRAPHIC_ITERATION_SEEKS[idx], idx + 1 >= LLUtils.LEXICONOGRAPHIC_ITERATION_SEEKS.length ? null : LLUtils.LEXICONOGRAPHIC_ITERATION_SEEKS[idx + 1] - )) - .map(range -> (Callable) () -> { + )).map(range -> { long partialCount = 0; try (var rangeReadOpts = new ReadOptions(readOpts)) { Slice sliceBegin; @@ -1116,6 +1125,8 @@ public class LLLocalDictionary implements LLDictionary { } return partialCount; } + } catch (RocksDBException ex) { + throw new CompletionException(new IOException("Failed to get size", ex)); } finally { if (sliceBegin != null) { sliceBegin.close(); @@ -1125,14 +1136,7 @@ public class LLLocalDictionary implements LLDictionary { } } } - }) - .map(task -> commonPool.submit(task)) - .toList(); - long count = 0; - for (ForkJoinTask future : futures) { - count += future.join(); - } - return count; + }), fastSummingLong()); } else { long count = 0; try (var rocksIterator = db.newIterator(readOpts, null, null)) { diff --git a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalKeyValueDatabase.java b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalKeyValueDatabase.java index ac2ae10..651c4de 100644 --- a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalKeyValueDatabase.java +++ b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalKeyValueDatabase.java @@ -1,6 +1,8 @@ package it.cavallium.dbengine.database.disk; import static it.cavallium.dbengine.database.LLUtils.MARKER_ROCKSDB; +import static it.cavallium.dbengine.utils.StreamUtils.collect; +import static it.cavallium.dbengine.utils.StreamUtils.iterating; import static java.lang.Boolean.parseBoolean; import static java.util.Objects.requireNonNull; import static org.rocksdb.ColumnFamilyOptionsInterface.DEFAULT_COMPACTION_MEMTABLE_MEMORY_BUDGET; @@ -663,7 +665,7 @@ public class LLLocalKeyValueDatabase extends Backuppable implements LLKeyValueDa logger.warn("Column {} doesn't exist", column); return; } - files.forEachOrdered(sst -> { + collect(files, iterating(sst -> { try (var opts = new IngestExternalFileOptions()) { opts.setIngestBehind(!replaceExisting); opts.setSnapshotConsistency(false); @@ -673,7 +675,7 @@ public class LLLocalKeyValueDatabase extends Backuppable implements LLKeyValueDa } catch (RocksDBException ex) { throw new DBException(new DBException("Failed to ingest SST file " + sst, ex)); } - }); + })); } private record RocksLevelOptions(CompressionType compressionType, CompressionOptions compressionOptions) {} @@ -1264,7 +1266,7 @@ public class LLLocalKeyValueDatabase extends Backuppable implements LLKeyValueDa } public void ingestSSTS(Stream sstsFlux) { - sstsFlux.map(path -> path.toAbsolutePath().toString()).forEachOrdered(sst -> { + collect(sstsFlux.map(path -> path.toAbsolutePath().toString()), iterating(sst -> { var closeReadLock = closeLock.readLock(); try (var opts = new IngestExternalFileOptions()) { try { @@ -1277,7 +1279,7 @@ public class LLLocalKeyValueDatabase extends Backuppable implements LLKeyValueDa } finally { closeLock.unlockRead(closeReadLock); } - }); + })); } @Override @@ -1576,67 +1578,64 @@ public class LLLocalKeyValueDatabase extends Backuppable implements LLKeyValueDa } Path basePath = dbPath; try { - Files - .walk(basePath, 1) - .filter(p -> !p.equals(basePath)) - .filter(p -> { - var fileName = p.getFileName().toString(); - if (fileName.startsWith("LOG.old.")) { - var parts = fileName.split("\\."); - if (parts.length == 3) { - try { - long nameSuffix = Long.parseUnsignedLong(parts[2]); - return true; - } catch (NumberFormatException ex) { - return false; - } + try (var f = Files.walk(basePath, 1)) { + f.filter(p -> !p.equals(basePath)).filter(p -> { + var fileName = p.getFileName().toString(); + if (fileName.startsWith("LOG.old.")) { + var parts = fileName.split("\\."); + if (parts.length == 3) { + try { + long nameSuffix = Long.parseUnsignedLong(parts[2]); + return true; + } catch (NumberFormatException ex) { + return false; } } - if (fileName.endsWith(".log")) { - var parts = fileName.split("\\."); - if (parts.length == 2) { - try { - int name = Integer.parseUnsignedInt(parts[0]); - return true; - } catch (NumberFormatException ex) { - return false; - } + } + if (fileName.endsWith(".log")) { + var parts = fileName.split("\\."); + if (parts.length == 2) { + try { + int name = Integer.parseUnsignedInt(parts[0]); + return true; + } catch (NumberFormatException ex) { + return false; } } + } + return false; + }).filter(p -> { + try { + BasicFileAttributes attrs = Files.readAttributes(p, BasicFileAttributes.class); + if (attrs.isRegularFile() && !attrs.isSymbolicLink() && !attrs.isDirectory()) { + long ctime = attrs.creationTime().toMillis(); + long atime = attrs.lastAccessTime().toMillis(); + long mtime = attrs.lastModifiedTime().toMillis(); + long lastTime = Math.max(Math.max(ctime, atime), mtime); + long safeTime; + if (p.getFileName().toString().startsWith("LOG.old.")) { + safeTime = System.currentTimeMillis() - Duration.ofHours(24).toMillis(); + } else { + safeTime = System.currentTimeMillis() - Duration.ofHours(12).toMillis(); + } + if (lastTime < safeTime) { + return true; + } + } + } catch (IOException ex) { + logger.error("Error when deleting unused log files", ex); return false; - }) - .filter(p -> { - try { - BasicFileAttributes attrs = Files.readAttributes(p, BasicFileAttributes.class); - if (attrs.isRegularFile() && !attrs.isSymbolicLink() && !attrs.isDirectory()) { - long ctime = attrs.creationTime().toMillis(); - long atime = attrs.lastAccessTime().toMillis(); - long mtime = attrs.lastModifiedTime().toMillis(); - long lastTime = Math.max(Math.max(ctime, atime), mtime); - long safeTime; - if (p.getFileName().toString().startsWith("LOG.old.")) { - safeTime = System.currentTimeMillis() - Duration.ofHours(24).toMillis(); - } else { - safeTime = System.currentTimeMillis() - Duration.ofHours(12).toMillis(); - } - if (lastTime < safeTime) { - return true; - } - } - } catch (IOException ex) { - logger.error("Error when deleting unused log files", ex); - return false; - } - return false; - }) - .forEach(path -> { - try { - Files.deleteIfExists(path); - System.out.println("Deleted log file \"" + path + "\""); - } catch (IOException e) { - logger.error(MARKER_ROCKSDB, "Failed to delete log file \"" + path + "\"", e); - } - }); + } + return false; + }).forEach(path -> { + try { + Files.deleteIfExists(path); + System.out.println("Deleted log file \"" + path + "\""); + } catch (IOException e) { + logger.error(MARKER_ROCKSDB, "Failed to delete log file \"" + path + "\"", e); + } + }); + } } catch (IOException ex) { logger.error(MARKER_ROCKSDB, "Failed to delete unused log files", ex); } diff --git a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalLuceneIndex.java b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalLuceneIndex.java index 09f3bd8..935b309 100644 --- a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalLuceneIndex.java +++ b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalLuceneIndex.java @@ -4,6 +4,8 @@ import static it.cavallium.dbengine.database.LLUtils.MARKER_LUCENE; import static it.cavallium.dbengine.database.LLUtils.toDocument; import static it.cavallium.dbengine.database.LLUtils.toFields; import static it.cavallium.dbengine.lucene.searcher.GlobalQueryRewrite.NO_REWRITE; +import static it.cavallium.dbengine.utils.StreamUtils.collect; +import static it.cavallium.dbengine.utils.StreamUtils.fastListing; import static java.util.Objects.requireNonNull; import com.google.common.collect.Multimap; @@ -365,7 +367,8 @@ public class LLLocalLuceneIndex extends SimpleResource implements IBackuppable, }); return count.sum(); } else { - var documentsList = documents.toList(); + var documentsList = collect(documents, fastListing()); + assert documentsList != null; var count = documentsList.size(); StopWatch stopWatch = StopWatch.createStarted(); try { diff --git a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalMultiLuceneIndex.java b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalMultiLuceneIndex.java index ea59c14..b6e6d97 100644 --- a/src/main/java/it/cavallium/dbengine/database/disk/LLLocalMultiLuceneIndex.java +++ b/src/main/java/it/cavallium/dbengine/database/disk/LLLocalMultiLuceneIndex.java @@ -1,6 +1,14 @@ package it.cavallium.dbengine.database.disk; import static it.cavallium.dbengine.lucene.LuceneUtils.getLuceneIndexId; +import static it.cavallium.dbengine.utils.StreamUtils.LUCENE_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collect; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.executing; +import static it.cavallium.dbengine.utils.StreamUtils.fastListing; +import static it.cavallium.dbengine.utils.StreamUtils.fastReducing; +import static it.cavallium.dbengine.utils.StreamUtils.fastSummingLong; +import static it.cavallium.dbengine.utils.StreamUtils.partitionByInt; import static java.util.stream.Collectors.groupingBy; import com.google.common.collect.Multimap; @@ -36,6 +44,7 @@ import it.cavallium.dbengine.rpc.current.data.IndicizerSimilarities; import it.cavallium.dbengine.rpc.current.data.LuceneOptions; import it.cavallium.dbengine.utils.DBException; import it.cavallium.dbengine.utils.SimpleResource; +import it.cavallium.dbengine.utils.StreamUtils; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntList; import java.io.Closeable; @@ -141,10 +150,12 @@ public class LLLocalMultiLuceneIndex extends SimpleResource implements LLLuceneI private LLIndexSearchers getIndexSearchers(LLSnapshot snapshot) { // Resolve the snapshot of each shard - return LLIndexSearchers.of(Streams.mapWithIndex(this.luceneIndicesSet.parallelStream(), (luceneIndex, index) -> { - var subSnapshot = resolveSnapshot(snapshot, (int) index); - return luceneIndex.retrieveSearcher(subSnapshot); - }).toList()); + return LLIndexSearchers.of(StreamUtils.toListOn(StreamUtils.LUCENE_SCHEDULER, + Streams.mapWithIndex(this.luceneIndicesSet.stream(), (luceneIndex, index) -> { + var subSnapshot = resolveSnapshot(snapshot, (int) index); + return luceneIndex.retrieveSearcher(subSnapshot); + }) + )); } @Override @@ -154,17 +165,11 @@ public class LLLocalMultiLuceneIndex extends SimpleResource implements LLLuceneI @Override public long addDocuments(boolean atomic, Stream> documents) { - var groupedRequests = documents - .collect(groupingBy(term -> getLuceneIndexId(term.getKey(), totalShards), - Int2ObjectOpenHashMap::new, - Collectors.toList() - )); - - return groupedRequests - .int2ObjectEntrySet() - .stream() - .map(entry -> luceneIndicesById[entry.getIntKey()].addDocuments(atomic, entry.getValue().stream())) - .reduce(0L, Long::sum); + return collectOn(LUCENE_SCHEDULER, + partitionByInt(term -> getLuceneIndexId(term.getKey(), totalShards), documents) + .map(entry -> luceneIndicesById[entry.key()].addDocuments(atomic, entry.values().stream())), + fastSummingLong() + ); } @Override @@ -179,17 +184,11 @@ public class LLLocalMultiLuceneIndex extends SimpleResource implements LLLuceneI @Override public long updateDocuments(Stream> documents) { - var groupedRequests = documents - .collect(groupingBy(term -> getLuceneIndexId(term.getKey(), totalShards), - Int2ObjectOpenHashMap::new, - Collectors.toList() - )); - - return groupedRequests - .int2ObjectEntrySet() - .stream() - .map(entry -> luceneIndicesById[entry.getIntKey()].updateDocuments(entry.getValue().stream())) - .reduce(0L, Long::sum); + return collectOn(LUCENE_SCHEDULER, + partitionByInt(term -> getLuceneIndexId(term.getKey(), totalShards), documents) + .map(entry -> luceneIndicesById[entry.key()].updateDocuments(entry.values().stream())), + fastSummingLong() + ); } @Override @@ -266,7 +265,7 @@ public class LLLocalMultiLuceneIndex extends SimpleResource implements LLLuceneI @Override protected void onClose() { - luceneIndicesSet.parallelStream().forEach(SafeCloseable::close); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(SafeCloseable::close)); if (multiSearcher instanceof Closeable closeable) { try { closeable.close(); @@ -278,29 +277,32 @@ public class LLLocalMultiLuceneIndex extends SimpleResource implements LLLuceneI @Override public void flush() { - luceneIndicesSet.parallelStream().forEach(LLLuceneIndex::flush); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::flush)); } @Override public void waitForMerges() { - luceneIndicesSet.parallelStream().forEach(LLLuceneIndex::waitForMerges); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::waitForMerges)); } @Override public void waitForLastMerges() { - luceneIndicesSet.parallelStream().forEach(LLLuceneIndex::waitForLastMerges); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::waitForLastMerges)); } @Override public void refresh(boolean force) { - luceneIndicesSet.parallelStream().forEach(index -> index.refresh(force)); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(index -> index.refresh(force))); } @Override public LLSnapshot takeSnapshot() { // Generate next snapshot index var snapshotIndex = nextSnapshotNumber.getAndIncrement(); - var snapshot = luceneIndicesSet.parallelStream().map(LLSnapshottable::takeSnapshot).toList(); + var snapshot = collectOn(StreamUtils.LUCENE_SCHEDULER, + luceneIndicesSet.stream().map(LLSnapshottable::takeSnapshot), + fastListing() + ); registeredSnapshots.put(snapshotIndex, snapshot); return new LLSnapshot(snapshotIndex); } @@ -322,12 +324,12 @@ public class LLLocalMultiLuceneIndex extends SimpleResource implements LLLuceneI @Override public void pauseForBackup() { - this.luceneIndicesSet.forEach(IBackuppable::pauseForBackup); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::pauseForBackup)); } @Override public void resumeAfterBackup() { - this.luceneIndicesSet.forEach(IBackuppable::resumeAfterBackup); + collectOn(StreamUtils.LUCENE_SCHEDULER, luceneIndicesSet.stream(), executing(LLLuceneIndex::resumeAfterBackup)); } @Override diff --git a/src/main/java/it/cavallium/dbengine/database/memory/LLMemoryDictionary.java b/src/main/java/it/cavallium/dbengine/database/memory/LLMemoryDictionary.java index e82d1ac..86534f3 100644 --- a/src/main/java/it/cavallium/dbengine/database/memory/LLMemoryDictionary.java +++ b/src/main/java/it/cavallium/dbengine/database/memory/LLMemoryDictionary.java @@ -1,5 +1,6 @@ package it.cavallium.dbengine.database.memory; +import static it.cavallium.dbengine.utils.StreamUtils.count; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.mapping; @@ -13,6 +14,7 @@ import it.cavallium.dbengine.database.LLRange; import it.cavallium.dbengine.database.LLSnapshot; import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.OptionalBuf; +import it.cavallium.dbengine.database.SerializedKey; import it.cavallium.dbengine.database.UpdateMode; import it.cavallium.dbengine.database.disk.BinarySerializationFunction; import it.cavallium.dbengine.database.serialization.KVSerializationFunction; @@ -242,8 +244,7 @@ public class LLMemoryDictionary implements LLDictionary { } @Override - public Stream updateMulti(Stream keys, - Stream serializedKeys, + public Stream updateMulti(Stream> keys, KVSerializationFunction updateFunction) { throw new UnsupportedOperationException("Not implemented"); } @@ -405,7 +406,7 @@ public class LLMemoryDictionary implements LLDictionary { @Override public boolean isRangeEmpty(@Nullable LLSnapshot snapshot, LLRange range, boolean fillCache) { - return getRangeKeys(snapshot, range, false, false).count() == 0; + return count(getRangeKeys(snapshot, range, false, false)) == 0; } @Override diff --git a/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java b/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java index b4e1c41..8968182 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java +++ b/src/main/java/it/cavallium/dbengine/lucene/LuceneUtils.java @@ -385,7 +385,7 @@ public class LuceneUtils { public static Stream convertHits(Stream hitsFlux, List indexSearchers, @Nullable String keyFieldName) { - return hitsFlux.parallel().mapMulti((hit, sink) -> { + return hitsFlux.mapMulti((hit, sink) -> { var mapped = mapHitBlocking(hit, indexSearchers, keyFieldName); if (mapped != null) { sink.accept(mapped); diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java index 96ffe15..5437aac 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/DecimalBucketMultiSearcher.java @@ -1,11 +1,16 @@ package it.cavallium.dbengine.lucene.searcher; +import static it.cavallium.dbengine.utils.StreamUtils.LUCENE_SCHEDULER; +import static it.cavallium.dbengine.utils.StreamUtils.collectOn; +import static it.cavallium.dbengine.utils.StreamUtils.fastListing; + import com.google.common.collect.Streams; import it.cavallium.dbengine.database.disk.LLIndexSearchers; import it.cavallium.dbengine.lucene.collector.Buckets; import it.cavallium.dbengine.lucene.collector.DecimalBucketMultiCollectorManager; import it.cavallium.dbengine.utils.DBException; import java.io.IOException; +import java.util.Collection; import java.util.List; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -30,7 +35,7 @@ public class DecimalBucketMultiSearcher { } } - private Buckets search(Iterable indexSearchers, + private Buckets search(Collection indexSearchers, BucketParams bucketParams, @NotNull List queries, @Nullable Query normalizationQuery) { @@ -44,12 +49,12 @@ public class DecimalBucketMultiSearcher { bucketParams.collectionRate(), bucketParams.sampleSize() ); - return cmm.reduce(Streams.stream(indexSearchers).parallel().map(indexSearcher -> { + return cmm.reduce(collectOn(LUCENE_SCHEDULER, indexSearchers.stream().map(indexSearcher -> { try { return cmm.search(indexSearcher); } catch (IOException e) { throw new DBException(e); } - }).toList()); + }), fastListing())); } } diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java index 05e1945..d9a2d7e 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/ScoredPagedMultiSearcher.java @@ -1,7 +1,9 @@ package it.cavallium.dbengine.lucene.searcher; import static it.cavallium.dbengine.lucene.searcher.PaginationInfo.MAX_SINGLE_SEARCH_LIMIT; +import static it.cavallium.dbengine.utils.StreamUtils.LUCENE_SCHEDULER; import static it.cavallium.dbengine.utils.StreamUtils.streamWhileNonNull; +import static it.cavallium.dbengine.utils.StreamUtils.toListOn; import com.google.common.collect.Streams; import it.cavallium.dbengine.client.query.current.data.TotalHitsCount; @@ -178,9 +180,8 @@ public class ScoredPagedMultiSearcher implements MultiSearcher { return null; }; record IndexedShard(IndexSearcher indexSearcher, long shardIndex) {} - List shardResults = Streams - .mapWithIndex(indexSearchers.stream(), IndexedShard::new) - .map(shardWithIndex -> { + List shardResults = toListOn(LUCENE_SCHEDULER, + Streams.mapWithIndex(indexSearchers.stream(), IndexedShard::new).map(shardWithIndex -> { var index = (int) shardWithIndex.shardIndex(); var shard = shardWithIndex.indexSearcher(); @@ -192,7 +193,7 @@ public class ScoredPagedMultiSearcher implements MultiSearcher { throw new DBException(e); } }) - .toList(); + ); var pageTopDocs = cmm.reduce(shardResults); diff --git a/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java b/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java index b004d79..00a6329 100644 --- a/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java +++ b/src/main/java/it/cavallium/dbengine/lucene/searcher/UnsortedStreamingMultiSearcher.java @@ -50,7 +50,6 @@ public class UnsortedStreamingMultiSearcher implements MultiSearcher { private Stream getScoreDocs(LocalQueryParams localQueryParams, List shards) { return mapWithIndex(shards.stream(), (shard, shardIndex) -> LuceneGenerator.reactive(shard, localQueryParams, (int) shardIndex)) - .parallel() .flatMap(Function.identity()); } diff --git a/src/main/java/it/cavallium/dbengine/utils/MoshiPolymorphic.java b/src/main/java/it/cavallium/dbengine/utils/MoshiPolymorphic.java index 2c67e0c..3dfc408 100644 --- a/src/main/java/it/cavallium/dbengine/utils/MoshiPolymorphic.java +++ b/src/main/java/it/cavallium/dbengine/utils/MoshiPolymorphic.java @@ -206,7 +206,7 @@ public abstract class MoshiPolymorphic { && !Modifier.isTransient(modifiers) && !shouldIgnoreField(field.getName()); }) - .collect(Collectors.toList()); + .toList(); String[] fieldNames = new String[this.declaredFields.size()]; //noinspection unchecked this.fieldGetters = new Function[this.declaredFields.size()]; @@ -215,7 +215,7 @@ public abstract class MoshiPolymorphic { fieldNames[i] = declaredField.getName(); switch (getterStyle) { - case STANDARD_GETTERS: + case STANDARD_GETTERS -> { var getterMethod = declaredField .getDeclaringClass() .getMethod("get" + StringUtils.capitalize(declaredField.getName())); @@ -226,11 +226,9 @@ public abstract class MoshiPolymorphic { throw new RuntimeException(e); } }; - break; - case RECORDS_GETTERS: - var getterMethod2 = declaredField - .getDeclaringClass() - .getMethod(declaredField.getName()); + } + case RECORDS_GETTERS -> { + var getterMethod2 = declaredField.getDeclaringClass().getMethod(declaredField.getName()); fieldGetters[i] = obj -> { try { return getterMethod2.invoke(obj); @@ -238,16 +236,14 @@ public abstract class MoshiPolymorphic { throw new RuntimeException(e); } }; - break; - case FIELDS: - fieldGetters[i] = t -> { - try { - return declaredField.get(t); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); - } - }; - break; + } + case FIELDS -> fieldGetters[i] = t -> { + try { + return declaredField.get(t); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + }; } i++; diff --git a/src/main/java/it/cavallium/dbengine/utils/PartitionByIntSpliterator.java b/src/main/java/it/cavallium/dbengine/utils/PartitionByIntSpliterator.java new file mode 100644 index 0000000..d4389a2 --- /dev/null +++ b/src/main/java/it/cavallium/dbengine/utils/PartitionByIntSpliterator.java @@ -0,0 +1,89 @@ +package it.cavallium.dbengine.utils; + +import static java.util.Comparator.naturalOrder; + +import it.cavallium.dbengine.utils.PartitionByIntSpliterator.IntPartition; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Spliterator; +import java.util.Spliterators.AbstractSpliterator; +import java.util.function.Consumer; +import java.util.function.ToIntFunction; + +/** + * Create a partitioned view of an existing spliterator + * + * @author cavallium + */ +public class PartitionByIntSpliterator extends AbstractSpliterator> { + + private final Spliterator spliterator; + private final ToIntFunction partitionBy; + private HoldingConsumer holder; + private Comparator> comparator; + + public PartitionByIntSpliterator(Spliterator toWrap, ToIntFunction partitionBy) { + super(Long.MAX_VALUE, toWrap.characteristics() & ~SIZED | NONNULL); + this.spliterator = toWrap; + this.partitionBy = partitionBy; + } + + @Override + public boolean tryAdvance(Consumer> action) { + final HoldingConsumer h; + if (holder == null) { + h = new HoldingConsumer<>(); + if (!spliterator.tryAdvance(h)) { + return false; + } + holder = h; + } else { + h = holder; + } + final ArrayList partition = new ArrayList<>(); + final int partitionKey = partitionBy.applyAsInt(h.value); + boolean didAdvance; + do { + partition.add(h.value); + } while ((didAdvance = spliterator.tryAdvance(h)) && partitionBy.applyAsInt(h.value) == partitionKey); + if (!didAdvance) { + holder = null; + } + action.accept(new IntPartition<>(partitionKey, partition)); + return true; + } + + static final class HoldingConsumer implements Consumer { + + T value; + + @Override + public void accept(T value) { + this.value = value; + } + } + + @Override + public Comparator> getComparator() { + final Comparator> c = this.comparator; + return c != null ? c : (this.comparator = comparator()); + } + + private Comparator> comparator() { + final Comparator cmp; + { + var comparator = spliterator.getComparator(); + //noinspection unchecked + cmp = comparator != null ? comparator : (Comparator) naturalOrder(); + } + return (left, right) -> { + var leftV = left.values; + var rightV = right.values; + final int c = cmp.compare(leftV.get(0), rightV.get(0)); + return c != 0 ? c : cmp.compare(leftV.get(leftV.size() - 1), rightV.get(rightV.size() - 1)); + }; + } + + public record IntPartition(int key, List values) {} +} \ No newline at end of file diff --git a/src/main/java/it/cavallium/dbengine/utils/PartitionBySpliterator.java b/src/main/java/it/cavallium/dbengine/utils/PartitionBySpliterator.java new file mode 100644 index 0000000..e53a95b --- /dev/null +++ b/src/main/java/it/cavallium/dbengine/utils/PartitionBySpliterator.java @@ -0,0 +1,88 @@ +package it.cavallium.dbengine.utils; + +import it.cavallium.dbengine.utils.PartitionBySpliterator.Partition; +import java.util.*; +import java.util.Spliterators.AbstractSpliterator; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +import static java.util.Comparator.naturalOrder; + +/** + * Create a partitioned view of an existing spliterator + * + * @author cavallium + */ +public class PartitionBySpliterator extends AbstractSpliterator> { + + private final Spliterator spliterator; + private final Function partitionBy; + private HoldingConsumer holder; + private Comparator> comparator; + + public PartitionBySpliterator(Spliterator toWrap, Function partitionBy) { + super(Long.MAX_VALUE, toWrap.characteristics() & ~SIZED | NONNULL); + this.spliterator = toWrap; + this.partitionBy = partitionBy; + } + + @Override + public boolean tryAdvance(Consumer> action) { + final HoldingConsumer h; + if (holder == null) { + h = new HoldingConsumer<>(); + if (!spliterator.tryAdvance(h)) { + return false; + } + holder = h; + } else { + h = holder; + } + final ArrayList partition = new ArrayList<>(); + final K partitionKey = partitionBy.apply(h.value); + boolean didAdvance; + do { + partition.add(h.value); + } while ((didAdvance = spliterator.tryAdvance(h)) && Objects.equals(partitionBy.apply(h.value), partitionKey)); + if (!didAdvance) { + holder = null; + } + action.accept(new Partition<>(partitionKey, partition)); + return true; + } + + static final class HoldingConsumer implements Consumer { + + T value; + + @Override + public void accept(T value) { + this.value = value; + } + } + + @Override + public Comparator> getComparator() { + final Comparator> c = this.comparator; + return c != null ? c : (this.comparator = comparator()); + } + + private Comparator> comparator() { + final Comparator cmp; + { + var comparator = spliterator.getComparator(); + //noinspection unchecked + cmp = comparator != null ? comparator : (Comparator) naturalOrder(); + } + return (left, right) -> { + var leftV = left.values; + var rightV = right.values; + final int c = cmp.compare(leftV.get(0), rightV.get(0)); + return c != 0 ? c : cmp.compare(leftV.get(leftV.size() - 1), rightV.get(rightV.size() - 1)); + }; + } + + public record Partition(K key, List values) {} +} \ No newline at end of file diff --git a/src/main/java/it/cavallium/dbengine/utils/StreamUtils.java b/src/main/java/it/cavallium/dbengine/utils/StreamUtils.java index c1e9f35..4efa9f0 100644 --- a/src/main/java/it/cavallium/dbengine/utils/StreamUtils.java +++ b/src/main/java/it/cavallium/dbengine/utils/StreamUtils.java @@ -2,26 +2,63 @@ package it.cavallium.dbengine.utils; import com.google.common.collect.Iterators; import com.google.common.collect.Streams; -import it.cavallium.dbengine.database.SubStageEntry; -import it.cavallium.dbengine.database.collections.DatabaseStage; +import it.cavallium.dbengine.utils.PartitionByIntSpliterator.IntPartition; +import it.cavallium.dbengine.utils.PartitionBySpliterator.Partition; import java.util.ArrayList; -import java.util.Collection; +import java.util.Collections; import java.util.Comparator; +import java.util.EnumSet; import java.util.Iterator; import java.util.List; +import java.util.Set; import java.util.Spliterator; -import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.LongAdder; +import java.util.function.BiConsumer; +import java.util.function.BinaryOperator; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; +import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; import java.util.stream.Collector; -import java.util.stream.Collectors; +import java.util.stream.Collector.Characteristics; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class StreamUtils { + public static final ForkJoinPool LUCENE_SCHEDULER = new ForkJoinPool(); + public static final ForkJoinPool ROCKSDB_SCHEDULER = new ForkJoinPool(); + + private static final Collector TO_LIST_FAKE_COLLECTOR = new FakeCollector(); + private static final Collector COUNT_FAKE_COLLECTOR = new FakeCollector(); + + private static final Set CH_NOID = Collections.emptySet(); + private static final Set CH_CONCURRENT_NOID = Collections.unmodifiableSet(EnumSet.of( + Characteristics.CONCURRENT, + Characteristics.UNORDERED + )); + private static final Object NULL = new Object(); + private static final Supplier NULL_SUPPLIER = () -> NULL; + private static final BinaryOperator COMBINER = (a, b) -> NULL; + private static final Function FINISHER = x -> null; + private static final Collector SUMMING_LONG_COLLECTOR = new SummingLongCollector(); + + public static Collector> fastListing() { + //noinspection unchecked + return (Collector>) TO_LIST_FAKE_COLLECTOR; + } + + public static Collector fastCounting() { + //noinspection unchecked + return (Collector) COUNT_FAKE_COLLECTOR; + } + @SafeVarargs @SuppressWarnings("UnstableApiUsage") public static Stream mergeComparing(Comparator comparator, Stream... streams) { @@ -51,13 +88,19 @@ public class StreamUtils { } public static Stream> batches(Stream stream, int batchSize) { - return batchSize <= 0 - ? Stream.of(stream.collect(Collectors.toList())) - : StreamSupport.stream(new BatchSpliterator<>(stream.spliterator(), batchSize), stream.isParallel()); + if (batchSize <= 0) { + return Stream.of(toList(stream)); + } else if (batchSize == 1) { + return stream.map(Collections::singletonList); + } else { + return StreamSupport + .stream(new BatchSpliterator<>(stream.spliterator(), batchSize), stream.isParallel()) + .onClose(stream::close); + } } @SuppressWarnings("UnstableApiUsage") - public static Stream streamWhileNonNull(Supplier supplier) { + public static Stream streamWhileNonNull(Supplier supplier) { var it = new Iterator() { private boolean nextSet = false; @@ -81,30 +124,116 @@ public class StreamUtils { return Streams.stream(it); } - public static List toListClose(Stream stream) { - try (stream) { - return stream.toList(); + public static List toList(Stream stream) { + return collect(stream, fastListing()); + } + + @SuppressWarnings("DataFlowIssue") + public static long count(Stream stream) { + return collect(stream, fastCounting()); + } + + public static List toListOn(ForkJoinPool forkJoinPool, Stream stream) { + return collectOn(forkJoinPool, stream, fastListing()); + } + + public static long countOn(ForkJoinPool forkJoinPool, Stream stream) { + return collectOn(forkJoinPool, stream, fastCounting()); + } + + /** + * Collects and closes the stream on the specified pool + */ + public static R collectOn(@Nullable ForkJoinPool pool, + @Nullable Stream stream, + @NotNull Collector collector) { + if (stream == null) { + return null; + } + if (pool != null) { + try { + return pool.submit(() -> collect(stream.parallel(), collector)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } else { + return collect(stream, collector); } } - public static R collectClose(Stream stream, Collector collector) { + /** + * Collects and closes the stream on the specified pool + */ + @SuppressWarnings("unchecked") + public static R collect(@Nullable Stream stream, @NotNull Collector collector) { try (stream) { - return stream.collect(collector); + if (collector == TO_LIST_FAKE_COLLECTOR) { + if (stream != null) { + return (R) stream.toList(); + } else { + return (R) List.of(); + } + } else if (collector == COUNT_FAKE_COLLECTOR) { + if (stream != null) { + return (R) (Long) stream.count(); + } else { + return (R) (Long) 0L; + } + } else if (stream == null) { + throw new NullPointerException("Stream is null"); + } else if (collector == SUMMING_LONG_COLLECTOR) { + LongAdder sum = new LongAdder(); + ((Stream) stream).forEach(sum::add); + return (R) (Long) sum.sum(); + } else if (collector.getClass() == ExecutingCollector.class) { + stream.forEach(((ExecutingCollector) collector).getConsumer()); + return null; + } else if (collector.getClass() == IteratingCollector.class) { + stream.forEachOrdered(((IteratingCollector) collector).getConsumer()); + return null; + } else { + return stream.collect(collector); + } } } - public static long countClose(Stream stream) { - try (stream) { - return stream.count(); - } + public static Collector executing(Consumer consumer) { + return new ExecutingCollector<>(consumer); } - public static X scheduleOnPool(ForkJoinPool pool, Callable supplier) { - try { - return pool.submit(supplier).get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } + public static Collector iterating(Consumer consumer) { + return new IteratingCollector<>(consumer); + } + + /** + * @param op must be fast and non-blocking! + */ + public static Collector fastReducing(T identity, BinaryOperator op) { + return new ConcurrentUnorderedReducingCollector<>(identity, Function.identity(), op); + } + + /** + * @param mapper must be fast and non-blocking! + * @param op must be fast and non-blocking! + */ + public static Collector fastReducing(T identity, Function mapper, BinaryOperator op) { + return new ConcurrentUnorderedReducingCollector<>(identity, mapper, op); + } + + public static Stream> partitionBy(Function partitionBy, Stream in) { + return StreamSupport + .stream(new PartitionBySpliterator<>(in.spliterator(), partitionBy), in.isParallel()) + .onClose(in::close); + } + + public static Stream> partitionByInt(ToIntFunction partitionBy, Stream in) { + return StreamSupport + .stream(new PartitionByIntSpliterator<>(in.spliterator(), partitionBy), in.isParallel()) + .onClose(in::close); + } + + public static Collector fastSummingLong() { + return SUMMING_LONG_COLLECTOR; } private record BatchSpliterator(Spliterator base, int batchSize) implements Spliterator> { @@ -144,4 +273,157 @@ public class StreamUtils { } } + + private static final class FakeCollector implements Collector { + + @Override + public Supplier supplier() { + throw new IllegalStateException(); + } + + @Override + public BiConsumer accumulator() { + throw new IllegalStateException(); + } + + @Override + public BinaryOperator combiner() { + throw new IllegalStateException(); + } + + @Override + public Function finisher() { + throw new IllegalStateException(); + } + + @Override + public Set characteristics() { + throw new IllegalStateException(); + } + } + + private abstract static sealed class AbstractExecutingCollector implements Collector { + + private final Consumer consumer; + + public AbstractExecutingCollector(Consumer consumer) { + this.consumer = consumer; + } + + @Override + public Supplier supplier() { + return NULL_SUPPLIER; + } + + @Override + public BiConsumer accumulator() { + return (o, i) -> consumer.accept(i); + } + + @Override + public BinaryOperator combiner() { + return COMBINER; + } + + @Override + public Function finisher() { + return FINISHER; + } + + public Consumer getConsumer() { + return consumer; + } + } + + private static final class ExecutingCollector extends AbstractExecutingCollector { + + public ExecutingCollector(Consumer consumer) { + super(consumer); + } + + @Override + public Set characteristics() { + return CH_CONCURRENT_NOID; + } + } + + private static final class IteratingCollector extends AbstractExecutingCollector { + + public IteratingCollector(Consumer consumer) { + super(consumer); + } + + @Override + public Set characteristics() { + return CH_NOID; + } + } + + private record ConcurrentUnorderedReducingCollector(T identity, Function mapper, + BinaryOperator op) implements + Collector, U> { + + @Override + public Supplier> supplier() { + return () -> new AtomicReference<>(mapper.apply(identity)); + } + + // Can be called from multiple threads! + @Override + public BiConsumer, T> accumulator() { + return (a, t) -> a.updateAndGet(v1 -> op.apply(v1, mapper.apply(t))); + } + + @Override + public BinaryOperator> combiner() { + return (a, b) -> { + a.set(op.apply(a.get(), b.get())); + return a; + }; + } + + @Override + public Function, U> finisher() { + return AtomicReference::get; + } + + @Override + public Set characteristics() { + return CH_CONCURRENT_NOID; + } + } + + private static final class SummingLongCollector implements Collector { + + public SummingLongCollector() { + } + + @Override + public Supplier supplier() { + return LongAdder::new; + } + + @Override + public BiConsumer accumulator() { + return LongAdder::add; + } + + @Override + public BinaryOperator combiner() { + return (la1, la2) -> { + la1.add(la2.sum()); + return la1; + }; + } + + @Override + public Function finisher() { + return LongAdder::sum; + } + + @Override + public Set characteristics() { + return CH_CONCURRENT_NOID; + } + } } diff --git a/src/test/java/it/cavallium/dbengine/tests/TestDictionaryMap.java b/src/test/java/it/cavallium/dbengine/tests/TestDictionaryMap.java index 780f165..8b4d115 100644 --- a/src/test/java/it/cavallium/dbengine/tests/TestDictionaryMap.java +++ b/src/test/java/it/cavallium/dbengine/tests/TestDictionaryMap.java @@ -1,12 +1,10 @@ package it.cavallium.dbengine.tests; import static it.cavallium.dbengine.tests.DbTestUtils.*; -import static it.cavallium.dbengine.utils.StreamUtils.toListClose; +import static it.cavallium.dbengine.utils.StreamUtils.toList; import com.google.common.collect.Streams; -import it.cavallium.dbengine.database.LLUtils; import it.cavallium.dbengine.database.UpdateMode; -import it.cavallium.dbengine.utils.StreamUtils; import it.unimi.dsi.fastutil.objects.Object2ObjectSortedMap; import it.unimi.dsi.fastutil.objects.Object2ObjectSortedMaps; import java.io.IOException; @@ -19,16 +17,13 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -401,8 +396,8 @@ public abstract class TestDictionaryMap { var remainingEntries = new ArrayList>(); var stpVer = run(shouldFail, () -> tempDb(getTempDbGenerator(), db -> { var map = tempDatabaseMapDictionaryMap(tempDictionary(db, updateMode), mapType, 5); - return Arrays.asList(toListClose(map.setAllValuesAndGetPrevious(entries.entrySet().stream())), - toListClose(map.setAllValuesAndGetPrevious(entries.entrySet().stream())) + return Arrays.asList(toList(map.setAllValuesAndGetPrevious(entries.entrySet().stream())), + toList(map.setAllValuesAndGetPrevious(entries.entrySet().stream())) ); })); if (shouldFail) { @@ -425,7 +420,7 @@ public abstract class TestDictionaryMap { var entriesFlux = entries.entrySet(); var keysFlux = entriesFlux.stream().map(Entry::getKey).toList(); map.set(entries); - var resultsFlux = toListClose(map.getMulti(null, entries.keySet().stream())); + var resultsFlux = toList(map.getMulti(null, entries.keySet().stream())); return Streams .zip(keysFlux.stream(), resultsFlux.stream(), Map::entry) .filter(k -> k.getValue().isPresent()) @@ -504,7 +499,7 @@ public abstract class TestDictionaryMap { var stpVer = run(shouldFail, () -> tempDb(getTempDbGenerator(), db -> { var map = tempDatabaseMapDictionaryMap(tempDictionary(db, updateMode), mapType, 5); map.putMulti(entries.entrySet().stream()); - return toListClose(map.getAllValues(null, false)); + return toList(map.getAllValues(null, false)); })); if (shouldFail) { this.checkLeaks = false; @@ -535,7 +530,7 @@ public abstract class TestDictionaryMap { var stpVer = run(shouldFail, () -> tempDb(getTempDbGenerator(), db -> { var map = tempDatabaseMapDictionaryMap(tempDictionary(db, updateMode), mapType, 5); map.putMulti(entries.entrySet().stream()); - return toListClose(map.getAllStages(null, false).map(stage -> { + return toList(map.getAllStages(null, false).map(stage -> { var v = stage.getValue().get(null); if (v == null) { return null; diff --git a/src/test/java/it/cavallium/dbengine/tests/TestLLDictionary.java b/src/test/java/it/cavallium/dbengine/tests/TestLLDictionary.java index 6ce7b3a..90b972a 100644 --- a/src/test/java/it/cavallium/dbengine/tests/TestLLDictionary.java +++ b/src/test/java/it/cavallium/dbengine/tests/TestLLDictionary.java @@ -2,7 +2,7 @@ package it.cavallium.dbengine.tests; import static it.cavallium.dbengine.tests.DbTestUtils.ensureNoLeaks; import static it.cavallium.dbengine.tests.DbTestUtils.runVoid; -import static it.cavallium.dbengine.utils.StreamUtils.toListClose; +import static it.cavallium.dbengine.utils.StreamUtils.toList; import static org.junit.jupiter.api.Assertions.assertEquals; import it.cavallium.dbengine.tests.DbTestUtils.TempDb; @@ -13,7 +13,6 @@ import it.cavallium.dbengine.database.LLKeyValueDatabase; import it.cavallium.dbengine.database.LLRange; import it.cavallium.dbengine.database.UpdateMode; import it.cavallium.dbengine.database.UpdateReturnMode; -import it.cavallium.dbengine.utils.StreamUtils; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -165,8 +164,8 @@ public abstract class TestLLDictionary { var afterSize = dict.sizeRange(null, LLRange.all(), false); Assertions.assertEquals(1, afterSize - beforeSize); - Assertions.assertTrue(toListClose(dict.getRangeKeys(null, RANGE_ALL, false, false).map(this::toString)).contains("test-nonexistent")); - Assertions.assertTrue(toListClose(dict.getRangeKeys(null, RANGE_ALL, true, false).map(this::toString)).contains("test-nonexistent")); + Assertions.assertTrue(toList(dict.getRangeKeys(null, RANGE_ALL, false, false).map(this::toString)).contains("test-nonexistent")); + Assertions.assertTrue(toList(dict.getRangeKeys(null, RANGE_ALL, true, false).map(this::toString)).contains("test-nonexistent")); } @ParameterizedTest @@ -213,11 +212,11 @@ public abstract class TestLLDictionary { assertEquals(expected, afterSize - beforeSize); if (updateMode != UpdateMode.DISALLOW) { - Assertions.assertTrue(toListClose(dict + Assertions.assertTrue(toList(dict .getRangeKeys(null, RANGE_ALL, false, false) .map(this::toString)) .contains("test-nonexistent")); - Assertions.assertTrue(toListClose(dict + Assertions.assertTrue(toList(dict .getRangeKeys(null, RANGE_ALL, true, false) .map(this::toString)) .contains("test-nonexistent"));