Harden ref-counting concurrency semantics (#8583)

Motivation

#8563 highlighted race conditions introduced by the prior optimistic
update optimization in 83a19d5650. These
were known at the time but considered acceptable given the perf
benefit in high contention scenarios.

This PR proposes a modified approach which provides roughly half the
gains but stronger concurrency semantics. Race conditions still exist
but their scope is narrowed to much less likely cases (releases
coinciding with retain overflow), and even in those
cases certain guarantees are still assured. Once release() returns true,
all subsequent release/retains are guaranteed to throw, and in
particular deallocate will be called at most once.

Modifications

- Use even numbers internally (including -ve) for live refcounts
- "Final" release changes to odd number (equivalent to refcount 0)
- Retain still uses faster getAndAdd, release uses CAS loop
- First CAS attempt uses non-volatile read
- Thread.yield() after a failed CAS provides a net gain

Result

More (though not completely) robust concurrency semantics for ref
counting; increased latency under high contention, but still roughly
twice as fast as the original logic. Bench results to follow
This commit is contained in:
Nick Hill 2018-11-28 23:32:32 -08:00 committed by Norman Maurer
parent 057c19f92a
commit fedf3ccecb
3 changed files with 284 additions and 44 deletions

View File

@ -31,7 +31,9 @@ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf {
private static final AtomicIntegerFieldUpdater<AbstractReferenceCountedByteBuf> refCntUpdater =
AtomicIntegerFieldUpdater.newUpdater(AbstractReferenceCountedByteBuf.class, "refCnt");
private volatile int refCnt = 1;
// even => "real" refcount is (refCnt >>> 1); odd => "real" refcount is 0
@SuppressWarnings("unused")
private volatile int refCnt = 2;
static {
long refCntFieldOffset = -1;
@ -47,29 +49,37 @@ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf {
REFCNT_FIELD_OFFSET = refCntFieldOffset;
}
private static int realRefCnt(int rawCnt) {
return (rawCnt & 1) != 0 ? 0 : rawCnt >>> 1;
}
protected AbstractReferenceCountedByteBuf(int maxCapacity) {
super(maxCapacity);
}
private int nonVolatileRawCnt() {
// TODO: Once we compile against later versions of Java we can replace the Unsafe usage here by varhandles.
return REFCNT_FIELD_OFFSET != -1 ? PlatformDependent.getInt(this, REFCNT_FIELD_OFFSET)
: refCntUpdater.get(this);
}
@Override
int internalRefCnt() {
// Try to do non-volatile read for performance as the ensureAccessible() is racy anyway and only provide
// a best-effort guard.
//
// TODO: Once we compile against later versions of Java we can replace the Unsafe usage here by varhandles.
return REFCNT_FIELD_OFFSET != -1 ? PlatformDependent.getInt(this, REFCNT_FIELD_OFFSET) : refCnt();
return realRefCnt(nonVolatileRawCnt());
}
@Override
public int refCnt() {
return refCnt;
return realRefCnt(refCntUpdater.get(this));
}
/**
* An unsafe operation intended for use by a subclass that sets the reference count of the buffer directly
*/
protected final void setRefCnt(int refCnt) {
refCntUpdater.set(this, refCnt);
protected final void setRefCnt(int newRefCnt) {
refCntUpdater.set(this, newRefCnt << 1); // overflow OK here
}
@Override
@ -83,11 +93,18 @@ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf {
}
private ByteBuf retain0(final int increment) {
int oldRef = refCntUpdater.getAndAdd(this, increment);
if (oldRef <= 0 || oldRef + increment < oldRef) {
// Ensure we don't resurrect (which means the refCnt was 0) and also that we encountered an overflow.
refCntUpdater.getAndAdd(this, -increment);
throw new IllegalReferenceCountException(oldRef, increment);
// all changes to the raw count are 2x the "real" change
int adjustedIncrement = increment << 1; // overflow OK here
int oldRef = refCntUpdater.getAndAdd(this, adjustedIncrement);
if ((oldRef & 1) != 0) {
throw new IllegalReferenceCountException(0, increment);
}
// don't pass 0!
if ((oldRef <= 0 && oldRef + adjustedIncrement >= 0)
|| (oldRef >= 0 && oldRef + adjustedIncrement < oldRef)) {
// overflow case
refCntUpdater.getAndAdd(this, -adjustedIncrement);
throw new IllegalReferenceCountException(realRefCnt(oldRef), increment);
}
return this;
}
@ -113,18 +130,57 @@ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf {
}
private boolean release0(int decrement) {
int oldRef = refCntUpdater.getAndAdd(this, -decrement);
if (oldRef == decrement) {
deallocate();
return true;
int rawCnt = nonVolatileRawCnt(), realCnt = toLiveRealCnt(rawCnt, decrement);
if (decrement == realCnt) {
if (refCntUpdater.compareAndSet(this, rawCnt, 1)) {
deallocate();
return true;
}
return retryRelease0(decrement);
}
if (oldRef < decrement || oldRef - decrement > oldRef) {
// Ensure we don't over-release, and avoid underflow.
refCntUpdater.getAndAdd(this, decrement);
throw new IllegalReferenceCountException(oldRef, -decrement);
}
return false;
return releaseNonFinal0(decrement, rawCnt, realCnt);
}
private boolean releaseNonFinal0(int decrement, int rawCnt, int realCnt) {
if (decrement < realCnt
// all changes to the raw count are 2x the "real" change
&& refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) {
return false;
}
return retryRelease0(decrement);
}
private boolean retryRelease0(int decrement) {
for (;;) {
int rawCnt = refCntUpdater.get(this), realCnt = toLiveRealCnt(rawCnt, decrement);
if (decrement == realCnt) {
if (refCntUpdater.compareAndSet(this, rawCnt, 1)) {
deallocate();
return true;
}
} else if (decrement < realCnt) {
// all changes to the raw count are 2x the "real" change
if (refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) {
return false;
}
} else {
throw new IllegalReferenceCountException(realCnt, -decrement);
}
Thread.yield(); // this benefits throughput under high contention
}
}
/**
* Like {@link #realRefCnt(int)} but throws if refCnt == 0
*/
private static int toLiveRealCnt(int rawCnt, int decrement) {
if ((rawCnt & 1) == 0) {
return rawCnt >>> 1;
}
// odd rawCnt => already deallocated
throw new IllegalReferenceCountException(0, -decrement);
}
/**
* Called once {@link #refCnt()} is equals 0.
*/

View File

@ -15,30 +15,58 @@
*/
package io.netty.util;
import static io.netty.util.internal.ObjectUtil.checkPositive;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import static io.netty.util.internal.ObjectUtil.checkPositive;
import io.netty.util.internal.PlatformDependent;
/**
* Abstract base class for classes wants to implement {@link ReferenceCounted}.
*/
public abstract class AbstractReferenceCounted implements ReferenceCounted {
private static final long REFCNT_FIELD_OFFSET;
private static final AtomicIntegerFieldUpdater<AbstractReferenceCounted> refCntUpdater =
AtomicIntegerFieldUpdater.newUpdater(AbstractReferenceCounted.class, "refCnt");
private volatile int refCnt = 1;
// even => "real" refcount is (refCnt >>> 1); odd => "real" refcount is 0
@SuppressWarnings("unused")
private volatile int refCnt = 2;
static {
long refCntFieldOffset = -1;
try {
if (PlatformDependent.hasUnsafe()) {
refCntFieldOffset = PlatformDependent.objectFieldOffset(
AbstractReferenceCounted.class.getDeclaredField("refCnt"));
}
} catch (Throwable ignore) {
refCntFieldOffset = -1;
}
REFCNT_FIELD_OFFSET = refCntFieldOffset;
}
private static int realRefCnt(int rawCnt) {
return (rawCnt & 1) != 0 ? 0 : rawCnt >>> 1;
}
private int nonVolatileRawCnt() {
// TODO: Once we compile against later versions of Java we can replace the Unsafe usage here by varhandles.
return REFCNT_FIELD_OFFSET != -1 ? PlatformDependent.getInt(this, REFCNT_FIELD_OFFSET)
: refCntUpdater.get(this);
}
@Override
public final int refCnt() {
return refCnt;
public int refCnt() {
return realRefCnt(refCntUpdater.get(this));
}
/**
* An unsafe operation intended for use by a subclass that sets the reference count of the buffer directly
*/
protected final void setRefCnt(int refCnt) {
refCntUpdater.set(this, refCnt);
protected final void setRefCnt(int newRefCnt) {
refCntUpdater.set(this, newRefCnt << 1); // overflow OK here
}
@Override
@ -51,12 +79,19 @@ public abstract class AbstractReferenceCounted implements ReferenceCounted {
return retain0(checkPositive(increment, "increment"));
}
private ReferenceCounted retain0(int increment) {
int oldRef = refCntUpdater.getAndAdd(this, increment);
if (oldRef <= 0 || oldRef + increment < oldRef) {
// Ensure we don't resurrect (which means the refCnt was 0) and also that we encountered an overflow.
refCntUpdater.getAndAdd(this, -increment);
throw new IllegalReferenceCountException(oldRef, increment);
private ReferenceCounted retain0(final int increment) {
// all changes to the raw count are 2x the "real" change
int adjustedIncrement = increment << 1; // overflow OK here
int oldRef = refCntUpdater.getAndAdd(this, adjustedIncrement);
if ((oldRef & 1) != 0) {
throw new IllegalReferenceCountException(0, increment);
}
// don't pass 0!
if ((oldRef <= 0 && oldRef + adjustedIncrement >= 0)
|| (oldRef >= 0 && oldRef + adjustedIncrement < oldRef)) {
// overflow case
refCntUpdater.getAndAdd(this, -adjustedIncrement);
throw new IllegalReferenceCountException(realRefCnt(oldRef), increment);
}
return this;
}
@ -77,16 +112,55 @@ public abstract class AbstractReferenceCounted implements ReferenceCounted {
}
private boolean release0(int decrement) {
int oldRef = refCntUpdater.getAndAdd(this, -decrement);
if (oldRef == decrement) {
deallocate();
return true;
} else if (oldRef < decrement || oldRef - decrement > oldRef) {
// Ensure we don't over-release, and avoid underflow.
refCntUpdater.getAndAdd(this, decrement);
throw new IllegalReferenceCountException(oldRef, -decrement);
int rawCnt = nonVolatileRawCnt(), realCnt = toLiveRealCnt(rawCnt, decrement);
if (decrement == realCnt) {
if (refCntUpdater.compareAndSet(this, rawCnt, 1)) {
deallocate();
return true;
}
return retryRelease0(decrement);
}
return false;
return releaseNonFinal0(decrement, rawCnt, realCnt);
}
private boolean releaseNonFinal0(int decrement, int rawCnt, int realCnt) {
if (decrement < realCnt
// all changes to the raw count are 2x the "real" change
&& refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) {
return false;
}
return retryRelease0(decrement);
}
private boolean retryRelease0(int decrement) {
for (;;) {
int rawCnt = refCntUpdater.get(this), realCnt = toLiveRealCnt(rawCnt, decrement);
if (decrement == realCnt) {
if (refCntUpdater.compareAndSet(this, rawCnt, 1)) {
deallocate();
return true;
}
} else if (decrement < realCnt) {
// all changes to the raw count are 2x the "real" change
if (refCntUpdater.compareAndSet(this, rawCnt, rawCnt - (decrement << 1))) {
return false;
}
} else {
throw new IllegalReferenceCountException(realCnt, -decrement);
}
Thread.yield(); // this benefits throughput under high contention
}
}
/**
* Like {@link #realRefCnt(int)} but throws if refCnt == 0
*/
private static int toLiveRealCnt(int rawCnt, int decrement) {
if ((rawCnt & 1) == 0) {
return rawCnt >>> 1;
}
// odd rawCnt => already deallocated
throw new IllegalReferenceCountException(0, -decrement);
}
/**

View File

@ -15,8 +15,17 @@
*/
package io.netty.util;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test;
import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -74,6 +83,107 @@ public class AbstractReferenceCountedTest {
referenceCounted.retain(2);
}
@Test(timeout = 30000)
public void testRetainFromMultipleThreadsThrowsReferenceCountException() throws Exception {
int threads = 4;
Queue<Future<?>> futures = new ArrayDeque<Future<?>>(threads);
ExecutorService service = Executors.newFixedThreadPool(threads);
final AtomicInteger refCountExceptions = new AtomicInteger();
try {
for (int i = 0; i < 10000; i++) {
final AbstractReferenceCounted referenceCounted = newReferenceCounted();
final CountDownLatch retainLatch = new CountDownLatch(1);
assertTrue(referenceCounted.release());
for (int a = 0; a < threads; a++) {
final int retainCnt = ThreadLocalRandom.current().nextInt(1, Integer.MAX_VALUE);
futures.add(service.submit(new Runnable() {
@Override
public void run() {
try {
retainLatch.await();
try {
referenceCounted.retain(retainCnt);
} catch (IllegalReferenceCountException e) {
refCountExceptions.incrementAndGet();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}));
}
retainLatch.countDown();
for (;;) {
Future<?> f = futures.poll();
if (f == null) {
break;
}
f.get();
}
assertEquals(4, refCountExceptions.get());
refCountExceptions.set(0);
}
} finally {
service.shutdown();
}
}
@Test(timeout = 30000)
public void testReleaseFromMultipleThreadsThrowsReferenceCountException() throws Exception {
int threads = 4;
Queue<Future<?>> futures = new ArrayDeque<Future<?>>(threads);
ExecutorService service = Executors.newFixedThreadPool(threads);
final AtomicInteger refCountExceptions = new AtomicInteger();
try {
for (int i = 0; i < 10000; i++) {
final AbstractReferenceCounted referenceCounted = newReferenceCounted();
final CountDownLatch releaseLatch = new CountDownLatch(1);
final AtomicInteger releasedCount = new AtomicInteger();
for (int a = 0; a < threads; a++) {
final AtomicInteger releaseCnt = new AtomicInteger(0);
futures.add(service.submit(new Runnable() {
@Override
public void run() {
try {
releaseLatch.await();
try {
if (referenceCounted.release(releaseCnt.incrementAndGet())) {
releasedCount.incrementAndGet();
}
} catch (IllegalReferenceCountException e) {
refCountExceptions.incrementAndGet();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}));
}
releaseLatch.countDown();
for (;;) {
Future<?> f = futures.poll();
if (f == null) {
break;
}
f.get();
}
assertEquals(3, refCountExceptions.get());
assertEquals(1, releasedCount.get());
refCountExceptions.set(0);
}
} finally {
service.shutdown();
}
}
private static AbstractReferenceCounted newReferenceCounted() {
return new AbstractReferenceCounted() {
@Override