Fix a bug where LineBasedFrameDecoder does not handle too long lines correctly

- Related: #1287
This commit is contained in:
Trustin Lee 2013-04-19 13:04:20 +09:00
parent 6bb00cea6f
commit 4a5dc32224
2 changed files with 108 additions and 42 deletions

View File

@ -35,6 +35,7 @@ public class LineBasedFrameDecoder extends ByteToMessageDecoder {
/** True if we're discarding input because we're already over maxLength. */
private boolean discarding;
private int discardedBytes;
/**
* Creates a new decoder.
@ -77,50 +78,64 @@ public class LineBasedFrameDecoder extends ByteToMessageDecoder {
protected Object decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
final int eol = findEndOfLine(buffer);
if (eol != -1) {
final ByteBuf frame;
final int length = eol - buffer.readerIndex();
assert length >= 0: "Invalid length=" + length;
if (discarding) {
frame = null;
buffer.skipBytes(length);
if (!failFast) {
fail(ctx, "over " + (maxLength + length) + " bytes");
}
} else {
int delimLength;
final byte delim = buffer.getByte(buffer.readerIndex() + length);
if (delim == '\r') {
delimLength = 2; // Skip the \r\n.
} else {
delimLength = 1;
if (!discarding) {
if (eol >= 0) {
final ByteBuf frame;
final int length = eol - buffer.readerIndex();
final int delimLength = buffer.getByte(eol) == '\r'? 2 : 1;
if (length > maxLength) {
buffer.readerIndex(eol + delimLength);
fail(ctx, length);
return null;
}
if (stripDelimiter) {
frame = buffer.readBytes(length);
buffer.skipBytes(delimLength);
} else {
frame = buffer.readBytes(length + delimLength);
}
}
return frame;
}
final int buffered = buffer.readableBytes();
if (!discarding && buffered > maxLength) {
discarding = true;
if (failFast) {
fail(ctx, buffered + " bytes buffered already");
return frame;
} else {
final int length = buffer.readableBytes();
if (length > maxLength) {
discardedBytes = length;
buffer.readerIndex(buffer.writerIndex());
discarding = true;
if (failFast) {
fail(ctx, "over " + discardedBytes);
}
}
return null;
}
} else {
if (eol >= 0) {
final int length = discardedBytes + eol - buffer.readerIndex();
final int delimLength = buffer.getByte(eol) == '\r'? 2 : 1;
buffer.readerIndex(eol + delimLength);
discardedBytes = 0;
discarding = false;
if (!failFast) {
fail(ctx, length);
}
} else {
discardedBytes = buffer.readableBytes();
buffer.readerIndex(buffer.writerIndex());
}
return null;
}
if (discarding) {
buffer.skipBytes(buffer.readableBytes());
}
return null;
}
private void fail(final ChannelHandlerContext ctx, final String msg) {
ctx.fireExceptionCaught(new TooLongFrameException("Frame length exceeds " + maxLength + " ("
+ msg + ')'));
private void fail(final ChannelHandlerContext ctx, int length) {
fail(ctx, String.valueOf(length));
}
private void fail(final ChannelHandlerContext ctx, String length) {
ctx.fireExceptionCaught(
new TooLongFrameException(
"frame length (" + length + ") exceeds the allowed maximum (" + maxLength + ')'));
}
/**

View File

@ -16,29 +16,80 @@
package io.netty.handler.codec;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedByteChannel;
import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test;
import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;
public class LineBasedFrameDecoderTest {
@Test
public void testDecodeWithStrip() throws Exception {
EmbeddedByteChannel ch = new EmbeddedByteChannel(new LineBasedFrameDecoder(8192, true, false));
ch.writeInbound(Unpooled.copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII));
Assert.assertEquals("first", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
Assert.assertEquals("second", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
Assert.assertNull(ch.readInbound());
ch.writeInbound(copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII));
assertEquals("first", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
assertEquals("second", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
assertNull(ch.readInbound());
}
@Test
public void testDecodeWithoutStrip() throws Exception {
EmbeddedByteChannel ch = new EmbeddedByteChannel(new LineBasedFrameDecoder(8192, false, false));
ch.writeInbound(Unpooled.copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII));
Assert.assertEquals("first\r\n", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
Assert.assertEquals("second\n", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
Assert.assertNull(ch.readInbound());
ch.writeInbound(copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII));
assertEquals("first\r\n", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
assertEquals("second\n", ((ByteBuf) ch.readInbound()).toString(CharsetUtil.US_ASCII));
assertNull(ch.readInbound());
}
@Test
public void testTooLongLine1() throws Exception {
EmbeddedByteChannel ch = new EmbeddedByteChannel(new LineBasedFrameDecoder(16, false, false));
try {
ch.writeInbound(copiedBuffer("12345678901234567890\r\nfirst\nsecond", CharsetUtil.US_ASCII));
fail();
} catch (Exception e) {
assertThat(e, is(instanceOf(TooLongFrameException.class)));
}
assertThat((ByteBuf) ch.readInbound(), is(copiedBuffer("first\n", CharsetUtil.US_ASCII)));
assertThat(ch.finish(), is(false));
}
@Test
public void testTooLongLine2() throws Exception {
EmbeddedByteChannel ch = new EmbeddedByteChannel(new LineBasedFrameDecoder(16, false, false));
assertFalse(ch.writeInbound(copiedBuffer("12345678901234567", CharsetUtil.US_ASCII)));
try {
ch.writeInbound(copiedBuffer("890\r\nfirst\r\n", CharsetUtil.US_ASCII));
fail();
} catch (Exception e) {
assertThat(e, is(instanceOf(TooLongFrameException.class)));
}
assertThat((ByteBuf) ch.readInbound(), is(copiedBuffer("first\r\n", CharsetUtil.US_ASCII)));
assertThat(ch.finish(), is(false));
}
@Test
public void testTooLongLineWithFailFast() throws Exception {
EmbeddedByteChannel ch = new EmbeddedByteChannel(new LineBasedFrameDecoder(16, false, true));
try {
ch.writeInbound(copiedBuffer("12345678901234567", CharsetUtil.US_ASCII));
fail();
} catch (Exception e) {
assertThat(e, is(instanceOf(TooLongFrameException.class)));
}
assertThat(ch.writeInbound(copiedBuffer("890", CharsetUtil.US_ASCII)), is(false));
assertThat(ch.writeInbound(copiedBuffer("123\r\nfirst\r\n", CharsetUtil.US_ASCII)), is(true));
assertThat((ByteBuf) ch.readInbound(), is(copiedBuffer("first\r\n", CharsetUtil.US_ASCII)));
assertThat(ch.finish(), is(false));
}
}