diff --git a/src/main/java/it/cavallium/rockserver/core/common/Callback.java b/src/main/java/it/cavallium/rockserver/core/common/Callback.java index fe5e006..cf89933 100644 --- a/src/main/java/it/cavallium/rockserver/core/common/Callback.java +++ b/src/main/java/it/cavallium/rockserver/core/common/Callback.java @@ -1,5 +1,6 @@ package it.cavallium.rockserver.core.common; +import it.cavallium.rockserver.core.common.Callback.CallbackPreviousPresence; import java.util.List; import org.jetbrains.annotations.Nullable; @@ -11,6 +12,10 @@ public sealed interface Callback { || callback instanceof CallbackChanged; } + static boolean requiresGettingPreviousPresence(PutCallback callback) { + return callback instanceof Callback.CallbackPreviousPresence; + } + static boolean requiresGettingCurrentValue(GetCallback callback) { return callback instanceof CallbackCurrent; } @@ -50,6 +55,11 @@ public sealed interface Callback { return (CallbackChanged) CallbackChanged.INSTANCE; } + @SuppressWarnings("unchecked") + static CallbackPreviousPresence previousPresence() { + return (CallbackPreviousPresence) CallbackPreviousPresence.INSTANCE; + } + @SuppressWarnings("unchecked") static CallbackVoid none() { return (CallbackVoid) CallbackVoid.INSTANCE; @@ -97,4 +107,9 @@ public sealed interface Callback { private static final CallbackChanged INSTANCE = new CallbackChanged<>(); } + + record CallbackPreviousPresence() implements PutCallback, PatchCallback { + + private static final CallbackPreviousPresence INSTANCE = new CallbackPreviousPresence<>(); + } } diff --git a/src/main/java/it/cavallium/rockserver/core/common/Utils.java b/src/main/java/it/cavallium/rockserver/core/common/Utils.java index 7d3dde1..05b7a61 100644 --- a/src/main/java/it/cavallium/rockserver/core/common/Utils.java +++ b/src/main/java/it/cavallium/rockserver/core/common/Utils.java @@ -17,9 +17,12 @@ import java.util.List; import java.util.function.Function; import java.util.stream.Stream; +import org.jetbrains.annotations.Contract; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class Utils { + @SuppressWarnings("resource") private static final MemorySegment DUMMY_EMPTY_VALUE = Arena .global() @@ -62,24 +65,26 @@ public class Utils { } } - @NotNull + @Contract(value = "!null -> !null; null -> null", pure = true) public static MemorySegment toMemorySegment(byte... array) { if (array != null) { return MemorySegment.ofArray(array); } else { - return MemorySegment.NULL; + return null; } } - public static MemorySegment toMemorySegment(Arena arena, byte... array) { + @Contract("null, !null -> fail; _, null -> null") + public static @Nullable MemorySegment toMemorySegment(Arena arena, byte... array) { if (array != null) { + assert arena != null; return arena.allocateArray(ValueLayout.JAVA_BYTE, array); } else { - return MemorySegment.NULL; + return null; } } - @NotNull + @Contract(value = "!null -> !null; null -> null", pure = true) public static MemorySegment toMemorySegmentSimple(int... array) { if (array != null) { var newArray = new byte[array.length]; @@ -88,26 +93,27 @@ public class Utils { } return MemorySegment.ofArray(newArray); } else { - return MemorySegment.NULL; + return null; } } - @NotNull + @Contract("null, !null -> fail; _, null -> null") public static MemorySegment toMemorySegmentSimple(Arena arena, int... array) { if (array != null) { + assert arena != null; var newArray = new byte[array.length]; for (int i = 0; i < array.length; i++) { newArray[i] = (byte) array[i]; } return arena.allocateArray(ValueLayout.JAVA_BYTE, newArray); } else { - return MemorySegment.NULL; + return null; } } - public static byte[] toByteArray(MemorySegment memorySegment) { - return memorySegment.toArray(BIG_ENDIAN_BYTES); - } + public static byte @NotNull[] toByteArray(@NotNull MemorySegment memorySegment) { + return memorySegment.toArray(BIG_ENDIAN_BYTES); + } public static List mapList(Collection input, Function mapper) { var result = new ArrayList(input.size()); @@ -132,6 +138,7 @@ public class Utils { public static boolean valueEquals(MemorySegment previousValue, MemorySegment currentValue) { previousValue = requireNonNullElse(previousValue, NULL); currentValue = requireNonNullElse(currentValue, NULL); - return MemorySegment.mismatch(previousValue, 0, previousValue.byteSize(), currentValue, 0, currentValue.byteSize()) == -1; + return MemorySegment.mismatch(previousValue, 0, previousValue.byteSize(), currentValue, 0, currentValue.byteSize()) + == -1; } } diff --git a/src/main/java/it/cavallium/rockserver/core/impl/EmbeddedDB.java b/src/main/java/it/cavallium/rockserver/core/impl/EmbeddedDB.java index 7ca031e..b9cf2ba 100644 --- a/src/main/java/it/cavallium/rockserver/core/impl/EmbeddedDB.java +++ b/src/main/java/it/cavallium/rockserver/core/impl/EmbeddedDB.java @@ -6,6 +6,7 @@ import static org.rocksdb.KeyMayExist.KeyMayExistEnum.kExistsWithValue; import static org.rocksdb.KeyMayExist.KeyMayExistEnum.kExistsWithoutValue; import it.cavallium.rockserver.core.common.Callback; +import it.cavallium.rockserver.core.common.Callback.CallbackExists; import it.cavallium.rockserver.core.common.Callback.GetCallback; import it.cavallium.rockserver.core.common.Callback.IteratorCallback; import it.cavallium.rockserver.core.common.Callback.PutCallback; @@ -34,6 +35,8 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeoutException; +import java.util.function.Consumer; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; import org.cliffc.high_scale_lib.NonBlockingHashMapLong; @@ -314,8 +317,45 @@ public class EmbeddedDB implements RocksDBAPI, Closeable { } } + /** + * @param txConsumer this can be called multiple times, if the optimistic transaction failed + */ + public R wrapWithTransactionIfNeeded(@Nullable REntry tx, boolean needTransaction, + ExFunction<@Nullable REntry, R> txConsumer) throws Exception { + if (needTransaction) { + return ensureWrapWithTransaction(tx, txConsumer); + } else { + return txConsumer.apply(tx); + } + } + + + /** + * @param txConsumer this can be called multiple times, if the optimistic transaction failed + */ + public R ensureWrapWithTransaction(@Nullable REntry tx, + ExFunction<@NotNull REntry, R> txConsumer) throws Exception { + R result; + if (tx == null) { + // Retry using a transaction: transactions are required to handle this kind of data + var newTx = this.openTransactionInternal(Long.MAX_VALUE); + try { + boolean committed; + do { + result = txConsumer.apply(newTx); + committed = this.closeTransaction(newTx, true); + } while (!committed); + } finally { + this.closeTransaction(newTx, false); + } + } else { + result = txConsumer.apply(tx); + } + return result; + } + private U put(Arena arena, - @Nullable REntry tx, + @Nullable REntry optionalTx, ColumnInstance col, @NotNull MemorySegment @NotNull[] keys, @NotNull MemorySegment value, @@ -323,64 +363,64 @@ public class EmbeddedDB implements RocksDBAPI, Closeable { // Check for null value col.checkNullableValue(value); try { - MemorySegment previousValue; - MemorySegment calculatedKey = col.calculateKey(arena, keys); - if (col.hasBuckets()) { - if (tx != null) { + boolean needsTx = col.hasBuckets() + || Callback.requiresGettingPreviousValue(callback) + || Callback.requiresGettingPreviousPresence(callback); + return wrapWithTransactionIfNeeded(optionalTx, needsTx, tx -> { + MemorySegment previousValue; + MemorySegment calculatedKey = col.calculateKey(arena, keys); + if (col.hasBuckets()) { + assert tx != null; var bucketElementKeys = col.getBucketElementKeys(keys); try (var readOptions = new ReadOptions()) { var previousRawBucketByteArray = tx.val().getForUpdate(readOptions, col.cfh(), calculatedKey.toArray(BIG_ENDIAN_BYTES), true); MemorySegment previousRawBucket = toMemorySegment(arena, previousRawBucketByteArray); - var bucket = new Bucket(col, previousRawBucket); - previousValue = bucket.addElement(bucketElementKeys, value); - tx.val().put(col.cfh(), Utils.toByteArray(calculatedKey), Utils.toByteArray(bucket.toSegment(arena))); + var bucket = previousRawBucket != null ? new Bucket(col, previousRawBucket) : new Bucket(col); + previousValue = transformResultValue(col, bucket.addElement(bucketElementKeys, value)); + tx.val().put(col.cfh(), Utils.toByteArray(calculatedKey), Utils.toByteArray(bucket.toSegment(arena))); } catch (RocksDBException e) { throw it.cavallium.rockserver.core.common.RocksDBException.of(RocksDBErrorType.PUT_1, e); } } else { - // Retry using a transaction: transactions are required to handle this kind of data - var newTx = this.openTransactionInternal(Long.MAX_VALUE); - try { - boolean committed; - do { - previousValue = put(arena, newTx, col, keys, value, Callback.previous()); - committed = this.closeTransaction(newTx, true); - } while (!committed); - } finally { - this.closeTransaction(newTx, false); - } - } - } else { - if (Callback.requiresGettingPreviousValue(callback)) { - try (var readOptions = new ReadOptions()) { - byte[] previousValueByteArray; - if (tx != null) { - previousValueByteArray = tx.val().get(col.cfh(), readOptions, calculatedKey.toArray(BIG_ENDIAN_BYTES)); - } else { - previousValueByteArray = db.get().get(col.cfh(), readOptions, calculatedKey.toArray(BIG_ENDIAN_BYTES)); + if (Callback.requiresGettingPreviousValue(callback)) { + assert tx != null; + try (var readOptions = new ReadOptions()) { + byte[] previousValueByteArray; + previousValueByteArray = tx.val().getForUpdate(readOptions, col.cfh(), calculatedKey.toArray(BIG_ENDIAN_BYTES), true); + previousValue = transformResultValue(col, toMemorySegment(arena, previousValueByteArray)); + } catch (RocksDBException e) { + throw it.cavallium.rockserver.core.common.RocksDBException.of(RocksDBErrorType.PUT_2, e); } - previousValue = toMemorySegment(arena, previousValueByteArray); - } catch (RocksDBException e) { - throw it.cavallium.rockserver.core.common.RocksDBException.of(RocksDBErrorType.PUT_2, e); + } else if (Callback.requiresGettingPreviousPresence(callback)) { + // todo: in the future this should be replaced with just keyExists + assert tx != null; + try (var readOptions = new ReadOptions()) { + byte[] previousValueByteArray; + previousValueByteArray = tx.val().getForUpdate(readOptions, col.cfh(), calculatedKey.toArray(BIG_ENDIAN_BYTES), true); + previousValue = previousValueByteArray != null ? MemorySegment.NULL : null; + } catch (RocksDBException e) { + throw it.cavallium.rockserver.core.common.RocksDBException.of(RocksDBErrorType.PUT_2, e); + } + } else { + previousValue = null; } - } else { - previousValue = null; - } - if (tx != null) { - tx.val().put(col.cfh(), Utils.toByteArray(calculatedKey), Utils.toByteArray(value)); - } else { - try (var w = new WriteOptions()) { - var keyBB = calculatedKey.asByteBuffer(); - ByteBuffer valueBB = (col.schema().hasValue() ? value : Utils.dummyEmptyValue()).asByteBuffer(); - db.get().put(col.cfh(), w, keyBB, valueBB); + if (tx != null) { + tx.val().put(col.cfh(), Utils.toByteArray(calculatedKey), Utils.toByteArray(value)); + } else { + try (var w = new WriteOptions()) { + var keyBB = calculatedKey.asByteBuffer(); + ByteBuffer valueBB = (col.schema().hasValue() ? value : Utils.dummyEmptyValue()).asByteBuffer(); + db.get().put(col.cfh(), w, keyBB, valueBB); + } } } - } - return Callback.safeCast(switch (callback) { - case Callback.CallbackVoid ignored -> null; - case Callback.CallbackPrevious ignored -> previousValue; - case Callback.CallbackChanged ignored -> Utils.valueEquals(previousValue, value); - case Callback.CallbackDelta ignored -> new Delta<>(previousValue, value); + return Callback.safeCast(switch (callback) { + case Callback.CallbackVoid ignored -> null; + case Callback.CallbackPrevious ignored -> previousValue; + case Callback.CallbackPreviousPresence ignored -> previousValue != null; + case Callback.CallbackChanged ignored -> !Utils.valueEquals(previousValue, value); + case Callback.CallbackDelta ignored -> new Delta<>(previousValue, value); + }); }); } catch (it.cavallium.rockserver.core.common.RocksDBException ex) { throw ex; @@ -389,6 +429,10 @@ public class EmbeddedDB implements RocksDBAPI, Closeable { } } + private MemorySegment transformResultValue(ColumnInstance col, MemorySegment realPreviousValue) { + return col.schema().hasValue() ? realPreviousValue : (realPreviousValue != null ? MemorySegment.NULL : null); + } + @Override public T get(Arena arena, long transactionId, @@ -400,6 +444,11 @@ public class EmbeddedDB implements RocksDBAPI, Closeable { // Column id var col = getColumn(columnId); + if (!col.schema().hasValue() && Callback.requiresGettingCurrentValue(callback)) { + throw it.cavallium.rockserver.core.common.RocksDBException.of(RocksDBErrorType.VALUE_MUST_BE_NULL, + "The specified callback requires a return value, but this column does not have values!"); + } + MemorySegment foundValue; boolean existsValue; REntry tx; @@ -438,7 +487,7 @@ public class EmbeddedDB implements RocksDBAPI, Closeable { //noinspection ConstantValue assert tx == null; foundValue = null; - existsValue = db.get().keyExists(calculatedKey.asByteBuffer()); + existsValue = db.get().keyExists(col.cfh(), calculatedKey.asByteBuffer()); } else { foundValue = null; existsValue = false; diff --git a/src/main/java/it/cavallium/rockserver/core/impl/ExFunction.java b/src/main/java/it/cavallium/rockserver/core/impl/ExFunction.java new file mode 100644 index 0000000..dcb20de --- /dev/null +++ b/src/main/java/it/cavallium/rockserver/core/impl/ExFunction.java @@ -0,0 +1,7 @@ +package it.cavallium.rockserver.core.impl; + +@FunctionalInterface +public interface ExFunction { + + U apply(T input) throws Exception; +} diff --git a/src/test/java/it/cavallium/rockserver/core/impl/test/EmbeddedDBTest.java b/src/test/java/it/cavallium/rockserver/core/impl/test/EmbeddedDBTest.java index 9fd1f49..d72c642 100644 --- a/src/test/java/it/cavallium/rockserver/core/impl/test/EmbeddedDBTest.java +++ b/src/test/java/it/cavallium/rockserver/core/impl/test/EmbeddedDBTest.java @@ -11,6 +11,7 @@ import it.cavallium.rockserver.core.common.Utils; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.objects.ObjectList; import java.lang.foreign.Arena; +import java.util.Arrays; import java.util.concurrent.ThreadLocalRandom; import org.junit.jupiter.api.Assertions; @@ -156,6 +157,8 @@ abstract class EmbeddedDBTest { if (!getHasValues()) { Assertions.assertThrows(Exception.class, () -> db.put(arena, 0, colId, key, toMemorySegmentSimple(arena, 123), Callback.delta())); + } else { + Assertions.assertThrows(Exception.class, () -> db.put(arena, 0, colId, key, MemorySegment.NULL, Callback.delta())); } Assertions.assertThrows(Exception.class, () -> db.put(arena, 0, colId, key, null, Callback.delta())); @@ -176,11 +179,11 @@ abstract class EmbeddedDBTest { delta = db.put(arena, 0, colId, key, value1, Callback.delta()); Assertions.assertNull(delta.previous()); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value1)); + assertSegmentEquals(value1, delta.current()); delta = db.put(arena, 0, colId, key, value2, Callback.delta()); - Assertions.assertTrue(Utils.valueEquals(delta.previous(), value1)); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value2)); + assertSegmentEquals(value1, delta.previous()); + assertSegmentEquals(value2, delta.current()); } @Test @@ -198,17 +201,61 @@ abstract class EmbeddedDBTest { Delta delta; + Assertions.assertFalse(db.put(arena, 0, colId, getKeyI(3), value2, Callback.previousPresence())); + Assertions.assertFalse(db.put(arena, 0, colId, getKeyI(4), value2, Callback.previousPresence())); + + delta = db.put(arena, 0, colId, key1, value1, Callback.delta()); Assertions.assertNull(delta.previous()); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value1)); + assertSegmentEquals(value1, delta.current()); delta = db.put(arena, 0, colId, key2, value2, Callback.delta()); Assertions.assertNull(delta.previous()); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value2)); + assertSegmentEquals(value2, delta.current()); delta = db.put(arena, 0, colId, key2, value1, Callback.delta()); - Assertions.assertTrue(Utils.valueEquals(delta.previous(), value2)); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value1)); + assertSegmentEquals(value2, delta.previous()); + assertSegmentEquals(value1, delta.current()); + + delta = db.put(arena, 0, colId, key2, value1, Callback.delta()); + assertSegmentEquals(value1, delta.previous()); + assertSegmentEquals(value1, delta.current()); + + + Assertions.assertTrue(db.put(arena, 0, colId, key1, value2, Callback.previousPresence())); + Assertions.assertTrue(db.put(arena, 0, colId, key2, value2, Callback.previousPresence())); + + + delta = db.put(arena, 0, colId, key1, value1, Callback.delta()); + assertSegmentEquals(value2, delta.previous()); + assertSegmentEquals(value1, delta.current()); + + delta = db.put(arena, 0, colId, key2, value1, Callback.delta()); + assertSegmentEquals(value2, delta.previous()); + assertSegmentEquals(value1, delta.current()); + + + Assertions.assertNull(db.put(arena, 0, colId, key1, value2, Callback.none())); + Assertions.assertNull(db.put(arena, 0, colId, key2, value2, Callback.none())); + + + assertSegmentEquals(value2, db.put(arena, 0, colId, key1, value1, Callback.previous())); + assertSegmentEquals(value2, db.put(arena, 0, colId, key2, value1, Callback.previous())); + + assertSegmentEquals(value1, db.put(arena, 0, colId, key1, value1, Callback.previous())); + assertSegmentEquals(value1, db.put(arena, 0, colId, key2, value1, Callback.previous())); + + if (!Utils.valueEquals(value1, value2)) { + Assertions.assertTrue(db.put(arena, 0, colId, key1, value2, Callback.changed())); + Assertions.assertTrue(db.put(arena, 0, colId, key2, value2, Callback.changed())); + } + + Assertions.assertFalse(db.put(arena, 0, colId, key1, value2, Callback.changed())); + Assertions.assertFalse(db.put(arena, 0, colId, key2, value2, Callback.changed())); + + + assertSegmentEquals(value2, db.put(arena, 0, colId, key1, value1, Callback.previous())); + assertSegmentEquals(value2, db.put(arena, 0, colId, key2, value1, Callback.previous())); } protected ColumnSchema getSchema() { @@ -237,35 +284,58 @@ abstract class EmbeddedDBTest { var delta = db.put(arena, 0, colId, key1, value1, Callback.delta()); Assertions.assertNull(delta.previous()); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value1)); + assertSegmentEquals(value1, delta.current()); delta = db.put(arena, 0, colId, collidingKey1, value2, Callback.delta()); Assertions.assertNull(delta.previous()); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value2)); + assertSegmentEquals(value2, delta.current()); delta = db.put(arena, 0, colId, collidingKey1, value1, Callback.delta()); - Assertions.assertTrue(Utils.valueEquals(delta.previous(), value2)); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value1)); + assertSegmentEquals(value2, delta.previous()); + assertSegmentEquals(value1, delta.current()); delta = db.put(arena, 0, colId, key2, value1, Callback.delta()); Assertions.assertNull(delta.previous()); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value1)); + assertSegmentEquals(value1, delta.current()); delta = db.put(arena, 0, colId, key2, value2, Callback.delta()); - Assertions.assertTrue(Utils.valueEquals(delta.previous(), value1)); - Assertions.assertTrue(Utils.valueEquals(delta.current(), value2)); + assertSegmentEquals(value1, delta.previous()); + assertSegmentEquals(value2, delta.current()); } @Test void get() { - Assertions.assertNull(db.get(arena, 0, colId, key1, Callback.current())); - Assertions.assertNull(db.get(arena, 0, colId, collidingKey1, Callback.current())); - Assertions.assertNull(db.get(arena, 0, colId, key2, Callback.current())); + if (getHasValues()) { + Assertions.assertNull(db.get(arena, 0, colId, key1, Callback.current())); + Assertions.assertNull(db.get(arena, 0, colId, collidingKey1, Callback.current())); + Assertions.assertNull(db.get(arena, 0, colId, key2, Callback.current())); + } + Assertions.assertFalse(db.get(arena, 0, colId, key1, Callback.exists())); + Assertions.assertFalse(db.get(arena, 0, colId, collidingKey1, Callback.exists())); + Assertions.assertFalse(db.get(arena, 0, colId, key2, Callback.exists())); + fillSomeKeys(); - Assertions.assertTrue(Utils.valueEquals(value1, db.get(arena, 0, colId, key1, Callback.current()))); - Assertions.assertNull(db.get(arena, 0, colId, getNotFoundKeyI(0), Callback.current())); - Assertions.assertTrue(Utils.valueEquals(value2, db.get(arena, 0, colId, collidingKey1, Callback.current()))); - Assertions.assertTrue(Utils.valueEquals(bigValue, db.get(arena, 0, colId, key2, Callback.current()))); + + if (getHasValues()) { + assertSegmentEquals(value1, db.get(arena, 0, colId, key1, Callback.current())); + Assertions.assertNull(db.get(arena, 0, colId, getNotFoundKeyI(0), Callback.current())); + assertSegmentEquals(value2, db.get(arena, 0, colId, collidingKey1, Callback.current())); + assertSegmentEquals(bigValue, db.get(arena, 0, colId, key2, Callback.current())); + } + + Assertions.assertTrue(db.get(arena, 0, colId, key1, Callback.exists())); + Assertions.assertFalse(db.get(arena, 0, colId, getNotFoundKeyI(0), Callback.exists())); + Assertions.assertTrue(db.get(arena, 0, colId, collidingKey1, Callback.exists())); + Assertions.assertTrue(db.get(arena, 0, colId, key2, Callback.exists())); + } + + public static void assertSegmentEquals(MemorySegment expected, MemorySegment input) { + if (!Utils.valueEquals(expected, input)) { + Assertions.fail( + "Memory segments are not equal! Expected: " + + (expected != null ? Arrays.toString(Utils.toByteArray(expected)) : "null") + + ", Input: " + (input != null ? Arrays.toString(Utils.toByteArray(input)) : "null")); + } } @SuppressWarnings("DataFlowIssue")