Validate fixed header bits in MQTT (#11389)

Motivation:
The MQTT spec states that the bits in the fixed header must be set to specific values depending on message type. If a client sends a message with the wrong bits, the server must treat the message as malformed. Netty did not check the value of the reserved bits in the fixed header.

See:
MQTT3.1.1: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html#_Toc442180835
MQTT 5.0: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901023

Modification:
Add validation checks to MqttDecoder.java
Add unit tests to MqttCodecTest.java
Fixed two instances where messages were generated for other unit tests with an incorrect fixed header.

Result:
Fixes #11379.
This commit is contained in:
Hylke van der Schaaf 2021-06-16 14:59:15 +02:00 committed by Norman Maurer
parent 98e3605d4d
commit a36d5312c5
2 changed files with 138 additions and 3 deletions

View File

@ -153,7 +153,11 @@ public final class MqttDecoder extends ReplayingDecoder<DecoderState> {
}
/**
* Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length.
* Decodes the fixed header. It's one byte for the flags and then variable
* bytes for the remaining length.
*
* @see
* https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html#_Toc442180841
*
* @param buffer the buffer to decode from
* @return the fixed header
@ -166,6 +170,59 @@ public final class MqttDecoder extends ReplayingDecoder<DecoderState> {
int qosLevel = (b1 & 0x06) >> 1;
boolean retain = (b1 & 0x01) != 0;
switch (messageType) {
case PUBLISH:
if (qosLevel == 3) {
throw new DecoderException("Illegal QOS Level in fixed header of PUBLISH message ("
+ qosLevel + ')');
}
break;
case PUBREL:
case SUBSCRIBE:
case UNSUBSCRIBE:
if (dupFlag) {
throw new DecoderException("Illegal BIT 3 in fixed header of " + messageType.toString()
+ " message, must be 0, found 1");
}
if (qosLevel != 1) {
throw new DecoderException("Illegal QOS Level in fixed header of " + messageType.toString()
+ " message, must be 1, found " + qosLevel);
}
if (retain) {
throw new DecoderException("Illegal BIT 0 in fixed header of " + messageType.toString()
+ " message, must be 0, found 1");
}
break;
case AUTH:
case CONNACK:
case CONNECT:
case DISCONNECT:
case PINGREQ:
case PINGRESP:
case PUBACK:
case PUBCOMP:
case PUBREC:
case SUBACK:
case UNSUBACK:
if (dupFlag) {
throw new DecoderException("Illegal BIT 3 in fixed header of " + messageType.toString()
+ " message, must be 0, found 1");
}
if (qosLevel != 0) {
throw new DecoderException("Illegal BIT 2 or 1 in fixed header of " + messageType.toString()
+ " message, must be 0, found " + qosLevel);
}
if (retain) {
throw new DecoderException("Illegal BIT 0 in fixed header of " + messageType.toString()
+ " message, must be 0, found 1");
}
break;
default:
throw new DecoderException("Unknown message type, do not know how to validate fixed header");
}
int remainingLength = 0;
int multiplier = 1;
short digit;

View File

@ -159,6 +159,84 @@ public class MqttCodecTest {
assertEquals("non-zero reserved flag", cause.getMessage());
}
@Test
public void testConnectMessageNonZeroReservedBit0Mqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
byte firstByte = byteBuf.getByte(0);
byteBuf.setByte(0, (byte) (firstByte | 1)); // set bit 0 to 1
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
checkForSingleDecoderException(captor);
}
@Test
public void testConnectMessageNonZeroReservedBit1Mqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
byte firstByte = byteBuf.getByte(0);
byteBuf.setByte(0, (byte) (firstByte | 2)); // set bit 1 to 1
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
checkForSingleDecoderException(captor);
}
@Test
public void testConnectMessageNonZeroReservedBit2Mqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
byte firstByte = byteBuf.getByte(0);
byteBuf.setByte(0, (byte) (firstByte | 4)); // set bit 2 to 1
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
checkForSingleDecoderException(captor);
}
@Test
public void testConnectMessageNonZeroReservedBit3Mqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
byte firstByte = byteBuf.getByte(0);
byteBuf.setByte(0, (byte) (firstByte | 8)); // set bit 3 to 1
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
checkForSingleDecoderException(captor);
}
@Test
public void testSubscribeMessageNonZeroReservedBit0Mqtt311() throws Exception {
final MqttSubscribeMessage message = createSubscribeMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
byte firstByte = byteBuf.getByte(0);
byteBuf.setByte(0, (byte) (firstByte | 1)); // set bit 1 to 0
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
checkForSingleDecoderException(captor);
}
@Test
public void testSubscribeMessageZeroReservedBit1Mqtt311() throws Exception {
final MqttSubscribeMessage message = createSubscribeMessage();
ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message);
byte firstByte = byteBuf.getByte(0);
byteBuf.setByte(0, (byte) (firstByte & ~2)); // set bit 1 to 0
ArgumentCaptor<MqttMessage> captor = ArgumentCaptor.forClass(MqttMessage.class);
mqttDecoder.channelRead(ctx, byteBuf);
verify(ctx).fireChannelRead(captor.capture());
checkForSingleDecoderException(captor);
}
private void checkForSingleDecoderException(ArgumentCaptor<MqttMessage> captor) {
final MqttMessage result = captor.getValue();
assertTrue("Decoding should have resulted in a DecoderException",
result.decoderResult().cause() instanceof DecoderException);
}
@Test
public void testConnectMessageNoPassword() throws Exception {
final MqttConnectMessage message = createConnectMessage(
@ -877,7 +955,7 @@ public class MqttCodecTest {
private static MqttSubscribeMessage createSubscribeMessage() {
MqttFixedHeader mqttFixedHeader =
new MqttFixedHeader(MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, true, 0);
new MqttFixedHeader(MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0);
MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345);
List<MqttTopicSubscription> topicSubscriptions = new LinkedList<>();
@ -903,7 +981,7 @@ public class MqttCodecTest {
private static MqttUnsubscribeMessage createUnsubscribeMessage() {
MqttFixedHeader mqttFixedHeader =
new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, true, 0);
new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0);
MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345);
List<String> topics = new LinkedList<>();