package it.cavallium.dbengine.utils; import com.google.common.collect.Iterators; import com.google.common.collect.Streams; import it.cavallium.dbengine.utils.PartitionByIntSpliterator.IntPartition; import it.cavallium.dbengine.utils.PartitionBySpliterator.Partition; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.EnumSet; import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.Spliterator; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory; import java.util.concurrent.ForkJoinWorkerThread; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.BinaryOperator; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.function.ToIntFunction; import java.util.stream.Collector; import java.util.stream.Collector.Characteristics; import java.util.stream.Stream; import java.util.stream.StreamSupport; import org.apache.commons.lang3.function.FailableFunction; import org.apache.commons.lang3.function.FailableSupplier; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; public class StreamUtils { public static final ForkJoinPool LUCENE_POOL = newNamedForkJoinPool("Lucene", false); public static final ForkJoinPool GRAPH_POOL = newNamedForkJoinPool("Graph", false); public static final ForkJoinPool ROCKSDB_POOL = newNamedForkJoinPool("RocksDB", false); private static final Collector TO_LIST_FAKE_COLLECTOR = new FakeCollector(); private static final Collector COUNT_FAKE_COLLECTOR = new FakeCollector(); private static final Collector FIRST_FAKE_COLLECTOR = new FakeCollector(); private static final Collector ANY_FAKE_COLLECTOR = new FakeCollector(); private static final Set CH_NOID = Collections.emptySet(); private static final Set CH_CONCURRENT_NOID = Collections.unmodifiableSet(EnumSet.of( Characteristics.CONCURRENT, Characteristics.UNORDERED )); private static final Object NULL = new Object(); private static final Supplier NULL_SUPPLIER = () -> NULL; private static final BinaryOperator COMBINER = (a, b) -> NULL; private static final Function FINISHER = x -> null; private static final Collector SUMMING_LONG_COLLECTOR = new SummingLongCollector(); private static final Consumer NOOP_CONSUMER = x -> {}; public static ForkJoinPool newNamedForkJoinPool(String name, boolean async) { final int MAX_CAP = 0x7fff; // max #workers - 1 return new ForkJoinPool( Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()), new NamedForkJoinWorkerThreadFactory(name), null, async, 0, MAX_CAP, 1, null, 60_000L, TimeUnit.MILLISECONDS ); } public static Collector> fastListing() { //noinspection unchecked return (Collector>) TO_LIST_FAKE_COLLECTOR; } public static Collector fastCounting() { //noinspection unchecked return (Collector) COUNT_FAKE_COLLECTOR; } public static Collector> fastFirst() { //noinspection unchecked return (Collector>) FIRST_FAKE_COLLECTOR; } public static Collector> fastAny() { //noinspection unchecked return (Collector>) ANY_FAKE_COLLECTOR; } @SafeVarargs @SuppressWarnings("UnstableApiUsage") public static Stream mergeComparing(Comparator comparator, Stream... streams) { List> iterators = new ArrayList<>(streams.length); for (Stream stream : streams) { var it = stream.iterator(); if (it.hasNext()) { iterators.add(it); } } Stream resultStream; if (iterators.isEmpty()) { resultStream = Stream.empty(); } else if (iterators.size() == 1) { resultStream = Streams.stream(iterators.get(0)); } else { resultStream = Streams.stream(Iterators.mergeSorted(iterators, comparator)); } return resultStream.onClose(() -> { for (Stream stream : streams) { stream.close(); } }); } public static Stream> batches(Stream stream, int batchSize) { if (batchSize <= 0) { return Stream.of(toList(stream)); } else if (batchSize == 1) { return stream.map(Collections::singletonList); } else { return StreamSupport .stream(new BatchSpliterator<>(stream.spliterator(), batchSize), stream.isParallel()) .onClose(stream::close); } } public static Stream streamWhileNonNull(Supplier supplier) { return streamWhile(supplier, Objects::nonNull); } @SuppressWarnings("UnstableApiUsage") public static Stream streamWhile(Supplier supplier, Predicate endPredicate) { var it = new Iterator() { private boolean nextSet = false; private X next; @Override public boolean hasNext() { if (!nextSet) { next = supplier.get(); nextSet = true; } return endPredicate.test(next); } @Override public X next() { nextSet = false; return next; } }; return Streams.stream(it); } @SuppressWarnings("UnstableApiUsage") public static Stream streamUntil(Supplier supplier, Predicate endPredicate) { var it = new Iterator() { private boolean nextSet = false; private byte state = (byte) 0; private X next; @Override public boolean hasNext() { if (state == (byte) 2) { return false; } else { if (!nextSet) { next = supplier.get(); state = endPredicate.test(next) ? (byte) 1 : 0; nextSet = true; } return true; } } @Override public X next() { if (state == (byte) 1) { state = (byte) 2; } nextSet = false; return next; } }; return Streams.stream(it); } @SuppressWarnings("DataFlowIssue") @NotNull public static List toList(Stream stream) { return StreamUtils.collect(stream, fastListing()); } @SuppressWarnings("DataFlowIssue") @NotNull public static Optional toFirst(Stream stream) { return StreamUtils.collect(stream, fastFirst()); } @SuppressWarnings("DataFlowIssue") @NotNull public static Optional toAny(Stream stream) { return StreamUtils.collect(stream, fastAny()); } @SuppressWarnings("DataFlowIssue") public static long count(Stream stream) { return collect(stream, fastCounting()); } public static List toListOn(ForkJoinPool forkJoinPool, Stream stream) { return collectOn(forkJoinPool, stream, fastListing()); } public static long countOn(ForkJoinPool forkJoinPool, Stream stream) { return collectOn(forkJoinPool, stream, fastCounting()); } @NotNull public static Optional toFirstOn(ForkJoinPool forkJoinPool, Stream stream) { return collectOn(forkJoinPool, stream, fastFirst()); } @NotNull public static Optional toAnyOn(ForkJoinPool forkJoinPool, Stream stream) { return collectOn(forkJoinPool, stream, fastAny()); } /** * Collects and closes the stream on the specified pool */ public static R collectOn(@Nullable ForkJoinPool pool, @Nullable Stream stream, @NotNull Collector collector) { if (stream == null) { return null; } if (pool != null) { try { return pool.submit(() -> collect(stream.parallel(), collector)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(e); } } else { return collect(stream, collector); } } /** * Collects and closes the stream on the specified pool */ @SuppressWarnings("unchecked") public static R collect(@Nullable Stream stream, @NotNull Collector collector) { try (stream) { if (collector == TO_LIST_FAKE_COLLECTOR) { if (stream != null) { return (R) stream.toList(); } else { return (R) List.of(); } } else if (collector == COUNT_FAKE_COLLECTOR) { if (stream != null) { return (R) (Long) stream.count(); } else { return (R) (Long) 0L; } } else if (collector == FIRST_FAKE_COLLECTOR) { if (stream != null) { return (R) stream.findFirst(); } else { return (R) Optional.empty(); } } else if (collector == ANY_FAKE_COLLECTOR) { if (stream != null) { return (R) stream.findAny(); } else { return (R) Optional.empty(); } } else if (stream == null) { throw new NullPointerException("Stream is null"); } else if (collector == SUMMING_LONG_COLLECTOR) { LongAdder sum = new LongAdder(); ((Stream) stream).forEach(sum::add); return (R) (Long) sum.sum(); } else if (collector.getClass() == CountingExecutingCollector.class) { LongAdder sum = new LongAdder(); var consumer = ((CountingExecutingCollector) collector).getConsumer(); stream.forEach(v -> { sum.increment(); consumer.accept(v); }); return (R) (Long) sum.sum(); } else if (collector.getClass() == ExecutingCollector.class) { stream.forEach(((ExecutingCollector) collector).getConsumer()); return null; } else if (collector.getClass() == IteratingCollector.class) { stream.forEachOrdered(((IteratingCollector) collector).getConsumer()); return null; } else { return stream.collect(collector); } } } public static Collector executing(Consumer consumer) { return new ExecutingCollector<>(consumer); } public static Collector executing() { //noinspection unchecked return new ExecutingCollector<>((Consumer) NOOP_CONSUMER); } public static Collector countingExecuting(Consumer consumer) { return new CountingExecutingCollector<>(consumer); } public static Collector iterating(Consumer consumer) { return new IteratingCollector<>(consumer); } /** * @param op must be fast and non-blocking! */ public static Collector fastReducing(T identity, BinaryOperator op) { return new ConcurrentUnorderedReducingCollector<>(identity, Function.identity(), op); } /** * @param mapper must be fast and non-blocking! * @param op must be fast and non-blocking! */ public static Collector fastReducing(T identity, Function mapper, BinaryOperator op) { return new ConcurrentUnorderedReducingCollector<>(identity, mapper, op); } public static Stream> partitionBy(Function partitionBy, Stream in) { return StreamSupport .stream(new PartitionBySpliterator<>(in.spliterator(), partitionBy), in.isParallel()) .onClose(in::close); } public static Stream> partitionByInt(ToIntFunction partitionBy, Stream in) { return StreamSupport .stream(new PartitionByIntSpliterator<>(in.spliterator(), partitionBy), in.isParallel()) .onClose(in::close); } public static Collector fastSummingLong() { return SUMMING_LONG_COLLECTOR; } public static Stream indexed(Stream stream, BiFunction mapper) { return Streams.mapWithIndex(stream, mapper::apply); } /** * Checks if stream.onClose() will be called during the stream lifetime */ public static Stream resourceStream(Stream stream) { var sr = new StreamResource(null); return stream.onClose(sr::close); } /** * Checks if stream.onClose() will be called during the stream lifetime */ public static Stream resourceStream( FailableSupplier resourceInitializer, FailableFunction, ? extends EX> streamInitializer) throws EX { SR resource = resourceInitializer.get(); var sr = new StreamResource(resource::close); try { Stream stream = streamInitializer.apply(resource); return stream.onClose(sr::close); } catch (Throwable ex) { sr.close(); throw ex; } } /** * Checks if stream.onClose() will be called during the stream lifetime */ public static Stream resourceStream(Stream stream, Runnable finalization) { var sr = new StreamResource(finalization); try { return stream.onClose(sr::close); } catch (Throwable ex) { sr.close(); throw ex; } } /** * Checks if stream.onClose() will be called during the stream lifetime */ public static Stream resourceStream(Supplier> stream, Runnable finalization) { var sr = new StreamResource(finalization); try { return stream.get().onClose(sr::close); } catch (Throwable ex) { sr.close(); throw ex; } } public static Predicate and(Collection> predicateList) { Predicate result = null; for (Predicate predicate : predicateList) { if (result == null) { //noinspection unchecked result = (Predicate) predicate; } else { result = result.and(predicate); } } return result; } public static Predicate or(Collection> predicateList) { Predicate result = null; for (Predicate predicate : predicateList) { if (result == null) { //noinspection unchecked result = (Predicate) predicate; } else { result = result.and(predicate); } } return result; } private record BatchSpliterator(Spliterator base, int batchSize) implements Spliterator> { @Override public boolean tryAdvance(Consumer> action) { final List batch = new ArrayList<>(batchSize); //noinspection StatementWithEmptyBody for (int i = 0; i < batchSize && base.tryAdvance(batch::add); i++) { } if (batch.isEmpty()) { return false; } action.accept(batch); return true; } @Override public Spliterator> trySplit() { if (base.estimateSize() <= batchSize) { return null; } final Spliterator splitBase = this.base.trySplit(); return splitBase == null ? null : new BatchSpliterator<>(splitBase, batchSize); } @Override public long estimateSize() { final double baseSize = base.estimateSize(); return baseSize == 0 ? 0 : (long) Math.ceil(baseSize / (double) batchSize); } @Override public int characteristics() { return base.characteristics(); } } private static final class FakeCollector implements Collector { @Override public Supplier supplier() { throw new IllegalStateException(); } @Override public BiConsumer accumulator() { throw new IllegalStateException(); } @Override public BinaryOperator combiner() { throw new IllegalStateException(); } @Override public Function finisher() { throw new IllegalStateException(); } @Override public Set characteristics() { throw new IllegalStateException(); } } private abstract static sealed class AbstractExecutingCollector implements Collector { private final Consumer consumer; public AbstractExecutingCollector(Consumer consumer) { this.consumer = consumer; } @Override public Supplier supplier() { return NULL_SUPPLIER; } @Override public BiConsumer accumulator() { return (o, i) -> consumer.accept(i); } @Override public BinaryOperator combiner() { return COMBINER; } @Override public abstract Function finisher(); public Consumer getConsumer() { return consumer; } } private static final class ExecutingCollector extends AbstractExecutingCollector { public ExecutingCollector(Consumer consumer) { super(consumer); } @Override public Set characteristics() { return CH_CONCURRENT_NOID; } @Override public Function finisher() { return FINISHER; } } private static final class CountingExecutingCollector extends AbstractExecutingCollector { public CountingExecutingCollector(Consumer consumer) { super(consumer); } @Override public Function finisher() { throw new UnsupportedOperationException("This is a custom collector, do not use with the regular stream api"); } @Override public Set characteristics() { return CH_CONCURRENT_NOID; } } private static final class IteratingCollector extends AbstractExecutingCollector { public IteratingCollector(Consumer consumer) { super(consumer); } @Override public Function finisher() { return FINISHER; } @Override public Set characteristics() { return CH_NOID; } } private record ConcurrentUnorderedReducingCollector(T identity, Function mapper, BinaryOperator op) implements Collector, U> { @Override public Supplier> supplier() { return () -> new AtomicReference<>(mapper.apply(identity)); } // Can be called from multiple threads! @Override public BiConsumer, T> accumulator() { return (a, t) -> a.updateAndGet(v1 -> op.apply(v1, mapper.apply(t))); } @Override public BinaryOperator> combiner() { return (a, b) -> { a.set(op.apply(a.get(), b.get())); return a; }; } @Override public Function, U> finisher() { return AtomicReference::get; } @Override public Set characteristics() { return CH_CONCURRENT_NOID; } } private static final class SummingLongCollector implements Collector { public SummingLongCollector() { } @Override public Supplier supplier() { return LongAdder::new; } @Override public BiConsumer accumulator() { return LongAdder::add; } @Override public BinaryOperator combiner() { return (la1, la2) -> { la1.add(la2.sum()); return la1; }; } @Override public Function finisher() { return LongAdder::sum; } @Override public Set characteristics() { return CH_CONCURRENT_NOID; } } private static class NamedForkJoinWorkerThreadFactory implements ForkJoinWorkerThreadFactory { private final AtomicInteger nextWorkerId = new AtomicInteger(0); private final String name; public NamedForkJoinWorkerThreadFactory(String name) { this.name = name; } @Override public ForkJoinWorkerThread newThread(ForkJoinPool pool) { final ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool); worker.setName("ForkJoinPool-" + name + "-worker-" + nextWorkerId.getAndIncrement()); return worker; } } private static class StreamResource extends SimpleResource { private final Runnable finalization; public StreamResource(Runnable finalization) { this.finalization = finalization; } @Override protected void onClose() { if (finalization != null) { finalization.run(); } } } }