diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java index b4854cbb9e..9530f73fca 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java @@ -153,7 +153,11 @@ public final class MqttDecoder extends ReplayingDecoder { } /** - * Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length. + * Decodes the fixed header. It's one byte for the flags and then variable + * bytes for the remaining length. + * + * @see + * https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/errata01/os/mqtt-v3.1.1-errata01-os-complete.html#_Toc442180841 * * @param buffer the buffer to decode from * @return the fixed header @@ -166,6 +170,59 @@ public final class MqttDecoder extends ReplayingDecoder { int qosLevel = (b1 & 0x06) >> 1; boolean retain = (b1 & 0x01) != 0; + switch (messageType) { + case PUBLISH: + if (qosLevel == 3) { + throw new DecoderException("Illegal QOS Level in fixed header of PUBLISH message (" + + qosLevel + ')'); + } + break; + + case PUBREL: + case SUBSCRIBE: + case UNSUBSCRIBE: + if (dupFlag) { + throw new DecoderException("Illegal BIT 3 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + if (qosLevel != 1) { + throw new DecoderException("Illegal QOS Level in fixed header of " + messageType.toString() + + " message, must be 1, found " + qosLevel); + } + if (retain) { + throw new DecoderException("Illegal BIT 0 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + break; + + case AUTH: + case CONNACK: + case CONNECT: + case DISCONNECT: + case PINGREQ: + case PINGRESP: + case PUBACK: + case PUBCOMP: + case PUBREC: + case SUBACK: + case UNSUBACK: + if (dupFlag) { + throw new DecoderException("Illegal BIT 3 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + if (qosLevel != 0) { + throw new DecoderException("Illegal BIT 2 or 1 in fixed header of " + messageType.toString() + + " message, must be 0, found " + qosLevel); + } + if (retain) { + throw new DecoderException("Illegal BIT 0 in fixed header of " + messageType.toString() + + " message, must be 0, found 1"); + } + break; + default: + throw new DecoderException("Unknown message type, do not know how to validate fixed header"); + } + int remainingLength = 0; int multiplier = 1; short digit; diff --git a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java index b820e7f895..17fbb8bf41 100644 --- a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java +++ b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java @@ -159,6 +159,84 @@ public class MqttCodecTest { assertEquals("non-zero reserved flag", cause.getMessage()); } + @Test + public void testConnectMessageNonZeroReservedBit0Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 1)); // set bit 0 to 1 + ArgumentCaptor captor = ArgumentCaptor.forClass(MqttMessage.class); + mqttDecoder.channelRead(ctx, byteBuf); + verify(ctx).fireChannelRead(captor.capture()); + checkForSingleDecoderException(captor); + } + + @Test + public void testConnectMessageNonZeroReservedBit1Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 2)); // set bit 1 to 1 + ArgumentCaptor captor = ArgumentCaptor.forClass(MqttMessage.class); + mqttDecoder.channelRead(ctx, byteBuf); + verify(ctx).fireChannelRead(captor.capture()); + checkForSingleDecoderException(captor); + } + + @Test + public void testConnectMessageNonZeroReservedBit2Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 4)); // set bit 2 to 1 + ArgumentCaptor captor = ArgumentCaptor.forClass(MqttMessage.class); + mqttDecoder.channelRead(ctx, byteBuf); + verify(ctx).fireChannelRead(captor.capture()); + checkForSingleDecoderException(captor); + } + + @Test + public void testConnectMessageNonZeroReservedBit3Mqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 8)); // set bit 3 to 1 + ArgumentCaptor captor = ArgumentCaptor.forClass(MqttMessage.class); + mqttDecoder.channelRead(ctx, byteBuf); + verify(ctx).fireChannelRead(captor.capture()); + checkForSingleDecoderException(captor); + } + + @Test + public void testSubscribeMessageNonZeroReservedBit0Mqtt311() throws Exception { + final MqttSubscribeMessage message = createSubscribeMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte | 1)); // set bit 1 to 0 + ArgumentCaptor captor = ArgumentCaptor.forClass(MqttMessage.class); + mqttDecoder.channelRead(ctx, byteBuf); + verify(ctx).fireChannelRead(captor.capture()); + checkForSingleDecoderException(captor); + } + + @Test + public void testSubscribeMessageZeroReservedBit1Mqtt311() throws Exception { + final MqttSubscribeMessage message = createSubscribeMessage(); + ByteBuf byteBuf = MqttEncoder.doEncode(ctx, message); + byte firstByte = byteBuf.getByte(0); + byteBuf.setByte(0, (byte) (firstByte & ~2)); // set bit 1 to 0 + ArgumentCaptor captor = ArgumentCaptor.forClass(MqttMessage.class); + mqttDecoder.channelRead(ctx, byteBuf); + verify(ctx).fireChannelRead(captor.capture()); + checkForSingleDecoderException(captor); + } + + private void checkForSingleDecoderException(ArgumentCaptor captor) { + final MqttMessage result = captor.getValue(); + assertTrue("Decoding should have resulted in a DecoderException", + result.decoderResult().cause() instanceof DecoderException); + } + @Test public void testConnectMessageNoPassword() throws Exception { final MqttConnectMessage message = createConnectMessage( @@ -877,7 +955,7 @@ public class MqttCodecTest { private static MqttSubscribeMessage createSubscribeMessage() { MqttFixedHeader mqttFixedHeader = - new MqttFixedHeader(MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, true, 0); + new MqttFixedHeader(MqttMessageType.SUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0); MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345); List topicSubscriptions = new LinkedList<>(); @@ -903,7 +981,7 @@ public class MqttCodecTest { private static MqttUnsubscribeMessage createUnsubscribeMessage() { MqttFixedHeader mqttFixedHeader = - new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, true, 0); + new MqttFixedHeader(MqttMessageType.UNSUBSCRIBE, false, MqttQoS.AT_LEAST_ONCE, false, 0); MqttMessageIdVariableHeader mqttMessageIdVariableHeader = MqttMessageIdVariableHeader.from(12345); List topics = new LinkedList<>();