From a06ddc7d0ad305888364ef0384de46e0ef2dbf34 Mon Sep 17 00:00:00 2001 From: Andrea Cavalli Date: Wed, 25 Sep 2024 19:27:24 +0200 Subject: [PATCH] Zero allocation String encoder --- .../it/cavallium/buffer/BufDataOutput.java | 46 ++++++-- .../buffer/ZeroAllocationEncoder.java | 111 ++++++++++++++++++ .../stream/SafeByteArrayInputStream.java | 35 ++++++ .../it/cavallium/stream/SafeDataInput.java | 5 + .../cavallium/stream/SafeDataInputStream.java | 14 +++ .../stream/SafeFilterInputStream.java | 39 +++++- .../it/cavallium/stream/SafeInputStream.java | 7 ++ .../cavallium/buffer/BufDataOutputTest.java | 21 ++++ .../buffer/ZeroAllocationEncoderTest.java | 86 ++++++++++++++ 9 files changed, 353 insertions(+), 11 deletions(-) create mode 100644 datagen/src/main/java/it/cavallium/buffer/ZeroAllocationEncoder.java create mode 100644 datagen/src/test/java/it/cavallium/buffer/BufDataOutputTest.java create mode 100644 datagen/src/test/java/it/cavallium/buffer/ZeroAllocationEncoderTest.java diff --git a/datagen/src/main/java/it/cavallium/buffer/BufDataOutput.java b/datagen/src/main/java/it/cavallium/buffer/BufDataOutput.java index 0bd62eb..de204d6 100644 --- a/datagen/src/main/java/it/cavallium/buffer/BufDataOutput.java +++ b/datagen/src/main/java/it/cavallium/buffer/BufDataOutput.java @@ -10,6 +10,7 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Objects; + import org.jetbrains.annotations.NotNull; public class BufDataOutput implements SafeDataOutput { @@ -216,21 +217,46 @@ public class BufDataOutput implements SafeDataOutput { @Override public void writeShortText(String s, Charset charset) { - var out = s.getBytes(charset); - if (out.length > Short.MAX_VALUE) { - throw new IndexOutOfBoundsException("String too long: " + out.length + " bytes"); + if (charset == StandardCharsets.UTF_8) { + var beforeWrite = this.buf.position(); + this.buf.position(beforeWrite + Short.BYTES); + ZeroAllocationEncoder.INSTANCE.encodeTo(s, this); + var afterWrite = this.buf.position(); + this.buf.position(beforeWrite); + var len = Math.toIntExact(afterWrite - beforeWrite - Short.BYTES); + if (len > Short.MAX_VALUE) { + this.buf.position(beforeWrite); + throw new IndexOutOfBoundsException("String too long: " + len + " bytes"); + } + this.writeShort(len); + this.buf.position(afterWrite); + } else { + var out = s.getBytes(charset); + if (out.length > Short.MAX_VALUE) { + throw new IndexOutOfBoundsException("String too long: " + out.length + " bytes"); + } + checkOutOfBounds(Short.BYTES + out.length); + dOut.writeShort(out.length); + dOut.write(out); } - checkOutOfBounds(Short.BYTES + out.length); - dOut.writeShort(out.length); - dOut.write(out); } @Override public void writeMediumText(String s, Charset charset) { - var out = s.getBytes(charset); - checkOutOfBounds(Integer.BYTES + out.length); - dOut.writeInt(out.length); - dOut.write(out); + if (charset == StandardCharsets.UTF_8) { + var beforeWrite = this.buf.position(); + this.buf.position(beforeWrite + Integer.BYTES); + ZeroAllocationEncoder.INSTANCE.encodeTo(s, this); + var afterWrite = this.buf.position(); + this.buf.position(beforeWrite); + this.writeInt(Math.toIntExact(afterWrite - beforeWrite - Integer.BYTES)); + this.buf.position(afterWrite); + } else { + var out = s.getBytes(charset); + checkOutOfBounds(Integer.BYTES + out.length); + dOut.writeInt(out.length); + dOut.write(out); + } } public Buf asList() { diff --git a/datagen/src/main/java/it/cavallium/buffer/ZeroAllocationEncoder.java b/datagen/src/main/java/it/cavallium/buffer/ZeroAllocationEncoder.java new file mode 100644 index 0000000..962c225 --- /dev/null +++ b/datagen/src/main/java/it/cavallium/buffer/ZeroAllocationEncoder.java @@ -0,0 +1,111 @@ +package it.cavallium.buffer; + +import it.cavallium.stream.SafeDataInput; +import it.cavallium.stream.SafeDataOutput; + +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.*; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicReference; + +public class ZeroAllocationEncoder { + + public static final ZeroAllocationEncoder INSTANCE = new ZeroAllocationEncoder(8192); + + private static final ThreadLocal CHARSET_ENCODER_UTF8 = ThreadLocal.withInitial(() -> + StandardCharsets.UTF_8.newEncoder() + .onMalformedInput(CodingErrorAction.REPLACE) + .onUnmappableCharacter(CodingErrorAction.REPLACE)); + + private static final ThreadLocal CHARSET_DECODER_UTF8 = ThreadLocal.withInitial(() -> + StandardCharsets.UTF_8.newDecoder() + .onMalformedInput(CodingErrorAction.REPLACE) + .onUnmappableCharacter(CodingErrorAction.REPLACE)); + + private final ThreadLocal bufferThreadLocal; + + private final ThreadLocal> charBufferRefThreadLocal; + + public ZeroAllocationEncoder(int outBufferSize) { + bufferThreadLocal = ThreadLocal.withInitial(() -> ByteBuffer.allocate(outBufferSize)); + charBufferRefThreadLocal = ThreadLocal.withInitial(() -> new AtomicReference<>(CharBuffer.allocate(outBufferSize))); + } + + public void encodeTo(String s, SafeDataOutput bufDataOutput) { + var encoder = CHARSET_ENCODER_UTF8.get(); + var buf = bufferThreadLocal.get(); + var charBuffer = CharBuffer.wrap(s); + CoderResult result; + do { + buf.clear(); + result = encoder.encode(charBuffer, buf, true); + buf.flip(); + var bufArray = buf.array(); + var bufArrayOffset = buf.arrayOffset(); + bufDataOutput.write(bufArray, bufArrayOffset + buf.position(), buf.remaining()); + if (result.isUnderflow()) { + break; + } else if (result.isOverflow()) { + continue; + } else if (result.isError()) { + try { + result.throwException(); + } catch (CharacterCodingException e) { + // This should not happen + throw new Error(e); + } + } + } while (true); + } + + public String decodeFrom(SafeDataInput bufDataInput, int length) { + var decoder = CHARSET_DECODER_UTF8.get(); + var byteBuf = bufferThreadLocal.get(); + var charBufRef = charBufferRefThreadLocal.get(); + var charBuf = charBufRef.get(); + if (charBuf.capacity() < length) { + charBuf = CharBuffer.allocate(length); + charBufRef.set(charBuf); + } else { + charBuf.clear(); + } + var remainingLengthToRead = length; + CoderResult result; + do { + byteBuf.clear(); + bufDataInput.readFully(byteBuf, Math.min(remainingLengthToRead, byteBuf.limit())); + byteBuf.flip(); + remainingLengthToRead -= byteBuf.remaining(); + result = decoder.decode(byteBuf, charBuf, true); + if (result.isUnderflow()) { + if (remainingLengthToRead > 0) { + continue; + } else { + charBuf.flip(); + return charBuf.toString(); + } + } else if (result.isOverflow()) { + throw new UnsupportedOperationException(); + } else if (result.isError()) { + try { + result.throwException(); + } catch (CharacterCodingException e) { + // This should not happen + throw new Error(e); + } + } + } while (true); + } + + private CharBuffer getNextCharBuf(ArrayList charBufs, int charBufIndex) { + if (charBufIndex == 0) return charBufs.getFirst(); + if (charBufIndex >= charBufs.size()) { + var b = charBufs.getFirst().duplicate(); + charBufs.add(b); + return b; + } else { + return charBufs.get(charBufIndex); + } + } +} diff --git a/datagen/src/main/java/it/cavallium/stream/SafeByteArrayInputStream.java b/datagen/src/main/java/it/cavallium/stream/SafeByteArrayInputStream.java index 0d4229b..f4ef2e5 100644 --- a/datagen/src/main/java/it/cavallium/stream/SafeByteArrayInputStream.java +++ b/datagen/src/main/java/it/cavallium/stream/SafeByteArrayInputStream.java @@ -16,7 +16,10 @@ package it.cavallium.stream; +import java.nio.ByteBuffer; import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Objects; /** Simple, fast and repositionable byte-array input stream. * @@ -124,6 +127,38 @@ public class SafeByteArrayInputStream extends SafeMeasurableInputStream implemen return n; } + @Override + public void readNBytes(int length, ByteBuffer buffer) { + Objects.checkFromIndexSize(0, length, buffer.remaining()); + if (this.available() < length) { + throw new IndexOutOfBoundsException(this.length); + } + buffer.put(array, offset + this.position, length); + position += length; + } + + @Override + public int readNBytes(byte[] b, int off, int length) { + Objects.checkFromIndexSize(off, length, b.length); + var cappedLength = Math.min(this.available(), length); + if (cappedLength < 0) { + return 0; + } + System.arraycopy(array, this.offset + this.position, b, off, cappedLength); + position += cappedLength; + return cappedLength; + } + + @Override + public byte[] readNBytes(int length) { + if (this.available() < length) { + throw new IndexOutOfBoundsException(this.length); + } + var result = Arrays.copyOfRange(this.array, this.offset + position, this.offset + position + length); + position += length; + return result; + } + @Override public String readString(int length, Charset charset) { if (this.available() < length) { diff --git a/datagen/src/main/java/it/cavallium/stream/SafeDataInput.java b/datagen/src/main/java/it/cavallium/stream/SafeDataInput.java index 37f6e3f..5602c9f 100644 --- a/datagen/src/main/java/it/cavallium/stream/SafeDataInput.java +++ b/datagen/src/main/java/it/cavallium/stream/SafeDataInput.java @@ -2,6 +2,7 @@ package it.cavallium.stream; import java.io.Closeable; import java.io.DataInput; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -102,6 +103,10 @@ public interface SafeDataInput extends Closeable, DataInput { */ int read(byte[] b, int off, int len); + void readFully(ByteBuffer dst); + + void readFully(ByteBuffer dst, int len); + void readFully(byte @NotNull [] b); void readFully(byte @NotNull [] b, int off, int len); diff --git a/datagen/src/main/java/it/cavallium/stream/SafeDataInputStream.java b/datagen/src/main/java/it/cavallium/stream/SafeDataInputStream.java index 7b4435a..ef7582b 100644 --- a/datagen/src/main/java/it/cavallium/stream/SafeDataInputStream.java +++ b/datagen/src/main/java/it/cavallium/stream/SafeDataInputStream.java @@ -28,6 +28,8 @@ package it.cavallium.stream; import it.cavallium.buffer.IgnoreCoverage; import org.jetbrains.annotations.NotNull; +import java.nio.ByteBuffer; + public class SafeDataInputStream extends SafeFilterInputStream implements SafeDataInput { /** @@ -92,6 +94,18 @@ public class SafeDataInputStream extends SafeFilterInputStream implements SafeDa } } + @Override + public void readFully(ByteBuffer dst) { + readFully(dst, dst.remaining()); + } + + @Override + public final void readFully(ByteBuffer dst, int len) { + if (len < 0) + throw new IndexOutOfBoundsException(); + in.readNBytes(len, dst); + } + /** * See the general contract of the {@code skipBytes} * method of {@code DataInput}. diff --git a/datagen/src/main/java/it/cavallium/stream/SafeFilterInputStream.java b/datagen/src/main/java/it/cavallium/stream/SafeFilterInputStream.java index 6bcb1ca..6d0ce15 100644 --- a/datagen/src/main/java/it/cavallium/stream/SafeFilterInputStream.java +++ b/datagen/src/main/java/it/cavallium/stream/SafeFilterInputStream.java @@ -3,6 +3,8 @@ package it.cavallium.stream; import it.cavallium.buffer.IgnoreCoverage; import org.jetbrains.annotations.NotNull; +import java.io.OutputStream; +import java.nio.ByteBuffer; import java.nio.charset.Charset; /** @@ -83,7 +85,7 @@ public class SafeFilterInputStream extends SafeInputStream { */ @IgnoreCoverage public int read(byte @NotNull [] b) { - return read(b, 0, b.length); + return in.read(b, 0, b.length); } /** @@ -112,6 +114,26 @@ public class SafeFilterInputStream extends SafeInputStream { return in.read(b, off, len); } + @Override + public void readNBytes(int len, ByteBuffer buffer) { + in.readNBytes(len, buffer); + } + + @Override + public byte[] readAllBytes() { + return in.readAllBytes(); + } + + @Override + public int readNBytes(byte[] b, int off, int len) { + return in.readNBytes(b, off, len); + } + + @Override + public byte[] readNBytes(int len) { + return in.readNBytes(len); + } + /** * Skips over and discards {@code n} bytes of data from the * input stream. The {@code skip} method may, for a variety of @@ -227,4 +249,19 @@ public class SafeFilterInputStream extends SafeInputStream { public @NotNull String readString(int length, Charset charset) { return in.readString(length, charset); } + + @Override + public long transferTo(OutputStream out) { + return in.transferTo(out); + } + + @Override + public void skipNBytes(long n) { + in.skipNBytes(n); + } + + @Override + public String toString() { + return in.toString(); + } } diff --git a/datagen/src/main/java/it/cavallium/stream/SafeInputStream.java b/datagen/src/main/java/it/cavallium/stream/SafeInputStream.java index 26245c1..c853521 100644 --- a/datagen/src/main/java/it/cavallium/stream/SafeInputStream.java +++ b/datagen/src/main/java/it/cavallium/stream/SafeInputStream.java @@ -6,6 +6,7 @@ import org.jetbrains.annotations.NotNull; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Arrays; @@ -122,6 +123,12 @@ public abstract class SafeInputStream extends InputStream { return result; } + @IgnoreCoverage + public void readNBytes(int len, ByteBuffer buffer) { + var b = readNBytes(len); + buffer.put(b); + } + @IgnoreCoverage public int readNBytes(byte[] b, int off, int len) { Objects.checkFromIndexSize(off, len, b.length); diff --git a/datagen/src/test/java/it/cavallium/buffer/BufDataOutputTest.java b/datagen/src/test/java/it/cavallium/buffer/BufDataOutputTest.java new file mode 100644 index 0000000..c963904 --- /dev/null +++ b/datagen/src/test/java/it/cavallium/buffer/BufDataOutputTest.java @@ -0,0 +1,21 @@ +package it.cavallium.buffer; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; + +class BufDataOutputTest { + @Test + public void writeMediumText() { + var bdo = BufDataOutput.create(); + bdo.writeInt(5); + bdo.writeMediumText("ciao", StandardCharsets.UTF_8); + var buf2 = bdo.toList(); + var bdi = BufDataInput.create(buf2); + bdi.skipNBytes(Integer.BYTES); + Assertions.assertEquals("ciao", bdi.readMediumText(StandardCharsets.UTF_8)); + } +} \ No newline at end of file diff --git a/datagen/src/test/java/it/cavallium/buffer/ZeroAllocationEncoderTest.java b/datagen/src/test/java/it/cavallium/buffer/ZeroAllocationEncoderTest.java new file mode 100644 index 0000000..e1efaa5 --- /dev/null +++ b/datagen/src/test/java/it/cavallium/buffer/ZeroAllocationEncoderTest.java @@ -0,0 +1,86 @@ +package it.cavallium.buffer; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; + +class ZeroAllocationEncoderTest { + + private static final ZeroAllocationEncoder INSTANCE = new ZeroAllocationEncoder(16); + + @Test + void encodeToEmpty() { + testEncodeString(""); + } + + @Test + void decodeEmpty() { + testDecodeString(""); + } + + @Test + void encodeTo1Underflow() { + testEncodeString("ciao"); + } + + @Test + void decode1Underflow() { + testDecodeString("ciao"); + } + + @Test + void encodeToExact1() { + testEncodeString("lorem ipsum dolo"); + } + + @Test + void decodeExact1() { + testDecodeString("lorem ipsum dolo"); + } + + @Test + void encodeToOverflow1() { + testEncodeString("lorem ipsum dolor sit amet"); + } + + @Test + void decodeOverflow1() { + testDecodeString("lorem ipsum dolor sit amet"); + } + + @Test + void encodeToExact2() { + testEncodeString("lorem ipsum dolor sit amet my na"); + } + + @Test + void decodeExact2() { + testDecodeString("lorem ipsum dolor sit amet my na"); + } + + @Test + void encodeToOverflow2() { + testEncodeString("lorem ipsum dolor sit amet my name is giovanni"); + } + + @Test + void decodeOverflow2() { + testDecodeString("lorem ipsum dolor sit amet my name is giovanni"); + } + + public void testEncodeString(String s) { + var bdo = BufDataOutput.create(); + INSTANCE.encodeTo(s, bdo); + var out = bdo.toList().toString(StandardCharsets.UTF_8); + Assertions.assertEquals(s, out); + } + + private void testDecodeString(String s) { + var in = BufDataInput.create(Buf.wrap(s.getBytes(StandardCharsets.UTF_8))); + var out = INSTANCE.decodeFrom(in, in.available()); + Assertions.assertEquals(s, out); + } +} \ No newline at end of file