AbstractByteBuf.ensureWritable(...) should check if buffer was released

Motivation:

AbstractByteBuf.ensureWritable(...) should check if buffer was released and if so throw an IllegalReferenceCountException

Modifications:

Ensure we throw in all cases.

Result:

More consistent and correct behaviour
This commit is contained in:
Norman Maurer 2017-07-17 15:44:46 +02:00
parent 96e06aa74d
commit d125adec38
10 changed files with 26 additions and 70 deletions

View File

@ -266,7 +266,8 @@ public abstract class AbstractByteBuf extends ByteBuf {
return this; return this;
} }
private void ensureWritable0(int minWritableBytes) { final void ensureWritable0(int minWritableBytes) {
ensureAccessible();
if (minWritableBytes <= writableBytes()) { if (minWritableBytes <= writableBytes()) {
return; return;
} }
@ -286,6 +287,7 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public int ensureWritable(int minWritableBytes, boolean force) { public int ensureWritable(int minWritableBytes, boolean force) {
ensureAccessible();
if (minWritableBytes < 0) { if (minWritableBytes < 0) {
throw new IllegalArgumentException(String.format( throw new IllegalArgumentException(String.format(
"minWritableBytes: %d (expected: >= 0)", minWritableBytes)); "minWritableBytes: %d (expected: >= 0)", minWritableBytes));
@ -668,16 +670,16 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public int setCharSequence(int index, CharSequence sequence, Charset charset) { public int setCharSequence(int index, CharSequence sequence, Charset charset) {
if (charset.equals(CharsetUtil.UTF_8)) { if (charset.equals(CharsetUtil.UTF_8)) {
ensureWritable(ByteBufUtil.utf8MaxBytes(sequence)); ensureWritable0(ByteBufUtil.utf8MaxBytes(sequence));
return ByteBufUtil.writeUtf8(this, index, sequence, sequence.length()); return ByteBufUtil.writeUtf8(this, index, sequence, sequence.length());
} }
if (charset.equals(CharsetUtil.US_ASCII) || charset.equals(CharsetUtil.ISO_8859_1)) { if (charset.equals(CharsetUtil.US_ASCII) || charset.equals(CharsetUtil.ISO_8859_1)) {
int len = sequence.length(); int len = sequence.length();
ensureWritable(len); ensureWritable0(len);
return ByteBufUtil.writeAscii(this, index, sequence, len); return ByteBufUtil.writeAscii(this, index, sequence, len);
} }
byte[] bytes = sequence.toString().getBytes(charset); byte[] bytes = sequence.toString().getBytes(charset);
ensureWritable(bytes.length); ensureWritable0(bytes.length);
setBytes(index, bytes); setBytes(index, bytes);
return bytes.length; return bytes.length;
} }
@ -934,7 +936,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeByte(int value) { public ByteBuf writeByte(int value) {
ensureAccessible();
ensureWritable0(1); ensureWritable0(1);
_setByte(writerIndex++, value); _setByte(writerIndex++, value);
return this; return this;
@ -942,7 +943,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeShort(int value) { public ByteBuf writeShort(int value) {
ensureAccessible();
ensureWritable0(2); ensureWritable0(2);
_setShort(writerIndex, value); _setShort(writerIndex, value);
writerIndex += 2; writerIndex += 2;
@ -951,7 +951,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeShortLE(int value) { public ByteBuf writeShortLE(int value) {
ensureAccessible();
ensureWritable0(2); ensureWritable0(2);
_setShortLE(writerIndex, value); _setShortLE(writerIndex, value);
writerIndex += 2; writerIndex += 2;
@ -960,7 +959,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeMedium(int value) { public ByteBuf writeMedium(int value) {
ensureAccessible();
ensureWritable0(3); ensureWritable0(3);
_setMedium(writerIndex, value); _setMedium(writerIndex, value);
writerIndex += 3; writerIndex += 3;
@ -969,7 +967,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeMediumLE(int value) { public ByteBuf writeMediumLE(int value) {
ensureAccessible();
ensureWritable0(3); ensureWritable0(3);
_setMediumLE(writerIndex, value); _setMediumLE(writerIndex, value);
writerIndex += 3; writerIndex += 3;
@ -978,7 +975,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeInt(int value) { public ByteBuf writeInt(int value) {
ensureAccessible();
ensureWritable0(4); ensureWritable0(4);
_setInt(writerIndex, value); _setInt(writerIndex, value);
writerIndex += 4; writerIndex += 4;
@ -987,7 +983,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeIntLE(int value) { public ByteBuf writeIntLE(int value) {
ensureAccessible();
ensureWritable0(4); ensureWritable0(4);
_setIntLE(writerIndex, value); _setIntLE(writerIndex, value);
writerIndex += 4; writerIndex += 4;
@ -996,7 +991,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeLong(long value) { public ByteBuf writeLong(long value) {
ensureAccessible();
ensureWritable0(8); ensureWritable0(8);
_setLong(writerIndex, value); _setLong(writerIndex, value);
writerIndex += 8; writerIndex += 8;
@ -1005,7 +999,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeLongLE(long value) { public ByteBuf writeLongLE(long value) {
ensureAccessible();
ensureWritable0(8); ensureWritable0(8);
_setLongLE(writerIndex, value); _setLongLE(writerIndex, value);
writerIndex += 8; writerIndex += 8;
@ -1032,7 +1025,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { public ByteBuf writeBytes(byte[] src, int srcIndex, int length) {
ensureAccessible();
ensureWritable(length); ensureWritable(length);
setBytes(writerIndex, src, srcIndex, length); setBytes(writerIndex, src, srcIndex, length);
writerIndex += length; writerIndex += length;
@ -1064,7 +1056,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) {
ensureAccessible();
ensureWritable(length); ensureWritable(length);
setBytes(writerIndex, src, srcIndex, length); setBytes(writerIndex, src, srcIndex, length);
writerIndex += length; writerIndex += length;
@ -1073,9 +1064,8 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public ByteBuf writeBytes(ByteBuffer src) { public ByteBuf writeBytes(ByteBuffer src) {
ensureAccessible();
int length = src.remaining(); int length = src.remaining();
ensureWritable(length); ensureWritable0(length);
setBytes(writerIndex, src); setBytes(writerIndex, src);
writerIndex += length; writerIndex += length;
return this; return this;
@ -1084,7 +1074,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public int writeBytes(InputStream in, int length) public int writeBytes(InputStream in, int length)
throws IOException { throws IOException {
ensureAccessible();
ensureWritable(length); ensureWritable(length);
int writtenBytes = setBytes(writerIndex, in, length); int writtenBytes = setBytes(writerIndex, in, length);
if (writtenBytes > 0) { if (writtenBytes > 0) {
@ -1095,7 +1084,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public int writeBytes(ScatteringByteChannel in, int length) throws IOException { public int writeBytes(ScatteringByteChannel in, int length) throws IOException {
ensureAccessible();
ensureWritable(length); ensureWritable(length);
int writtenBytes = setBytes(writerIndex, in, length); int writtenBytes = setBytes(writerIndex, in, length);
if (writtenBytes > 0) { if (writtenBytes > 0) {
@ -1106,7 +1094,6 @@ public abstract class AbstractByteBuf extends ByteBuf {
@Override @Override
public int writeBytes(FileChannel in, long position, int length) throws IOException { public int writeBytes(FileChannel in, long position, int length) throws IOException {
ensureAccessible();
ensureWritable(length); ensureWritable(length);
int writtenBytes = setBytes(writerIndex, in, position, length); int writtenBytes = setBytes(writerIndex, in, position, length);
if (writtenBytes > 0) { if (writtenBytes > 0) {
@ -1123,7 +1110,7 @@ public abstract class AbstractByteBuf extends ByteBuf {
ensureWritable(length); ensureWritable(length);
int wIndex = writerIndex; int wIndex = writerIndex;
checkIndex(wIndex, length); checkIndex0(wIndex, length);
int nLong = length >>> 3; int nLong = length >>> 3;
int nBytes = length & 7; int nBytes = length & 7;

View File

@ -122,7 +122,7 @@ abstract class AbstractUnsafeSwappedByteBuf extends SwappedByteBuf {
@Override @Override
public final ByteBuf writeShort(int value) { public final ByteBuf writeShort(int value) {
wrapped.ensureWritable(2); wrapped.ensureWritable0(2);
_setShort(wrapped, wrapped.writerIndex, nativeByteOrder ? (short) value : Short.reverseBytes((short) value)); _setShort(wrapped, wrapped.writerIndex, nativeByteOrder ? (short) value : Short.reverseBytes((short) value));
wrapped.writerIndex += 2; wrapped.writerIndex += 2;
return this; return this;
@ -130,7 +130,7 @@ abstract class AbstractUnsafeSwappedByteBuf extends SwappedByteBuf {
@Override @Override
public final ByteBuf writeInt(int value) { public final ByteBuf writeInt(int value) {
wrapped.ensureWritable(4); wrapped.ensureWritable0(4);
_setInt(wrapped, wrapped.writerIndex, nativeByteOrder ? value : Integer.reverseBytes(value)); _setInt(wrapped, wrapped.writerIndex, nativeByteOrder ? value : Integer.reverseBytes(value));
wrapped.writerIndex += 4; wrapped.writerIndex += 4;
return this; return this;
@ -138,7 +138,7 @@ abstract class AbstractUnsafeSwappedByteBuf extends SwappedByteBuf {
@Override @Override
public final ByteBuf writeLong(long value) { public final ByteBuf writeLong(long value) {
wrapped.ensureWritable(8); wrapped.ensureWritable0(8);
_setLong(wrapped, wrapped.writerIndex, nativeByteOrder ? value : Long.reverseBytes(value)); _setLong(wrapped, wrapped.writerIndex, nativeByteOrder ? value : Long.reverseBytes(value));
wrapped.writerIndex += 8; wrapped.writerIndex += 8;
return this; return this;

View File

@ -482,7 +482,7 @@ public final class ByteBufUtil {
for (;;) { for (;;) {
if (buf instanceof AbstractByteBuf) { if (buf instanceof AbstractByteBuf) {
AbstractByteBuf byteBuf = (AbstractByteBuf) buf; AbstractByteBuf byteBuf = (AbstractByteBuf) buf;
byteBuf.ensureWritable(utf8MaxBytes(seq)); byteBuf.ensureWritable0(utf8MaxBytes(seq));
int written = writeUtf8(byteBuf, byteBuf.writerIndex, seq, seq.length()); int written = writeUtf8(byteBuf, byteBuf.writerIndex, seq, seq.length());
byteBuf.writerIndex += written; byteBuf.writerIndex += written;
return written; return written;
@ -583,7 +583,7 @@ public final class ByteBufUtil {
for (;;) { for (;;) {
if (buf instanceof AbstractByteBuf) { if (buf instanceof AbstractByteBuf) {
AbstractByteBuf byteBuf = (AbstractByteBuf) buf; AbstractByteBuf byteBuf = (AbstractByteBuf) buf;
byteBuf.ensureWritable(len); byteBuf.ensureWritable0(len);
int written = writeAscii(byteBuf, byteBuf.writerIndex, seq, len); int written = writeAscii(byteBuf, byteBuf.writerIndex, seq, len);
byteBuf.writerIndex += written; byteBuf.writerIndex += written;
return written; return written;

View File

@ -374,7 +374,8 @@ final class PooledUnsafeDirectByteBuf extends PooledByteBuf<ByteBuffer> {
@Override @Override
public ByteBuf setZero(int index, int length) { public ByteBuf setZero(int index, int length) {
UnsafeByteBufUtil.setZero(this, addr(index), index, length); checkIndex(index, length);
UnsafeByteBufUtil.setZero(addr(index), length);
return this; return this;
} }
@ -382,7 +383,7 @@ final class PooledUnsafeDirectByteBuf extends PooledByteBuf<ByteBuffer> {
public ByteBuf writeZero(int length) { public ByteBuf writeZero(int length) {
ensureWritable(length); ensureWritable(length);
int wIndex = writerIndex; int wIndex = writerIndex;
setZero(wIndex, length); UnsafeByteBufUtil.setZero(addr(wIndex), length);
writerIndex = wIndex + length; writerIndex = wIndex + length;
return this; return this;
} }

View File

@ -131,8 +131,9 @@ final class PooledUnsafeHeapByteBuf extends PooledHeapByteBuf {
@Override @Override
public ByteBuf setZero(int index, int length) { public ByteBuf setZero(int index, int length) {
if (PlatformDependent.javaVersion() >= 7) { if (PlatformDependent.javaVersion() >= 7) {
checkIndex(index, length);
// Only do on java7+ as the needed Unsafe call was only added there. // Only do on java7+ as the needed Unsafe call was only added there.
_setZero(index, length); UnsafeByteBufUtil.setZero(memory, idx(index), length);
return this; return this;
} }
return super.setZero(index, length); return super.setZero(index, length);
@ -144,18 +145,13 @@ final class PooledUnsafeHeapByteBuf extends PooledHeapByteBuf {
// Only do on java7+ as the needed Unsafe call was only added there. // Only do on java7+ as the needed Unsafe call was only added there.
ensureWritable(length); ensureWritable(length);
int wIndex = writerIndex; int wIndex = writerIndex;
_setZero(wIndex, length); UnsafeByteBufUtil.setZero(memory, idx(wIndex), length);
writerIndex = wIndex + length; writerIndex = wIndex + length;
return this; return this;
} }
return super.writeZero(length); return super.writeZero(length);
} }
private void _setZero(int index, int length) {
checkIndex(index, length);
UnsafeByteBufUtil.setZero(memory, idx(index), length);
}
@Override @Override
@Deprecated @Deprecated
protected SwappedByteBuf newSwappedByteBuf() { protected SwappedByteBuf newSwappedByteBuf() {

View File

@ -517,7 +517,8 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
@Override @Override
public ByteBuf setZero(int index, int length) { public ByteBuf setZero(int index, int length) {
UnsafeByteBufUtil.setZero(this, addr(index), index, length); checkIndex(index, length);
UnsafeByteBufUtil.setZero(addr(index), length);
return this; return this;
} }
@ -525,7 +526,7 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
public ByteBuf writeZero(int length) { public ByteBuf writeZero(int length) {
ensureWritable(length); ensureWritable(length);
int wIndex = writerIndex; int wIndex = writerIndex;
setZero(wIndex, length); UnsafeByteBufUtil.setZero(addr(wIndex), length);
writerIndex = wIndex + length; writerIndex = wIndex + length;
return this; return this;
} }

View File

@ -245,7 +245,8 @@ class UnpooledUnsafeHeapByteBuf extends UnpooledHeapByteBuf {
public ByteBuf setZero(int index, int length) { public ByteBuf setZero(int index, int length) {
if (PlatformDependent.javaVersion() >= 7) { if (PlatformDependent.javaVersion() >= 7) {
// Only do on java7+ as the needed Unsafe call was only added there. // Only do on java7+ as the needed Unsafe call was only added there.
_setZero(index, length); checkIndex(index, length);
UnsafeByteBufUtil.setZero(array, index, length);
return this; return this;
} }
return super.setZero(index, length); return super.setZero(index, length);
@ -257,18 +258,13 @@ class UnpooledUnsafeHeapByteBuf extends UnpooledHeapByteBuf {
// Only do on java7+ as the needed Unsafe call was only added there. // Only do on java7+ as the needed Unsafe call was only added there.
ensureWritable(length); ensureWritable(length);
int wIndex = writerIndex; int wIndex = writerIndex;
_setZero(wIndex, length); UnsafeByteBufUtil.setZero(array, wIndex, length);
writerIndex = wIndex + length; writerIndex = wIndex + length;
return this; return this;
} }
return super.writeZero(length); return super.writeZero(length);
} }
private void _setZero(int index, int length) {
checkIndex(index, length);
UnsafeByteBufUtil.setZero(array, index, length);
}
@Override @Override
@Deprecated @Deprecated
protected SwappedByteBuf newSwappedByteBuf() { protected SwappedByteBuf newSwappedByteBuf() {

View File

@ -581,12 +581,11 @@ final class UnsafeByteBufUtil {
} }
} }
static void setZero(AbstractByteBuf buf, long addr, int index, int length) { static void setZero(long addr, int length) {
if (length == 0) { if (length == 0) {
return; return;
} }
buf.checkIndex(index, length);
PlatformDependent.setMemory(addr, length, ZERO); PlatformDependent.setMemory(addr, length, ZERO);
} }

View File

@ -94,18 +94,6 @@ public class SlicedByteBufTest extends AbstractByteBufTest {
super.testNioBufferExposeOnlyRegion(); super.testNioBufferExposeOnlyRegion();
} }
@Test(expected = IndexOutOfBoundsException.class)
@Override
public void testEnsureWritableAfterRelease() {
super.testEnsureWritableAfterRelease();
}
@Test(expected = IndexOutOfBoundsException.class)
@Override
public void testWriteZeroAfterRelease() throws IOException {
super.testWriteZeroAfterRelease();
}
@Test(expected = IndexOutOfBoundsException.class) @Test(expected = IndexOutOfBoundsException.class)
@Override @Override
public void testGetReadOnlyDirectDst() { public void testGetReadOnlyDirectDst() {

View File

@ -88,18 +88,6 @@ public class WrappedUnpooledUnsafeByteBufTest extends BigEndianUnsafeDirectByteB
super.testNioBufferExposeOnlyRegion(); super.testNioBufferExposeOnlyRegion();
} }
@Test(expected = IndexOutOfBoundsException.class)
@Override
public void testEnsureWritableAfterRelease() {
super.testEnsureWritableAfterRelease();
}
@Test(expected = IndexOutOfBoundsException.class)
@Override
public void testWriteZeroAfterRelease() throws IOException {
super.testWriteZeroAfterRelease();
}
@Test(expected = IndexOutOfBoundsException.class) @Test(expected = IndexOutOfBoundsException.class)
@Override @Override
public void testGetReadOnlyDirectDst() { public void testGetReadOnlyDirectDst() {