Optimize Hpack and AsciiString hashcode and equals (#8902)

Motivation:

While looking at hpack header-processing hotspots I noticed some low
level too-big-to-inline methods which can be shrunk.

Modifications:

Reduce bytecode size and/or runtime operations used for the following
methods:

PlatformDependent0.equals(byte[], ...)
PlatformDependent0.equalsConstantTime(byte[], ...)
PlatformDependent0.hashCodeAscii(byte[],int,int)
PlatformDependent.hashCodeAscii(CharSequence)

Result:

Existing benchmarks show decent improvement

Before

Benchmark                     (size)   Mode  Cnt         Score         Error  Units
HpackUtilBenchmark.newEquals   SMALL  thrpt    5  17200229.374 ± 1701239.198  ops/s
HpackUtilBenchmark.newEquals  MEDIUM  thrpt    5   3386061.629 ±   72264.685  ops/s
HpackUtilBenchmark.newEquals   LARGE  thrpt    5    507579.209 ±   65883.951  ops/s

After

Benchmark                     (size)   Mode  Cnt         Score         Error  Units
HpackUtilBenchmark.newEquals   SMALL  thrpt    5  29221527.058 ± 4805825.836  ops/s
HpackUtilBenchmark.newEquals  MEDIUM  thrpt    5   6556251.645 ±  466115.199  ops/s
HpackUtilBenchmark.newEquals   LARGE  thrpt    5    879828.889 ±  148136.641  ops/s

Before

Benchmark                          (size)  Mode  Cnt     Score     Error  Units
PlatformDepBench.unsafeBytesEqual       4  avgt   10     4.263 ±   0.110  ns/op
PlatformDepBench.unsafeBytesEqual      10  avgt   10     5.206 ±   0.133  ns/op
PlatformDepBench.unsafeBytesEqual      50  avgt   10     8.160 ±   0.320  ns/op
PlatformDepBench.unsafeBytesEqual     100  avgt   10    13.810 ±   0.751  ns/op
PlatformDepBench.unsafeBytesEqual    1000  avgt   10    89.077 ±   7.275  ns/op
PlatformDepBench.unsafeBytesEqual   10000  avgt   10   773.940 ±  24.579  ns/op
PlatformDepBench.unsafeBytesEqual  100000  avgt   10  7546.807 ± 110.395  ns/op

After

Benchmark                          (size)  Mode  Cnt     Score     Error  Units
PlatformDepBench.unsafeBytesEqual       4  avgt   10     3.337 ±   0.087  ns/op
PlatformDepBench.unsafeBytesEqual      10  avgt   10     4.286 ±   0.194  ns/op
PlatformDepBench.unsafeBytesEqual      50  avgt   10     7.817 ±   0.123  ns/op
PlatformDepBench.unsafeBytesEqual     100  avgt   10    11.260 ±   0.412  ns/op
PlatformDepBench.unsafeBytesEqual    1000  avgt   10    84.255 ±   2.596  ns/op
PlatformDepBench.unsafeBytesEqual   10000  avgt   10   591.892 ±   5.136  ns/op
PlatformDepBench.unsafeBytesEqual  100000  avgt   10  6978.859 ± 285.043  ns/op
This commit is contained in:
Nick Hill 2019-03-07 21:55:11 -08:00 committed by Norman Maurer
parent 7e76d02fe7
commit a258cd6f42
2 changed files with 79 additions and 139 deletions

View File

@ -682,83 +682,43 @@ public final class PlatformDependent {
* The resulting hash code will be case insensitive.
*/
public static int hashCodeAscii(CharSequence bytes) {
final int length = bytes.length();
final int remainingBytes = length & 7;
int hash = HASH_CODE_ASCII_SEED;
final int remainingBytes = bytes.length() & 7;
// Benchmarking shows that by just naively looping for inputs 8~31 bytes long we incur a relatively large
// performance penalty (only achieve about 60% performance of loop which iterates over each char). So because
// of this we take special provisions to unroll the looping for these conditions.
switch (bytes.length()) {
case 31:
case 30:
case 29:
case 28:
case 27:
case 26:
case 25:
case 24:
hash = hashCodeAsciiCompute(bytes, bytes.length() - 24,
hashCodeAsciiCompute(bytes, bytes.length() - 16,
hashCodeAsciiCompute(bytes, bytes.length() - 8, hash)));
break;
case 23:
case 22:
case 21:
case 20:
case 19:
case 18:
case 17:
case 16:
hash = hashCodeAsciiCompute(bytes, bytes.length() - 16,
hashCodeAsciiCompute(bytes, bytes.length() - 8, hash));
break;
case 15:
case 14:
case 13:
case 12:
case 11:
case 10:
case 9:
case 8:
hash = hashCodeAsciiCompute(bytes, bytes.length() - 8, hash);
break;
case 7:
case 6:
case 5:
case 4:
case 3:
case 2:
case 1:
case 0:
break;
default:
for (int i = bytes.length() - 8; i >= remainingBytes; i -= 8) {
hash = hashCodeAsciiCompute(bytes, i, hash);
if (length >= 32) {
for (int i = length - 8; i >= remainingBytes; i -= 8) {
hash = hashCodeAsciiCompute(bytes, i, hash);
}
} else if (length >= 8) {
hash = hashCodeAsciiCompute(bytes, length - 8, hash);
if (length >= 16) {
hash = hashCodeAsciiCompute(bytes, length - 16, hash);
if (length >= 24) {
hash = hashCodeAsciiCompute(bytes, length - 24, hash);
}
break;
}
}
switch(remainingBytes) {
case 7:
return ((hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0)))
* HASH_CODE_C2 + hashCodeAsciiSanitizeShort(bytes, 1))
* HASH_CODE_C1 + hashCodeAsciiSanitizeInt(bytes, 3);
case 6:
return (hash * HASH_CODE_C1 + hashCodeAsciiSanitizeShort(bytes, 0))
* HASH_CODE_C2 + hashCodeAsciiSanitizeInt(bytes, 2);
case 5:
return (hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0)))
* HASH_CODE_C2 + hashCodeAsciiSanitizeInt(bytes, 1);
case 4:
return hash * HASH_CODE_C1 + hashCodeAsciiSanitizeInt(bytes, 0);
case 3:
return (hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0)))
* HASH_CODE_C2 + hashCodeAsciiSanitizeShort(bytes, 1);
case 2:
return hash * HASH_CODE_C1 + hashCodeAsciiSanitizeShort(bytes, 0);
case 1:
return hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0));
default:
return hash;
if (remainingBytes == 0) {
return hash;
}
int offset = 0;
if (remainingBytes != 2 & remainingBytes != 4 & remainingBytes != 6) { // 1, 3, 5, 7
hash = hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0));
offset = 1;
}
if (remainingBytes != 1 & remainingBytes != 4 & remainingBytes != 5) { // 2, 3, 6, 7
hash = hash * (offset == 0 ? HASH_CODE_C1 : HASH_CODE_C2)
+ hashCodeAsciiSanitize(hashCodeAsciiSanitizeShort(bytes, offset));
offset += 2;
}
if (remainingBytes >= 4) { // 4, 5, 6, 7
return hash * ((offset == 0 | offset == 3) ? HASH_CODE_C1 : HASH_CODE_C2)
+ hashCodeAsciiSanitizeInt(bytes, offset);
}
return hash;
}
private static final class Mpsc {

View File

@ -569,72 +569,57 @@ final class PlatformDependent0 {
}
static boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) {
if (length <= 0) {
return true;
}
final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1;
final long baseOffset2 = BYTE_ARRAY_BASE_OFFSET + startPos2;
int remainingBytes = length & 7;
final long end = baseOffset1 + remainingBytes;
for (long i = baseOffset1 - 8 + length, j = baseOffset2 - 8 + length; i >= end; i -= 8, j -= 8) {
if (UNSAFE.getLong(bytes1, i) != UNSAFE.getLong(bytes2, j)) {
return false;
final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1;
final long diff = startPos2 - startPos1;
if (length >= 8) {
final long end = baseOffset1 + remainingBytes;
for (long i = baseOffset1 - 8 + length; i >= end; i -= 8) {
if (UNSAFE.getLong(bytes1, i) != UNSAFE.getLong(bytes2, i + diff)) {
return false;
}
}
}
if (remainingBytes >= 4) {
remainingBytes -= 4;
if (UNSAFE.getInt(bytes1, baseOffset1 + remainingBytes) !=
UNSAFE.getInt(bytes2, baseOffset2 + remainingBytes)) {
long pos = baseOffset1 + remainingBytes;
if (UNSAFE.getInt(bytes1, pos) != UNSAFE.getInt(bytes2, pos + diff)) {
return false;
}
}
final long baseOffset2 = baseOffset1 + diff;
if (remainingBytes >= 2) {
return UNSAFE.getChar(bytes1, baseOffset1) == UNSAFE.getChar(bytes2, baseOffset2) &&
(remainingBytes == 2 || bytes1[startPos1 + 2] == bytes2[startPos2 + 2]);
(remainingBytes == 2 ||
UNSAFE.getByte(bytes1, baseOffset1 + 2) == UNSAFE.getByte(bytes2, baseOffset2 + 2));
}
return bytes1[startPos1] == bytes2[startPos2];
return remainingBytes == 0 ||
UNSAFE.getByte(bytes1, baseOffset1) == UNSAFE.getByte(bytes2, baseOffset2);
}
static int equalsConstantTime(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) {
long result = 0;
long remainingBytes = length & 7;
final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1;
final long baseOffset2 = BYTE_ARRAY_BASE_OFFSET + startPos2;
final int remainingBytes = length & 7;
final long end = baseOffset1 + remainingBytes;
for (long i = baseOffset1 - 8 + length, j = baseOffset2 - 8 + length; i >= end; i -= 8, j -= 8) {
result |= UNSAFE.getLong(bytes1, i) ^ UNSAFE.getLong(bytes2, j);
final long diff = startPos2 - startPos1;
for (long i = baseOffset1 - 8 + length; i >= end; i -= 8) {
result |= UNSAFE.getLong(bytes1, i) ^ UNSAFE.getLong(bytes2, i + diff);
}
switch (remainingBytes) {
case 7:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getInt(bytes1, baseOffset1 + 3) ^ UNSAFE.getInt(bytes2, baseOffset2 + 3)) |
(UNSAFE.getChar(bytes1, baseOffset1 + 1) ^ UNSAFE.getChar(bytes2, baseOffset2 + 1)) |
(UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0);
case 6:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getInt(bytes1, baseOffset1 + 2) ^ UNSAFE.getInt(bytes2, baseOffset2 + 2)) |
(UNSAFE.getChar(bytes1, baseOffset1) ^ UNSAFE.getChar(bytes2, baseOffset2)), 0);
case 5:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getInt(bytes1, baseOffset1 + 1) ^ UNSAFE.getInt(bytes2, baseOffset2 + 1)) |
(UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0);
case 4:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getInt(bytes1, baseOffset1) ^ UNSAFE.getInt(bytes2, baseOffset2)), 0);
case 3:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getChar(bytes1, baseOffset1 + 1) ^ UNSAFE.getChar(bytes2, baseOffset2 + 1)) |
(UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0);
case 2:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getChar(bytes1, baseOffset1) ^ UNSAFE.getChar(bytes2, baseOffset2)), 0);
case 1:
return ConstantTimeUtils.equalsConstantTime(result |
(UNSAFE.getByte(bytes1, baseOffset1) ^ UNSAFE.getByte(bytes2, baseOffset2)), 0);
default:
return ConstantTimeUtils.equalsConstantTime(result, 0);
if (remainingBytes >= 4) {
result |= UNSAFE.getInt(bytes1, baseOffset1) ^ UNSAFE.getInt(bytes2, baseOffset1 + diff);
remainingBytes -= 4;
}
if (remainingBytes >= 2) {
long pos = end - remainingBytes;
result |= UNSAFE.getChar(bytes1, pos) ^ UNSAFE.getChar(bytes2, pos + diff);
remainingBytes -= 2;
}
if (remainingBytes == 1) {
long pos = end - 1;
result |= UNSAFE.getByte(bytes1, pos) ^ UNSAFE.getByte(bytes2, pos + diff);
}
return ConstantTimeUtils.equalsConstantTime(result, 0);
}
static boolean isZero(byte[] bytes, int startPos, int length) {
@ -665,35 +650,30 @@ final class PlatformDependent0 {
static int hashCodeAscii(byte[] bytes, int startPos, int length) {
int hash = HASH_CODE_ASCII_SEED;
final long baseOffset = BYTE_ARRAY_BASE_OFFSET + startPos;
long baseOffset = BYTE_ARRAY_BASE_OFFSET + startPos;
final int remainingBytes = length & 7;
final long end = baseOffset + remainingBytes;
for (long i = baseOffset - 8 + length; i >= end; i -= 8) {
hash = hashCodeAsciiCompute(UNSAFE.getLong(bytes, i), hash);
}
switch(remainingBytes) {
case 7:
return ((hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset)))
* HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset + 1)))
* HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset + 3));
case 6:
return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset)))
* HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset + 2));
case 5:
return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset)))
* HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset + 1));
case 4:
return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset));
case 3:
return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset)))
* HASH_CODE_C2 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset + 1));
case 2:
return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset));
case 1:
return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset));
default:
if (remainingBytes == 0) {
return hash;
}
int hcConst = HASH_CODE_C1;
if (remainingBytes != 2 & remainingBytes != 4 & remainingBytes != 6) { // 1, 3, 5, 7
hash = hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset));
hcConst = HASH_CODE_C2;
baseOffset++;
}
if (remainingBytes != 1 & remainingBytes != 4 & remainingBytes != 5) { // 2, 3, 6, 7
hash = hash * hcConst + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset));
hcConst = hcConst == HASH_CODE_C1 ? HASH_CODE_C2 : HASH_CODE_C1;
baseOffset += 2;
}
if (remainingBytes >= 4) { // 4, 5, 6, 7
return hash * hcConst + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset));
}
return hash;
}
static int hashCodeAsciiCompute(long value, int hash) {