Introduce `ByteToMessageDecoderForBuffer` (#11654)

__Motivation__

In order to migrate all codec incrementally to use `Buffer`, we need a version of `ByteToMessageDecoder` that uses `Buffer`.

__Modification__

- Added the new version of `ByteToMessageDecoder` with a new name so that both old and new version can co-exist and we can incrementally migrate all codecs
- Migrated `FixedLengthFrameDecoder` as it was simple and used in tests.

__Result__

We have the basic building block to start migrating all codecs to the new `Buffer` API.
This commit is contained in:
Nitesh Kant 2021-09-14 17:06:49 -07:00 committed by GitHub
parent a76842dcd5
commit 29cae0445a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1367 additions and 13 deletions

View File

@ -0,0 +1,686 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.api.Buffer;
import io.netty.buffer.api.BufferAllocator;
import io.netty.buffer.api.CompositeBuffer;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.StringUtil;
import java.net.SocketAddress;
import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo;
import static java.util.Objects.requireNonNull;
/**
* {@link ChannelHandler} which decodes bytes in a stream-like fashion from one {@link Buffer} to an
* other Message type.
*
* For example here is an implementation which reads all readable bytes from
* the input {@link Buffer}, creates a new {@link Buffer} and forward it to the next {@link ChannelHandler}
* in the {@link ChannelPipeline}.
*
* <pre>
* public class SquareDecoder extends {@link ByteToMessageDecoderForBuffer} {
* {@code @Override}
* public void decode({@link ChannelHandlerContext} ctx, {@link Buffer} in)
* throws {@link Exception} {
* ctx.fireChannelRead(in.readBytes(in.readableBytes()));
* }
* }
* </pre>
*
* <h3>Frame detection</h3>
* <p>
* Generally frame detection should be handled earlier in the pipeline by adding a
* {@link DelimiterBasedFrameDecoder}, {@link FixedLengthFrameDecoder}, {@link LengthFieldBasedFrameDecoder},
* or {@link LineBasedFrameDecoder}.
* <p>
* If a custom frame decoder is required, then one needs to be careful when implementing
* one with {@link ByteToMessageDecoderForBuffer}. Ensure there are enough bytes in the buffer for a
* complete frame by checking {@link Buffer#readableBytes()}. If there are not enough bytes
* for a complete frame, return without modifying the reader index to allow more bytes to arrive.
* <p>
* To check for complete frames without modifying the reader index, use methods like {@link Buffer#getInt(int)}.
* One <strong>MUST</strong> use the reader index when using methods like {@link Buffer#getInt(int)}.
* For example calling <tt>in.getInt(0)</tt> is assuming the frame starts at the beginning of the buffer, which
* is not always the case. Use <tt>in.getInt(in.readerIndex())</tt> instead.
* <h3>Pitfalls</h3>
* <p>
* Be aware that sub-classes of {@link ByteToMessageDecoderForBuffer} <strong>MUST NOT</strong>
* annotated with {@link @Sharable}.
*/
public abstract class ByteToMessageDecoderForBuffer extends ChannelHandlerAdapter {
/**
* Cumulate {@link Buffer}s by merge them into one {@link Buffer}'s, using memory copies.
*/
public static final Cumulator MERGE_CUMULATOR = new MergeCumulator();
/**
* Cumulate {@link Buffer}s by add them to a {@link CompositeBuffer} and so do no memory copy whenever possible.
* Be aware that {@link CompositeBuffer} use a more complex indexing implementation so depending on your use-case
* and the decoder implementation this may be slower then just use the {@link #MERGE_CUMULATOR}.
*/
public static final Cumulator COMPOSITE_CUMULATOR = new CompositeBufferCumulator();
private final int discardAfterReads = 16;
private final Cumulator cumulator;
private Buffer cumulation;
private boolean singleDecode;
private boolean first;
/**
* This flag is used to determine if we need to call {@link ChannelHandlerContext#read()} to consume more data
* when {@link ChannelConfig#isAutoRead()} is {@code false}.
*/
private boolean firedChannelRead;
private int numReads;
private ByteToMessageDecoderContext context;
protected ByteToMessageDecoderForBuffer() {
this(MERGE_CUMULATOR);
}
protected ByteToMessageDecoderForBuffer(Cumulator cumulator) {
this.cumulator = requireNonNull(cumulator, "cumulator");
ensureNotSharable();
}
/**
* If set then only one message is decoded on each {@link #channelRead(ChannelHandlerContext, Object)}
* call. This may be useful if you need to do some protocol upgrade and want to make sure nothing is mixed up.
*
* Default is {@code false} as this has performance impacts.
*/
public void setSingleDecode(boolean singleDecode) {
this.singleDecode = singleDecode;
}
/**
* If {@code true} then only one message is decoded on each
* {@link #channelRead(ChannelHandlerContext, Object)} call.
*
* Default is {@code false} as this has performance impacts.
*/
public boolean isSingleDecode() {
return singleDecode;
}
/**
* Returns the actual number of readable bytes in the internal cumulative
* buffer of this decoder. You usually do not need to rely on this value
* to write a decoder. Use it only when you must use it at your own risk.
* This method is a shortcut to {@link #internalBuffer() internalBuffer().readableBytes()}.
*/
protected int actualReadableBytes() {
return internalBuffer().readableBytes();
}
/**
* Returns the internal cumulative buffer of this decoder, if exists, else {@code null}. You usually
* do not need to access the internal buffer directly to write a decoder.
* Use it only when you must use it at your own risk.
*
* @return Internal {@link Buffer} if exists, else {@code null}.
*/
protected Buffer internalBuffer() {
return cumulation;
}
@Override
public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
context = new ByteToMessageDecoderContext(ctx);
handlerAdded0(context);
}
protected void handlerAdded0(ChannelHandlerContext ctx) throws Exception {
}
@Override
public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
Buffer buf = cumulation;
if (buf != null) {
// Directly set this to null so we are sure we not access it in any other method here anymore.
cumulation = null;
numReads = 0;
int readable = buf.readableBytes();
if (readable > 0) {
ctx.fireChannelRead(buf);
ctx.fireChannelReadComplete();
} else {
buf.close();
}
}
handlerRemoved0(context);
}
/**
* Gets called after the {@link ByteToMessageDecoderForBuffer} was removed from the actual context and it doesn't
* handle events anymore.
*/
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { }
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof Buffer) {
try {
Buffer data = (Buffer) msg;
first = cumulation == null;
if (first) {
if (data.readOnly()) {
cumulation = CompositeBuffer.compose(ctx.bufferAllocator(), data.copy().send());
data.close();
} else {
cumulation = data;
}
} else {
cumulation = cumulator.cumulate(ctx.bufferAllocator(), cumulation, data);
}
assert context.ctx == ctx || ctx == context;
callDecode(context, cumulation);
} catch (DecoderException e) {
throw e;
} catch (Exception e) {
throw new DecoderException(e);
} finally {
if (cumulation != null && cumulation.readableBytes() == 0) {
numReads = 0;
if (cumulation.isAccessible()) {
cumulation.close();
}
cumulation = null;
} else if (++numReads >= discardAfterReads) {
// We did enough reads already try to discard some bytes so we not risk to see a OOME.
// See https://github.com/netty/netty/issues/4275
numReads = 0;
discardSomeReadBytes();
}
firedChannelRead |= context.fireChannelReadCallCount() > 0;
context.reset();
}
} else {
ctx.fireChannelRead(msg);
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
numReads = 0;
discardSomeReadBytes();
if (!firedChannelRead && !ctx.channel().config().isAutoRead()) {
ctx.read();
}
firedChannelRead = false;
ctx.fireChannelReadComplete();
}
protected final void discardSomeReadBytes() {
if (cumulation != null && !first) {
// discard some bytes if possible to make more room in the buffer.
cumulation.compact();
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
assert context.ctx == ctx || ctx == context;
channelInputClosed(context, true);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
ctx.fireUserEventTriggered(evt);
if (evt instanceof ChannelInputShutdownEvent) {
// The decodeLast method is invoked when a channelInactive event is encountered.
// This method is responsible for ending requests in some situations and must be called
// when the input has been shutdown.
assert context.ctx == ctx || ctx == context;
channelInputClosed(context, false);
}
}
private void channelInputClosed(ByteToMessageDecoderContext ctx, boolean callChannelInactive) {
try {
channelInputClosed(ctx);
} catch (DecoderException e) {
throw e;
} catch (Exception e) {
throw new DecoderException(e);
} finally {
if (cumulation != null) {
cumulation.close();
cumulation = null;
}
if (ctx.fireChannelReadCallCount() > 0) {
ctx.reset();
// Something was read, call fireChannelReadComplete()
ctx.fireChannelReadComplete();
}
if (callChannelInactive) {
ctx.fireChannelInactive();
}
}
}
/**
* Called when the input of the channel was closed which may be because it changed to inactive or because of
* {@link ChannelInputShutdownEvent}.
*/
void channelInputClosed(ByteToMessageDecoderContext ctx) throws Exception {
if (cumulation != null) {
callDecode(ctx, cumulation);
// If callDecode(...) removed the handle from the pipeline we should not call decodeLast(...) as this would
// be unexpected.
if (!ctx.isRemoved()) {
// Use Unpooled.EMPTY_BUFFER if cumulation become null after calling callDecode(...).
// See https://github.com/netty/netty/issues/10802.
Buffer buffer = cumulation == null ? ctx.bufferAllocator().allocate(0) : cumulation;
decodeLast(ctx, buffer);
}
} else {
decodeLast(ctx, ctx.bufferAllocator().allocate(0));
}
}
/**
* Called once data should be decoded from the given {@link Buffer}. This method will call
* {@link #decode(ChannelHandlerContext, Buffer)} as long as decoding should take place.
*
* @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoderForBuffer} belongs to
* @param in the {@link Buffer} from which to read data
*/
void callDecode(ByteToMessageDecoderContext ctx, Buffer in) {
try {
while (in.readableBytes() > 0 && !ctx.isRemoved()) {
int oldInputLength = in.readableBytes();
int numReadCalled = ctx.fireChannelReadCallCount();
decodeRemovalReentryProtection(ctx, in);
// Check if this handler was removed before continuing the loop.
// If it was removed, it is not safe to continue to operate on the buffer.
//
// See https://github.com/netty/netty/issues/1664
if (ctx.isRemoved()) {
break;
}
if (numReadCalled == ctx.fireChannelReadCallCount()) {
if (oldInputLength == in.readableBytes()) {
break;
} else {
continue;
}
}
if (oldInputLength == in.readableBytes()) {
throw new DecoderException(
StringUtil.simpleClassName(getClass()) +
".decode() did not read anything but decoded a message.");
}
if (isSingleDecode()) {
break;
}
}
} catch (DecoderException e) {
throw e;
} catch (Exception cause) {
throw new DecoderException(cause);
}
}
/**
* Decode the from one {@link Buffer} to another. This method will be called till either the input
* {@link Buffer} has nothing to read when return from this method or till nothing was read from the input
* {@link Buffer}.
*
* @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoderForBuffer} belongs to
* @param in the {@link Buffer} from which to read data
* @throws Exception is thrown if an error occurs
*/
protected abstract void decode(ChannelHandlerContext ctx, Buffer in) throws Exception;
/**
* Decode the from one {@link Buffer} to an other. This method will be called till either the input
* {@link Buffer} has nothing to read when return from this method or till nothing was read from the input
* {@link Buffer}.
*
* @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoderForBuffer} belongs to
* @param in the {@link Buffer} from which to read data
* @throws Exception is thrown if an error occurs
*/
final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, Buffer in)
throws Exception {
decode(ctx, in);
}
/**
* Is called one last time when the {@link ChannelHandlerContext} goes in-active. Which means the
* {@link #channelInactive(ChannelHandlerContext)} was triggered.
*
* By default this will just call {@link #decode(ChannelHandlerContext, Buffer)} but sub-classes may
* override this for some special cleanup operation.
*/
protected void decodeLast(ChannelHandlerContext ctx, Buffer in) throws Exception {
if (in.readableBytes() > 0) {
// Only call decode() if there is something left in the buffer to decode.
// See https://github.com/netty/netty/issues/4386
decodeRemovalReentryProtection(ctx, in);
}
}
private static Buffer expandCumulationAndWrite(BufferAllocator alloc, Buffer oldCumulation, Buffer in) {
final int newSize = safeFindNextPositivePowerOfTwo(oldCumulation.readableBytes() + in.readableBytes());
Buffer newCumulation = oldCumulation.readOnly() ? alloc.allocate(newSize) :
oldCumulation.ensureWritable(newSize);
try {
if (newCumulation != oldCumulation) {
newCumulation.writeBytes(oldCumulation);
}
newCumulation.writeBytes(in);
return newCumulation;
} finally {
if (newCumulation != oldCumulation) {
oldCumulation.close();
}
}
}
/**
* Cumulate {@link ByteBuf}s.
*/
public interface Cumulator {
/**
* Cumulate the given {@link Buffer}s and return the {@link Buffer} that holds the cumulated bytes.
* The implementation is responsible to correctly handle the life-cycle of the given {@link Buffer}s and so
* call {@link Buffer#close()} if a {@link Buffer} is fully consumed.
*/
Buffer cumulate(BufferAllocator alloc, Buffer cumulation, Buffer in);
}
// Package private so we can also make use of it in ReplayingDecoder.
static final class ByteToMessageDecoderContext implements ChannelHandlerContext {
private final ChannelHandlerContext ctx;
private int fireChannelReadCalled;
private ByteToMessageDecoderContext(ChannelHandlerContext ctx) {
this.ctx = ctx;
}
void reset() {
fireChannelReadCalled = 0;
}
int fireChannelReadCallCount() {
return fireChannelReadCalled;
}
@Override
public Channel channel() {
return ctx.channel();
}
@Override
public EventExecutor executor() {
return ctx.executor();
}
@Override
public String name() {
return ctx.name();
}
@Override
public ChannelHandler handler() {
return ctx.handler();
}
@Override
public boolean isRemoved() {
return ctx.isRemoved();
}
@Override
public ChannelHandlerContext fireChannelRegistered() {
ctx.fireChannelRegistered();
return this;
}
@Override
public ChannelHandlerContext fireChannelUnregistered() {
ctx.fireChannelUnregistered();
return this;
}
@Override
public ChannelHandlerContext fireChannelActive() {
ctx.fireChannelActive();
return this;
}
@Override
public ChannelHandlerContext fireChannelInactive() {
ctx.fireChannelInactive();
return this;
}
@Override
public ChannelHandlerContext fireExceptionCaught(Throwable cause) {
ctx.fireExceptionCaught(cause);
return this;
}
@Override
public ChannelHandlerContext fireUserEventTriggered(Object evt) {
ctx.fireUserEventTriggered(evt);
return this;
}
@Override
public ChannelHandlerContext fireChannelRead(Object msg) {
fireChannelReadCalled ++;
ctx.fireChannelRead(msg);
return this;
}
@Override
public ChannelHandlerContext fireChannelReadComplete() {
ctx.fireChannelReadComplete();
return this;
}
@Override
public ChannelHandlerContext fireChannelWritabilityChanged() {
ctx.fireChannelWritabilityChanged();
return this;
}
@Override
public Future<Void> register() {
return ctx.register();
}
@Override
public ChannelHandlerContext read() {
ctx.read();
return this;
}
@Override
public ChannelHandlerContext flush() {
ctx.flush();
return this;
}
@Override
public ChannelPipeline pipeline() {
return ctx.pipeline();
}
@Override
public ByteBufAllocator alloc() {
return ctx.alloc();
}
@Override
public BufferAllocator bufferAllocator() {
return ctx.bufferAllocator();
}
@Override
@Deprecated
public <T> Attribute<T> attr(AttributeKey<T> key) {
return ctx.attr(key);
}
@Override
@Deprecated
public <T> boolean hasAttr(AttributeKey<T> key) {
return ctx.hasAttr(key);
}
@Override
public Future<Void> bind(SocketAddress localAddress) {
return ctx.bind(localAddress);
}
@Override
public Future<Void> connect(SocketAddress remoteAddress) {
return ctx.connect(remoteAddress);
}
@Override
public Future<Void> connect(SocketAddress remoteAddress, SocketAddress localAddress) {
return ctx.connect(remoteAddress, localAddress);
}
@Override
public Future<Void> disconnect() {
return ctx.disconnect();
}
@Override
public Future<Void> close() {
return ctx.close();
}
@Override
public Future<Void> deregister() {
return ctx.deregister();
}
@Override
public Future<Void> write(Object msg) {
return ctx.write(msg);
}
@Override
public Future<Void> writeAndFlush(Object msg) {
return ctx.writeAndFlush(msg);
}
@Override
public Promise<Void> newPromise() {
return ctx.newPromise();
}
@Override
public Future<Void> newSucceededFuture() {
return ctx.newSucceededFuture();
}
@Override
public Future<Void> newFailedFuture(Throwable cause) {
return ctx.newFailedFuture(cause);
}
}
private static final class CompositeBufferCumulator implements Cumulator {
@Override
public Buffer cumulate(BufferAllocator alloc, Buffer cumulation, Buffer in) {
if (cumulation.readableBytes() == 0) {
cumulation.close();
return in;
}
CompositeBuffer composite;
try (in) {
if (CompositeBuffer.isComposite(cumulation)) {
CompositeBuffer tmp = (CompositeBuffer) cumulation;
// Since we are extending the composite buffer below, we have to make sure there is no space to
// write in the existing cumulation.
if (tmp.writerOffset() < tmp.capacity()) {
composite = tmp.split();
tmp.close();
} else {
composite = tmp;
}
} else {
composite = CompositeBuffer.compose(alloc, cumulation.send());
}
composite.extendWith((in.readOnly() ? in.copy() : in).send());
return composite;
}
}
@Override
public String toString() {
return "CompositeBufferCumulator";
}
}
private static final class MergeCumulator implements Cumulator {
@Override
public Buffer cumulate(BufferAllocator alloc, Buffer cumulation, Buffer in) {
if (cumulation.readableBytes() == 0) {
// If cumulation is empty and input buffer is contiguous, use it directly
cumulation.close();
return in;
}
// We must close input Buffer in all cases as otherwise it may produce a leak if writeBytes(...) throw
// for whatever close (for example because of OutOfMemoryError)
try (in) {
final int required = in.readableBytes();
if (required > cumulation.writableBytes() || cumulation.readOnly()) {
return expandCumulationAndWrite(alloc, cumulation, in);
}
cumulation.writeBytes(in);
return cumulation;
}
}
@Override
public String toString() {
return "MergeCumulator";
}
}
}

View File

@ -15,13 +15,13 @@
*/
package io.netty.handler.codec;
import static io.netty.util.internal.ObjectUtil.checkPositive;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.api.Buffer;
import io.netty.channel.ChannelHandlerContext;
import static io.netty.util.internal.ObjectUtil.checkPositive;
/**
* A decoder that splits the received {@link ByteBuf}s by the fixed number
* A decoder that splits the received {@link Buffer}s by the fixed number
* of bytes. For example, if you received the following four fragmented packets:
* <pre>
* +---+----+------+----+
@ -36,7 +36,7 @@ import io.netty.channel.ChannelHandlerContext;
* +-----+-----+-----+
* </pre>
*/
public class FixedLengthFrameDecoder extends ByteToMessageDecoder {
public class FixedLengthFrameDecoder extends ByteToMessageDecoderForBuffer {
private final int frameLength;
@ -50,8 +50,19 @@ public class FixedLengthFrameDecoder extends ByteToMessageDecoder {
this.frameLength = frameLength;
}
/**
* Creates a new instance.
*
* @param frameLength the length of the frame
*/
public FixedLengthFrameDecoder(int frameLength, Cumulator cumulator) {
super(cumulator);
checkPositive(frameLength, "frameLength");
this.frameLength = frameLength;
}
@Override
protected final void decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
protected final void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
Object decoded = decode0(ctx, in);
if (decoded != null) {
ctx.fireChannelRead(decoded);
@ -59,19 +70,19 @@ public class FixedLengthFrameDecoder extends ByteToMessageDecoder {
}
/**
* Create a frame out of the {@link ByteBuf} and return it.
* Create a frame out of the {@link Buffer} and return it.
*
* @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to
* @param in the {@link ByteBuf} from which to read data
* @return frame the {@link ByteBuf} which represent the frame or {@code null} if no frame could
* @param in the {@link Buffer} from which to read data
* @return frame the {@link Buffer} which represent the frame or {@code null} if no frame could
* be created.
*/
protected Object decode0(
@SuppressWarnings("UnusedParameters") ChannelHandlerContext ctx, ByteBuf in) throws Exception {
protected Object decode0(@SuppressWarnings("UnusedParameters") ChannelHandlerContext ctx, Buffer in)
throws Exception {
if (in.readableBytes() < frameLength) {
return null;
} else {
return in.readRetainedSlice(frameLength);
return in.split(in.readerOffset() + frameLength);
}
}
}

View File

@ -0,0 +1,653 @@
/*
* Copyright 2021 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec;
import io.netty.buffer.api.Buffer;
import io.netty.buffer.api.BufferAllocator;
import io.netty.buffer.api.BufferStub;
import io.netty.buffer.api.CompositeBuffer;
import io.netty.buffer.api.Send;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import io.netty.handler.codec.ByteToMessageDecoderForBuffer.Cumulator;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Stream;
import static io.netty.buffer.api.BufferAllocator.offHeapPooled;
import static io.netty.buffer.api.BufferAllocator.offHeapUnpooled;
import static io.netty.buffer.api.BufferAllocator.onHeapPooled;
import static io.netty.buffer.api.BufferAllocator.onHeapUnpooled;
import static io.netty.buffer.api.CompositeBuffer.compose;
import static io.netty.handler.codec.ByteToMessageDecoderForBuffer.COMPOSITE_CUMULATOR;
import static io.netty.handler.codec.ByteToMessageDecoderForBuffer.MERGE_CUMULATOR;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ByteToMessageDecoderForBufferTest {
private static final String PARAMETERIZED_NAME = "allocator = {0}, cumulator = {1}";
private BufferAllocator allocator;
public static Stream<Arguments> allocators() {
return Stream.of(
arguments(onHeapUnpooled(), MERGE_CUMULATOR),
arguments(onHeapUnpooled(), COMPOSITE_CUMULATOR),
arguments(offHeapUnpooled(), MERGE_CUMULATOR),
arguments(offHeapUnpooled(), COMPOSITE_CUMULATOR),
arguments(onHeapPooled(), MERGE_CUMULATOR),
arguments(onHeapPooled(), COMPOSITE_CUMULATOR),
arguments(offHeapPooled(), MERGE_CUMULATOR),
arguments(offHeapPooled(), COMPOSITE_CUMULATOR)
);
}
@BeforeEach
public void closeAllocator() {
if (allocator != null) {
allocator.close();
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void removeSelf(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
private boolean removed;
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
assertFalse(removed);
in.readByte();
removed = true;
ctx.pipeline().remove(this);
}
});
try (Buffer buf = newBufferWithData(allocator, 'a', 'b', 'c')) {
channel.writeInbound(buf.copy());
try (Buffer b = channel.readInbound()) {
buf.readByte();
assertContentEquals(b, buf);
}
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void removeSelfThenWriteToBuffer(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
try (Buffer buf = newBufferWithData(allocator, 4, 'a', 'b', 'c')) {
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
private boolean removed;
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
assertFalse(removed);
in.readByte();
removed = true;
ctx.pipeline().remove(this);
// This should not let it keep calling decode
buf.writeByte((byte) 'd');
}
});
channel.writeInbound(buf.copy());
try (Buffer expected = allocator.allocate(8); Buffer b = channel.readInbound()) {
expected.writeBytes(new byte[]{'b', 'c'});
assertContentEquals(expected, b);
}
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void internalBufferClearPostReadFully(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
Buffer buf = newBufferWithData(allocator, 'a');
EmbeddedChannel channel = newInternalBufferTestChannel(cumulator, Buffer::readByte);
assertFalse(channel.writeInbound(buf));
assertFalse(channel.finish());
assertFalse(buf.isAccessible());
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void internalBufferClearReadPartly(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
Buffer buf = newBufferWithData(allocator, 'a', 'b');
EmbeddedChannel channel = newInternalBufferTestChannel(cumulator, Buffer::readByte);
assertTrue(channel.writeInbound(buf));
try (Buffer expected = newBufferWithData(allocator, 'b'); Buffer b = channel.readInbound()) {
assertContentEquals(b, expected);
assertNull(channel.readInbound());
assertFalse(channel.finish());
}
assertFalse(buf.isAccessible());
}
private static EmbeddedChannel newInternalBufferTestChannel(
Cumulator cumulator, Consumer<Buffer> readBeforeRemove) {
return new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
Buffer buffer = internalBuffer();
assertNotNull(buffer);
assertTrue(buffer.isAccessible());
readBeforeRemove.accept(in);
// Removal from pipeline should clear internal buffer
ctx.pipeline().remove(this);
}
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) {
assertCumulationReleased(internalBuffer());
}
});
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
ctx.pipeline().remove(this);
assertTrue(in.isAccessible());
}
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) {
assertCumulationReleased(internalBuffer());
}
});
Buffer buf = newBufferWithRandomBytes(allocator);
assertTrue(channel.writeInbound(buf));
assertTrue(channel.finishAndReleaseAll());
assertFalse(buf.isAccessible());
}
private static void assertCumulationReleased(Buffer buffer) {
assertTrue(buffer == null || !buffer.isAccessible(), "unexpected value: " + buffer);
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void fireChannelReadCompleteOnInactive(BufferAllocator allocator, Cumulator cumulator) throws Exception {
this.allocator = allocator;
final BlockingQueue<Integer> queue = new LinkedBlockingDeque<>();
final Buffer buf = newBufferWithData(allocator, 'a', 'b');
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
int readable = in.readableBytes();
assertTrue(readable > 0);
in.readerOffset(in.readerOffset() + readable);
}
@Override
protected void decodeLast(ChannelHandlerContext ctx, Buffer in) {
assertFalse(in.readableBytes() > 0);
ctx.fireChannelRead("data");
}
}, new ChannelHandler() {
@Override
public void channelInactive(ChannelHandlerContext ctx) {
queue.add(3);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
queue.add(1);
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
if (!ctx.channel().isActive()) {
queue.add(2);
}
}
});
assertFalse(channel.writeInbound(buf));
assertFalse(channel.finish());
assertEquals(1, (int) queue.take());
assertEquals(2, (int) queue.take());
assertEquals(3, (int) queue.take());
assertTrue(queue.isEmpty());
assertFalse(buf.isAccessible());
}
// See https://github.com/netty/netty/issues/4635
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void removeWhileInCallDecode(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
final Object upgradeMessage = new Object();
final ByteToMessageDecoderForBuffer decoder = new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
assertEquals('a', in.readByte());
ctx.fireChannelRead(upgradeMessage);
}
};
EmbeddedChannel channel = new EmbeddedChannel(decoder, new ChannelHandler() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg == upgradeMessage) {
ctx.pipeline().remove(decoder);
return;
}
ctx.fireChannelRead(msg);
}
});
try (Buffer buf = newBufferWithData(allocator, 'a', 'b', 'c')) {
assertTrue(channel.writeInbound(buf.copy()));
try (Buffer b = channel.readInbound()) {
buf.readByte();
assertContentEquals(b, buf);
assertFalse(channel.finish());
}
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void decodeLastEmptyBuffer(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
int readable = in.readableBytes();
assertTrue(readable > 0);
ctx.fireChannelRead(transferBytes(ctx.bufferAllocator(), in, readable));
}
});
try (Buffer buf = newBufferWithRandomBytes(allocator)) {
assertTrue(channel.writeInbound(buf.copy()));
try (Buffer b = channel.readInbound()) {
assertContentEquals(b, buf);
}
assertNull(channel.readInbound());
assertFalse(channel.finish());
assertNull(channel.readInbound());
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void decodeLastNonEmptyBuffer(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
private boolean decodeLast;
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
int readable = in.readableBytes();
assertTrue(readable > 0);
if (!decodeLast && readable == 1) {
return;
}
final int length = decodeLast ? readable : readable - 1;
ctx.fireChannelRead(transferBytes(ctx.bufferAllocator(), in, length));
}
@Override
protected void decodeLast(ChannelHandlerContext ctx, Buffer in) throws Exception {
assertFalse(decodeLast);
decodeLast = true;
super.decodeLast(ctx, in);
}
});
try (Buffer buf = newBufferWithRandomBytes(allocator)) {
assertTrue(channel.writeInbound(buf.copy()));
try (Buffer b = channel.readInbound()) {
assertContentEquals(b, buf.copy(0, buf.readableBytes() - 1));
}
assertNull(channel.readInbound());
assertTrue(channel.finish());
try (Buffer b1 = channel.readInbound()) {
assertContentEquals(b1, buf.copy(buf.readableBytes() - 1, 1));
}
assertNull(channel.readInbound());
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void readOnlyBuffer(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) { }
});
assertFalse(channel.writeInbound(newBufferWithData(allocator, 8, 'a').makeReadOnly()));
assertFalse(channel.writeInbound(newBufferWithData(allocator, 'b')));
assertFalse(channel.writeInbound(newBufferWithData(allocator, 'c').makeReadOnly()));
assertFalse(channel.finish());
}
static class WriteFailingBuffer extends BufferStub {
private final Error error = new Error();
private int untilFailure;
WriteFailingBuffer(BufferAllocator allocator, int untilFailure, int capacity) {
super(allocator.allocate(capacity));
this.untilFailure = untilFailure;
}
@Override
public Buffer writeBytes(Buffer source) {
if (--untilFailure <= 0) {
throw error;
}
return super.writeBytes(source);
}
Error writeError() {
return error;
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void releaseWhenMergeCumulateThrows(BufferAllocator allocator) {
this.allocator = allocator;
try (WriteFailingBuffer oldCumulation = new WriteFailingBuffer(allocator, 1, 64)) {
oldCumulation.writeByte((byte) 0);
Buffer in = newBufferWithRandomBytes(allocator, 12);
final Error err = assertThrows(Error.class, () -> MERGE_CUMULATOR.cumulate(allocator, oldCumulation, in));
assertSame(oldCumulation.writeError(), err);
assertFalse(in.isAccessible());
assertTrue(oldCumulation.isAccessible());
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void releaseWhenMergeCumulateThrowsInExpand(BufferAllocator allocator) {
this.allocator = allocator;
final WriteFailingBuffer cumulation = new WriteFailingBuffer(allocator, 1, 16) {
@Override
public int readableBytes() {
return 1;
}
};
Buffer in = newBufferWithRandomBytes(allocator, 12);
Throwable thrown = null;
try {
BufferAllocator mockAlloc = mock(BufferAllocator.class);
MERGE_CUMULATOR.cumulate(mockAlloc, cumulation, in);
} catch (Throwable t) {
thrown = t;
}
assertFalse(in.isAccessible());
assertSame(cumulation.writeError(), thrown);
assertTrue(cumulation.isAccessible());
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void releaseWhenMergeCumulateThrowsInExpandAndCumulatorIsReadOnly(BufferAllocator allocator) {
this.allocator = allocator;
Buffer oldCumulation = newBufferWithData(allocator, 8, (char) 1).makeReadOnly();
final WriteFailingBuffer newCumulation = new WriteFailingBuffer(allocator, 1, 16) ;
Buffer in = newBufferWithRandomBytes(allocator, 12);
Throwable thrown = null;
try {
BufferAllocator mockAlloc = mock(BufferAllocator.class);
when(mockAlloc.allocate(anyInt())).thenReturn(newCumulation);
MERGE_CUMULATOR.cumulate(mockAlloc, oldCumulation, in);
} catch (Throwable t) {
thrown = t;
}
assertFalse(in.isAccessible());
assertSame(newCumulation.writeError(), thrown);
assertFalse(oldCumulation.isAccessible());
assertTrue(newCumulation.isAccessible());
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void releaseWhenCompositeCumulateThrows(BufferAllocator allocator) {
this.allocator = allocator;
final Error error = new Error();
try (CompositeBuffer cumulation = compose(allocator, newBufferWithRandomBytes(allocator).send())) {
Buffer in = new BufferStub(newBufferWithRandomBytes(allocator, 12)) {
@Override
public Send<Buffer> send() {
return new Send<>() {
@Override
public Buffer receive() {
throw error;
}
@Override
public void close() {
}
@Override
public boolean referentIsInstanceOf(Class<?> cls) {
return Buffer.class.isAssignableFrom(cls);
}
};
}
};
final Error err = assertThrows(Error.class, () -> COMPOSITE_CUMULATOR.cumulate(allocator, cumulation, in));
assertSame(error, err);
assertFalse(in.isAccessible());
}
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void doesNotOverRead(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
class ReadInterceptingHandler implements ChannelHandler {
private int readsTriggered;
@Override
public void read(ChannelHandlerContext ctx) {
readsTriggered++;
ctx.read();
}
}
ReadInterceptingHandler interceptor = new ReadInterceptingHandler();
EmbeddedChannel channel = new EmbeddedChannel();
channel.config().setAutoRead(false);
channel.pipeline().addLast(interceptor, new FixedLengthFrameDecoder(3, cumulator));
assertEquals(0, interceptor.readsTriggered);
// 0 complete frames, 1 partial frame: SHOULD trigger a read
channel.writeInbound(newBufferWithData(allocator, new byte[] { 0, 1 }));
assertEquals(1, interceptor.readsTriggered);
// 2 complete frames, 0 partial frames: should NOT trigger a read
channel.writeInbound(newBufferWithData(allocator, new byte[]{2}),
newBufferWithData(allocator, new byte[]{3, 4, 5}));
assertEquals(1, interceptor.readsTriggered);
// 1 complete frame, 1 partial frame: should NOT trigger a read
channel.writeInbound(newBufferWithData(allocator, new byte[] { 6, 7, 8 }),
newBufferWithData(allocator, new byte[] { 9 }));
assertEquals(1, interceptor.readsTriggered);
// 1 complete frame, 1 partial frame: should NOT trigger a read
channel.writeInbound(newBufferWithData(allocator, new byte[] { 10, 11 }),
newBufferWithData(allocator, new byte[] { 12 }));
assertEquals(1, interceptor.readsTriggered);
// 0 complete frames, 1 partial frame: SHOULD trigger a read
channel.writeInbound(newBufferWithData(allocator, new byte[] { 13 }));
assertEquals(2, interceptor.readsTriggered);
// 1 complete frame, 0 partial frames: should NOT trigger a read
channel.writeInbound(newBufferWithData(allocator, new byte[] { 14 }));
assertEquals(2, interceptor.readsTriggered);
for (int i = 0; i < 5; i++) {
try (Buffer read = channel.readInbound()) {
assertEquals(i * 3, read.getByte(0));
assertEquals(i * 3 + 1, read.getByte(1));
assertEquals(i * 3 + 2, read.getByte(2));
}
}
assertFalse(channel.finish());
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void testDisorder(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
ByteToMessageDecoderForBuffer decoder = new ByteToMessageDecoderForBuffer(cumulator) {
int count;
//read 4 byte then remove this decoder
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
ctx.fireChannelRead(in.readByte());
if (++count >= 4) {
ctx.pipeline().remove(this);
}
}
};
EmbeddedChannel channel = new EmbeddedChannel(decoder);
assertTrue(channel.writeInbound(newBufferWithData(allocator, new byte[]{1, 2, 3, 4, 5})));
assertEquals((byte) 1, (Byte) channel.readInbound());
assertEquals((byte) 2, (Byte) channel.readInbound());
assertEquals((byte) 3, (Byte) channel.readInbound());
assertEquals((byte) 4, (Byte) channel.readInbound());
Buffer buffer5 = channel.readInbound();
assertNotNull(buffer5);
assertEquals((byte) 5, buffer5.readByte());
assertFalse(buffer5.readableBytes() > 0);
assertTrue(buffer5.isAccessible());
assertFalse(channel.finish());
}
@ParameterizedTest(name = PARAMETERIZED_NAME)
@MethodSource("allocators")
public void testDecodeLast(BufferAllocator allocator, Cumulator cumulator) {
this.allocator = allocator;
final AtomicBoolean removeHandler = new AtomicBoolean();
EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoderForBuffer(cumulator) {
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) {
if (removeHandler.get()) {
ctx.pipeline().remove(this);
}
}
});
try (Buffer buf = newBufferWithRandomBytes(allocator)) {
assertFalse(channel.writeInbound(buf.copy()));
assertNull(channel.readInbound());
removeHandler.set(true);
// This should trigger channelInputClosed(...)
channel.pipeline().fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE);
assertTrue(channel.finish());
try (Buffer b = channel.readInbound()) {
assertContentEquals(buf, b);
}
assertNull(channel.readInbound());
}
}
private static Buffer newBufferWithRandomBytes(BufferAllocator allocator) {
return newBufferWithRandomBytes(allocator, 1024);
}
private static Buffer newBufferWithRandomBytes(BufferAllocator allocator, int length) {
final Buffer buf = allocator.allocate(length);
byte[] bytes = new byte[length];
ThreadLocalRandom.current().nextBytes(bytes);
buf.writeBytes(bytes);
return buf;
}
private static Buffer newBufferWithData(BufferAllocator allocator, int capacity, char... data) {
final Buffer buf = allocator.allocate(capacity);
for (char datum : data) {
buf.writeByte((byte) datum);
}
return buf;
}
private static Buffer newBufferWithData(BufferAllocator allocator, char... data) {
return newBufferWithData(allocator, data.length, data);
}
private static Buffer newBufferWithData(BufferAllocator allocator, byte... data) {
return allocator.allocate(data.length).writeBytes(data);
}
private static void assertContentEquals(Buffer actual, Buffer expected) {
assertArrayEquals(readByteArray(expected), readByteArray(actual));
}
private static byte[] readByteArray(Buffer buf) {
byte[] bs = new byte[buf.readableBytes()];
buf.copyInto(buf.readerOffset(), bs, 0, bs.length);
buf.readerOffset(buf.writerOffset());
return bs;
}
private static Buffer transferBytes(BufferAllocator allocator, Buffer src, int length) {
final Buffer msg = allocator.allocate(length);
src.copyInto(src.readerOffset(), msg, 0, length);
msg.writerOffset(length);
src.readerOffset(length);
return msg;
}
}

View File

@ -25,8 +25,8 @@ import io.netty.buffer.UnpooledHeapByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.concurrent.BlockingQueue;
@ -432,6 +432,7 @@ public class ByteToMessageDecoderTest {
}
}
@Disabled("FixedLengthFrameDecoder is migrated to use Buffer")
@Test
public void testDoesNotOverRead() {
class ReadInterceptingHandler implements ChannelHandler {

View File

@ -25,6 +25,7 @@ import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.FixedLengthFrameDecoder;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
@ -43,11 +44,13 @@ public class SocketFixedLengthEchoTest extends AbstractSocketTest {
random.nextBytes(data);
}
@Disabled("FixedLengthFrameDecoder is migrated to use Buffer")
@Test
public void testFixedLengthEcho(TestInfo testInfo) throws Throwable {
run(testInfo, this::testFixedLengthEcho);
}
@Disabled("FixedLengthFrameDecoder is migrated to use Buffer")
@Test
public void testFixedLengthEchoNotAutoRead(TestInfo testInfo) throws Throwable {
run(testInfo, this::testFixedLengthEchoNotAutoRead);