Fix encoding/decoding for UTF-8 stomp commands and headers (#9740)

Motivation:

According STOMP spec (https://stomp.github.io/stomp-specification-1.2.html#Value_Encoding) we have to encode and decode commands and headers to UTF-8

Modification:

Provide ability for StompSubframeDecoder and StompSubframeEncoder work with UTF-8
This commit is contained in:
Andrey Mizurov 2019-11-06 14:07:38 +03:00 committed by Norman Maurer
parent 625981a296
commit af532d2b7e
5 changed files with 249 additions and 133 deletions

View File

@ -15,9 +15,6 @@
*/
package io.netty.handler.codec.stomp;
import java.util.List;
import java.util.Locale;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
@ -26,35 +23,32 @@ import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.ReplayingDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.stomp.StompSubframeDecoder.State;
import io.netty.util.ByteProcessor;
import io.netty.util.internal.AppendableCharSequence;
import io.netty.util.internal.StringUtil;
import static io.netty.buffer.ByteBufUtil.indexOf;
import static io.netty.buffer.ByteBufUtil.readBytes;
import static io.netty.util.internal.ObjectUtil.checkPositive;
import java.util.List;
import static io.netty.buffer.ByteBufUtil.*;
import static io.netty.util.internal.ObjectUtil.*;
/**
* Decodes {@link ByteBuf}s into {@link StompHeadersSubframe}s and
* {@link StompContentSubframe}s.
* Decodes {@link ByteBuf}s into {@link StompHeadersSubframe}s and {@link StompContentSubframe}s.
*
* <h3>Parameters to control memory consumption: </h3>
* {@code maxLineLength} the maximum length of line -
* restricts length of command and header lines
* If the length of the initial line exceeds this value, a
* {@link TooLongFrameException} will be raised.
* {@code maxLineLength} the maximum length of line - restricts length of command and header lines If the length of the
* initial line exceeds this value, a {@link TooLongFrameException} will be raised.
* <br>
* {@code maxChunkSize}
* The maximum length of the content or each chunk. If the content length
* (or the length of each chunk) exceeds this value, the content or chunk
* ill be split into multiple {@link StompContentSubframe}s whose length is
* {@code maxChunkSize} at maximum.
* {@code maxChunkSize} The maximum length of the content or each chunk. If the content length (or the length of each
* chunk) exceeds this value, the content or chunk ill be split into multiple {@link StompContentSubframe}s whose length
* is {@code maxChunkSize} at maximum.
*
* <h3>Chunked Content</h3>
*
* If the content of a stomp message is greater than {@code maxChunkSize}
* the transfer encoding of the HTTP message is 'chunked', this decoder
* generates multiple {@link StompContentSubframe} instances to avoid excessive memory
* consumption. Note, that every message, even with no content decodes with
* {@link LastStompContentSubframe} at the end to simplify upstream message parsing.
* <p>
* If the content of a stomp message is greater than {@code maxChunkSize} the transfer encoding of the HTTP message is
* 'chunked', this decoder generates multiple {@link StompContentSubframe} instances to avoid excessive memory
* consumption. Note, that every message, even with no content decodes with {@link LastStompContentSubframe} at the end
* to simplify upstream message parsing.
*/
public class StompSubframeDecoder extends ReplayingDecoder<State> {
@ -70,9 +64,9 @@ public class StompSubframeDecoder extends ReplayingDecoder<State> {
INVALID_CHUNK
}
private final int maxLineLength;
private final Utf8LineParser commandParser;
private final HeaderParser headerParser;
private final int maxChunkSize;
private final boolean validateHeaders;
private int alreadyReadChunkSize;
private LastStompContentSubframe lastContent;
private long contentLength = -1;
@ -94,8 +88,8 @@ public class StompSubframeDecoder extends ReplayingDecoder<State> {
checkPositive(maxLineLength, "maxLineLength");
checkPositive(maxChunkSize, "maxChunkSize");
this.maxChunkSize = maxChunkSize;
this.maxLineLength = maxLineLength;
this.validateHeaders = validateHeaders;
commandParser = new Utf8LineParser(new AppendableCharSequence(16), maxLineLength);
headerParser = new HeaderParser(new AppendableCharSequence(128), maxLineLength, validateHeaders);
}
@Override
@ -189,34 +183,24 @@ public class StompSubframeDecoder extends ReplayingDecoder<State> {
}
private StompCommand readCommand(ByteBuf in) {
String commandStr = readLine(in, 16);
StompCommand command = null;
CharSequence commandSequence = commandParser.parse(in);
if (commandSequence == null) {
throw new DecoderException("Failed to read command from channel");
}
String commandStr = commandSequence.toString();
try {
command = StompCommand.valueOf(commandStr);
return StompCommand.valueOf(commandStr);
} catch (IllegalArgumentException iae) {
//do nothing
throw new DecoderException("Cannot to parse command " + commandStr);
}
if (command == null) {
commandStr = commandStr.toUpperCase(Locale.US);
try {
command = StompCommand.valueOf(commandStr);
} catch (IllegalArgumentException iae) {
//do nothing
}
}
if (command == null) {
throw new DecoderException("failed to read command from channel");
}
return command;
}
private State readHeaders(ByteBuf buffer, StompHeaders headers) {
AppendableCharSequence buf = new AppendableCharSequence(128);
for (;;) {
boolean headerRead = readHeader(headers, buf, buffer);
boolean headerRead = headerParser.parseHeader(headers, buffer);
if (!headerRead) {
if (headers.contains(StompHeaders.CONTENT_LENGTH)) {
contentLength = getContentLength(headers, 0);
contentLength = getContentLength(headers);
if (contentLength == 0) {
return State.FINALIZE_FRAME_READ;
}
@ -226,8 +210,8 @@ public class StompSubframeDecoder extends ReplayingDecoder<State> {
}
}
private static long getContentLength(StompHeaders headers, long defaultValue) {
long contentLength = headers.getLong(StompHeaders.CONTENT_LENGTH, defaultValue);
private static long getContentLength(StompHeaders headers) {
long contentLength = headers.getLong(StompHeaders.CONTENT_LENGTH, 0L);
if (contentLength < 0) {
throw new DecoderException(StompHeaders.CONTENT_LENGTH + " must be non-negative");
}
@ -252,75 +236,147 @@ public class StompSubframeDecoder extends ReplayingDecoder<State> {
}
}
private String readLine(ByteBuf buffer, int initialBufferSize) {
AppendableCharSequence buf = new AppendableCharSequence(initialBufferSize);
int lineLength = 0;
for (;;) {
byte nextByte = buffer.readByte();
if (nextByte == StompConstants.CR) {
//do nothing
} else if (nextByte == StompConstants.LF) {
return buf.toString();
} else {
if (lineLength >= maxLineLength) {
invalidLineLength();
}
lineLength ++;
buf.append((char) nextByte);
}
}
}
private boolean readHeader(StompHeaders headers, AppendableCharSequence buf, ByteBuf buffer) {
buf.reset();
int lineLength = 0;
String key = null;
boolean valid = false;
for (;;) {
byte nextByte = buffer.readByte();
if (nextByte == StompConstants.COLON && key == null) {
key = buf.toString();
valid = true;
buf.reset();
} else if (nextByte == StompConstants.CR) {
//do nothing
} else if (nextByte == StompConstants.LF) {
if (key == null && lineLength == 0) {
return false;
} else if (valid) {
headers.add(key, buf.toString());
} else if (validateHeaders) {
invalidHeader(key, buf.toString());
}
return true;
} else {
if (lineLength >= maxLineLength) {
invalidLineLength();
}
if (nextByte == StompConstants.COLON && key != null) {
valid = false;
}
lineLength ++;
buf.append((char) nextByte);
}
}
}
private void invalidHeader(String key, String value) {
String line = key != null ? key + ":" + value : value;
throw new IllegalArgumentException("a header value or name contains a prohibited character ':'"
+ ", " + line);
}
private void invalidLineLength() {
throw new TooLongFrameException("An STOMP line is larger than " + maxLineLength + " bytes.");
}
private void resetDecoder() {
checkpoint(State.SKIP_CONTROL_CHARACTERS);
contentLength = -1;
alreadyReadChunkSize = 0;
lastContent = null;
}
private static class Utf8LineParser implements ByteProcessor {
private final AppendableCharSequence charSeq;
private final int maxLineLength;
private int lineLength;
private char interim;
private boolean nextRead;
Utf8LineParser(AppendableCharSequence charSeq, int maxLineLength) {
this.charSeq = checkNotNull(charSeq, "charSeq");
this.maxLineLength = maxLineLength;
}
AppendableCharSequence parse(ByteBuf byteBuf) {
reset();
int offset = byteBuf.forEachByte(this);
if (offset == -1) {
return null;
}
byteBuf.readerIndex(offset + 1);
return charSeq;
}
AppendableCharSequence charSequence() {
return charSeq;
}
@Override
public boolean process(byte nextByte) throws Exception {
if (nextByte == StompConstants.CR) {
++lineLength;
return true;
}
if (nextByte == StompConstants.LF) {
return false;
}
if (++lineLength > maxLineLength) {
throw new TooLongFrameException("An STOMP line is larger than " + maxLineLength + " bytes.");
}
// 1 byte - 0xxxxxxx - 7 bits
// 2 byte - 110xxxxx 10xxxxxx - 11 bits
// 3 byte - 1110xxxx 10xxxxxx 10xxxxxx - 16 bits
if (nextRead) {
interim |= (nextByte & 0x3F) << 6;
nextRead = false;
} else if (interim != 0) { // flush 2 or 3 byte
charSeq.append((char) (interim | (nextByte & 0x3F)));
interim = 0;
} else if (nextByte >= 0) { // INITIAL BRANCH
// The first 128 characters (US-ASCII) need one byte.
charSeq.append((char) nextByte);
} else if ((nextByte & 0xE0) == 0xC0) {
// The next 1920 characters need two bytes and we can define
// a first byte by mask 110xxxxx.
interim = (char) ((nextByte & 0x1F) << 6);
} else {
// The rest of characters need three bytes.
interim = (char) ((nextByte & 0x0F) << 12);
nextRead = true;
}
return true;
}
protected void reset() {
charSeq.reset();
lineLength = 0;
interim = 0;
nextRead = false;
}
}
private static final class HeaderParser extends Utf8LineParser {
private final boolean validateHeaders;
private String name;
private boolean valid;
HeaderParser(AppendableCharSequence charSeq, int maxLineLength, boolean validateHeaders) {
super(charSeq, maxLineLength);
this.validateHeaders = validateHeaders;
}
boolean parseHeader(StompHeaders headers, ByteBuf buf) {
AppendableCharSequence value = super.parse(buf);
if (value == null || (name == null && value.length() == 0)) {
return false;
}
if (valid) {
headers.add(name, value.toString());
} else if (validateHeaders) {
if (StringUtil.isNullOrEmpty(name)) {
throw new IllegalArgumentException("received an invalid header line '" + value.toString() + '\'');
}
String line = name + ':' + value.toString();
throw new IllegalArgumentException("a header value or name contains a prohibited character ':'"
+ ", " + line);
}
return true;
}
@Override
public boolean process(byte nextByte) throws Exception {
if (nextByte == StompConstants.COLON) {
if (name == null) {
AppendableCharSequence charSeq = charSequence();
if (charSeq.length() != 0) {
name = charSeq.substring(0, charSeq.length());
charSeq.reset();
valid = true;
return true;
} else {
name = StringUtil.EMPTY_STRING;
}
} else {
valid = false;
}
}
return super.process(nextByte);
}
@Override
protected void reset() {
name = null;
valid = false;
super.reset();
}
}
}

View File

@ -15,17 +15,15 @@
*/
package io.netty.handler.codec.stomp;
import java.util.List;
import java.util.Map.Entry;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.AsciiHeadersEncoder;
import io.netty.handler.codec.AsciiHeadersEncoder.NewlineType;
import io.netty.handler.codec.AsciiHeadersEncoder.SeparatorType;
import io.netty.handler.codec.MessageToMessageEncoder;
import io.netty.util.CharsetUtil;
import java.util.List;
import java.util.Map.Entry;
/**
* Encodes a {@link StompFrame} or a {@link StompSubframe} into a {@link ByteBuf}.
*/
@ -64,11 +62,13 @@ public class StompSubframeEncoder extends MessageToMessageEncoder<StompSubframe>
private static ByteBuf encodeFrame(StompHeadersSubframe frame, ChannelHandlerContext ctx) {
ByteBuf buf = ctx.alloc().buffer();
buf.writeCharSequence(frame.command().toString(), CharsetUtil.US_ASCII);
buf.writeCharSequence(frame.command().toString(), CharsetUtil.UTF_8);
buf.writeByte(StompConstants.LF);
AsciiHeadersEncoder headersEncoder = new AsciiHeadersEncoder(buf, SeparatorType.COLON, NewlineType.LF);
for (Entry<CharSequence, CharSequence> entry : frame.headers()) {
headersEncoder.encode(entry);
ByteBufUtil.writeUtf8(buf, entry.getKey());
buf.writeByte(StompConstants.COLON);
ByteBufUtil.writeUtf8(buf, entry.getValue());
buf.writeByte(StompConstants.LF);
}
buf.writeByte(StompConstants.LF);
return buf;

View File

@ -22,15 +22,9 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static io.netty.handler.codec.stomp.StompTestConstants.FRAME_WITH_INVALID_HEADER;
import static io.netty.util.CharsetUtil.US_ASCII;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static io.netty.handler.codec.stomp.StompTestConstants.*;
import static io.netty.util.CharsetUtil.*;
import static org.junit.Assert.*;
public class StompSubframeDecoderTest {
@ -165,7 +159,7 @@ public class StompSubframeDecoderTest {
@Test
public void testValidateHeadersDecodingDisabled() {
ByteBuf invalidIncoming = Unpooled.copiedBuffer(FRAME_WITH_INVALID_HEADER.getBytes(US_ASCII));
ByteBuf invalidIncoming = Unpooled.copiedBuffer(FRAME_WITH_INVALID_HEADER.getBytes(UTF_8));
assertTrue(channel.writeInbound(invalidIncoming));
StompHeadersSubframe frame = channel.readInbound();
@ -185,7 +179,7 @@ public class StompSubframeDecoderTest {
public void testValidateHeadersDecodingEnabled() {
channel = new EmbeddedChannel(new StompSubframeDecoder(true));
ByteBuf invalidIncoming = Unpooled.copiedBuffer(FRAME_WITH_INVALID_HEADER.getBytes(US_ASCII));
ByteBuf invalidIncoming = Unpooled.wrappedBuffer(FRAME_WITH_INVALID_HEADER.getBytes(UTF_8));
assertTrue(channel.writeInbound(invalidIncoming));
StompHeadersSubframe frame = channel.readInbound();
@ -194,4 +188,37 @@ public class StompSubframeDecoderTest {
assertEquals("a header value or name contains a prohibited character ':', current-time:2000-01-01T00:00:00",
frame.decoderResult().cause().getMessage());
}
@Test
public void testNotValidFrameWithEmptyHeaderName() {
channel = new EmbeddedChannel(new StompSubframeDecoder(true));
ByteBuf invalidIncoming = Unpooled.wrappedBuffer(FRAME_WITH_EMPTY_HEADER_NAME.getBytes(UTF_8));
assertTrue(channel.writeInbound(invalidIncoming));
StompHeadersSubframe frame = channel.readInbound();
assertNotNull(frame);
assertTrue(frame.decoderResult().isFailure());
assertEquals("received an invalid header line ':header-value'",
frame.decoderResult().cause().getMessage());
}
@Test
public void testUtf8FrameDecoding() {
channel = new EmbeddedChannel(new StompSubframeDecoder(true));
ByteBuf incoming = Unpooled.wrappedBuffer(SEND_FRAME_UTF8.getBytes(UTF_8));
assertTrue(channel.writeInbound(incoming));
StompHeadersSubframe headersSubFrame = channel.readInbound();
assertNotNull(headersSubFrame);
assertFalse(headersSubFrame.decoderResult().isFailure());
assertEquals("/queue/№11±♛нетти♕", headersSubFrame.headers().getAsString("destination"));
assertTrue(headersSubFrame.headers().contains("content-type"));
StompContentSubframe contentSubFrame = channel.readInbound();
assertNotNull(contentSubFrame);
assertEquals("body", contentSubFrame.content().toString(UTF_8));
assertTrue(contentSubFrame.release());
}
}

View File

@ -18,11 +18,13 @@ package io.netty.handler.codec.stomp;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.AsciiString;
import io.netty.util.CharsetUtil;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static io.netty.handler.codec.stomp.StompTestConstants.*;
import static org.junit.Assert.*;
public class StompSubframeEncoderTest {
@ -63,4 +65,22 @@ public class StompSubframeEncoderTest {
assertEquals(StompTestConstants.CONNECT_FRAME, content);
aggregatedBuffer.release();
}
@Test
public void testUtf8FrameEncoding() {
StompFrame frame = new DefaultStompFrame(StompCommand.SEND,
Unpooled.wrappedBuffer("body".getBytes(CharsetUtil.UTF_8)));
StompHeaders incoming = frame.headers();
incoming.set(StompHeaders.DESTINATION, "/queue/№11±♛нетти♕");
incoming.set(StompHeaders.CONTENT_TYPE, AsciiString.of("text/plain"));
channel.writeOutbound(frame);
ByteBuf headers = channel.readOutbound();
ByteBuf content = channel.readOutbound();
ByteBuf fullFrame = Unpooled.wrappedBuffer(headers, content);
assertEquals(SEND_FRAME_UTF8, fullFrame.toString(CharsetUtil.UTF_8));
assertTrue(fullFrame.release());
}
}

View File

@ -64,5 +64,18 @@ public final class StompTestConstants {
'\n' +
"some body\0";
public static final String FRAME_WITH_EMPTY_HEADER_NAME = "SEND\n" +
"destination:/some-destination\n" +
"content-type:text/plain\n" +
":header-value\n" +
'\n' +
"some body\0";
public static final String SEND_FRAME_UTF8 = "SEND\n" +
"destination:/queue/№11±♛нетти♕\n" +
"content-type:text/plain\n" +
'\n' +
"body\0";
private StompTestConstants() { }
}