Add supporting MQTT 3.1.1

Motivation:

MQTT 3.1.1 became an OASIS Standard at 13 Nov 2014.
http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/mqtt-v3.1.1.html
MQTT 3.1.1 is a minor update of 3.1. But, previous codec-mqtt supported only MQTT 3.1.

Modifications:

- Add protocol name `MQTT` with previous `MQIsdp` for `CONNECT`’s variable header.
- Update client identifier validation for 3.1 with 3.1.1.
- Add `FAILURE (0x80)` for `SUBACK`’s new error code.
- Add a test for encode/decode `CONNECT` of 3.1.1.

Result:

MqttEncoder/MqttDecoder can encode/decode frames of 3.1 or 3.1.1.
This commit is contained in:
Jongyeol Choi 2014-11-14 19:59:45 +09:00 committed by Norman Maurer
parent c8a1d077b5
commit 28c388e525
6 changed files with 94 additions and 40 deletions

View File

@ -38,9 +38,16 @@ final class MqttCodecUtil {
return messageId != 0;
}
static boolean isValidClientId(String clientId) {
static boolean isValidClientId(MqttVersion mqttVersion, String clientId) {
if (mqttVersion == MqttVersion.MQTT_3_1) {
return clientId != null && clientId.length() >= MIN_CLIENT_ID_LENGTH &&
clientId.length() <= MAX_CLIENT_ID_LENGTH;
} else if (mqttVersion == MqttVersion.MQTT_3_1_1) {
// In 3.1.3.1 Client Identifier of MQTT 3.1.1 specification, The Server MAY allow ClientIds
// that contain more than 23 encoded bytes. And, The Server MAY allow zero-length ClientId.
return clientId != null;
}
throw new IllegalArgumentException(mqttVersion + " is unknown mqtt version");
}
static MqttFixedHeader validateFixedHeader(MqttFixedHeader mqttFixedHeader) {

View File

@ -27,7 +27,6 @@ import java.util.ArrayList;
import java.util.List;
import static io.netty.handler.codec.mqtt.MqttCodecUtil.*;
import static io.netty.handler.codec.mqtt.MqttVersion.*;
/**
* Decodes Mqtt messages from bytes, following
@ -203,15 +202,15 @@ public class MqttDecoder extends ReplayingDecoder<DecoderState> {
private static Result<MqttConnectVariableHeader> decodeConnectionVariableHeader(ByteBuf buffer) {
final Result<String> protoString = decodeString(buffer);
if (!PROTOCOL_NAME.equals(protoString.value)) {
throw new MqttUnacceptableProtocolVersionException("missing " + PROTOCOL_NAME + " signature");
}
int numberOfBytesConsumed = protoString.numberOfBytesConsumed;
final byte version = buffer.readByte();
final byte protocolLevel = buffer.readByte();
numberOfBytesConsumed += 1;
final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(protoString.value, protocolLevel);
final int b1 = buffer.readUnsignedByte();
numberOfBytesConsumed += 2;
numberOfBytesConsumed += 1;
final Result<Integer> keepAlive = decodeMsbLsb(buffer);
numberOfBytesConsumed += keepAlive.numberOfBytesConsumed;
@ -224,8 +223,8 @@ public class MqttDecoder extends ReplayingDecoder<DecoderState> {
final boolean cleanSession = (b1 & 0x02) == 0x02;
final MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader(
PROTOCOL_NAME,
version,
mqttVersion.protocolName(),
mqttVersion.protocolLevel(),
hasUserName,
hasPassword,
willRetain,
@ -321,7 +320,9 @@ public class MqttDecoder extends ReplayingDecoder<DecoderState> {
MqttConnectVariableHeader mqttConnectVariableHeader) {
final Result<String> decodedClientId = decodeString(buffer);
final String decodedClientIdValue = decodedClientId.value;
if (!isValidClientId(decodedClientIdValue)) {
final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(),
(byte) mqttConnectVariableHeader.version());
if (!isValidClientId(mqttVersion, decodedClientIdValue)) {
throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue);
}
int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed;

View File

@ -24,6 +24,8 @@ import io.netty.util.CharsetUtil;
import java.util.List;
import static io.netty.handler.codec.mqtt.MqttCodecUtil.*;
/**
* Encodes Mqtt messages into bytes following the protocl specification v3.1
* as described here <a href="http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html">MQTTV3.1</a>
@ -34,8 +36,6 @@ public class MqttEncoder extends MessageToMessageEncoder<MqttMessage> {
private static final byte[] EMPTY = new byte[0];
private static final byte[] CONNECT_VARIABLE_HEADER_START = {0, 6, 'M', 'Q', 'I', 's', 'd', 'p'};
@Override
protected void encode(ChannelHandlerContext ctx, MqttMessage msg, List<Object> out) throws Exception {
out.add(doEncode(ctx.alloc(), msg));
@ -91,18 +91,18 @@ public class MqttEncoder extends MessageToMessageEncoder<MqttMessage> {
private static ByteBuf encodeConnectMessage(
ByteBufAllocator byteBufAllocator,
MqttConnectMessage message) {
int variableHeaderBufferSize = 12;
int payloadBufferSize = 0;
MqttFixedHeader mqttFixedHeader = message.fixedHeader();
MqttConnectVariableHeader variableHeader = message.variableHeader();
MqttConnectPayload payload = message.payload();
MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(variableHeader.name(),
(byte) variableHeader.version());
// Client id
String clientIdentifier = payload.clientIdentifier();
if (!isValidClientIdentifier(clientIdentifier)) {
throw new IllegalArgumentException(
"invalid clientIdentifier: " + clientIdentifier + " (expected: less than 23 chars long)");
if (!isValidClientId(mqttVersion, clientIdentifier)) {
throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + clientIdentifier);
}
byte[] clientIdentifierBytes = encodeStringUtf8(clientIdentifier);
payloadBufferSize += 2 + clientIdentifierBytes.length;
@ -130,13 +130,16 @@ public class MqttEncoder extends MessageToMessageEncoder<MqttMessage> {
}
// Fixed header
byte[] protocolNameBytes = mqttVersion.protocolNameBytes();
int variableHeaderBufferSize = 2 + protocolNameBytes.length + 4;
int variablePartSize = variableHeaderBufferSize + payloadBufferSize;
int fixedHeaderBufferSize = 1 + getVariableLengthInt(variablePartSize);
ByteBuf buf = byteBufAllocator.buffer(fixedHeaderBufferSize + variablePartSize);
buf.writeByte(getFixedHeaderByte1(mqttFixedHeader));
writeVariableLengthInt(buf, variablePartSize);
buf.writeBytes(CONNECT_VARIABLE_HEADER_START);
buf.writeShort(protocolNameBytes.length);
buf.writeBytes(protocolNameBytes);
buf.writeByte(variableHeader.version());
buf.writeByte(getConnVariableHeaderFlag(variableHeader));
@ -382,12 +385,4 @@ public class MqttEncoder extends MessageToMessageEncoder<MqttMessage> {
private static byte[] encodeStringUtf8(String s) {
return s.getBytes(CharsetUtil.UTF_8);
}
private static boolean isValidClientIdentifier(String clientIdentifier) {
if (clientIdentifier == null) {
return false;
}
int length = clientIdentifier.length();
return length >= 1 && length <= 23;
}
}

View File

@ -18,7 +18,8 @@ package io.netty.handler.codec.mqtt;
public enum MqttQoS {
AT_MOST_ONCE(0),
AT_LEAST_ONCE(1),
EXACTLY_ONCE(2);
EXACTLY_ONCE(2),
FAILURE(0x80);
private final int value;

View File

@ -16,12 +16,47 @@
package io.netty.handler.codec.mqtt;
import io.netty.util.CharsetUtil;
/**
* Holds Constant values used by multiple classes in mqtt-codec.
* Mqtt version specific constant values used by multiple classes in mqtt-codec.
*/
final class MqttVersion {
public enum MqttVersion {
MQTT_3_1("MQIsdp", (byte) 3),
MQTT_3_1_1("MQTT", (byte) 4);
static final String PROTOCOL_NAME = "MQIsdp";
private String name;
private byte level;
private MqttVersion(String protocolName, byte protocolLevel) {
this.name = protocolName;
this.level = protocolLevel;
}
public String protocolName() {
return name;
}
public byte[] protocolNameBytes() {
return name.getBytes(CharsetUtil.UTF_8);
}
public byte protocolLevel() {
return level;
}
public static MqttVersion fromProtocolNameAndLevel(String protocolName, byte protocolLevel) {
for (MqttVersion mv : values()) {
if (mv.name.equals(protocolName)) {
if (mv.level == protocolLevel) {
return mv;
} else {
throw new MqttUnacceptableProtocolVersionException(protocolName + " and " +
protocolLevel + " are not match");
}
}
}
throw new MqttUnacceptableProtocolVersionException(protocolName + "is unknown protocol name");
}
private MqttVersion() { }
}

View File

@ -30,7 +30,6 @@ import org.mockito.MockitoAnnotations;
import java.util.LinkedList;
import java.util.List;
import static io.netty.handler.codec.mqtt.MqttVersion.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
@ -45,7 +44,6 @@ public class MqttCodecTest {
private static final String USER_NAME = "happy_user";
private static final String PASSWORD = "123_or_no_pwd";
private static final int PROTOCOL_VERSION = 3;
private static final int KEEP_ALIVE_SECONDS = 600;
private static final ByteBufAllocator ALLOCATOR = new UnpooledByteBufAllocator(false);
@ -65,8 +63,25 @@ public class MqttCodecTest {
}
@Test
public void testConnectMessage() throws Exception {
final MqttConnectMessage message = createConnectMessage();
public void testConnectMessageForMqtt31() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message);
final List<Object> out = new LinkedList<Object>();
mqttDecoder.decode(ctx, byteBuf, out);
assertEquals("Expected one object bout got " + out.size(), 1, out.size());
final MqttConnectMessage decodedMessage = (MqttConnectMessage) out.get(0);
validateFixedHeaders(message.fixedHeader(), decodedMessage.fixedHeader());
vlidateConnectVariableHeader(message.variableHeader(), decodedMessage.variableHeader());
validateConnectPayload(message.payload(), decodedMessage.payload());
}
@Test
public void testConnectMessageForMqtt311() throws Exception {
final MqttConnectMessage message = createConnectMessage(MqttVersion.MQTT_3_1_1);
ByteBuf byteBuf = MqttEncoder.doEncode(ALLOCATOR, message);
final List<Object> out = new LinkedList<Object>();
@ -250,13 +265,13 @@ public class MqttCodecTest {
return new MqttMessage(mqttFixedHeader, mqttMessageIdVariableHeader);
}
private static MqttConnectMessage createConnectMessage() {
private static MqttConnectMessage createConnectMessage(MqttVersion mqttVersion) {
MqttFixedHeader mqttFixedHeader =
new MqttFixedHeader(MqttMessageType.CONNECT, false, MqttQoS.AT_MOST_ONCE, false, 0);
MqttConnectVariableHeader mqttConnectVariableHeader =
new MqttConnectVariableHeader(
PROTOCOL_NAME,
PROTOCOL_VERSION,
mqttVersion.protocolName(),
mqttVersion.protocolLevel(),
true,
true,
true,