De-duplicate UnpooledDirectByteBuf/UnpooledUnsafeDirectByteBuf (#9085)

Motivation

While digging around looking at something else I noticed that these
share a lot of logic and it would be nice to reduce that duplication.

Modifications

Have UnpooledUnsafeDirectByteBuf extend UnpooledDirectByteBuf and make
adjustments to ensure existing behaviour remains unchanged.

The most significant addition needed to UnpooledUnsafeDirectByteBuf was
re-overriding the getPrimitive/setPrimitive methods to revert back to
the AbstractByteBuf versions which include bounds checks
(UnpooledDirectByteBuf excludes these as an optimization, relying on
those done by underlying ByteBuffer).

Result

~200 fewer lines, less duplicate logic.
This commit is contained in:
Nick Hill 2019-06-03 04:04:10 -07:00 committed by Norman Maurer
parent 4b8db65b16
commit 4ba75b99af
2 changed files with 95 additions and 300 deletions

View File

@ -39,7 +39,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
private final ByteBufAllocator alloc;
private ByteBuffer buffer;
ByteBuffer buffer; // accessed by UnpooledUnsafeNoCleanerDirectByteBuf.reallocateDirect()
private ByteBuffer tmpNioBuf;
private int capacity;
private boolean doNotFree;
@ -61,7 +61,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
}
this.alloc = alloc;
setByteBuffer(allocateDirect(initialCapacity));
setByteBuffer(allocateDirect(initialCapacity), false);
}
/**
@ -70,6 +70,11 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
* @param maxCapacity the maximum capacity of the underlying direct buffer
*/
protected UnpooledDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer, int maxCapacity) {
this(alloc, initialBuffer, maxCapacity, false, true);
}
UnpooledDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer,
int maxCapacity, boolean doFree, boolean slice) {
super(maxCapacity);
requireNonNull(alloc, "alloc");
requireNonNull(initialBuffer, "initialBuffer");
@ -87,8 +92,8 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
}
this.alloc = alloc;
doNotFree = true;
setByteBuffer(initialBuffer.slice().order(ByteOrder.BIG_ENDIAN));
doNotFree = !doFree;
setByteBuffer((slice ? initialBuffer.slice() : initialBuffer).order(ByteOrder.BIG_ENDIAN), false);
writerIndex(initialCapacity);
}
@ -106,13 +111,15 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
PlatformDependent.freeDirectBuffer(buffer);
}
private void setByteBuffer(ByteBuffer buffer) {
ByteBuffer oldBuffer = this.buffer;
if (oldBuffer != null) {
if (doNotFree) {
doNotFree = false;
} else {
freeDirect(oldBuffer);
void setByteBuffer(ByteBuffer buffer, boolean tryFree) {
if (tryFree) {
ByteBuffer oldBuffer = this.buffer;
if (oldBuffer != null) {
if (doNotFree) {
doNotFree = false;
} else {
freeDirect(oldBuffer);
}
}
}
@ -146,7 +153,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
newBuffer.position(0).limit(oldBuffer.capacity());
newBuffer.put(oldBuffer);
newBuffer.clear();
setByteBuffer(newBuffer);
setByteBuffer(newBuffer, true);
} else if (newCapacity < oldCapacity) {
ByteBuffer oldBuffer = buffer;
ByteBuffer newBuffer = allocateDirect(newCapacity);
@ -161,7 +168,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
} else {
setIndex(newCapacity, newCapacity);
}
setByteBuffer(newBuffer);
setByteBuffer(newBuffer, true);
}
return this;
}
@ -303,7 +310,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
return this;
}
private void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) {
void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) {
checkDstIndex(index, length, dstIndex, dst.length);
ByteBuffer tmpBuf;
@ -330,7 +337,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
return this;
}
private void getBytes(int index, ByteBuffer dst, boolean internal) {
void getBytes(int index, ByteBuffer dst, boolean internal) {
checkIndex(index, dst.remaining());
ByteBuffer tmpBuf;
@ -479,7 +486,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
return this;
}
private void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException {
void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException {
ensureAccessible();
if (length == 0) {
return;
@ -572,7 +579,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
ByteBuffer tmpBuf = internalNioBuffer();
tmpBuf.clear().position(index).limit(index + length);
try {
return in.read(tmpNioBuf);
return in.read(tmpBuf);
} catch (ClosedChannelException ignored) {
return -1;
}
@ -584,7 +591,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf {
ByteBuffer tmpBuf = internalNioBuffer();
tmpBuf.clear().position(index).limit(index + length);
try {
return in.read(tmpNioBuf, position);
return in.read(tmpBuf, position);
} catch (ClosedChannelException ignored) {
return -1;
}

View File

@ -15,34 +15,20 @@
*/
package io.netty.buffer;
import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
import static java.util.Objects.requireNonNull;
import io.netty.util.internal.PlatformDependent;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.nio.channels.GatheringByteChannel;
import java.nio.channels.ScatteringByteChannel;
/**
* A NIO {@link ByteBuffer} based buffer. It is recommended to use
* {@link UnpooledByteBufAllocator#directBuffer(int, int)}, {@link Unpooled#directBuffer(int)} and
* {@link Unpooled#wrappedBuffer(ByteBuffer)} instead of calling the constructor explicitly.}
*/
public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf {
public class UnpooledUnsafeDirectByteBuf extends UnpooledDirectByteBuf {
private final ByteBufAllocator alloc;
private ByteBuffer tmpNioBuf;
private int capacity;
private boolean doNotFree;
ByteBuffer buffer;
long memoryAddress;
/**
@ -52,17 +38,7 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
* @param maxCapacity the maximum capacity of the underlying direct buffer
*/
public UnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, int initialCapacity, int maxCapacity) {
super(maxCapacity);
requireNonNull(alloc, "alloc");
checkPositiveOrZero(initialCapacity, "initialCapacity");
checkPositiveOrZero(maxCapacity, "maxCapacity");
if (initialCapacity > maxCapacity) {
throw new IllegalArgumentException(String.format(
"initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity));
}
this.alloc = alloc;
setByteBuffer(allocateDirect(initialCapacity), false);
super(alloc, initialCapacity, maxCapacity);
}
/**
@ -80,131 +56,16 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
// sun/misc/Unsafe.java#l1250
//
// We also call slice() explicitly here to preserve behaviour with previous netty releases.
this(alloc, initialBuffer.slice(), maxCapacity, false);
super(alloc, initialBuffer, maxCapacity, /* doFree = */ false, /* slice = */ true);
}
UnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer, int maxCapacity, boolean doFree) {
super(maxCapacity);
requireNonNull(alloc, "alloc");
requireNonNull(initialBuffer, "initialBuffer");
if (!initialBuffer.isDirect()) {
throw new IllegalArgumentException("initialBuffer is not a direct buffer.");
}
if (initialBuffer.isReadOnly()) {
throw new IllegalArgumentException("initialBuffer is a read-only buffer.");
}
int initialCapacity = initialBuffer.remaining();
if (initialCapacity > maxCapacity) {
throw new IllegalArgumentException(String.format(
"initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity));
}
this.alloc = alloc;
doNotFree = !doFree;
setByteBuffer(initialBuffer.order(ByteOrder.BIG_ENDIAN), false);
writerIndex(initialCapacity);
}
/**
* Allocate a new direct {@link ByteBuffer} with the given initialCapacity.
*/
protected ByteBuffer allocateDirect(int initialCapacity) {
return ByteBuffer.allocateDirect(initialCapacity);
}
/**
* Free a direct {@link ByteBuffer}
*/
protected void freeDirect(ByteBuffer buffer) {
PlatformDependent.freeDirectBuffer(buffer);
super(alloc, initialBuffer, maxCapacity, doFree, false);
}
final void setByteBuffer(ByteBuffer buffer, boolean tryFree) {
if (tryFree) {
ByteBuffer oldBuffer = this.buffer;
if (oldBuffer != null) {
if (doNotFree) {
doNotFree = false;
} else {
freeDirect(oldBuffer);
}
}
}
this.buffer = buffer;
super.setByteBuffer(buffer, tryFree);
memoryAddress = PlatformDependent.directBufferAddress(buffer);
tmpNioBuf = null;
capacity = buffer.remaining();
}
@Override
public boolean isDirect() {
return true;
}
@Override
public int capacity() {
return capacity;
}
@Override
public ByteBuf capacity(int newCapacity) {
checkNewCapacity(newCapacity);
int readerIndex = readerIndex();
int writerIndex = writerIndex();
int oldCapacity = capacity;
if (newCapacity > oldCapacity) {
ByteBuffer oldBuffer = buffer;
ByteBuffer newBuffer = allocateDirect(newCapacity);
oldBuffer.position(0).limit(oldBuffer.capacity());
newBuffer.position(0).limit(oldBuffer.capacity());
newBuffer.put(oldBuffer);
newBuffer.clear();
setByteBuffer(newBuffer, true);
} else if (newCapacity < oldCapacity) {
ByteBuffer oldBuffer = buffer;
ByteBuffer newBuffer = allocateDirect(newCapacity);
if (readerIndex < newCapacity) {
if (writerIndex > newCapacity) {
writerIndex(writerIndex = newCapacity);
}
oldBuffer.position(readerIndex).limit(writerIndex);
newBuffer.position(readerIndex).limit(writerIndex);
newBuffer.put(oldBuffer);
newBuffer.clear();
} else {
setIndex(newCapacity, newCapacity);
}
setByteBuffer(newBuffer, true);
}
return this;
}
@Override
public ByteBufAllocator alloc() {
return alloc;
}
@Override
public ByteOrder order() {
return ByteOrder.BIG_ENDIAN;
}
@Override
public boolean hasArray() {
return false;
}
@Override
public byte[] array() {
throw new UnsupportedOperationException("direct buffer");
}
@Override
public int arrayOffset() {
throw new UnsupportedOperationException("direct buffer");
}
@Override
@ -218,11 +79,23 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
return memoryAddress;
}
@Override
public byte getByte(int index) {
checkIndex(index);
return _getByte(index);
}
@Override
protected byte _getByte(int index) {
return UnsafeByteBufUtil.getByte(addr(index));
}
@Override
public short getShort(int index) {
checkIndex(index, 2);
return _getShort(index);
}
@Override
protected short _getShort(int index) {
return UnsafeByteBufUtil.getShort(addr(index));
@ -233,6 +106,12 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
return UnsafeByteBufUtil.getShortLE(addr(index));
}
@Override
public int getUnsignedMedium(int index) {
checkIndex(index, 3);
return _getUnsignedMedium(index);
}
@Override
protected int _getUnsignedMedium(int index) {
return UnsafeByteBufUtil.getUnsignedMedium(addr(index));
@ -243,6 +122,12 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
return UnsafeByteBufUtil.getUnsignedMediumLE(addr(index));
}
@Override
public int getInt(int index) {
checkIndex(index, 4);
return _getInt(index);
}
@Override
protected int _getInt(int index) {
return UnsafeByteBufUtil.getInt(addr(index));
@ -253,6 +138,12 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
return UnsafeByteBufUtil.getIntLE(addr(index));
}
@Override
public long getLong(int index) {
checkIndex(index, 8);
return _getLong(index);
}
@Override
protected long _getLong(int index) {
return UnsafeByteBufUtil.getLong(addr(index));
@ -270,23 +161,19 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
}
@Override
public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) {
void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) {
UnsafeByteBufUtil.getBytes(this, addr(index), index, dst, dstIndex, length);
return this;
}
@Override
public ByteBuf getBytes(int index, ByteBuffer dst) {
void getBytes(int index, ByteBuffer dst, boolean internal) {
UnsafeByteBufUtil.getBytes(this, addr(index), index, dst);
return this;
}
@Override
public ByteBuf readBytes(ByteBuffer dst) {
int length = dst.remaining();
checkReadableBytes(length);
getBytes(readerIndex, dst);
readerIndex += length;
public ByteBuf setByte(int index, int value) {
checkIndex(index);
_setByte(index, value);
return this;
}
@ -295,6 +182,13 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
UnsafeByteBufUtil.setByte(addr(index), value);
}
@Override
public ByteBuf setShort(int index, int value) {
checkIndex(index, 2);
_setShort(index, value);
return this;
}
@Override
protected void _setShort(int index, int value) {
UnsafeByteBufUtil.setShort(addr(index), value);
@ -305,6 +199,13 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
UnsafeByteBufUtil.setShortLE(addr(index), value);
}
@Override
public ByteBuf setMedium(int index, int value) {
checkIndex(index, 3);
_setMedium(index, value);
return this;
}
@Override
protected void _setMedium(int index, int value) {
UnsafeByteBufUtil.setMedium(addr(index), value);
@ -315,6 +216,13 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
UnsafeByteBufUtil.setMediumLE(addr(index), value);
}
@Override
public ByteBuf setInt(int index, int value) {
checkIndex(index, 4);
_setInt(index, value);
return this;
}
@Override
protected void _setInt(int index, int value) {
UnsafeByteBufUtil.setInt(addr(index), value);
@ -325,6 +233,13 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
UnsafeByteBufUtil.setIntLE(addr(index), value);
}
@Override
public ByteBuf setLong(int index, long value) {
checkIndex(index, 8);
_setLong(index, value);
return this;
}
@Override
protected void _setLong(int index, long value) {
UnsafeByteBufUtil.setLong(addr(index), value);
@ -354,62 +269,8 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
}
@Override
public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException {
void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException {
UnsafeByteBufUtil.getBytes(this, addr(index), index, out, length);
return this;
}
@Override
public int getBytes(int index, GatheringByteChannel out, int length) throws IOException {
return getBytes(index, out, length, false);
}
private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException {
ensureAccessible();
if (length == 0) {
return 0;
}
ByteBuffer tmpBuf;
if (internal) {
tmpBuf = internalNioBuffer();
} else {
tmpBuf = buffer.duplicate();
}
tmpBuf.clear().position(index).limit(index + length);
return out.write(tmpBuf);
}
@Override
public int getBytes(int index, FileChannel out, long position, int length) throws IOException {
return getBytes(index, out, position, length, false);
}
private int getBytes(int index, FileChannel out, long position, int length, boolean internal) throws IOException {
ensureAccessible();
if (length == 0) {
return 0;
}
ByteBuffer tmpBuf = internal ? internalNioBuffer() : buffer.duplicate();
tmpBuf.clear().position(index).limit(index + length);
return out.write(tmpBuf, position);
}
@Override
public int readBytes(GatheringByteChannel out, int length) throws IOException {
checkReadableBytes(length);
int readBytes = getBytes(readerIndex, out, length, true);
readerIndex += readBytes;
return readBytes;
}
@Override
public int readBytes(FileChannel out, long position, int length) throws IOException {
checkReadableBytes(length);
int readBytes = getBytes(readerIndex, out, position, length, true);
readerIndex += readBytes;
return readBytes;
}
@Override
@ -417,85 +278,12 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf
return UnsafeByteBufUtil.setBytes(this, addr(index), index, in, length);
}
@Override
public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException {
ensureAccessible();
ByteBuffer tmpBuf = internalNioBuffer();
tmpBuf.clear().position(index).limit(index + length);
try {
return in.read(tmpBuf);
} catch (ClosedChannelException ignored) {
return -1;
}
}
@Override
public int setBytes(int index, FileChannel in, long position, int length) throws IOException {
ensureAccessible();
ByteBuffer tmpBuf = internalNioBuffer();
tmpBuf.clear().position(index).limit(index + length);
try {
return in.read(tmpBuf, position);
} catch (ClosedChannelException ignored) {
return -1;
}
}
@Override
public int nioBufferCount() {
return 1;
}
@Override
public ByteBuffer[] nioBuffers(int index, int length) {
return new ByteBuffer[] { nioBuffer(index, length) };
}
@Override
public ByteBuf copy(int index, int length) {
return UnsafeByteBufUtil.copy(this, addr(index), index, length);
}
@Override
public ByteBuffer internalNioBuffer(int index, int length) {
checkIndex(index, length);
return (ByteBuffer) internalNioBuffer().clear().position(index).limit(index + length);
}
private ByteBuffer internalNioBuffer() {
ByteBuffer tmpNioBuf = this.tmpNioBuf;
if (tmpNioBuf == null) {
this.tmpNioBuf = tmpNioBuf = buffer.duplicate();
}
return tmpNioBuf;
}
@Override
public ByteBuffer nioBuffer(int index, int length) {
checkIndex(index, length);
return ((ByteBuffer) buffer.duplicate().position(index).limit(index + length)).slice();
}
@Override
protected void deallocate() {
ByteBuffer buffer = this.buffer;
if (buffer == null) {
return;
}
this.buffer = null;
if (!doNotFree) {
freeDirect(buffer);
}
}
@Override
public ByteBuf unwrap() {
return null;
}
long addr(int index) {
final long addr(int index) {
return memoryAddress + index;
}