The MqttDecoder incorrectly skip bytes before throwing TooLongFrameException (#11362)

Motivation:

Commit c32c520edd incorrectly skip the bytes of the replay decoder buffer. The number of bytes to skip is determined by ByteBuf#readableBytes() instead of using ByteToMessageDecoder#actualReadableBytes(). As result it throws an exception because the ByteBuf provided will return a too large value (Integer.MAX_VALUE - reader index) causing a bound check error in the skipBytes method. This is not detected by the tests because most tests are calling the decode(...) method with a regular ByteBuf. In practice when this method is called with a specialized ByteBuf when channelRead(...) is called. Such tests should actually use channelRead with proper mocking of the ChannelHandlerContext

Modification:

- Rewrite the MqttCodecTest to use channelRead(...) instead of decode(...) and use proper mocking of ChannelHandlerContext to get the message emitted by the decoder.
- Use actualReadableBytes() instead of buff.readableBytes() to compute the number of bytes to skip

Result:

Skip correctly the number of bytes when a too large message is found and improve testing. See #11361

Signed-off-by: Julien Viet <julien@julienviet.com>
This commit is contained in:
Julien Viet 2021-06-10 15:05:25 +02:00 committed by Norman Maurer
parent 4aef1cc77c
commit 7b39415543
2 changed files with 130 additions and 147 deletions

View File

@ -98,7 +98,7 @@ public final class MqttDecoder extends ReplayingDecoder<DecoderState> {
final Result<?> decodedVariableHeader = decodeVariableHeader(ctx, buffer, mqttFixedHeader);
variableHeader = decodedVariableHeader.value;
if (bytesRemainingInVariablePart > maxBytesInMessage) {
buffer.skipBytes(buffer.readableBytes());
buffer.skipBytes(actualReadableBytes());
throw new TooLongFrameException("too large message: " + bytesRemainingInVariablePart + " bytes");
}
bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed;

View File

@ -25,12 +25,17 @@ import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.util.Attribute;
import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
@ -70,6 +75,8 @@ public class MqttCodecTest {
@Mock
private final Attribute<MqttVersion> versionAttrMock = mock(Attribute.class);
private final List<Object> out = new ArrayList<Object>();
private final MqttDecoder mqttDecoder = new MqttDecoder();
/**
@ -78,11 +85,28 @@ public class MqttCodecTest {
private final MqttDecoder mqttDecoderLimitedMessageSize = new MqttDecoder(1);
@Before
public void setup() {
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
when(ctx.channel()).thenReturn(channel);
when(ctx.alloc()).thenReturn(ALLOCATOR);
when(ctx.fireChannelRead(any())).then(new Answer<ChannelHandlerContext>() {
@Override
public ChannelHandlerContext answer(InvocationOnMock invocation) {
out.add(invocation.getArguments()[0]);
return ctx;
}
});
when(channel.attr(MqttCodecUtil.MQTT_VERSION_KEY)).thenReturn(versionAttrMock);
mqttDecoder.handlerAdded(ctx);
mqttDecoderLimitedMessageSize.handlerAdded(ctx);
}
@After
public void after() {
for (Object o : out) {
ReferenceCountUtil.release(o);
}
out.clear();
}
@Test
@ -91,7 +115,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttConnectMessage decodedMessage = captor.getValue();
@ -107,7 +131,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttConnectMessage decodedMessage = captor.getValue();
@ -121,12 +145,11 @@ public class MqttCodecTest {
public void testConnectMessageWithNonZeroReservedFlagForMqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
// Set the reserved flag in the CONNECT Packet to 1
byteBuf.setByte(9, byteBuf.getByte(9) | 0x1);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -134,9 +157,6 @@ public class MqttCodecTest {
Throwable cause = decodedMessage.decoderResult().cause();
assertTrue(cause instanceof DecoderException);
assertEquals("non-zero reserved flag", cause.getMessage());
} finally {
byteBuf.release();
}
}
@Test
@ -161,7 +181,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnAckMessage> captor = ArgumentCaptor.forClass(MqttConnAckMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttConnAckMessage decodedMessage = captor.getValue();
@ -175,7 +195,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttPublishMessage> captor = ArgumentCaptor.forClass(MqttPublishMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttPublishMessage decodedMessage = captor.getValue();
@ -210,7 +230,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttSubscribeMessage decodedMessage = captor.getValue();
@ -225,7 +245,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttSubAckMessage decodedMessage = captor.getValue();
@ -246,7 +266,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
MqttSubAckMessage decodedMessage = captor.getValue();
@ -263,7 +283,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttUnsubscribeMessage> captor = ArgumentCaptor.forClass(MqttUnsubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttUnsubscribeMessage decodedMessage = captor.getValue();
@ -298,12 +318,11 @@ public class MqttCodecTest {
final MqttMessage message = createMessageWithFixedHeader(MqttMessageType.PINGREQ);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
// setting an invalid message type (15, reserved and forbidden by MQTT 3.1.1 spec)
byteBuf.setByte(0, 0xF0);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -311,19 +330,14 @@ public class MqttCodecTest {
Throwable cause = decodedMessage.decoderResult().cause();
assertTrue(cause instanceof DecoderException);
assertEquals("AUTH message requires at least MQTT 5", cause.getMessage());
} finally {
byteBuf.release();
}
}
@Test
public void testConnectMessageForMqtt31TooLarge() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -333,9 +347,6 @@ public class MqttCodecTest {
validateConnectVariableHeader(message.variableHeader(),
(MqttConnectVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
@ -343,9 +354,8 @@ public class MqttCodecTest {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -355,19 +365,14 @@ public class MqttCodecTest {
validateConnectVariableHeader(message.variableHeader(),
(MqttConnectVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
public void testConnAckMessageTooLarge() throws Exception {
final MqttConnAckMessage message = createConnAckMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -375,19 +380,14 @@ public class MqttCodecTest {
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
public void testPublishMessageTooLarge() throws Exception {
final MqttPublishMessage message = createPublishMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -397,9 +397,6 @@ public class MqttCodecTest {
validatePublishVariableHeader(message.variableHeader(),
(MqttPublishVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
@ -407,9 +404,8 @@ public class MqttCodecTest {
final MqttSubscribeMessage message = createSubscribeMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -419,9 +415,6 @@ public class MqttCodecTest {
validateMessageIdVariableHeader(message.variableHeader(),
(MqttMessageIdVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
@ -429,9 +422,8 @@ public class MqttCodecTest {
final MqttSubAckMessage message = createSubAckMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -441,9 +433,6 @@ public class MqttCodecTest {
validateMessageIdVariableHeader(message.variableHeader(),
(MqttMessageIdVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
@ -451,9 +440,8 @@ public class MqttCodecTest {
final MqttUnsubscribeMessage message = createUnsubscribeMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try {
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -463,9 +451,6 @@ public class MqttCodecTest {
validateMessageIdVariableHeader(message.variableHeader(),
(MqttMessageIdVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
}
@Test
@ -480,7 +465,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttConnectMessage decodedMessage = captor.getValue();
@ -500,7 +485,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnAckMessage> captor = ArgumentCaptor.forClass(MqttConnAckMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttConnAckMessage decodedMessage = captor.getValue();
@ -529,7 +514,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttPublishMessage> captor = ArgumentCaptor.forClass(MqttPublishMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttPublishMessage decodedMessage = captor.getValue();
@ -549,7 +534,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -566,7 +551,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -584,7 +569,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttSubAckMessage decodedMessage = captor.getValue();
@ -617,7 +602,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttSubscribeMessage decodedMessage = captor.getValue();
@ -648,10 +633,8 @@ public class MqttCodecTest {
.build();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
final List<Object> out = new LinkedList<Object>();
ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttSubscribeMessage decodedMessage = captor.getValue();
@ -682,7 +665,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttUnsubAckMessage> captor = ArgumentCaptor.forClass(MqttUnsubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttUnsubAckMessage decodedMessage = captor.getValue();
@ -708,7 +691,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -729,7 +712,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -753,7 +736,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -775,7 +758,7 @@ public class MqttCodecTest {
clearInvocations(versionAttrMock);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, connectByteBuf);
mqttDecoder.channelRead(ctx, connectByteBuf);
verify(ctx).fireChannelRead(captor.capture());
verify(versionAttrMock, times(1)).set(MqttVersion.MQTT_5);
@ -792,7 +775,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();
@ -806,7 +789,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue();