From d44017189eb73d8b1bdc45da1fca84cd05cc228b Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Tue, 2 Aug 2016 17:55:55 +0200 Subject: [PATCH] Remove extra conditional check in retain Motivation: We not need to do an extra conditional check in retain(...) as we can just check for overflow after we did the increment. Modifications: - Remove extra conditional check - Add test code. Result: One conditional check less. --- .../AbstractReferenceCountedByteBuf.java | 13 +- .../AbstractReferenceCountedByteBufTest.java | 318 ++++++++++++++++++ .../netty/util/AbstractReferenceCounted.java | 13 +- .../util/AbstractReferenceCountedTest.java | 77 +++++ 4 files changed, 413 insertions(+), 8 deletions(-) create mode 100644 buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java create mode 100644 common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java diff --git a/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java b/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java index 257604d774..621af79568 100644 --- a/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java @@ -59,10 +59,13 @@ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf { public ByteBuf retain() { for (;;) { int refCnt = this.refCnt; - if (refCnt == 0 || refCnt == Integer.MAX_VALUE) { + final int nextCnt = refCnt + 1; + + // Ensure we not resurrect (which means the refCnt was 0) and also that we encountered an overflow. + if (nextCnt <= 1) { throw new IllegalReferenceCountException(refCnt, 1); } - if (refCntUpdater.compareAndSet(this, refCnt, refCnt + 1)) { + if (refCntUpdater.compareAndSet(this, refCnt, nextCnt)) { break; } } @@ -76,9 +79,11 @@ public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf { } for (;;) { - final int nextCnt; int refCnt = this.refCnt; - if (refCnt == 0 || (nextCnt = refCnt + increment) < 0) { + final int nextCnt = refCnt + increment; + + // Ensure we not resurrect (which means the refCnt was 0) and also that we encountered an overflow. + if (nextCnt <= increment) { throw new IllegalReferenceCountException(refCnt, increment); } if (refCntUpdater.compareAndSet(this, refCnt, nextCnt)) { diff --git a/buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java new file mode 100644 index 0000000000..48fe415332 --- /dev/null +++ b/buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java @@ -0,0 +1,318 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.IllegalReferenceCountException; +import org.junit.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class AbstractReferenceCountedByteBufTest { + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainOverflow() { + AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(Integer.MAX_VALUE); + assertEquals(Integer.MAX_VALUE, referenceCounted.refCnt()); + referenceCounted.retain(); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainOverflow2() { + AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertEquals(1, referenceCounted.refCnt()); + referenceCounted.retain(Integer.MAX_VALUE); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testReleaseOverflow() { + AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(0); + assertEquals(0, referenceCounted.refCnt()); + referenceCounted.release(Integer.MAX_VALUE); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainResurrect() { + AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + referenceCounted.retain(); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainResurrect2() { + AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + referenceCounted.retain(2); + } + + private static AbstractReferenceCountedByteBuf newReferenceCounted() { + return new AbstractReferenceCountedByteBuf(Integer.MAX_VALUE) { + + @Override + protected byte _getByte(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected short _getShort(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected short _getShortLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getUnsignedMedium(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getInt(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getIntLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected long _getLong(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected long _getLongLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setByte(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setShort(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setShortLE(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setMedium(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setMediumLE(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setInt(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setIntLE(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setLong(int index, long value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setLongLE(int index, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public int capacity() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBufAllocator alloc() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteOrder order() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf unwrap() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isDirect() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + throw new UnsupportedOperationException(); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf copy(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public int nioBufferCount() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasArray() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] array() { + throw new UnsupportedOperationException(); + } + + @Override + public int arrayOffset() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasMemoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + protected void deallocate() { + // NOOP + } + + @Override + public AbstractReferenceCountedByteBuf touch(Object hint) { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/common/src/main/java/io/netty/util/AbstractReferenceCounted.java b/common/src/main/java/io/netty/util/AbstractReferenceCounted.java index ef4903c17e..cf8112d4ce 100644 --- a/common/src/main/java/io/netty/util/AbstractReferenceCounted.java +++ b/common/src/main/java/io/netty/util/AbstractReferenceCounted.java @@ -53,10 +53,13 @@ public abstract class AbstractReferenceCounted implements ReferenceCounted { public ReferenceCounted retain() { for (;;) { int refCnt = this.refCnt; - if (refCnt == 0 || refCnt == Integer.MAX_VALUE) { + final int nextCnt = refCnt + 1; + + // Ensure we not resurrect (which means the refCnt was 0) and also that we encountered an overflow. + if (nextCnt <= 1) { throw new IllegalReferenceCountException(refCnt, 1); } - if (refCntUpdater.compareAndSet(this, refCnt, refCnt + 1)) { + if (refCntUpdater.compareAndSet(this, refCnt, nextCnt)) { break; } } @@ -70,9 +73,11 @@ public abstract class AbstractReferenceCounted implements ReferenceCounted { } for (;;) { - final int nextCnt; int refCnt = this.refCnt; - if (refCnt == 0 || (nextCnt = refCnt + increment) < 0) { + final int nextCnt = refCnt + increment; + + // Ensure we not resurrect (which means the refCnt was 0) and also that we encountered an overflow. + if (nextCnt <= increment) { throw new IllegalReferenceCountException(refCnt, increment); } if (refCntUpdater.compareAndSet(this, refCnt, nextCnt)) { diff --git a/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java b/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java new file mode 100644 index 0000000000..e004e5df7b --- /dev/null +++ b/common/src/test/java/io/netty/util/AbstractReferenceCountedTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class AbstractReferenceCountedTest { + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainOverflow() { + AbstractReferenceCounted referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(Integer.MAX_VALUE); + assertEquals(Integer.MAX_VALUE, referenceCounted.refCnt()); + referenceCounted.retain(); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainOverflow2() { + AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertEquals(1, referenceCounted.refCnt()); + referenceCounted.retain(Integer.MAX_VALUE); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testReleaseOverflow() { + AbstractReferenceCounted referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(0); + assertEquals(0, referenceCounted.refCnt()); + referenceCounted.release(Integer.MAX_VALUE); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainResurrect() { + AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + referenceCounted.retain(); + } + + @Test(expected = IllegalReferenceCountException.class) + public void testRetainResurrect2() { + AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + referenceCounted.retain(2); + } + + private static AbstractReferenceCounted newReferenceCounted() { + return new AbstractReferenceCounted() { + @Override + protected void deallocate() { + // NOOP + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + }; + } +}