The MqttDecoder incorrectly skip bytes before throwing TooLongFrameException (#11362)

Motivation:

Commit c32c520edd incorrectly skip the bytes of the replay decoder buffer. The number of bytes to skip is determined by ByteBuf#readableBytes() instead of using ByteToMessageDecoder#actualReadableBytes(). As result it throws an exception because the ByteBuf provided will return a too large value (Integer.MAX_VALUE - reader index) causing a bound check error in the skipBytes method. This is not detected by the tests because most tests are calling the decode(...) method with a regular ByteBuf. In practice when this method is called with a specialized ByteBuf when channelRead(...) is called. Such tests should actually use channelRead with proper mocking of the ChannelHandlerContext

Modification:

- Rewrite the MqttCodecTest to use channelRead(...) instead of decode(...) and use proper mocking of ChannelHandlerContext to get the message emitted by the decoder.
- Use actualReadableBytes() instead of buff.readableBytes() to compute the number of bytes to skip

Result:

Skip correctly the number of bytes when a too large message is found and improve testing. See #11361

Signed-off-by: Julien Viet <julien@julienviet.com>
This commit is contained in:
Julien Viet 2021-06-10 15:05:25 +02:00 committed by Norman Maurer
parent 4aef1cc77c
commit 7b39415543
2 changed files with 130 additions and 147 deletions

View File

@ -98,7 +98,7 @@ public final class MqttDecoder extends ReplayingDecoder<DecoderState> {
final Result<?> decodedVariableHeader = decodeVariableHeader(ctx, buffer, mqttFixedHeader); final Result<?> decodedVariableHeader = decodeVariableHeader(ctx, buffer, mqttFixedHeader);
variableHeader = decodedVariableHeader.value; variableHeader = decodedVariableHeader.value;
if (bytesRemainingInVariablePart > maxBytesInMessage) { if (bytesRemainingInVariablePart > maxBytesInMessage) {
buffer.skipBytes(buffer.readableBytes()); buffer.skipBytes(actualReadableBytes());
throw new TooLongFrameException("too large message: " + bytesRemainingInVariablePart + " bytes"); throw new TooLongFrameException("too large message: " + bytesRemainingInVariablePart + " bytes");
} }
bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed; bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed;

View File

@ -25,12 +25,17 @@ import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.EncoderException;
import io.netty.util.Attribute; import io.netty.util.Attribute;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil;
import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -70,6 +75,8 @@ public class MqttCodecTest {
@Mock @Mock
private final Attribute<MqttVersion> versionAttrMock = mock(Attribute.class); private final Attribute<MqttVersion> versionAttrMock = mock(Attribute.class);
private final List<Object> out = new ArrayList<Object>();
private final MqttDecoder mqttDecoder = new MqttDecoder(); private final MqttDecoder mqttDecoder = new MqttDecoder();
/** /**
@ -78,11 +85,28 @@ public class MqttCodecTest {
private final MqttDecoder mqttDecoderLimitedMessageSize = new MqttDecoder(1); private final MqttDecoder mqttDecoderLimitedMessageSize = new MqttDecoder(1);
@Before @Before
public void setup() { public void setup() throws Exception {
MockitoAnnotations.initMocks(this); MockitoAnnotations.initMocks(this);
when(ctx.channel()).thenReturn(channel); when(ctx.channel()).thenReturn(channel);
when(ctx.alloc()).thenReturn(ALLOCATOR); when(ctx.alloc()).thenReturn(ALLOCATOR);
when(ctx.fireChannelRead(any())).then(new Answer<ChannelHandlerContext>() {
@Override
public ChannelHandlerContext answer(InvocationOnMock invocation) {
out.add(invocation.getArguments()[0]);
return ctx;
}
});
when(channel.attr(MqttCodecUtil.MQTT_VERSION_KEY)).thenReturn(versionAttrMock); when(channel.attr(MqttCodecUtil.MQTT_VERSION_KEY)).thenReturn(versionAttrMock);
mqttDecoder.handlerAdded(ctx);
mqttDecoderLimitedMessageSize.handlerAdded(ctx);
}
@After
public void after() {
for (Object o : out) {
ReferenceCountUtil.release(o);
}
out.clear();
} }
@Test @Test
@ -91,7 +115,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class); ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttConnectMessage decodedMessage = captor.getValue(); final MqttConnectMessage decodedMessage = captor.getValue();
@ -107,7 +131,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class); ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttConnectMessage decodedMessage = captor.getValue(); final MqttConnectMessage decodedMessage = captor.getValue();
@ -121,22 +145,18 @@ public class MqttCodecTest {
public void testConnectMessageWithNonZeroReservedFlagForMqtt311() throws Exception { public void testConnectMessageWithNonZeroReservedFlagForMqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try { // Set the reserved flag in the CONNECT Packet to 1
// Set the reserved flag in the CONNECT Packet to 1 byteBuf.setByte(9, byteBuf.getByte(9) | 0x1);
byteBuf.setByte(9, byteBuf.getByte(9) | 0x1);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
assertTrue(decodedMessage.decoderResult().isFailure()); assertTrue(decodedMessage.decoderResult().isFailure());
Throwable cause = decodedMessage.decoderResult().cause(); Throwable cause = decodedMessage.decoderResult().cause();
assertTrue(cause instanceof DecoderException); assertTrue(cause instanceof DecoderException);
assertEquals("non-zero reserved flag", cause.getMessage()); assertEquals("non-zero reserved flag", cause.getMessage());
} finally {
byteBuf.release();
}
} }
@Test @Test
@ -161,7 +181,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnAckMessage> captor = ArgumentCaptor.forClass(MqttConnAckMessage.class); ArgumentCaptor<MqttConnAckMessage> captor = ArgumentCaptor.forClass(MqttConnAckMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttConnAckMessage decodedMessage = captor.getValue(); final MqttConnAckMessage decodedMessage = captor.getValue();
@ -175,7 +195,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttPublishMessage> captor = ArgumentCaptor.forClass(MqttPublishMessage.class); ArgumentCaptor<MqttPublishMessage> captor = ArgumentCaptor.forClass(MqttPublishMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttPublishMessage decodedMessage = captor.getValue(); final MqttPublishMessage decodedMessage = captor.getValue();
@ -210,7 +230,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class); ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttSubscribeMessage decodedMessage = captor.getValue(); final MqttSubscribeMessage decodedMessage = captor.getValue();
@ -225,7 +245,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class); ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttSubAckMessage decodedMessage = captor.getValue(); final MqttSubAckMessage decodedMessage = captor.getValue();
@ -246,7 +266,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class); ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
MqttSubAckMessage decodedMessage = captor.getValue(); MqttSubAckMessage decodedMessage = captor.getValue();
@ -263,7 +283,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttUnsubscribeMessage> captor = ArgumentCaptor.forClass(MqttUnsubscribeMessage.class); ArgumentCaptor<MqttUnsubscribeMessage> captor = ArgumentCaptor.forClass(MqttUnsubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttUnsubscribeMessage decodedMessage = captor.getValue(); final MqttUnsubscribeMessage decodedMessage = captor.getValue();
@ -298,44 +318,35 @@ public class MqttCodecTest {
final MqttMessage message = createMessageWithFixedHeader(MqttMessageType.PINGREQ); final MqttMessage message = createMessageWithFixedHeader(MqttMessageType.PINGREQ);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try { // setting an invalid message type (15, reserved and forbidden by MQTT 3.1.1 spec)
// setting an invalid message type (15, reserved and forbidden by MQTT 3.1.1 spec) byteBuf.setByte(0, 0xF0);
byteBuf.setByte(0, 0xF0);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
assertTrue(decodedMessage.decoderResult().isFailure()); assertTrue(decodedMessage.decoderResult().isFailure());
Throwable cause = decodedMessage.decoderResult().cause(); Throwable cause = decodedMessage.decoderResult().cause();
assertTrue(cause instanceof DecoderException); assertTrue(cause instanceof DecoderException);
assertEquals("AUTH message requires at least MQTT 5", cause.getMessage()); assertEquals("AUTH message requires at least MQTT 5", cause.getMessage());
} finally {
byteBuf.release();
}
} }
@Test @Test
public void testConnectMessageForMqtt31TooLarge() throws Exception { public void testConnectMessageForMqtt31TooLarge() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1); final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
try { final MqttMessage decodedMessage = captor.getValue();
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); assertEquals(0, byteBuf.readableBytes());
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
assertEquals(0, byteBuf.readableBytes()); validateConnectVariableHeader(message.variableHeader(),
(MqttConnectVariableHeader) decodedMessage.variableHeader());
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); validateDecoderExceptionTooLargeMessage(decodedMessage);
validateConnectVariableHeader(message.variableHeader(),
(MqttConnectVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
@ -343,63 +354,49 @@ public class MqttCodecTest {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try { ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf); verify(ctx).fireChannelRead(captor.capture());
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
assertEquals(0, byteBuf.readableBytes()); assertEquals(0, byteBuf.readableBytes());
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
validateConnectVariableHeader(message.variableHeader(), validateConnectVariableHeader(message.variableHeader(),
(MqttConnectVariableHeader) decodedMessage.variableHeader()); (MqttConnectVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage); validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
public void testConnAckMessageTooLarge() throws Exception { public void testConnAckMessageTooLarge() throws Exception {
final MqttConnAckMessage message = createConnAckMessage(); final MqttConnAckMessage message = createConnAckMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
try { final MqttMessage decodedMessage = captor.getValue();
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); assertEquals(0, byteBuf.readableBytes());
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
assertEquals(0, byteBuf.readableBytes()); validateDecoderExceptionTooLargeMessage(decodedMessage);
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
public void testPublishMessageTooLarge() throws Exception { public void testPublishMessageTooLarge() throws Exception {
final MqttPublishMessage message = createPublishMessage(); final MqttPublishMessage message = createPublishMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
try { final MqttMessage decodedMessage = captor.getValue();
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); assertEquals(0, byteBuf.readableBytes());
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
assertEquals(0, byteBuf.readableBytes()); validatePublishVariableHeader(message.variableHeader(),
(MqttPublishVariableHeader) decodedMessage.variableHeader());
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); validateDecoderExceptionTooLargeMessage(decodedMessage);
validatePublishVariableHeader(message.variableHeader(),
(MqttPublishVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
@ -407,21 +404,17 @@ public class MqttCodecTest {
final MqttSubscribeMessage message = createSubscribeMessage(); final MqttSubscribeMessage message = createSubscribeMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try { ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf); verify(ctx).fireChannelRead(captor.capture());
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
assertEquals(0, byteBuf.readableBytes()); assertEquals(0, byteBuf.readableBytes());
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
validateMessageIdVariableHeader(message.variableHeader(), validateMessageIdVariableHeader(message.variableHeader(),
(MqttMessageIdVariableHeader) decodedMessage.variableHeader()); (MqttMessageIdVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage); validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
@ -429,21 +422,17 @@ public class MqttCodecTest {
final MqttSubAckMessage message = createSubAckMessage(); final MqttSubAckMessage message = createSubAckMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try { ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf); verify(ctx).fireChannelRead(captor.capture());
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
assertEquals(0, byteBuf.readableBytes()); assertEquals(0, byteBuf.readableBytes());
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
validateMessageIdVariableHeader(message.variableHeader(), validateMessageIdVariableHeader(message.variableHeader(),
(MqttMessageIdVariableHeader) decodedMessage.variableHeader()); (MqttMessageIdVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage); validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
@ -451,21 +440,17 @@ public class MqttCodecTest {
final MqttUnsubscribeMessage message = createUnsubscribeMessage(); final MqttUnsubscribeMessage message = createUnsubscribeMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
try { ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); mqttDecoderLimitedMessageSize.channelRead(ctx, byteBuf);
mqttDecoderLimitedMessageSize.decode(ctx, byteBuf); verify(ctx).fireChannelRead(captor.capture());
verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
assertEquals(0, byteBuf.readableBytes()); assertEquals(0, byteBuf.readableBytes());
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
validateMessageIdVariableHeader(message.variableHeader(), validateMessageIdVariableHeader(message.variableHeader(),
(MqttMessageIdVariableHeader) decodedMessage.variableHeader()); (MqttMessageIdVariableHeader) decodedMessage.variableHeader());
validateDecoderExceptionTooLargeMessage(decodedMessage); validateDecoderExceptionTooLargeMessage(decodedMessage);
} finally {
byteBuf.release();
}
} }
@Test @Test
@ -480,7 +465,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class); ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttConnectMessage decodedMessage = captor.getValue(); final MqttConnectMessage decodedMessage = captor.getValue();
@ -500,7 +485,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttConnAckMessage> captor = ArgumentCaptor.forClass(MqttConnAckMessage.class); ArgumentCaptor<MqttConnAckMessage> captor = ArgumentCaptor.forClass(MqttConnAckMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttConnAckMessage decodedMessage = captor.getValue(); final MqttConnAckMessage decodedMessage = captor.getValue();
@ -529,7 +514,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttPublishMessage> captor = ArgumentCaptor.forClass(MqttPublishMessage.class); ArgumentCaptor<MqttPublishMessage> captor = ArgumentCaptor.forClass(MqttPublishMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttPublishMessage decodedMessage = captor.getValue(); final MqttPublishMessage decodedMessage = captor.getValue();
@ -549,7 +534,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
@ -566,7 +551,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
@ -584,7 +569,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class); ArgumentCaptor<MqttSubAckMessage> captor = ArgumentCaptor.forClass(MqttSubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttSubAckMessage decodedMessage = captor.getValue(); final MqttSubAckMessage decodedMessage = captor.getValue();
@ -617,7 +602,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class); ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttSubscribeMessage decodedMessage = captor.getValue(); final MqttSubscribeMessage decodedMessage = captor.getValue();
@ -648,10 +633,8 @@ public class MqttCodecTest {
.build(); .build();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
final List<Object> out = new LinkedList<Object>();
ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class); ArgumentCaptor<MqttSubscribeMessage> captor = ArgumentCaptor.forClass(MqttSubscribeMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttSubscribeMessage decodedMessage = captor.getValue(); final MqttSubscribeMessage decodedMessage = captor.getValue();
@ -682,7 +665,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttUnsubAckMessage> captor = ArgumentCaptor.forClass(MqttUnsubAckMessage.class); ArgumentCaptor<MqttUnsubAckMessage> captor = ArgumentCaptor.forClass(MqttUnsubAckMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttUnsubAckMessage decodedMessage = captor.getValue(); final MqttUnsubAckMessage decodedMessage = captor.getValue();
@ -708,7 +691,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
@ -729,7 +712,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
@ -753,7 +736,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
@ -775,7 +758,7 @@ public class MqttCodecTest {
clearInvocations(versionAttrMock); clearInvocations(versionAttrMock);
ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class); ArgumentCaptor<MqttConnectMessage> captor = ArgumentCaptor.forClass(MqttConnectMessage.class);
mqttDecoder.decode(ctx, connectByteBuf); mqttDecoder.channelRead(ctx, connectByteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
verify(versionAttrMock, times(1)).set(MqttVersion.MQTT_5); verify(versionAttrMock, times(1)).set(MqttVersion.MQTT_5);
@ -792,7 +775,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();
@ -806,7 +789,7 @@ public class MqttCodecTest {
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class); ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.decode(ctx, byteBuf); mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture()); verify(ctx).fireChannelRead(captor.capture());
final MqttMessage decodedMessage = captor.getValue(); final MqttMessage decodedMessage = captor.getValue();