Add ProtocolDetectionResult and use it in HAProxyMessageDecoder for allow detect HAProxy protocol.

Motivation:

Sometimes it is useful to detect if a ByteBuf contains a HAProxy header, for example if you want to write something like the PortUnification example.

Modifications:

- Add ProtocolDetectionResult which can be used as a return type for detecting different protocol.
- Add new method which allows to detect HA Proxy messages.

Result:

Easier to detect protocol.
This commit is contained in:
Norman Maurer 2015-05-29 09:26:57 +02:00
parent bad8e0d6ab
commit c53dbb748e
4 changed files with 217 additions and 9 deletions

View File

@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LineBasedFrameDecoder; import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.ProtocolDetectionResult;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import java.util.List; import java.util.List;
@ -72,11 +73,31 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
(byte) 0x0A (byte) 0x0A
}; };
private static final byte[] TEXT_PREFIX = {
(byte) 'P',
(byte) 'R',
(byte) 'O',
(byte) 'X',
(byte) 'Y',
};
/** /**
* Binary header prefix length * Binary header prefix length
*/ */
private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length; private static final int BINARY_PREFIX_LENGTH = BINARY_PREFIX.length;
/**
* {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V1}.
*/
private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V1 =
ProtocolDetectionResult.detected(HAProxyProtocolVersion.V1);
/**
* {@link ProtocolDetectionResult} for {@link HAProxyProtocolVersion#V2}.
*/
private static final ProtocolDetectionResult<HAProxyProtocolVersion> DETECTION_RESULT_V2 =
ProtocolDetectionResult.detected(HAProxyProtocolVersion.V2);
/** /**
* {@code true} if we're discarding input because we're already over maxLength * {@code true} if we're discarding input because we're already over maxLength
*/ */
@ -147,15 +168,7 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
} }
int idx = buffer.readerIndex(); int idx = buffer.readerIndex();
return match(BINARY_PREFIX, buffer, idx) ? buffer.getByte(idx + BINARY_PREFIX_LENGTH) : 1;
for (int i = 0; i < BINARY_PREFIX_LENGTH; i++) {
final byte b = buffer.getByte(idx + i);
if (b != BINARY_PREFIX[i]) {
return 1;
}
}
return buffer.getByte(idx + BINARY_PREFIX_LENGTH);
} }
/** /**
@ -357,4 +370,33 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
} }
throw ppex; throw ppex;
} }
/**
* Returns the {@link ProtocolDetectionResult} for the given {@link ByteBuf}.
*/
public static ProtocolDetectionResult<HAProxyProtocolVersion> detectProtocol(ByteBuf buffer) {
if (buffer.readableBytes() < 12) {
return ProtocolDetectionResult.needsMoreData();
}
int idx = buffer.readerIndex();
if (match(BINARY_PREFIX, buffer, idx)) {
return DETECTION_RESULT_V2;
}
if (match(TEXT_PREFIX, buffer, idx)) {
return DETECTION_RESULT_V1;
}
return ProtocolDetectionResult.invalid();
}
private static boolean match(byte[] prefix, ByteBuf buffer, int idx) {
for (int i = 0; i < prefix.length; i++) {
final byte b = buffer.getByte(idx + i);
if (b != prefix[i]) {
return false;
}
}
return true;
}
} }

View File

@ -15,8 +15,11 @@
*/ */
package io.netty.handler.codec.haproxy; package io.netty.handler.codec.haproxy;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.ProtocolDetectionResult;
import io.netty.handler.codec.ProtocolDetectionState;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.TransportProtocol; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.TransportProtocol;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
@ -896,4 +899,51 @@ public class HAProxyMessageDecoderTest {
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
} }
@Test
public void testDetectProtocol() {
final ByteBuf validHeaderV1 = copiedBuffer("PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n",
CharsetUtil.US_ASCII);
ProtocolDetectionResult<HAProxyProtocolVersion> result = HAProxyMessageDecoder.detectProtocol(validHeaderV1);
assertEquals(ProtocolDetectionState.DETECTED, result.state());
assertEquals(HAProxyProtocolVersion.V1, result.detectedProtocol());
validHeaderV1.release();
final ByteBuf invalidHeader = copiedBuffer("Invalid header", CharsetUtil.US_ASCII);
result = HAProxyMessageDecoder.detectProtocol(invalidHeader);
assertEquals(ProtocolDetectionState.INVALID, result.state());
assertNull(result.detectedProtocol());
invalidHeader.release();
final ByteBuf validHeaderV2 = buffer();
validHeaderV2.writeByte(0x0D);
validHeaderV2.writeByte(0x0A);
validHeaderV2.writeByte(0x0D);
validHeaderV2.writeByte(0x0A);
validHeaderV2.writeByte(0x00);
validHeaderV2.writeByte(0x0D);
validHeaderV2.writeByte(0x0A);
validHeaderV2.writeByte(0x51);
validHeaderV2.writeByte(0x55);
validHeaderV2.writeByte(0x49);
validHeaderV2.writeByte(0x54);
validHeaderV2.writeByte(0x0A);
result = HAProxyMessageDecoder.detectProtocol(validHeaderV2);
assertEquals(ProtocolDetectionState.DETECTED, result.state());
assertEquals(HAProxyProtocolVersion.V2, result.detectedProtocol());
validHeaderV2.release();
final ByteBuf incompleteHeader = buffer();
incompleteHeader.writeByte(0x0D);
incompleteHeader.writeByte(0x0A);
incompleteHeader.writeByte(0x0D);
incompleteHeader.writeByte(0x0A);
incompleteHeader.writeByte(0x00);
incompleteHeader.writeByte(0x0D);
incompleteHeader.writeByte(0x0A);
result = HAProxyMessageDecoder.detectProtocol(incompleteHeader);
assertEquals(ProtocolDetectionState.NEEDS_MORE_DATA, result.state());
assertNull(result.detectedProtocol());
incompleteHeader.release();
}
} }

View File

@ -0,0 +1,80 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec;
import static io.netty.util.internal.ObjectUtil.checkNotNull;
/**
* Result of detecting a protocol.
*
* @param <T> the type of the protocol
*/
public final class ProtocolDetectionResult<T> {
@SuppressWarnings({ "rawtypes", "unchecked" })
private static final ProtocolDetectionResult NEEDS_MORE_DATE =
new ProtocolDetectionResult(ProtocolDetectionState.NEEDS_MORE_DATA, null);
@SuppressWarnings({ "rawtypes", "unchecked" })
private static final ProtocolDetectionResult INVALID =
new ProtocolDetectionResult(ProtocolDetectionState.INVALID, null);
private final ProtocolDetectionState state;
private final T result;
/**
* Returns a {@link ProtocolDetectionResult} that signals that more data is needed to detect the protocol.
*/
@SuppressWarnings("unchecked")
public static <T> ProtocolDetectionResult<T> needsMoreData() {
return NEEDS_MORE_DATE;
}
/**
* Returns a {@link ProtocolDetectionResult} that signals the data was invalid for the protocol.
*/
@SuppressWarnings("unchecked")
public static <T> ProtocolDetectionResult<T> invalid() {
return INVALID;
}
/**
* Returns a {@link ProtocolDetectionResult} which holds the detected protocol.
*/
@SuppressWarnings("unchecked")
public static <T> ProtocolDetectionResult<T> detected(T protocol) {
return new ProtocolDetectionResult<T>(ProtocolDetectionState.DETECTED, checkNotNull(protocol, "protocol"));
}
private ProtocolDetectionResult(ProtocolDetectionState state, T result) {
this.state = state;
this.result = result;
}
/**
* Return the {@link ProtocolDetectionState}. If the state is {@link ProtocolDetectionState#DETECTED} you
* can retrieve the protocol via {@link #detectedProtocol()}.
*/
public ProtocolDetectionState state() {
return state;
}
/**
* Returns the protocol if {@link #state()} returns {@link ProtocolDetectionState#DETECTED}, otherwise {@code null}.
*/
public T detectedProtocol() {
return result;
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright 2015 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec;
/**
* The state of the current detection.
*/
public enum ProtocolDetectionState {
/**
* Need more data to detect the protocol.
*/
NEEDS_MORE_DATA,
/**
* The data was invalid.
*/
INVALID,
/**
* Protocol was detected,
*/
DETECTED
}