Zero allocation String encoder

This commit is contained in:
Andrea Cavalli 2024-09-25 19:27:24 +02:00
parent ea4f6db7a9
commit a06ddc7d0a
9 changed files with 353 additions and 11 deletions

View File

@ -10,6 +10,7 @@ import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Objects;
import org.jetbrains.annotations.NotNull;
public class BufDataOutput implements SafeDataOutput {
@ -216,21 +217,46 @@ public class BufDataOutput implements SafeDataOutput {
@Override
public void writeShortText(String s, Charset charset) {
var out = s.getBytes(charset);
if (out.length > Short.MAX_VALUE) {
throw new IndexOutOfBoundsException("String too long: " + out.length + " bytes");
if (charset == StandardCharsets.UTF_8) {
var beforeWrite = this.buf.position();
this.buf.position(beforeWrite + Short.BYTES);
ZeroAllocationEncoder.INSTANCE.encodeTo(s, this);
var afterWrite = this.buf.position();
this.buf.position(beforeWrite);
var len = Math.toIntExact(afterWrite - beforeWrite - Short.BYTES);
if (len > Short.MAX_VALUE) {
this.buf.position(beforeWrite);
throw new IndexOutOfBoundsException("String too long: " + len + " bytes");
}
this.writeShort(len);
this.buf.position(afterWrite);
} else {
var out = s.getBytes(charset);
if (out.length > Short.MAX_VALUE) {
throw new IndexOutOfBoundsException("String too long: " + out.length + " bytes");
}
checkOutOfBounds(Short.BYTES + out.length);
dOut.writeShort(out.length);
dOut.write(out);
}
checkOutOfBounds(Short.BYTES + out.length);
dOut.writeShort(out.length);
dOut.write(out);
}
@Override
public void writeMediumText(String s, Charset charset) {
var out = s.getBytes(charset);
checkOutOfBounds(Integer.BYTES + out.length);
dOut.writeInt(out.length);
dOut.write(out);
if (charset == StandardCharsets.UTF_8) {
var beforeWrite = this.buf.position();
this.buf.position(beforeWrite + Integer.BYTES);
ZeroAllocationEncoder.INSTANCE.encodeTo(s, this);
var afterWrite = this.buf.position();
this.buf.position(beforeWrite);
this.writeInt(Math.toIntExact(afterWrite - beforeWrite - Integer.BYTES));
this.buf.position(afterWrite);
} else {
var out = s.getBytes(charset);
checkOutOfBounds(Integer.BYTES + out.length);
dOut.writeInt(out.length);
dOut.write(out);
}
}
public Buf asList() {

View File

@ -0,0 +1,111 @@
package it.cavallium.buffer;
import it.cavallium.stream.SafeDataInput;
import it.cavallium.stream.SafeDataOutput;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.*;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicReference;
public class ZeroAllocationEncoder {
public static final ZeroAllocationEncoder INSTANCE = new ZeroAllocationEncoder(8192);
private static final ThreadLocal<CharsetEncoder> CHARSET_ENCODER_UTF8 = ThreadLocal.withInitial(() ->
StandardCharsets.UTF_8.newEncoder()
.onMalformedInput(CodingErrorAction.REPLACE)
.onUnmappableCharacter(CodingErrorAction.REPLACE));
private static final ThreadLocal<CharsetDecoder> CHARSET_DECODER_UTF8 = ThreadLocal.withInitial(() ->
StandardCharsets.UTF_8.newDecoder()
.onMalformedInput(CodingErrorAction.REPLACE)
.onUnmappableCharacter(CodingErrorAction.REPLACE));
private final ThreadLocal<ByteBuffer> bufferThreadLocal;
private final ThreadLocal<AtomicReference<CharBuffer>> charBufferRefThreadLocal;
public ZeroAllocationEncoder(int outBufferSize) {
bufferThreadLocal = ThreadLocal.withInitial(() -> ByteBuffer.allocate(outBufferSize));
charBufferRefThreadLocal = ThreadLocal.withInitial(() -> new AtomicReference<>(CharBuffer.allocate(outBufferSize)));
}
public void encodeTo(String s, SafeDataOutput bufDataOutput) {
var encoder = CHARSET_ENCODER_UTF8.get();
var buf = bufferThreadLocal.get();
var charBuffer = CharBuffer.wrap(s);
CoderResult result;
do {
buf.clear();
result = encoder.encode(charBuffer, buf, true);
buf.flip();
var bufArray = buf.array();
var bufArrayOffset = buf.arrayOffset();
bufDataOutput.write(bufArray, bufArrayOffset + buf.position(), buf.remaining());
if (result.isUnderflow()) {
break;
} else if (result.isOverflow()) {
continue;
} else if (result.isError()) {
try {
result.throwException();
} catch (CharacterCodingException e) {
// This should not happen
throw new Error(e);
}
}
} while (true);
}
public String decodeFrom(SafeDataInput bufDataInput, int length) {
var decoder = CHARSET_DECODER_UTF8.get();
var byteBuf = bufferThreadLocal.get();
var charBufRef = charBufferRefThreadLocal.get();
var charBuf = charBufRef.get();
if (charBuf.capacity() < length) {
charBuf = CharBuffer.allocate(length);
charBufRef.set(charBuf);
} else {
charBuf.clear();
}
var remainingLengthToRead = length;
CoderResult result;
do {
byteBuf.clear();
bufDataInput.readFully(byteBuf, Math.min(remainingLengthToRead, byteBuf.limit()));
byteBuf.flip();
remainingLengthToRead -= byteBuf.remaining();
result = decoder.decode(byteBuf, charBuf, true);
if (result.isUnderflow()) {
if (remainingLengthToRead > 0) {
continue;
} else {
charBuf.flip();
return charBuf.toString();
}
} else if (result.isOverflow()) {
throw new UnsupportedOperationException();
} else if (result.isError()) {
try {
result.throwException();
} catch (CharacterCodingException e) {
// This should not happen
throw new Error(e);
}
}
} while (true);
}
private CharBuffer getNextCharBuf(ArrayList<CharBuffer> charBufs, int charBufIndex) {
if (charBufIndex == 0) return charBufs.getFirst();
if (charBufIndex >= charBufs.size()) {
var b = charBufs.getFirst().duplicate();
charBufs.add(b);
return b;
} else {
return charBufs.get(charBufIndex);
}
}
}

View File

@ -16,7 +16,10 @@
package it.cavallium.stream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Objects;
/** Simple, fast and repositionable byte-array input stream.
*
@ -124,6 +127,38 @@ public class SafeByteArrayInputStream extends SafeMeasurableInputStream implemen
return n;
}
@Override
public void readNBytes(int length, ByteBuffer buffer) {
Objects.checkFromIndexSize(0, length, buffer.remaining());
if (this.available() < length) {
throw new IndexOutOfBoundsException(this.length);
}
buffer.put(array, offset + this.position, length);
position += length;
}
@Override
public int readNBytes(byte[] b, int off, int length) {
Objects.checkFromIndexSize(off, length, b.length);
var cappedLength = Math.min(this.available(), length);
if (cappedLength < 0) {
return 0;
}
System.arraycopy(array, this.offset + this.position, b, off, cappedLength);
position += cappedLength;
return cappedLength;
}
@Override
public byte[] readNBytes(int length) {
if (this.available() < length) {
throw new IndexOutOfBoundsException(this.length);
}
var result = Arrays.copyOfRange(this.array, this.offset + position, this.offset + position + length);
position += length;
return result;
}
@Override
public String readString(int length, Charset charset) {
if (this.available() < length) {

View File

@ -2,6 +2,7 @@ package it.cavallium.stream;
import java.io.Closeable;
import java.io.DataInput;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
@ -102,6 +103,10 @@ public interface SafeDataInput extends Closeable, DataInput {
*/
int read(byte[] b, int off, int len);
void readFully(ByteBuffer dst);
void readFully(ByteBuffer dst, int len);
void readFully(byte @NotNull [] b);
void readFully(byte @NotNull [] b, int off, int len);

View File

@ -28,6 +28,8 @@ package it.cavallium.stream;
import it.cavallium.buffer.IgnoreCoverage;
import org.jetbrains.annotations.NotNull;
import java.nio.ByteBuffer;
public class SafeDataInputStream extends SafeFilterInputStream implements SafeDataInput {
/**
@ -92,6 +94,18 @@ public class SafeDataInputStream extends SafeFilterInputStream implements SafeDa
}
}
@Override
public void readFully(ByteBuffer dst) {
readFully(dst, dst.remaining());
}
@Override
public final void readFully(ByteBuffer dst, int len) {
if (len < 0)
throw new IndexOutOfBoundsException();
in.readNBytes(len, dst);
}
/**
* See the general contract of the {@code skipBytes}
* method of {@code DataInput}.

View File

@ -3,6 +3,8 @@ package it.cavallium.stream;
import it.cavallium.buffer.IgnoreCoverage;
import org.jetbrains.annotations.NotNull;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
/**
@ -83,7 +85,7 @@ public class SafeFilterInputStream extends SafeInputStream {
*/
@IgnoreCoverage
public int read(byte @NotNull [] b) {
return read(b, 0, b.length);
return in.read(b, 0, b.length);
}
/**
@ -112,6 +114,26 @@ public class SafeFilterInputStream extends SafeInputStream {
return in.read(b, off, len);
}
@Override
public void readNBytes(int len, ByteBuffer buffer) {
in.readNBytes(len, buffer);
}
@Override
public byte[] readAllBytes() {
return in.readAllBytes();
}
@Override
public int readNBytes(byte[] b, int off, int len) {
return in.readNBytes(b, off, len);
}
@Override
public byte[] readNBytes(int len) {
return in.readNBytes(len);
}
/**
* Skips over and discards {@code n} bytes of data from the
* input stream. The {@code skip} method may, for a variety of
@ -227,4 +249,19 @@ public class SafeFilterInputStream extends SafeInputStream {
public @NotNull String readString(int length, Charset charset) {
return in.readString(length, charset);
}
@Override
public long transferTo(OutputStream out) {
return in.transferTo(out);
}
@Override
public void skipNBytes(long n) {
in.skipNBytes(n);
}
@Override
public String toString() {
return in.toString();
}
}

View File

@ -6,6 +6,7 @@ import org.jetbrains.annotations.NotNull;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
@ -122,6 +123,12 @@ public abstract class SafeInputStream extends InputStream {
return result;
}
@IgnoreCoverage
public void readNBytes(int len, ByteBuffer buffer) {
var b = readNBytes(len);
buffer.put(b);
}
@IgnoreCoverage
public int readNBytes(byte[] b, int off, int len) {
Objects.checkFromIndexSize(off, len, b.length);

View File

@ -0,0 +1,21 @@
package it.cavallium.buffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.*;
class BufDataOutputTest {
@Test
public void writeMediumText() {
var bdo = BufDataOutput.create();
bdo.writeInt(5);
bdo.writeMediumText("ciao", StandardCharsets.UTF_8);
var buf2 = bdo.toList();
var bdi = BufDataInput.create(buf2);
bdi.skipNBytes(Integer.BYTES);
Assertions.assertEquals("ciao", bdi.readMediumText(StandardCharsets.UTF_8));
}
}

View File

@ -0,0 +1,86 @@
package it.cavallium.buffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.nio.charset.StandardCharsets;
import static org.junit.jupiter.api.Assertions.*;
class ZeroAllocationEncoderTest {
private static final ZeroAllocationEncoder INSTANCE = new ZeroAllocationEncoder(16);
@Test
void encodeToEmpty() {
testEncodeString("");
}
@Test
void decodeEmpty() {
testDecodeString("");
}
@Test
void encodeTo1Underflow() {
testEncodeString("ciao");
}
@Test
void decode1Underflow() {
testDecodeString("ciao");
}
@Test
void encodeToExact1() {
testEncodeString("lorem ipsum dolo");
}
@Test
void decodeExact1() {
testDecodeString("lorem ipsum dolo");
}
@Test
void encodeToOverflow1() {
testEncodeString("lorem ipsum dolor sit amet");
}
@Test
void decodeOverflow1() {
testDecodeString("lorem ipsum dolor sit amet");
}
@Test
void encodeToExact2() {
testEncodeString("lorem ipsum dolor sit amet my na");
}
@Test
void decodeExact2() {
testDecodeString("lorem ipsum dolor sit amet my na");
}
@Test
void encodeToOverflow2() {
testEncodeString("lorem ipsum dolor sit amet my name is giovanni");
}
@Test
void decodeOverflow2() {
testDecodeString("lorem ipsum dolor sit amet my name is giovanni");
}
public void testEncodeString(String s) {
var bdo = BufDataOutput.create();
INSTANCE.encodeTo(s, bdo);
var out = bdo.toList().toString(StandardCharsets.UTF_8);
Assertions.assertEquals(s, out);
}
private void testDecodeString(String s) {
var in = BufDataInput.create(Buf.wrap(s.getBytes(StandardCharsets.UTF_8)));
var out = INSTANCE.decodeFrom(in, in.available());
Assertions.assertEquals(s, out);
}
}