diff --git a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java index 3487c159ef..04c744811f 100644 --- a/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java +++ b/codec-haproxy/src/main/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoder.java @@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.ProtocolDetectionResult; import io.netty.util.CharsetUtil; import java.util.List; @@ -72,11 +73,31 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { (byte) 0x0A }; + private static final byte[] TEXT_PREFIX = { + (byte) 'P', + (byte) 'R', + (byte) 'O', + (byte) 'X', + (byte) 'Y', + }; + /** * Binary header prefix length */ private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length; + /** + * {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V1}. + */ + private static final ProtocolDetectionResult DETECTION_RESULT_V1 = + ProtocolDetectionResult.detected(HAProxyProtocolVersion.V1); + + /** + * {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V2}. + */ + private static final ProtocolDetectionResult DETECTION_RESULT_V2 = + ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2); + /** * {@code true} if we're discarding input because we're already over maxLength */ @@ -147,15 +168,7 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { } int idx = buffer.readerIndex(); - - for (int i = 0; i < BINARY_PREFIX_LENGTH; i++) { - final byte b = buffer.getByte(idx + i); - if (b != BINARY_PREFIX[i]) { - return 1; - } - } - - return buffer.getByte(idx + BINARY_PREFIX_LENGTH); + return match(BINARY_PREFIX, buffer, idx) ? buffer.getByte(idx + BINARY_PREFIX_LENGTH) : 1; } /** @@ -357,4 +370,33 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder { } throw ppex; } + + /** + * Returns the {@link ProtocolDetectionResult} for the given {@link ByteBuf}. + */ + public static ProtocolDetectionResult detectProtocol(ByteBuf buffer) { + if (buffer.readableBytes() < 12) { + return ProtocolDetectionResult.needsMoreData(); + } + + int idx = buffer.readerIndex(); + + if (match(BINARY_PREFIX, buffer, idx)) { + return DETECTION_RESULT_V2; + } + if (match(TEXT_PREFIX, buffer, idx)) { + return DETECTION_RESULT_V1; + } + return ProtocolDetectionResult.invalid(); + } + + private static boolean match(byte[] prefix, ByteBuf buffer, int idx) { + for (int i = 0; i < prefix.length; i++) { + final byte b = buffer.getByte(idx + i); + if (b != prefix[i]) { + return false; + } + } + return true; + } } diff --git a/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java b/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java index beb7311b2c..6e833f6440 100644 --- a/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java +++ b/codec-haproxy/src/test/java/io/netty/handler/codec/haproxy/HAProxyMessageDecoderTest.java @@ -15,8 +15,11 @@ */ package io.netty.handler.codec.haproxy; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.ProtocolDetectionResult; +import io.netty.handler.codec.ProtocolDetectionState; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.TransportProtocol; import io.netty.util.CharsetUtil; @@ -896,4 +899,51 @@ public class HAProxyMessageDecoderTest { assertNull(ch.readInbound()); assertFalse(ch.finish()); } + + @Test + public void testDetectProtocol() { + final ByteBuf validHeaderV1 = copiedBuffer("PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n", + CharsetUtil.US_ASCII); + ProtocolDetectionResult result = HAProxyMessageDecoder.detectProtocol(validHeaderV1); + assertEquals(ProtocolDetectionState.DETECTED, result.state()); + assertEquals(HAProxyProtocolVersion.V1, result.detectedProtocol()); + validHeaderV1.release(); + + final ByteBuf invalidHeader = copiedBuffer("Invalid header", CharsetUtil.US_ASCII); + result = HAProxyMessageDecoder.detectProtocol(invalidHeader); + assertEquals(ProtocolDetectionState.INVALID, result.state()); + assertNull(result.detectedProtocol()); + invalidHeader.release(); + + final ByteBuf validHeaderV2 = buffer(); + validHeaderV2.writeByte(0x0D); + validHeaderV2.writeByte(0x0A); + validHeaderV2.writeByte(0x0D); + validHeaderV2.writeByte(0x0A); + validHeaderV2.writeByte(0x00); + validHeaderV2.writeByte(0x0D); + validHeaderV2.writeByte(0x0A); + validHeaderV2.writeByte(0x51); + validHeaderV2.writeByte(0x55); + validHeaderV2.writeByte(0x49); + validHeaderV2.writeByte(0x54); + validHeaderV2.writeByte(0x0A); + result = HAProxyMessageDecoder.detectProtocol(validHeaderV2); + assertEquals(ProtocolDetectionState.DETECTED, result.state()); + assertEquals(HAProxyProtocolVersion.V2, result.detectedProtocol()); + validHeaderV2.release(); + + final ByteBuf incompleteHeader = buffer(); + incompleteHeader.writeByte(0x0D); + incompleteHeader.writeByte(0x0A); + incompleteHeader.writeByte(0x0D); + incompleteHeader.writeByte(0x0A); + incompleteHeader.writeByte(0x00); + incompleteHeader.writeByte(0x0D); + incompleteHeader.writeByte(0x0A); + result = HAProxyMessageDecoder.detectProtocol(incompleteHeader); + assertEquals(ProtocolDetectionState.NEEDS_MORE_DATA, result.state()); + assertNull(result.detectedProtocol()); + incompleteHeader.release(); + } } diff --git a/codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java b/codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java new file mode 100644 index 0000000000..d4b435939c --- /dev/null +++ b/codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Result of detecting a protocol. + * + * @param the type of the protocol + */ +public final class ProtocolDetectionResult { + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static final ProtocolDetectionResult NEEDS_MORE_DATE = + new ProtocolDetectionResult(ProtocolDetectionState.NEEDS_MORE_DATA, null); + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static final ProtocolDetectionResult INVALID = + new ProtocolDetectionResult(ProtocolDetectionState.INVALID, null); + + private final ProtocolDetectionState state; + private final T result; + + /** + * Returns a {@link ProtocolDetectionResult} that signals that more data is needed to detect the protocol. + */ + @SuppressWarnings("unchecked") + public static ProtocolDetectionResult needsMoreData() { + return NEEDS_MORE_DATE; + } + + /** + * Returns a {@link ProtocolDetectionResult} that signals the data was invalid for the protocol. + */ + @SuppressWarnings("unchecked") + public static ProtocolDetectionResult invalid() { + return INVALID; + } + + /** + * Returns a {@link ProtocolDetectionResult} which holds the detected protocol. + */ + @SuppressWarnings("unchecked") + public static ProtocolDetectionResult detected(T protocol) { + return new ProtocolDetectionResult(ProtocolDetectionState.DETECTED, checkNotNull(protocol, "protocol")); + } + + private ProtocolDetectionResult(ProtocolDetectionState state, T result) { + this.state = state; + this.result = result; + } + + /** + * Return the {@link ProtocolDetectionState}. If the state is {@link ProtocolDetectionState#DETECTED} you + * can retrieve the protocol via {@link #detectedProtocol()}. + */ + public ProtocolDetectionState state() { + return state; + } + + /** + * Returns the protocol if {@link #state()} returns {@link ProtocolDetectionState#DETECTED}, otherwise {@code null}. + */ + public T detectedProtocol() { + return result; + } +} diff --git a/codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java b/codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java new file mode 100644 index 0000000000..c98e9804ab --- /dev/null +++ b/codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * The state of the current detection. + */ +public enum ProtocolDetectionState { + /** + * Need more data to detect the protocol. + */ + NEEDS_MORE_DATA, + + /** + * The data was invalid. + */ + INVALID, + + /** + * Protocol was detected, + */ + DETECTED +}