diff --git a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java index df2a663e89..87bea65a71 100644 --- a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java +++ b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java @@ -18,7 +18,6 @@ package io.netty.handler.codec.haproxy; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; -import io.netty.handler.codec.LineBasedFrameDecoder; import io.netty.handler.codec.ProtocolDetectionResult; import io.netty.util.CharsetUtil; @@ -50,11 +49,6 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { */ private static final int V2_MAX_TLV = 65535 - 216; - /** - * Version 1 header delimiter is always '\r\n' per spec - */ - private static final int DELIMITER_LENGTH = 2; - /** * Binary header prefix */ @@ -98,6 +92,11 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { private static final ProtocolDetectionResult DETECTION_RESULT_V2 = ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2); + /** + * Used to extract a header frame out of the {@link ByteBuf} and return it. + */ + private HeaderExtractor headerExtractor; + /** * {@code true} if we're discarding input because we're already over maxLength */ @@ -108,6 +107,11 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { */ private int discardedBytes; + /** + * Whether or not to throw an exception as soon as we exceed maxLength. + */ + private final boolean failFast; + /** * {@code true} if we're finished decoding the proxy protocol header */ @@ -125,14 +129,27 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { private final int v2MaxHeaderSize; /** - * Creates a new decoder with no additional data (TLV) restrictions + * Creates a new decoder with no additional data (TLV) restrictions, and should throw an exception as soon as + * we exceed maxLength. */ public HAProxyMessageDecoder() { - v2MaxHeaderSize = V2_MAX_LENGTH; + this(true); } /** - * Creates a new decoder with restricted additional data (TLV) size + * Creates a new decoder with no additional data (TLV) restrictions, whether or not to throw an exception as soon + * as we exceed maxLength. + * + * @param failFast Whether or not to throw an exception as soon as we exceed maxLength + */ + public HAProxyMessageDecoder(boolean failFast) { + v2MaxHeaderSize = V2_MAX_LENGTH; + this.failFast = failFast; + } + + /** + * Creates a new decoder with restricted additional data (TLV) size, and should throw an exception as soon as + * we exceed maxLength. *

* Note: limiting TLV size only affects processing of v2, binary headers. Also, as allowed by the 1.5 spec * TLV data is currently ignored. For maximum performance it would be best to configure your upstream proxy host to @@ -142,6 +159,17 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { * @param maxTlvSize maximum number of bytes allowed for additional data (Type-Length-Value vectors) in a v2 header */ public HAProxyMessageDecoder(int maxTlvSize) { + this(maxTlvSize, true); + } + + /** + * Creates a new decoder with restricted additional data (TLV) size, whether or not to throw an exception as soon + * as we exceed maxLength. + * + * @param maxTlvSize maximum number of bytes allowed for additional data (Type-Length-Value vectors) in a v2 header + * @param failFast Whether or not to throw an exception as soon as we exceed maxLength + */ + public HAProxyMessageDecoder(int maxTlvSize, boolean failFast) { if (maxTlvSize < 1) { v2MaxHeaderSize = V2_MIN_LENGTH; } else if (maxTlvSize > V2_MAX_TLV) { @@ -154,6 +182,7 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { v2MaxHeaderSize = calcMax; } } + this.failFast = failFast; } /** @@ -259,7 +288,6 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { /** * Create a frame out of the {@link ByteBuf} and return it. - * Based on code from {@link LineBasedFrameDecoder#decode(ChannelHandlerContext, ByteBuf)}. * * @param ctx the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to * @param buffer the {@link ByteBuf} from which to read data @@ -267,42 +295,14 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { * be created */ private ByteBuf decodeStruct(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { - final int eoh = findEndOfHeader(buffer); - if (!discarding) { - if (eoh >= 0) { - final int length = eoh - buffer.readerIndex(); - if (length > v2MaxHeaderSize) { - buffer.readerIndex(eoh); - failOverLimit(ctx, length); - return null; - } - return buffer.readSlice(length); - } else { - final int length = buffer.readableBytes(); - if (length > v2MaxHeaderSize) { - discardedBytes = length; - buffer.skipBytes(length); - discarding = true; - failOverLimit(ctx, "over " + discardedBytes); - } - return null; - } - } else { - if (eoh >= 0) { - buffer.readerIndex(eoh); - discardedBytes = 0; - discarding = false; - } else { - discardedBytes = buffer.readableBytes(); - buffer.skipBytes(discardedBytes); - } - return null; + if (headerExtractor == null) { + headerExtractor = new StructHeaderExtractor(v2MaxHeaderSize); } + return headerExtractor.extract(ctx, buffer); } /** * Create a frame out of the {@link ByteBuf} and return it. - * Based on code from {@link LineBasedFrameDecoder#decode(ChannelHandlerContext, ByteBuf)}. * * @param ctx the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to * @param buffer the {@link ByteBuf} from which to read data @@ -310,40 +310,10 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { * be created */ private ByteBuf decodeLine(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { - final int eol = findEndOfLine(buffer); - if (!discarding) { - if (eol >= 0) { - final int length = eol - buffer.readerIndex(); - if (length > V1_MAX_LENGTH) { - buffer.readerIndex(eol + DELIMITER_LENGTH); - failOverLimit(ctx, length); - return null; - } - ByteBuf frame = buffer.readSlice(length); - buffer.skipBytes(DELIMITER_LENGTH); - return frame; - } else { - final int length = buffer.readableBytes(); - if (length > V1_MAX_LENGTH) { - discardedBytes = length; - buffer.skipBytes(length); - discarding = true; - failOverLimit(ctx, "over " + discardedBytes); - } - return null; - } - } else { - if (eol >= 0) { - final int delimLength = buffer.getByte(eol) == '\r' ? 2 : 1; - buffer.readerIndex(eol + delimLength); - discardedBytes = 0; - discarding = false; - } else { - discardedBytes = buffer.readableBytes(); - buffer.skipBytes(discardedBytes); - } - return null; + if (headerExtractor == null) { + headerExtractor = new LineHeaderExtractor(V1_MAX_LENGTH); } + return headerExtractor.extract(ctx, buffer); } private void failOverLimit(final ChannelHandlerContext ctx, int length) { @@ -399,4 +369,119 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { } return true; } + + /** + * HeaderExtractor create a header frame out of the {@link ByteBuf}. + */ + private abstract class HeaderExtractor { + /** Header max size */ + private final int maxHeaderSize; + + protected HeaderExtractor(int maxHeaderSize) { + this.maxHeaderSize = maxHeaderSize; + } + + /** + * Create a frame out of the {@link ByteBuf} and return it. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link HAProxyMessageDecoder} belongs to + * @param buffer the {@link ByteBuf} from which to read data + * @return frame the {@link ByteBuf} which represent the frame or {@code null} if no frame could + * be created + * @throws Exception if exceed maxLength + */ + public ByteBuf extract(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { + final int eoh = findEndOfHeader(buffer); + if (!discarding) { + if (eoh >= 0) { + final int length = eoh - buffer.readerIndex(); + if (length > maxHeaderSize) { + buffer.readerIndex(eoh + delimiterLength(buffer, eoh)); + failOverLimit(ctx, length); + return null; + } + ByteBuf frame = buffer.readSlice(length); + buffer.skipBytes(delimiterLength(buffer, eoh)); + return frame; + } else { + final int length = buffer.readableBytes(); + if (length > maxHeaderSize) { + discardedBytes = length; + buffer.skipBytes(length); + discarding = true; + if (failFast) { + failOverLimit(ctx, "over " + discardedBytes); + } + } + return null; + } + } else { + if (eoh >= 0) { + final int length = discardedBytes + eoh - buffer.readerIndex(); + buffer.readerIndex(eoh + delimiterLength(buffer, eoh)); + discardedBytes = 0; + discarding = false; + if (!failFast) { + failOverLimit(ctx, "over " + length); + } + } else { + discardedBytes += buffer.readableBytes(); + buffer.skipBytes(buffer.readableBytes()); + } + return null; + } + } + + /** + * Find the end of the header from the given {@link ByteBuf},the end may be a CRLF, or the length given by the + * header. + * + * @param buffer the buffer to be searched + * @return {@code -1} if can not find the end, otherwise return the buffer index of end + */ + protected abstract int findEndOfHeader(ByteBuf buffer); + + /** + * Get the length of the header delimiter. + * + * @param buffer the buffer where delimiter is located + * @param eoh index of delimiter + * @return length of the delimiter + */ + protected abstract int delimiterLength(ByteBuf buffer, int eoh); + } + + private final class LineHeaderExtractor extends HeaderExtractor { + + LineHeaderExtractor(int maxHeaderSize) { + super(maxHeaderSize); + } + + @Override + protected int findEndOfHeader(ByteBuf buffer) { + return findEndOfLine(buffer); + } + + @Override + protected int delimiterLength(ByteBuf buffer, int eoh) { + return buffer.getByte(eoh) == '\r' ? 2 : 1; + } + } + + private final class StructHeaderExtractor extends HeaderExtractor { + + StructHeaderExtractor(int maxHeaderSize) { + super(maxHeaderSize); + } + + @Override + protected int findEndOfHeader(ByteBuf buffer) { + return HAProxyMessageDecoder.findEndOfHeader(buffer); + } + + @Override + protected int delimiterLength(ByteBuf buffer, int eoh) { + return 0; + } + } } diff --git a/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java b/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java index 02ff285565..2c323ea155 100644 --- a/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java +++ b/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java @@ -24,7 +24,9 @@ import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.TransportProtocol; import io.netty.util.CharsetUtil; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import java.util.List; @@ -32,6 +34,8 @@ import static io.netty.buffer.Unpooled.*; import static org.junit.Assert.*; public class HAProxyMessageDecoderTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); private EmbeddedChannel ch; @@ -164,6 +168,43 @@ public class HAProxyMessageDecoderTest { ch.writeInbound(copiedBuffer(header, CharsetUtil.US_ASCII)); } + @Test + public void testFailSlowHeaderTooLong() { + EmbeddedChannel slowFailCh = new EmbeddedChannel(new HAProxyMessageDecoder(false)); + try { + String headerPart1 = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 " + + "000000000000000000000000000000000000000000000000000000000000000000000443"; + // Should not throw exception + assertFalse(slowFailCh.writeInbound(copiedBuffer(headerPart1, CharsetUtil.US_ASCII))); + String headerPart2 = "more header data"; + // Should not throw exception + assertFalse(slowFailCh.writeInbound(copiedBuffer(headerPart2, CharsetUtil.US_ASCII))); + String headerPart3 = "end of header\r\n"; + + int discarded = headerPart1.length() + headerPart2.length() + headerPart3.length() - 2; + // Should throw exception + exceptionRule.expect(HAProxyProtocolException.class); + exceptionRule.expectMessage("over " + discarded); + assertFalse(slowFailCh.writeInbound(copiedBuffer(headerPart3, CharsetUtil.US_ASCII))); + } finally { + assertFalse(slowFailCh.finishAndReleaseAll()); + } + } + + @Test + public void testFailFastHeaderTooLong() { + EmbeddedChannel fastFailCh = new EmbeddedChannel(new HAProxyMessageDecoder(true)); + try { + String headerPart1 = "PROXY TCP4 192.168.0.1 192.168.0.11 56324 " + + "000000000000000000000000000000000000000000000000000000000000000000000443"; + exceptionRule.expect(HAProxyProtocolException.class); // Should throw exception, fail fast + exceptionRule.expectMessage("over " + headerPart1.length()); + assertFalse(fastFailCh.writeInbound(copiedBuffer(headerPart1, CharsetUtil.US_ASCII))); + } finally { + assertFalse(fastFailCh.finishAndReleaseAll()); + } + } + @Test public void testIncompleteHeader() { String header = "PROXY TCP4 192.168.0.1 192.168.0.11 56324";