diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java index 4a14028314..96af599743 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java @@ -153,7 +153,11 @@ public final class MqttDecoder extends ReplayingDecoder { } /** - * Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length. + * Decodes the fixed header. It's one byte for the flags and then variable + * bytes for the remaining length. + * + * @see + * https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html#_Toc442180841 * * @param buffer the buffer to decode from * @return the fixed header @@ -166,6 +170,59 @@ public final class MqttDecoder extends ReplayingDecoder { int qosLevel = (b1 & 0x06) >> 1; boolean retain = (b1 & 0x01) != 0; + switch (messageType) { + case PUBLISH: + if (qosLevel == 3) { + throw new DecoderException("Illegal QOS Level in fixed header of PUBLISH message (" + + qosLevel + ')'); + } + break; + + case PUBREL: + case SUBSCRIBE: + case UNSUBSCRIBE: + if (dupFlag) { + throw new DecoderException("Illegal BIT 3 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + if (qosLevel != 1) { + throw new DecoderException("Illegal QOS Level in fixed header of " + messageType.toString() + + " message, must be 1, found " + qosLevel); + } + if (retain) { + throw new DecoderException("Illegal BIT 0 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + break; + + case AUTH: + case CONNACK: + case CONNECT: + case DISCONNECT: + case PINGREQ: + case PINGRESP: + case PUBACK: + case PUBCOMP: + case PUBREC: + case SUBACK: + case UNSUBACK: + if (dupFlag) { + throw new DecoderException("Illegal BIT 3 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + if (qosLevel != 0) { + throw new DecoderException("Illegal BIT 2 or 1 in fixed header of " + messageType.toString() + + " message, must be 0, found " + qosLevel); + } + if (retain) { + throw new DecoderException("Illegal BIT 0 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + break; + default: + throw new DecoderException("Unknown message type, do not know how to validate fixed header"); + } + int remainingLength = 0; int multiplier = 1; short digit; diff --git a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java index 2284f48bfa..d72e21a921 100644 --- a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java +++ b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java @@ -156,6 +156,81 @@ public class MqttCodecTest { assertEquals("non-zero reserved flag", cause.getMessage()); } + @Test + public void testConnectMessageNonZeroReservedBit0Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 1)); // set bit 0 to 1 + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + checkForSingleDecoderException(out); + } + + @Test + public void testConnectMessageNonZeroReservedBit1Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 2)); // set bit 1 to 1 + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + checkForSingleDecoderException(out); + } + + @Test + public void testConnectMessageNonZeroReservedBit2Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 4)); // set bit 2 to 1 + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + checkForSingleDecoderException(out); + } + + @Test + public void testConnectMessageNonZeroReservedBit3Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 8)); // set bit 3 to 1 + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + checkForSingleDecoderException(out); + } + + @Test + public void testSubscribeMessageNonZeroReservedBit0Mqtt311() throws Exception { + final MqttSubscribeMessage message = createSubscribeMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 1)); // set bit 1 to 0 + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + checkForSingleDecoderException(out); + } + + @Test + public void testSubscribeMessageZeroReservedBit1Mqtt311() throws Exception { + final MqttSubscribeMessage message = createSubscribeMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte & ~2)); // set bit 1 to 0 + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + checkForSingleDecoderException(out); + } + + private void checkForSingleDecoderException(final List out) { + assertEquals("Expected one object but got " + out.size(), 1, out.size()); + assertFalse("Message should not be an MqttConnectMessage", + MqttConnectMessage.class.isAssignableFrom(out.get(0).getClass())); + MqttMessage result = (MqttMessage) out.get(0); + assertTrue("Decoding should have resulted in a DecoderException", + result.decoderResult().cause() instanceof DecoderException); + } + @Test public void testConnectMessageNoPassword() throws Exception { final MqttConnectMessage message = createConnectMessage( @@ -867,7 +942,7 @@ public class MqttCodecTest { private static MqttSubscribeMessage createSubscribeMessage() { MqttFixedHeader mqttFixedHeader = - new MqttFixedHeader(MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, true, 0); + new MqttFixedHeader(MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0); MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345); List topicSubscriptions = new LinkedList(); @@ -893,7 +968,7 @@ public class MqttCodecTest { private static MqttUnsubscribeMessage createUnsubscribeMessage() { MqttFixedHeader mqttFixedHeader = - new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, true, 0); + new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0); MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345); List topics = new LinkedList();