diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttCodecUtil.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttCodecUtil.java index b689e81ba4..46bc486acc 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttCodecUtil.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttCodecUtil.java @@ -38,9 +38,16 @@ final class MqttCodecUtil { return messageId != 0; } - static boolean isValidClientId(String clientId) { - return clientId != null && clientId.length() >= MIN_CLIENT_ID_LENGTH && - clientId.length() <= MAX_CLIENT_ID_LENGTH; + static boolean isValidClientId(MqttVersion mqttVersion, String clientId) { + if (mqttVersion == MqttVersion.MQTT_3_1) { + return clientId != null && clientId.length() >= MIN_CLIENT_ID_LENGTH && + clientId.length() <= MAX_CLIENT_ID_LENGTH; + } else if (mqttVersion == MqttVersion.MQTT_3_1_1) { + // In 3.1.3.1 Client Identifier of MQTT 3.1.1 specification, The Server MAY allow ClientId’s + // that contain more than 23 encoded bytes. And, The Server MAY allow zero-length ClientId. + return clientId != null; + } + throw new IllegalArgumentException(mqttVersion + " is unknown mqtt version"); } static MqttFixedHeader validateFixedHeader(MqttFixedHeader mqttFixedHeader) { diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java index 1f1a81c6c1..bcb770a354 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttDecoder.java @@ -27,7 +27,6 @@ import java.util.ArrayList; import java.util.List; import static io.netty.handler.codec.mqtt.MqttCodecUtil.*; -import static io.netty.handler.codec.mqtt.MqttVersion.*; /** * Decodes Mqtt messages from bytes, following @@ -203,15 +202,15 @@ public class MqttDecoder extends ReplayingDecoder { private static Result decodeConnectionVariableHeader(ByteBuf buffer) { final Result protoString = decodeString(buffer); - if (!PROTOCOL_NAME.equals(protoString.value)) { - throw new MqttUnacceptableProtocolVersionException("missing " + PROTOCOL_NAME + " signature"); - } - int numberOfBytesConsumed = protoString.numberOfBytesConsumed; - final byte version = buffer.readByte(); + final byte protocolLevel = buffer.readByte(); + numberOfBytesConsumed += 1; + + final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(protoString.value, protocolLevel); + final int b1 = buffer.readUnsignedByte(); - numberOfBytesConsumed += 2; + numberOfBytesConsumed += 1; final Result keepAlive = decodeMsbLsb(buffer); numberOfBytesConsumed += keepAlive.numberOfBytesConsumed; @@ -224,8 +223,8 @@ public class MqttDecoder extends ReplayingDecoder { final boolean cleanSession = (b1 & 0x02) == 0x02; final MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader( - PROTOCOL_NAME, - version, + mqttVersion.protocolName(), + mqttVersion.protocolLevel(), hasUserName, hasPassword, willRetain, @@ -321,7 +320,9 @@ public class MqttDecoder extends ReplayingDecoder { MqttConnectVariableHeader mqttConnectVariableHeader) { final Result decodedClientId = decodeString(buffer); final String decodedClientIdValue = decodedClientId.value; - if (!isValidClientId(decodedClientIdValue)) { + final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(), + (byte) mqttConnectVariableHeader.version()); + if (!isValidClientId(mqttVersion, decodedClientIdValue)) { throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue); } int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed; diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java index 817bd86b47..db2d871f9d 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttEncoder.java @@ -24,6 +24,8 @@ import io.netty.util.CharsetUtil; import java.util.List; +import static io.netty.handler.codec.mqtt.MqttCodecUtil.*; + /** * Encodes Mqtt messages into bytes following the protocl specification v3.1 * as described here MQTTV3.1 @@ -34,8 +36,6 @@ public class MqttEncoder extends MessageToMessageEncoder { private static final byte[] EMPTY = new byte[0]; - private static final byte[] CONNECT_VARIABLE_HEADER_START = {0, 6, 'M', 'Q', 'I', 's', 'd', 'p'}; - @Override protected void encode(ChannelHandlerContext ctx, MqttMessage msg, List out) throws Exception { out.add(doEncode(ctx.alloc(), msg)); @@ -91,18 +91,18 @@ public class MqttEncoder extends MessageToMessageEncoder { private static ByteBuf encodeConnectMessage( ByteBufAllocator byteBufAllocator, MqttConnectMessage message) { - int variableHeaderBufferSize = 12; int payloadBufferSize = 0; MqttFixedHeader mqttFixedHeader = message.fixedHeader(); MqttConnectVariableHeader variableHeader = message.variableHeader(); MqttConnectPayload payload = message.payload(); + MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(variableHeader.name(), + (byte) variableHeader.version()); // Client id String clientIdentifier = payload.clientIdentifier(); - if (!isValidClientIdentifier(clientIdentifier)) { - throw new IllegalArgumentException( - "invalid clientIdentifier: " + clientIdentifier + " (expected: less than 23 chars long)"); + if (!isValidClientId(mqttVersion, clientIdentifier)) { + throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + clientIdentifier); } byte[] clientIdentifierBytes = encodeStringUtf8(clientIdentifier); payloadBufferSize += 2 + clientIdentifierBytes.length; @@ -130,13 +130,16 @@ public class MqttEncoder extends MessageToMessageEncoder { } // Fixed header + byte[] protocolNameBytes = mqttVersion.protocolNameBytes(); + int variableHeaderBufferSize = 2 + protocolNameBytes.length + 4; int variablePartSize = variableHeaderBufferSize + payloadBufferSize; int fixedHeaderBufferSize = 1 + getVariableLengthInt(variablePartSize); ByteBuf buf = byteBufAllocator.buffer(fixedHeaderBufferSize + variablePartSize); buf.writeByte(getFixedHeaderByte1(mqttFixedHeader)); writeVariableLengthInt(buf, variablePartSize); - buf.writeBytes(CONNECT_VARIABLE_HEADER_START); + buf.writeShort(protocolNameBytes.length); + buf.writeBytes(protocolNameBytes); buf.writeByte(variableHeader.version()); buf.writeByte(getConnVariableHeaderFlag(variableHeader)); @@ -382,12 +385,4 @@ public class MqttEncoder extends MessageToMessageEncoder { private static byte[] encodeStringUtf8(String s) { return s.getBytes(CharsetUtil.UTF_8); } - - private static boolean isValidClientIdentifier(String clientIdentifier) { - if (clientIdentifier == null) { - return false; - } - int length = clientIdentifier.length(); - return length >= 1 && length <= 23; - } } diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttQoS.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttQoS.java index 2e654e141a..ba5430a6b5 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttQoS.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttQoS.java @@ -18,7 +18,8 @@ package io.netty.handler.codec.mqtt; public enum MqttQoS { AT_MOST_ONCE(0), AT_LEAST_ONCE(1), - EXACTLY_ONCE(2); + EXACTLY_ONCE(2), + FAILURE(0x80); private final int value; diff --git a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttVersion.java b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttVersion.java index 2608362058..1206a5c5f7 100644 --- a/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttVersion.java +++ b/codec-mqtt/src/main/java/io/netty/handler/codec/mqtt/MqttVersion.java @@ -16,12 +16,47 @@ package io.netty.handler.codec.mqtt; +import io.netty.util.CharsetUtil; + /** - * Holds Constant values used by multiple classes in mqtt-codec. + * Mqtt version specific constant values used by multiple classes in mqtt-codec. */ -final class MqttVersion { +public enum MqttVersion { + MQTT_3_1("MQIsdp", (byte) 3), + MQTT_3_1_1("MQTT", (byte) 4); - static final String PROTOCOL_NAME = "MQIsdp"; + private String name; + private byte level; + + private MqttVersion(String protocolName, byte protocolLevel) { + this.name = protocolName; + this.level = protocolLevel; + } + + public String protocolName() { + return name; + } + + public byte[] protocolNameBytes() { + return name.getBytes(CharsetUtil.UTF_8); + } + + public byte protocolLevel() { + return level; + } + + public static MqttVersion fromProtocolNameAndLevel(String protocolName, byte protocolLevel) { + for (MqttVersion mv : values()) { + if (mv.name.equals(protocolName)) { + if (mv.level == protocolLevel) { + return mv; + } else { + throw new MqttUnacceptableProtocolVersionException(protocolName + " and " + + protocolLevel + " are not match"); + } + } + } + throw new MqttUnacceptableProtocolVersionException(protocolName + "is unknown protocol name"); + } - private MqttVersion() { } } diff --git a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java index b4d8c00791..f16602a64e 100644 --- a/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java +++ b/codec-mqtt/src/test/java/io/netty/handler/codec/mqtt/MqttCodecTest.java @@ -30,7 +30,6 @@ import org.mockito.MockitoAnnotations; import java.util.LinkedList; import java.util.List; -import static io.netty.handler.codec.mqtt.MqttVersion.*; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -45,7 +44,6 @@ public class MqttCodecTest { private static final String USER_NAME = "happy_user"; private static final String PASSWORD = "123_or_no_pwd"; - private static final int PROTOCOL_VERSION = 3; private static final int KEEP_ALIVE_SECONDS = 600; private static final ByteBufAllocator ALLOCATOR = new UnpooledByteBufAllocator(false); @@ -65,8 +63,25 @@ public class MqttCodecTest { } @Test - public void testConnectMessage() throws Exception { - final MqttConnectMessage message = createConnectMessage(); + public void testConnectMessageForMqtt31() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1); + ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); + + final List out = new LinkedList(); + mqttDecoder.decode(ctx, byteBuf, out); + + assertEquals("Expected one object bout got " + out.size(), 1, out.size()); + + final MqttConnectMessage decodedMessage = (MqttConnectMessage) out.get(0); + + validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader()); + vlidateConnectVariableHeader(message.variableHeader(), decodedMessage.variableHeader()); + validateConnectPayload(message.payload(), decodedMessage.payload()); + } + + @Test + public void testConnectMessageForMqtt311() throws Exception { + final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1); ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message); final List out = new LinkedList(); @@ -250,13 +265,13 @@ public class MqttCodecTest { return new MqttMessage(mqttFixedHeader, mqttMessageIdVariableHeader); } - private static MqttConnectMessage createConnectMessage() { + private static MqttConnectMessage createConnectMessage(MqttVersion mqttVersion) { MqttFixedHeader mqttFixedHeader = new MqttFixedHeader(MqttMessageType.CONNECT, false, MqttQoS.AT_MOST_ONCE, false, 0); MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader( - PROTOCOL_NAME, - PROTOCOL_VERSION, + mqttVersion.protocolName(), + mqttVersion.protocolLevel(), true, true, true,