Make all compression codecs support buffers that don't have arrays (#11383)

Motivation:
Various compression codecs are currently hard-coded to only support buffers that are backed by byte-arrays that they are willing to expose.
This is efficient for most of the codecs, but compatibility suffers, as we are not able to freely choose our buffer implementations when compression codecs are involved.

Modification:
Add code to the compression codecs, that allow them to handle buffers that don't have arrays.
For many of the codecs, this unfortunately involves allocating temporary byte-arrays, and copying back-and-forth.
We have to do it that way since some codecs can _only_ work with byte-arrays.
Also add tests to verify that this works.

Result:
It is now possible to use all of our compression codecs with both on-heap and off-heap buffers.
The default buffer choice has not changed, however, so performance should be unaffected.
This commit is contained in:
Chris Vest 2021-06-14 10:55:35 +02:00 committed by GitHub
parent e0940fed7a
commit 98e3605d4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 149 deletions

View File

@ -16,6 +16,7 @@
package io.netty.handler.codec.compression; package io.netty.handler.codec.compression;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
@ -145,70 +146,50 @@ public class FastLzFrameDecoder extends ByteToMessageDecoder {
final int idx = in.readerIndex(); final int idx = in.readerIndex();
final int originalLength = this.originalLength; final int originalLength = this.originalLength;
final ByteBuf uncompressed; final byte[] output = originalLength == 0? EmptyArrays.EMPTY_BYTES : new byte[originalLength];
final byte[] output; final int outputPtr = 0;
final int outputPtr;
if (originalLength != 0) { if (isCompressed) {
uncompressed = ctx.alloc().heapBuffer(originalLength, originalLength); final byte[] input;
output = uncompressed.array(); final int inputPtr;
outputPtr = uncompressed.arrayOffset() + uncompressed.writerIndex(); if (in.hasArray()) {
} else { input = in.array();
uncompressed = null; inputPtr = in.arrayOffset() + idx;
output = EmptyArrays.EMPTY_BYTES;
outputPtr = 0;
}
boolean success = false;
try {
if (isCompressed) {
final byte[] input;
final int inputPtr;
if (in.hasArray()) {
input = in.array();
inputPtr = in.arrayOffset() + idx;
} else {
input = new byte[chunkLength];
in.getBytes(idx, input);
inputPtr = 0;
}
final int decompressedBytes = decompress(input, inputPtr, chunkLength,
output, outputPtr, originalLength);
if (originalLength != decompressedBytes) {
throw new DecompressionException(String.format(
"stream corrupted: originalLength(%d) and actual length(%d) mismatch",
originalLength, decompressedBytes));
}
} else { } else {
in.getBytes(idx, output, outputPtr, chunkLength); input = new byte[chunkLength];
in.getBytes(idx, input);
inputPtr = 0;
} }
final Checksum checksum = this.checksum; final int decompressedBytes = decompress(input, inputPtr, chunkLength,
if (hasChecksum && checksum != null) { output, outputPtr, originalLength);
checksum.reset(); if (originalLength != decompressedBytes) {
checksum.update(output, outputPtr, originalLength); throw new DecompressionException(String.format(
final int checksumResult = (int) checksum.getValue(); "stream corrupted: originalLength(%d) and actual length(%d) mismatch",
if (checksumResult != currentChecksum) { originalLength, decompressedBytes));
throw new DecompressionException(String.format(
"stream corrupted: mismatching checksum: %d (expected: %d)",
checksumResult, currentChecksum));
}
} }
} else {
in.getBytes(idx, output, outputPtr, chunkLength);
}
if (uncompressed != null) { final Checksum checksum = this.checksum;
uncompressed.writerIndex(uncompressed.writerIndex() + originalLength); if (hasChecksum && checksum != null) {
ctx.fireChannelRead(uncompressed); checksum.reset();
} checksum.update(output, outputPtr, originalLength);
in.skipBytes(chunkLength); final int checksumResult = (int) checksum.getValue();
if (checksumResult != currentChecksum) {
currentState = State.INIT_BLOCK; throw new DecompressionException(String.format(
success = true; "stream corrupted: mismatching checksum: %d (expected: %d)",
} finally { checksumResult, currentChecksum));
if (!success && uncompressed != null) {
uncompressed.release();
} }
} }
if (output.length > 0) {
ctx.fireChannelRead(Unpooled.wrappedBuffer(output).writerIndex(originalLength));
}
in.skipBytes(chunkLength);
currentState = State.INIT_BLOCK;
break; break;
case CORRUPTED: case CORRUPTED:
in.skipBytes(in.readableBytes()); in.skipBytes(in.readableBytes());

View File

@ -115,8 +115,15 @@ public class FastLzFrameEncoder extends MessageToByteEncoder<ByteBuf> {
blockType = BLOCK_TYPE_NON_COMPRESSED; blockType = BLOCK_TYPE_NON_COMPRESSED;
out.ensureWritable(outputOffset + 2 + length); out.ensureWritable(outputOffset + 2 + length);
final byte[] output = out.array(); final byte[] output;
final int outputPtr = out.arrayOffset() + outputOffset + 2; final int outputPtr;
if (out.hasArray()) {
output = out.array();
outputPtr = out.arrayOffset() + outputOffset + 2;
} else {
output = new byte[length];
outputPtr = 0;
}
if (checksum != null) { if (checksum != null) {
final byte[] input; final byte[] input;
@ -138,6 +145,9 @@ public class FastLzFrameEncoder extends MessageToByteEncoder<ByteBuf> {
} else { } else {
in.getBytes(idx, output, outputPtr, length); in.getBytes(idx, output, outputPtr, length);
} }
if (!out.hasArray()) {
out.setBytes(outputOffset + 2, output);
}
chunkLength = length; chunkLength = length;
} else { } else {
// try to compress // try to compress
@ -160,9 +170,19 @@ public class FastLzFrameEncoder extends MessageToByteEncoder<ByteBuf> {
final int maxOutputLength = calculateOutputBufferLength(length); final int maxOutputLength = calculateOutputBufferLength(length);
out.ensureWritable(outputOffset + 4 + maxOutputLength); out.ensureWritable(outputOffset + 4 + maxOutputLength);
final byte[] output = out.array(); final byte[] output;
final int outputPtr = out.arrayOffset() + outputOffset + 4; final int outputPtr;
if (out.hasArray()) {
output = out.array();
outputPtr = out.arrayOffset() + outputOffset + 4;
} else {
output = new byte[maxOutputLength];
outputPtr = 0;
}
final int compressedLength = compress(input, inputPtr, length, output, outputPtr, level); final int compressedLength = compress(input, inputPtr, length, output, outputPtr, level);
if (!out.hasArray()) {
out.setBytes(outputOffset + 4, output, 0, compressedLength);
}
if (compressedLength < length) { if (compressedLength < length) {
blockType = BLOCK_TYPE_COMPRESSED; blockType = BLOCK_TYPE_COMPRESSED;
chunkLength = compressedLength; chunkLength = compressedLength;
@ -171,7 +191,13 @@ public class FastLzFrameEncoder extends MessageToByteEncoder<ByteBuf> {
outputOffset += 2; outputOffset += 2;
} else { } else {
blockType = BLOCK_TYPE_NON_COMPRESSED; blockType = BLOCK_TYPE_NON_COMPRESSED;
System.arraycopy(input, inputPtr, output, outputPtr - 2, length); if (out.hasArray()) {
System.arraycopy(input, inputPtr, output, outputPtr - 2, length);
} else {
for (int i = 0; i < length; i++) {
out.setByte(outputOffset + 2 + i, input[inputPtr + i]);
}
}
chunkLength = length; chunkLength = length;
} }
} }

View File

@ -169,13 +169,24 @@ public class LzfDecoder extends ByteToMessageDecoder {
} }
ByteBuf uncompressed = ctx.alloc().heapBuffer(originalLength, originalLength); ByteBuf uncompressed = ctx.alloc().heapBuffer(originalLength, originalLength);
final byte[] outputArray = uncompressed.array(); final byte[] outputArray;
final int outPos = uncompressed.arrayOffset() + uncompressed.writerIndex(); final int outPos;
if (uncompressed.hasArray()) {
outputArray = uncompressed.array();
outPos = uncompressed.arrayOffset() + uncompressed.writerIndex();
} else {
outputArray = new byte[originalLength];
outPos = 0;
}
boolean success = false; boolean success = false;
try { try {
decoder.decodeChunk(inputArray, inPos, outputArray, outPos, outPos + originalLength); decoder.decodeChunk(inputArray, inPos, outputArray, outPos, outPos + originalLength);
uncompressed.writerIndex(uncompressed.writerIndex() + originalLength); if (uncompressed.hasArray()) {
uncompressed.writerIndex(uncompressed.writerIndex() + originalLength);
} else {
uncompressed.writeBytes(outputArray);
}
ctx.fireChannelRead(uncompressed); ctx.fireChannelRead(uncompressed);
in.skipBytes(chunkLength); in.skipBytes(chunkLength);
success = true; success = true;

View File

@ -159,10 +159,18 @@ public class LzfEncoder extends MessageToByteEncoder<ByteBuf> {
inputPtr = 0; inputPtr = 0;
} }
final int maxOutputLength = LZFEncoder.estimateMaxWorkspaceSize(length); // Estimate may apparently under-count by one in some cases.
final int maxOutputLength = LZFEncoder.estimateMaxWorkspaceSize(length) + 1;
out.ensureWritable(maxOutputLength); out.ensureWritable(maxOutputLength);
final byte[] output = out.array(); final byte[] output;
final int outputPtr = out.arrayOffset() + out.writerIndex(); final int outputPtr;
if (out.hasArray()) {
output = out.array();
outputPtr = out.arrayOffset() + out.writerIndex();
} else {
output = new byte[maxOutputLength];
outputPtr = 0;
}
final int outputLength; final int outputLength;
if (length >= compressThreshold) { if (length >= compressThreshold) {
@ -173,7 +181,12 @@ public class LzfEncoder extends MessageToByteEncoder<ByteBuf> {
outputLength = encodeNonCompress(input, inputPtr, length, output, outputPtr); outputLength = encodeNonCompress(input, inputPtr, length, output, outputPtr);
} }
out.writerIndex(out.writerIndex() + outputLength); if (out.hasArray()) {
out.writerIndex(out.writerIndex() + outputLength);
} else {
out.writeBytes(output, 0, outputLength);
}
in.skipBytes(length); in.skipBytes(length);
if (!in.hasArray()) { if (!in.hasArray()) {

View File

@ -45,14 +45,12 @@ public abstract class AbstractIntegrationTest {
protected abstract EmbeddedChannel createEncoder(); protected abstract EmbeddedChannel createEncoder();
protected abstract EmbeddedChannel createDecoder(); protected abstract EmbeddedChannel createDecoder();
@BeforeEach public void initChannels() {
public void initChannels() throws Exception {
encoder = createEncoder(); encoder = createEncoder();
decoder = createDecoder(); decoder = createDecoder();
} }
@AfterEach public void closeChannels() {
public void closeChannels() throws Exception {
encoder.close(); encoder.close();
for (;;) { for (;;) {
Object msg = encoder.readOutbound(); Object msg = encoder.readOutbound();
@ -74,19 +72,22 @@ public abstract class AbstractIntegrationTest {
@Test @Test
public void testEmpty() throws Exception { public void testEmpty() throws Exception {
testIdentity(EmptyArrays.EMPTY_BYTES); testIdentity(EmptyArrays.EMPTY_BYTES, true);
testIdentity(EmptyArrays.EMPTY_BYTES, false);
} }
@Test @Test
public void testOneByte() throws Exception { public void testOneByte() throws Exception {
final byte[] data = { 'A' }; final byte[] data = { 'A' };
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
public void testTwoBytes() throws Exception { public void testTwoBytes() throws Exception {
final byte[] data = { 'B', 'A' }; final byte[] data = { 'B', 'A' };
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
@ -94,14 +95,16 @@ public abstract class AbstractIntegrationTest {
final byte[] data = ("Netty is a NIO client server framework which enables " + final byte[] data = ("Netty is a NIO client server framework which enables " +
"quick and easy development of network applications such as protocol " + "quick and easy development of network applications such as protocol " +
"servers and clients.").getBytes(CharsetUtil.UTF_8); "servers and clients.").getBytes(CharsetUtil.UTF_8);
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
public void testLargeRandom() throws Exception { public void testLargeRandom() throws Exception {
final byte[] data = new byte[1024 * 1024]; final byte[] data = new byte[1024 * 1024];
rand.nextBytes(data); rand.nextBytes(data);
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
@ -111,7 +114,8 @@ public abstract class AbstractIntegrationTest {
for (int i = 0; i < 1024; i++) { for (int i = 0; i < 1024; i++) {
data[i] = 2; data[i] = 2;
} }
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
@ -120,20 +124,23 @@ public abstract class AbstractIntegrationTest {
for (int i = 0; i < data.length; i++) { for (int i = 0; i < data.length; i++) {
data[i] = i % 4 != 0 ? 0 : (byte) rand.nextInt(); data[i] = i % 4 != 0 ? 0 : (byte) rand.nextInt();
} }
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
public void testLongBlank() throws Exception { public void testLongBlank() throws Exception {
final byte[] data = new byte[102400]; final byte[] data = new byte[102400];
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
public void testLongSame() throws Exception { public void testLongSame() throws Exception {
final byte[] data = new byte[102400]; final byte[] data = new byte[102400];
Arrays.fill(data, (byte) 123); Arrays.fill(data, (byte) 123);
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
@ -142,32 +149,39 @@ public abstract class AbstractIntegrationTest {
for (int i = 0; i < data.length; i++) { for (int i = 0; i < data.length; i++) {
data[i] = (byte) i; data[i] = (byte) i;
} }
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
protected void testIdentity(final byte[] data) { protected void testIdentity(final byte[] data, boolean heapBuffer) {
final ByteBuf in = Unpooled.wrappedBuffer(data); initChannels();
assertTrue(encoder.writeOutbound(in.retain())); final ByteBuf in = heapBuffer? Unpooled.wrappedBuffer(data) :
assertTrue(encoder.finish()); Unpooled.directBuffer(data.length).setBytes(0, data);
final CompositeByteBuf compressed = Unpooled.compositeBuffer(); final CompositeByteBuf compressed = Unpooled.compositeBuffer();
ByteBuf msg;
while ((msg = encoder.readOutbound()) != null) {
compressed.addComponent(true, msg);
}
assertThat(compressed, is(notNullValue()));
decoder.writeInbound(compressed.retain());
assertFalse(compressed.isReadable());
final CompositeByteBuf decompressed = Unpooled.compositeBuffer(); final CompositeByteBuf decompressed = Unpooled.compositeBuffer();
while ((msg = decoder.readInbound()) != null) {
decompressed.addComponent(true, msg);
}
in.readerIndex(0);
assertEquals(in, decompressed);
compressed.release(); try {
decompressed.release(); assertTrue(encoder.writeOutbound(in.retain()));
in.release(); assertTrue(encoder.finish());
ByteBuf msg;
while ((msg = encoder.readOutbound()) != null) {
compressed.addComponent(true, msg);
}
assertThat(compressed, is(notNullValue()));
decoder.writeInbound(compressed.retain());
assertFalse(compressed.isReadable());
while ((msg = decoder.readInbound()) != null) {
decompressed.addComponent(true, msg);
}
in.readerIndex(0);
assertEquals(in, decompressed);
} finally {
compressed.release();
decompressed.release();
in.release();
closeChannels();
}
} }
} }

View File

@ -34,20 +34,23 @@ public class Bzip2IntegrationTest extends AbstractIntegrationTest {
public void test3Tables() throws Exception { public void test3Tables() throws Exception {
byte[] data = new byte[500]; byte[] data = new byte[500];
rand.nextBytes(data); rand.nextBytes(data);
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
public void test4Tables() throws Exception { public void test4Tables() throws Exception {
byte[] data = new byte[1100]; byte[] data = new byte[1100];
rand.nextBytes(data); rand.nextBytes(data);
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
@Test @Test
public void test5Tables() throws Exception { public void test5Tables() throws Exception {
byte[] data = new byte[2300]; byte[] data = new byte[2300];
rand.nextBytes(data); rand.nextBytes(data);
testIdentity(data); testIdentity(data, true);
testIdentity(data, false);
} }
} }

View File

@ -64,49 +64,54 @@ public class FastLzIntegrationTest extends AbstractIntegrationTest {
} }
@Override // test batched flow of data @Override // test batched flow of data
protected void testIdentity(final byte[] data) { protected void testIdentity(final byte[] data, boolean heapBuffer) {
final ByteBuf original = Unpooled.wrappedBuffer(data); initChannels();
final ByteBuf original = heapBuffer? Unpooled.wrappedBuffer(data) :
int written = 0, length = rand.nextInt(100); Unpooled.directBuffer(data.length).writeBytes(data);
while (written + length < data.length) {
ByteBuf in = Unpooled.wrappedBuffer(data, written, length);
encoder.writeOutbound(in);
written += length;
length = rand.nextInt(100);
}
ByteBuf in = Unpooled.wrappedBuffer(data, written, data.length - written);
encoder.writeOutbound(in);
encoder.finish();
ByteBuf msg;
final CompositeByteBuf compressed = Unpooled.compositeBuffer(); final CompositeByteBuf compressed = Unpooled.compositeBuffer();
while ((msg = encoder.readOutbound()) != null) {
compressed.addComponent(true, msg);
}
assertThat(compressed, is(notNullValue()));
final byte[] compressedArray = new byte[compressed.readableBytes()];
compressed.readBytes(compressedArray);
written = 0;
length = rand.nextInt(100);
while (written + length < compressedArray.length) {
in = Unpooled.wrappedBuffer(compressedArray, written, length);
decoder.writeInbound(in);
written += length;
length = rand.nextInt(100);
}
in = Unpooled.wrappedBuffer(compressedArray, written, compressedArray.length - written);
decoder.writeInbound(in);
assertFalse(compressed.isReadable());
final CompositeByteBuf decompressed = Unpooled.compositeBuffer(); final CompositeByteBuf decompressed = Unpooled.compositeBuffer();
while ((msg = decoder.readInbound()) != null) {
decompressed.addComponent(true, msg);
}
assertEquals(original, decompressed);
compressed.release(); try {
decompressed.release(); int written = 0, length = rand.nextInt(100);
original.release(); while (written + length < data.length) {
ByteBuf in = Unpooled.wrappedBuffer(data, written, length);
encoder.writeOutbound(in);
written += length;
length = rand.nextInt(100);
}
ByteBuf in = Unpooled.wrappedBuffer(data, written, data.length - written);
encoder.writeOutbound(in);
encoder.finish();
ByteBuf msg;
while ((msg = encoder.readOutbound()) != null) {
compressed.addComponent(true, msg);
}
assertThat(compressed, is(notNullValue()));
final byte[] compressedArray = new byte[compressed.readableBytes()];
compressed.readBytes(compressedArray);
written = 0;
length = rand.nextInt(100);
while (written + length < compressedArray.length) {
in = Unpooled.wrappedBuffer(compressedArray, written, length);
decoder.writeInbound(in);
written += length;
length = rand.nextInt(100);
}
in = Unpooled.wrappedBuffer(compressedArray, written, compressedArray.length - written);
decoder.writeInbound(in);
assertFalse(compressed.isReadable());
while ((msg = decoder.readInbound()) != null) {
decompressed.addComponent(true, msg);
}
assertEquals(original, decompressed);
} finally {
compressed.release();
decompressed.release();
original.release();
closeChannels();
}
} }
} }

View File

@ -64,7 +64,7 @@ public class SnappyIntegrationTest extends AbstractIntegrationTest {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1 -1, -1
}; };
testIdentity(data); testIdentity(data, true);
} }
// These tests were found using testRandom() with large RANDOM_RUNS. // These tests were found using testRandom() with large RANDOM_RUNS.
@ -104,7 +104,7 @@ public class SnappyIntegrationTest extends AbstractIntegrationTest {
private void testWithSeed(long seed) { private void testWithSeed(long seed) {
byte[] data = new byte[16 * 1048576]; byte[] data = new byte[16 * 1048576];
new Random(seed).nextBytes(data); new Random(seed).nextBytes(data);
testIdentity(data); testIdentity(data, true);
} }
private static void printSeedAsTest(long l) { private static void printSeedAsTest(long l) {