Add support for HAProxyMessageEncoder (#10175)

Motivation:

Add support for HAProxyMessageEncoder.
This should help java based HAProxy server implementations propagate proxy information.

Modification:

Add public constructors for `HAProxyMessage`, `HAProxyTLV`, `HAProxySSLTLV`.
Add additional argument checks for `HAProxyMessage` and modify exceptions thrown when creating via public constructors directly.
Introduce a `@Sharable` `HAProxyMessageEncoder` which encodes a `HAProxyMessage` into a byte array.
Add an example `HAProxyServer` and `HAProxyClient` to `io.netty.example`

Result:

Fixes #10164
This commit is contained in:
jrhee17 2020-04-16 16:35:06 +09:00 committed by Norman Maurer
parent 9077acb6ab
commit 9ac005222b
12 changed files with 1053 additions and 51 deletions

View File

@ -56,5 +56,31 @@ final class HAProxyConstants {
static final byte TPAF_UNIX_STREAM_BYTE = 0x31;
static final byte TPAF_UNIX_DGRAM_BYTE = 0x32;
/**
* V2 protocol binary header prefix
*/
static final byte[] BINARY_PREFIX = {
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x00,
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x51,
(byte) 0x55,
(byte) 0x49,
(byte) 0x54,
(byte) 0x0A
};
static final byte[] TEXT_PREFIX = {
(byte) 'P',
(byte) 'R',
(byte) 'O',
(byte) 'X',
(byte) 'Y',
};
private HAProxyConstants() { }
}

View File

@ -26,6 +26,7 @@ import io.netty.util.NetUtil;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetectorFactory;
import io.netty.util.ResourceLeakTracker;
import io.netty.util.internal.StringUtil;
import java.util.ArrayList;
import java.util.Collections;
@ -60,9 +61,16 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
}
/**
* Creates a new instance
* Creates a new instance of HAProxyMessage.
* @param protocolVersion the protocol version.
* @param command the command.
* @param proxiedProtocol the protocol containing the address family and transport protocol.
* @param sourceAddress the source address.
* @param destinationAddress the destination address.
* @param sourcePort the source port. This value must be 0 for unix, unspec addresses.
* @param destinationPort the destination port. This value must be 0 for unix, unspec addresses.
*/
private HAProxyMessage(
public HAProxyMessage(
HAProxyProtocolVersion protocolVersion, HAProxyCommand command, HAProxyProxiedProtocol proxiedProtocol,
String sourceAddress, String destinationAddress, int sourcePort, int destinationPort) {
@ -71,19 +79,29 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
}
/**
* Creates a new instance
* Creates a new instance of HAProxyMessage.
* @param protocolVersion the protocol version.
* @param command the command.
* @param proxiedProtocol the protocol containing the address family and transport protocol.
* @param sourceAddress the source address.
* @param destinationAddress the destination address.
* @param sourcePort the source port. This value must be 0 for unix, unspec addresses.
* @param destinationPort the destination port. This value must be 0 for unix, unspec addresses.
* @param tlvs the list of tlvs.
*/
private HAProxyMessage(
public HAProxyMessage(
HAProxyProtocolVersion protocolVersion, HAProxyCommand command, HAProxyProxiedProtocol proxiedProtocol,
String sourceAddress, String destinationAddress, int sourcePort, int destinationPort,
List<HAProxyTLV> tlvs) {
List<? extends HAProxyTLV> tlvs) {
requireNonNull(protocolVersion, "protocolVersion");
requireNonNull(proxiedProtocol, "proxiedProtocol");
requireNonNull(tlvs, "tlvs");
AddressFamily addrFamily = proxiedProtocol.addressFamily();
checkAddress(sourceAddress, addrFamily);
checkAddress(destinationAddress, addrFamily);
checkPort(sourcePort);
checkPort(destinationPort);
checkPort(sourcePort, addrFamily);
checkPort(destinationPort, addrFamily);
this.protocolVersion = protocolVersion;
this.command = command;
@ -329,9 +347,13 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
throw new HAProxyProtocolException("invalid TCP4/6 header: " + header + " (expected: 6 parts)");
}
try {
return new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY,
protAndFam, parts[2], parts[3], parts[4], parts[5]);
} catch (RuntimeException e) {
throw new HAProxyProtocolException("invalid HAProxy message", e);
}
}
/**
@ -373,18 +395,18 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
*
* @param value the port
* @return port as an integer
* @throws HAProxyProtocolException if port is not a valid integer
* @throws IllegalArgumentException if port is not a valid integer
*/
private static int portStringToInt(String value) {
int port;
try {
port = Integer.parseInt(value);
} catch (NumberFormatException e) {
throw new HAProxyProtocolException("invalid port: " + value, e);
throw new IllegalArgumentException("invalid port: " + value, e);
}
if (port <= 0 || port > 65535) {
throw new HAProxyProtocolException("invalid port: " + value + " (expected: 1 ~ 65535)");
throw new IllegalArgumentException("invalid port: " + value + " (expected: 1 ~ 65535)");
}
return port;
@ -395,7 +417,7 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
*
* @param address human-readable address
* @param addrFamily the {@link AddressFamily} to check the address against
* @throws HAProxyProtocolException if the address is invalid
* @throws IllegalArgumentException if the address is invalid
*/
private static void checkAddress(String address, AddressFamily addrFamily) {
requireNonNull(addrFamily, "addrFamily");
@ -403,10 +425,14 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
switch (addrFamily) {
case AF_UNSPEC:
if (address != null) {
throw new HAProxyProtocolException("unable to validate an AF_UNSPEC address: " + address);
throw new IllegalArgumentException("unable to validate an AF_UNSPEC address: " + address);
}
return;
case AF_UNIX:
requireNonNull(address, "address");
if (address.getBytes(CharsetUtil.US_ASCII).length > 108) {
throw new IllegalArgumentException("invalid AF_UNIX address: " + address);
}
return;
}
@ -415,28 +441,41 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
switch (addrFamily) {
case AF_IPv4:
if (!NetUtil.isValidIpV4Address(address)) {
throw new HAProxyProtocolException("invalid IPv4 address: " + address);
throw new IllegalArgumentException("invalid IPv4 address: " + address);
}
break;
case AF_IPv6:
if (!NetUtil.isValidIpV6Address(address)) {
throw new HAProxyProtocolException("invalid IPv6 address: " + address);
throw new IllegalArgumentException("invalid IPv6 address: " + address);
}
break;
default:
throw new Error();
throw new IllegalArgumentException("unexpected addrFamily: " + addrFamily);
}
}
/**
* Validate a UDP/TCP port
* Validate the port depending on the addrFamily.
*
* @param port the UDP/TCP port
* @throws HAProxyProtocolException if the port is out of range (0-65535 inclusive)
* @throws IllegalArgumentException if the port is out of range (0-65535 inclusive)
*/
private static void checkPort(int port) {
private static void checkPort(int port, AddressFamily addrFamily) {
switch (addrFamily) {
case AF_IPv6:
case AF_IPv4:
if (port < 0 || port > 65535) {
throw new HAProxyProtocolException("invalid port: " + port + " (expected: 1 ~ 65535)");
throw new IllegalArgumentException("invalid port: " + port + " (expected: 0 ~ 65535)");
}
break;
case AF_UNIX:
case AF_UNSPEC:
if (port != 0) {
throw new IllegalArgumentException("port cannot be specified with addrFamily: " + addrFamily);
}
break;
default:
throw new IllegalArgumentException("unexpected addrFamily: " + addrFamily);
}
}
@ -498,6 +537,14 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
return tlvs;
}
int tlvNumBytes() {
int tlvNumBytes = 0;
for (int i = 0; i < tlvs.size(); i++) {
tlvNumBytes += tlvs.get(i).totalNumBytes();
}
return tlvNumBytes;
}
@Override
public HAProxyMessage touch() {
tryRecord();
@ -556,4 +603,26 @@ public final class HAProxyMessage extends AbstractReferenceCounted {
}
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder(256)
.append(StringUtil.simpleClassName(this))
.append("(protocolVersion: ").append(protocolVersion)
.append(", command: ").append(command)
.append(", proxiedProtocol: ").append(proxiedProtocol)
.append(", sourceAddress: ").append(sourceAddress)
.append(", destinationAddress: ").append(destinationAddress)
.append(", sourcePort: ").append(sourcePort)
.append(", destinationPort: ").append(destinationPort)
.append(", tlvs: [");
if (!tlvs.isEmpty()) {
for (HAProxyTLV tlv: tlvs) {
sb.append(tlv).append(", ");
}
sb.setLength(sb.length() - 2);
}
sb.append("])");
return sb.toString();
}
}

View File

@ -21,6 +21,9 @@ import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.ProtocolDetectionResult;
import io.netty.util.CharsetUtil;
import static io.netty.handler.codec.haproxy.HAProxyConstants.*;
/**
* Decodes an HAProxy proxy protocol header
*
@ -47,32 +50,6 @@ public class HAProxyMessageDecoder extends ByteToMessageDecoder {
*/
private static final int V2_MAX_TLV = 65535 - 216;
/**
* Binary header prefix
*/
private static final byte[] BINARY_PREFIX = {
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x00,
(byte) 0x0D,
(byte) 0x0A,
(byte) 0x51,
(byte) 0x55,
(byte) 0x49,
(byte) 0x54,
(byte) 0x0A
};
private static final byte[] TEXT_PREFIX = {
(byte) 'P',
(byte) 'R',
(byte) 'O',
(byte) 'X',
(byte) 'Y',
};
/**
* Binary header prefix length
*/

View File

@ -0,0 +1,134 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.haproxy;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil;
import java.util.List;
import static io.netty.handler.codec.haproxy.HAProxyConstants.*;
/**
* Encodes an HAProxy proxy protocol message
*
* @see <a href="http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt">Proxy Protocol Specification</a>
*/
@Sharable
public final class HAProxyMessageEncoder extends MessageToByteEncoder<HAProxyMessage> {
private static final int V2_VERSION_BITMASK = 0x02 << 4;
// Length for source/destination addresses for the UNIX family must be 108 bytes each.
static final int UNIX_ADDRESS_BYTES_LENGTH = 108;
static final int TOTAL_UNIX_ADDRESS_BYTES_LENGTH = UNIX_ADDRESS_BYTES_LENGTH * 2;
public static final HAProxyMessageEncoder INSTANCE = new HAProxyMessageEncoder();
private HAProxyMessageEncoder() {
}
@Override
protected void encode(ChannelHandlerContext ctx, HAProxyMessage msg, ByteBuf out) throws Exception {
switch (msg.protocolVersion()) {
case V1:
encodeV1(msg, out);
break;
case V2:
encodeV2(msg, out);
break;
default:
throw new HAProxyProtocolException("Unsupported version: " + msg.protocolVersion());
}
}
private static void encodeV1(HAProxyMessage msg, ByteBuf out) {
out.writeBytes(TEXT_PREFIX);
out.writeByte((byte) ' ');
out.writeCharSequence(msg.proxiedProtocol().name(), CharsetUtil.US_ASCII);
out.writeByte((byte) ' ');
out.writeCharSequence(msg.sourceAddress(), CharsetUtil.US_ASCII);
out.writeByte((byte) ' ');
out.writeCharSequence(msg.destinationAddress(), CharsetUtil.US_ASCII);
out.writeByte((byte) ' ');
out.writeCharSequence(String.valueOf(msg.sourcePort()), CharsetUtil.US_ASCII);
out.writeByte((byte) ' ');
out.writeCharSequence(String.valueOf(msg.destinationPort()), CharsetUtil.US_ASCII);
out.writeByte((byte) '\r');
out.writeByte((byte) '\n');
}
private static void encodeV2(HAProxyMessage msg, ByteBuf out) {
out.writeBytes(BINARY_PREFIX);
out.writeByte(V2_VERSION_BITMASK | msg.command().byteValue());
out.writeByte(msg.proxiedProtocol().byteValue());
switch (msg.proxiedProtocol().addressFamily()) {
case AF_IPv4:
case AF_IPv6:
byte[] srcAddrBytes = NetUtil.createByteArrayFromIpAddressString(msg.sourceAddress());
byte[] dstAddrBytes = NetUtil.createByteArrayFromIpAddressString(msg.destinationAddress());
// srcAddrLen + dstAddrLen + 4 (srcPort + dstPort) + numTlvBytes
out.writeShort(srcAddrBytes.length + dstAddrBytes.length + 4 + msg.tlvNumBytes());
out.writeBytes(srcAddrBytes);
out.writeBytes(dstAddrBytes);
out.writeShort(msg.sourcePort());
out.writeShort(msg.destinationPort());
encodeTlvs(msg.tlvs(), out);
break;
case AF_UNIX:
out.writeShort(TOTAL_UNIX_ADDRESS_BYTES_LENGTH + msg.tlvNumBytes());
int srcAddrBytesWritten = out.writeCharSequence(msg.sourceAddress(), CharsetUtil.US_ASCII);
out.writeZero(UNIX_ADDRESS_BYTES_LENGTH - srcAddrBytesWritten);
int dstAddrBytesWritten = out.writeCharSequence(msg.destinationAddress(), CharsetUtil.US_ASCII);
out.writeZero(UNIX_ADDRESS_BYTES_LENGTH - dstAddrBytesWritten);
encodeTlvs(msg.tlvs(), out);
break;
case AF_UNSPEC:
out.writeShort(0);
break;
default:
throw new HAProxyProtocolException("unexpected addrFamily");
}
}
private static void encodeTlv(HAProxyTLV haProxyTLV, ByteBuf out) {
if (haProxyTLV instanceof HAProxySSLTLV) {
HAProxySSLTLV ssltlv = (HAProxySSLTLV) haProxyTLV;
out.writeByte(haProxyTLV.typeByteValue());
out.writeShort(ssltlv.contentNumBytes());
out.writeByte(ssltlv.client());
out.writeInt(ssltlv.verify());
encodeTlvs(ssltlv.encapsulatedTLVs(), out);
} else {
out.writeByte(haProxyTLV.typeByteValue());
ByteBuf value = haProxyTLV.content();
int readableBytes = value.readableBytes();
out.writeShort(readableBytes);
out.writeBytes(value.readSlice(readableBytes));
}
}
private static void encodeTlvs(List<HAProxyTLV> haProxyTLVs, ByteBuf out) {
for (int i = 0; i < haProxyTLVs.size(); i++) {
encodeTlv(haProxyTLVs.get(i), out);
}
}
}

View File

@ -17,6 +17,8 @@
package io.netty.handler.codec.haproxy;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.internal.StringUtil;
import java.util.Collections;
import java.util.List;
@ -35,7 +37,19 @@ public final class HAProxySSLTLV extends HAProxyTLV {
* Creates a new HAProxySSLTLV
*
* @param verify the verification result as defined in the specification for the pp2_tlv_ssl struct (see
* http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt)
* http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
* @param clientBitField the bitfield with client information
* @param tlvs the encapsulated {@link HAProxyTLV}s
*/
public HAProxySSLTLV(final int verify, final byte clientBitField, final List<HAProxyTLV> tlvs) {
this(verify, clientBitField, tlvs, Unpooled.EMPTY_BUFFER);
}
/**
* Creates a new HAProxySSLTLV
*
* @param verify the verification result as defined in the specification for the pp2_tlv_ssl struct (see
* http://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
* @param clientBitField the bitfield with client information
* @param tlvs the encapsulated {@link HAProxyTLV}s
* @param rawContent the raw TLV content
@ -69,6 +83,13 @@ public final class HAProxySSLTLV extends HAProxyTLV {
return (clientBitField & 0x4) != 0;
}
/**
* Returns the client bit field
*/
public byte client() {
return clientBitField;
}
/**
* Returns the verification result
*/
@ -83,4 +104,22 @@ public final class HAProxySSLTLV extends HAProxyTLV {
return tlvs;
}
@Override
int contentNumBytes() {
int tlvNumBytes = 0;
for (int i = 0; i < tlvs.size(); i++) {
tlvNumBytes += tlvs.get(i).totalNumBytes();
}
return 5 + tlvNumBytes; // clientBit(1) + verify(4) + tlvs
}
@Override
public String toString() {
return StringUtil.simpleClassName(this) +
"(type: " + type() +
", typeByteValue: " + typeByteValue() +
", client: " + client() +
", verify: " + verify() +
", numEncapsulatedTlvs: " + tlvs.size() + ')';
}
}

View File

@ -18,6 +18,7 @@ package io.netty.handler.codec.haproxy;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
import io.netty.util.internal.StringUtil;
import static java.util.Objects.requireNonNull;
@ -32,6 +33,18 @@ public class HAProxyTLV extends DefaultByteBufHolder {
private final Type type;
private final byte typeByteValue;
/**
* The size of this tlv in bytes.
* @return the number of bytes.
*/
int totalNumBytes() {
return 3 + contentNumBytes(); // type(1) + length(2) + content
}
int contentNumBytes() {
return content().readableBytes();
}
/**
* The registered types a TLV can have regarding the PROXY protocol 1.5 spec
*/
@ -56,7 +69,7 @@ public class HAProxyTLV extends DefaultByteBufHolder {
*
* @return the {@link Type} of a TLV
*/
public static Type typeForByteValue(final byte byteValue) {
public static Type typeForByteValue(byte byteValue) {
switch (byteValue) {
case 0x01:
return PP2_TYPE_ALPN;
@ -74,6 +87,52 @@ public class HAProxyTLV extends DefaultByteBufHolder {
return OTHER;
}
}
/**
* Returns the byte value for the {@link Type} as defined in the PROXY protocol 1.5 spec.
*
* @param type the {@link Type}
*
* @return the byte value of the {@link Type}.
*/
public static byte byteValueForType(Type type) {
switch (type) {
case PP2_TYPE_ALPN:
return 0x01;
case PP2_TYPE_AUTHORITY:
return 0x02;
case PP2_TYPE_SSL:
return 0x20;
case PP2_TYPE_SSL_VERSION:
return 0x21;
case PP2_TYPE_SSL_CN:
return 0x22;
case PP2_TYPE_NETNS:
return 0x30;
default:
throw new IllegalArgumentException("unknown type: " + type);
}
}
}
/**
* Creates a new HAProxyTLV
*
* @param typeByteValue the byteValue of the TLV. This is especially important if non-standard TLVs are used
* @param content the raw content of the TLV
*/
public HAProxyTLV(byte typeByteValue, ByteBuf content) {
this(Type.typeForByteValue(typeByteValue), typeByteValue, content);
}
/**
* Creates a new HAProxyTLV
*
* @param type the {@link Type} of the TLV
* @param content the raw content of the TLV
*/
public HAProxyTLV(Type type, ByteBuf content) {
this(type, Type.byteValueForType(type), content);
}
/**
@ -148,4 +207,12 @@ public class HAProxyTLV extends DefaultByteBufHolder {
super.touch(hint);
return this;
}
@Override
public String toString() {
return StringUtil.simpleClassName(this) +
"(type: " + type() +
", typeByteValue: " + typeByteValue() +
", content: " + contentToString() + ')';
}
}

View File

@ -0,0 +1,96 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.haproxy;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultithreadEventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalHandler;
import io.netty.channel.local.LocalServerChannel;
import org.junit.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
public class HAProxyIntegrationTest {
@Test
public void testBasicCase() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<HAProxyMessage> msgHolder = new AtomicReference<>();
LocalAddress localAddress = new LocalAddress("HAProxyIntegrationTest");
EventLoopGroup group = new MultithreadEventLoopGroup(LocalHandler.newFactory());
ServerBootstrap sb = new ServerBootstrap();
sb.channel(LocalServerChannel.class)
.group(group)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
ch.pipeline().addLast(new HAProxyMessageDecoder());
ch.pipeline().addLast(new SimpleChannelInboundHandler<HAProxyMessage>() {
@Override
protected void messageReceived(ChannelHandlerContext ctx, HAProxyMessage msg) {
msgHolder.set(msg.retain());
latch.countDown();
}
});
}
});
Channel serverChannel = sb.bind(localAddress).sync().channel();
Bootstrap b = new Bootstrap();
Channel clientChannel = b.channel(LocalChannel.class)
.handler(HAProxyMessageEncoder.INSTANCE)
.group(group)
.connect(localAddress).sync().channel();
try {
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"192.168.0.1", "192.168.0.11", 56324, 443);
clientChannel.writeAndFlush(message).sync();
assertTrue(latch.await(5, TimeUnit.SECONDS));
HAProxyMessage readMessage = msgHolder.get();
assertEquals(message.protocolVersion(), readMessage.protocolVersion());
assertEquals(message.command(), readMessage.command());
assertEquals(message.proxiedProtocol(), readMessage.proxiedProtocol());
assertEquals(message.sourceAddress(), readMessage.sourceAddress());
assertEquals(message.destinationAddress(), readMessage.destinationAddress());
assertEquals(message.sourcePort(), readMessage.sourcePort());
assertEquals(message.destinationPort(), readMessage.destinationPort());
readMessage.release();
} finally {
clientChannel.close().sync();
serverChannel.close().sync();
group.shutdownGracefully().sync();
}
}
}

View File

@ -0,0 +1,404 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.haproxy;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.haproxy.HAProxyTLV.Type;
import io.netty.util.ByteProcessor;
import io.netty.util.CharsetUtil;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static io.netty.handler.codec.haproxy.HAProxyConstants.*;
import static io.netty.handler.codec.haproxy.HAProxyMessageEncoder.*;
import static org.junit.Assert.*;
public class HaProxyMessageEncoderTest {
private static final int V2_HEADER_BYTES_LENGTH = 16;
private static final int IPv4_ADDRESS_BYTES_LENGTH = 12;
private static final int IPv6_ADDRESS_BYTES_LENGTH = 36;
@Test
public void testIPV4EncodeProxyV1() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"192.168.0.1", "192.168.0.11", 56324, 443);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
assertEquals("PROXY TCP4 192.168.0.1 192.168.0.11 56324 443\r\n",
byteBuf.toString(CharsetUtil.US_ASCII));
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testIPV6EncodeProxyV1() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP6,
"2001:0db8:85a3:0000:0000:8a2e:0370:7334", "1050:0:0:0:5:600:300c:326b", 56324, 443);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
assertEquals("PROXY TCP6 2001:0db8:85a3:0000:0000:8a2e:0370:7334 1050:0:0:0:5:600:300c:326b 56324 443\r\n",
byteBuf.toString(CharsetUtil.US_ASCII));
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testIPv4EncodeProxyV2() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"192.168.0.1", "192.168.0.11", 56324, 443);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
// header
byte[] headerBytes = ByteBufUtil.getBytes(byteBuf, 0, 12);
assertArrayEquals(BINARY_PREFIX, headerBytes);
// command
byte commandByte = byteBuf.getByte(12);
assertEquals(0x02, (commandByte & 0xf0) >> 4);
assertEquals(0x01, commandByte & 0x0f);
// transport protocol, address family
byte transportByte = byteBuf.getByte(13);
assertEquals(0x01, (transportByte & 0xf0) >> 4);
assertEquals(0x01, transportByte & 0x0f);
// source address length
int sourceAddrLength = byteBuf.getUnsignedShort(14);
assertEquals(12, sourceAddrLength);
// source address
byte[] sourceAddr = ByteBufUtil.getBytes(byteBuf, 16, 4);
assertArrayEquals(new byte[] { (byte) 0xc0, (byte) 0xa8, 0x00, 0x01 }, sourceAddr);
// destination address
byte[] destAddr = ByteBufUtil.getBytes(byteBuf, 20, 4);
assertArrayEquals(new byte[] { (byte) 0xc0, (byte) 0xa8, 0x00, 0x0b }, destAddr);
// source port
int sourcePort = byteBuf.getUnsignedShort(24);
assertEquals(56324, sourcePort);
// destination port
int destPort = byteBuf.getUnsignedShort(26);
assertEquals(443, destPort);
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testIPv6EncodeProxyV2() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP6,
"2001:0db8:85a3:0000:0000:8a2e:0370:7334", "1050:0:0:0:5:600:300c:326b", 56324, 443);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
// header
byte[] headerBytes = ByteBufUtil.getBytes(byteBuf, 0, 12);
assertArrayEquals(BINARY_PREFIX, headerBytes);
// command
byte commandByte = byteBuf.getByte(12);
assertEquals(0x02, (commandByte & 0xf0) >> 4);
assertEquals(0x01, commandByte & 0x0f);
// transport protocol, address family
byte transportByte = byteBuf.getByte(13);
assertEquals(0x02, (transportByte & 0xf0) >> 4);
assertEquals(0x01, transportByte & 0x0f);
// source address length
int sourceAddrLength = byteBuf.getUnsignedShort(14);
assertEquals(IPv6_ADDRESS_BYTES_LENGTH, sourceAddrLength);
// source address
byte[] sourceAddr = ByteBufUtil.getBytes(byteBuf, 16, 16);
assertArrayEquals(new byte[] {
(byte) 0x20, (byte) 0x01, 0x0d, (byte) 0xb8,
(byte) 0x85, (byte) 0xa3, 0x00, 0x00, 0x00, 0x00, (byte) 0x8a, 0x2e,
0x03, 0x70, 0x73, 0x34
}, sourceAddr);
// destination address
byte[] destAddr = ByteBufUtil.getBytes(byteBuf, 32, 16);
assertArrayEquals(new byte[] {
(byte) 0x10, (byte) 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x05, 0x06, 0x00, 0x30, 0x0c, 0x32, 0x6b
}, destAddr);
// source port
int sourcePort = byteBuf.getUnsignedShort(48);
assertEquals(56324, sourcePort);
// destination port
int destPort = byteBuf.getUnsignedShort(50);
assertEquals(443, destPort);
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testUnixEncodeProxyV2() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNIX_STREAM,
"/var/run/src.sock", "/var/run/dst.sock", 0, 0);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
// header
byte[] headerBytes = ByteBufUtil.getBytes(byteBuf, 0, 12);
assertArrayEquals(BINARY_PREFIX, headerBytes);
// command
byte commandByte = byteBuf.getByte(12);
assertEquals(0x02, (commandByte & 0xf0) >> 4);
assertEquals(0x01, commandByte & 0x0f);
// transport protocol, address family
byte transportByte = byteBuf.getByte(13);
assertEquals(0x03, (transportByte & 0xf0) >> 4);
assertEquals(0x01, transportByte & 0x0f);
// address length
int addrLength = byteBuf.getUnsignedShort(14);
assertEquals(TOTAL_UNIX_ADDRESS_BYTES_LENGTH, addrLength);
// source address
int srcAddrEnd = byteBuf.forEachByte(16, 108, ByteProcessor.FIND_NUL);
assertEquals("/var/run/src.sock",
byteBuf.slice(16, srcAddrEnd - 16).toString(CharsetUtil.US_ASCII));
// destination address
int dstAddrEnd = byteBuf.forEachByte(124, 108, ByteProcessor.FIND_NUL);
assertEquals("/var/run/dst.sock",
byteBuf.slice(124, dstAddrEnd - 124).toString(CharsetUtil.US_ASCII));
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testTLVEncodeProxy() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
List<HAProxyTLV> tlvs = new ArrayList<HAProxyTLV>();
ByteBuf helloWorld = Unpooled.copiedBuffer("hello world", CharsetUtil.US_ASCII);
HAProxyTLV alpnTlv = new HAProxyTLV(Type.PP2_TYPE_ALPN, (byte) 0x01, helloWorld.copy());
tlvs.add(alpnTlv);
ByteBuf arbitrary = Unpooled.copiedBuffer("an arbitrary string", CharsetUtil.US_ASCII);
HAProxyTLV authorityTlv = new HAProxyTLV(Type.PP2_TYPE_AUTHORITY, (byte) 0x01, arbitrary.copy());
tlvs.add(authorityTlv);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"192.168.0.1", "192.168.0.11", 56324, 443, tlvs);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
// length
assertEquals(byteBuf.getUnsignedShort(14), byteBuf.readableBytes() - V2_HEADER_BYTES_LENGTH);
// skip to tlv section
ByteBuf tlv = byteBuf.skipBytes(V2_HEADER_BYTES_LENGTH + IPv4_ADDRESS_BYTES_LENGTH);
// alpn tlv
assertEquals(alpnTlv.typeByteValue(), tlv.readByte());
short bufLength = tlv.readShort();
assertEquals(helloWorld.array().length, bufLength);
assertEquals(helloWorld, tlv.readBytes(bufLength));
// authority tlv
assertEquals(authorityTlv.typeByteValue(), tlv.readByte());
bufLength = tlv.readShort();
assertEquals(arbitrary.array().length, bufLength);
assertEquals(arbitrary, tlv.readBytes(bufLength));
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testSslTLVEncodeProxy() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
List<HAProxyTLV> tlvs = new ArrayList<HAProxyTLV>();
ByteBuf helloWorld = Unpooled.copiedBuffer("hello world", CharsetUtil.US_ASCII);
HAProxyTLV alpnTlv = new HAProxyTLV(Type.PP2_TYPE_ALPN, (byte) 0x01, helloWorld.copy());
tlvs.add(alpnTlv);
ByteBuf arbitrary = Unpooled.copiedBuffer("an arbitrary string", CharsetUtil.US_ASCII);
HAProxyTLV authorityTlv = new HAProxyTLV(Type.PP2_TYPE_AUTHORITY, (byte) 0x01, arbitrary.copy());
tlvs.add(authorityTlv);
ByteBuf sslContent = Unpooled.copiedBuffer("some ssl content", CharsetUtil.US_ASCII);
HAProxySSLTLV haProxySSLTLV = new HAProxySSLTLV(1, (byte) 0x01, tlvs, sslContent.copy());
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"192.168.0.1", "192.168.0.11", 56324, 443,
Collections.<HAProxyTLV>singletonList(haProxySSLTLV));
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
assertEquals(byteBuf.getUnsignedShort(14), byteBuf.readableBytes() - V2_HEADER_BYTES_LENGTH);
ByteBuf tlv = byteBuf.skipBytes(V2_HEADER_BYTES_LENGTH + IPv4_ADDRESS_BYTES_LENGTH);
// ssl tlv type
assertEquals(haProxySSLTLV.typeByteValue(), tlv.readByte());
// length
int bufLength = tlv.readUnsignedShort();
assertEquals(bufLength, tlv.readableBytes());
// client, verify
assertEquals(0x01, byteBuf.readByte());
assertEquals(1, byteBuf.readInt());
// alpn tlv
assertEquals(alpnTlv.typeByteValue(), tlv.readByte());
bufLength = tlv.readShort();
assertEquals(helloWorld.array().length, bufLength);
assertEquals(helloWorld, tlv.readBytes(bufLength));
// authority tlv
assertEquals(authorityTlv.typeByteValue(), tlv.readByte());
bufLength = tlv.readShort();
assertEquals(arbitrary.array().length, bufLength);
assertEquals(arbitrary, tlv.readBytes(bufLength));
byteBuf.release();
assertFalse(ch.finish());
}
@Test
public void testEncodeLocalProxyV2() {
EmbeddedChannel ch = new EmbeddedChannel(INSTANCE);
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.LOCAL, HAProxyProxiedProtocol.UNKNOWN,
null, null, 0, 0);
assertTrue(ch.writeOutbound(message));
ByteBuf byteBuf = ch.readOutbound();
// header
byte[] headerBytes = new byte[12];
byteBuf.readBytes(headerBytes);
assertArrayEquals(BINARY_PREFIX, headerBytes);
// command
byte commandByte = byteBuf.readByte();
assertEquals(0x02, (commandByte & 0xf0) >> 4);
assertEquals(0x00, commandByte & 0x0f);
// transport protocol, address family
byte transportByte = byteBuf.readByte();
assertEquals(0x00, transportByte);
// source address length
int sourceAddrLength = byteBuf.readUnsignedShort();
assertEquals(0, sourceAddrLength);
assertFalse(byteBuf.isReadable());
byteBuf.release();
assertFalse(ch.finish());
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidIpV4Address() {
String invalidIpv4Address = "192.168.0.1234";
new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
invalidIpv4Address, "192.168.0.11", 56324, 443);
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidIpV6Address() {
String invalidIpv6Address = "2001:0db8:85a3:0000:0000:8a2e:0370:73345";
new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP6,
invalidIpv6Address, "1050:0:0:0:5:600:300c:326b", 56324, 443);
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidUnixAddress() {
String invalidUnixAddress = new String(new byte[UNIX_ADDRESS_BYTES_LENGTH + 1]);
new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNIX_STREAM,
invalidUnixAddress, "/var/run/dst.sock", 0, 0);
}
@Test(expected = NullPointerException.class)
public void testNullUnixAddress() {
new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNIX_STREAM,
null, null, 0, 0);
}
@Test(expected = IllegalArgumentException.class)
public void testLongUnixAddress() {
String longUnixAddress = new String(new char[109]).replace("\0", "a");
new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNIX_STREAM,
"source", longUnixAddress, 0, 0);
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidUnixPort() {
new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNIX_STREAM,
"/var/run/src.sock", "/var/run/dst.sock", 80, 443);
}
}

View File

@ -103,6 +103,11 @@
<artifactId>netty-codec-mqtt</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>netty-codec-haproxy</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>

View File

@ -0,0 +1,60 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.example.haproxy;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.util.CharsetUtil;
import static io.netty.example.haproxy.HAProxyServer.*;
public final class HAProxyClient {
private static final String HOST = System.getProperty("host", "127.0.0.1");
public static void main(String[] args) throws Exception {
EventLoopGroup group = new NioEventLoopGroup();
try {
Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.handler(new HAProxyHandler());
// Start the connection attempt.
Channel ch = b.connect(HOST, PORT).sync().channel();
HAProxyMessage message = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4,
"127.0.0.1", "127.0.0.2", 8000, 9000);
ch.writeAndFlush(message).sync();
ch.writeAndFlush(Unpooled.copiedBuffer("Hello World!", CharsetUtil.US_ASCII)).sync();
ch.writeAndFlush(Unpooled.copiedBuffer("Bye now!", CharsetUtil.US_ASCII)).sync();
ch.close().sync();
} finally {
group.shutdownGracefully();
}
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.example.haproxy;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
public class HAProxyHandler extends ChannelOutboundHandlerAdapter {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
ctx.pipeline().addBefore(ctx.name(), null, HAProxyMessageEncoder.INSTANCE);
super.handlerAdded(ctx);
}
@Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
ChannelFuture future = ctx.write(msg, promise);
if (msg instanceof HAProxyMessage) {
future.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
ctx.pipeline().remove(HAProxyMessageEncoder.INSTANCE);
ctx.pipeline().remove(HAProxyHandler.this);
} else {
ctx.close();
}
}
});
}
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright 2020 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.example.haproxy;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultithreadEventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
public final class HAProxyServer {
static final int PORT = Integer.parseInt(System.getProperty("port", "8080"));
public static void main(String[] args) throws Exception {
EventLoopGroup bossGroup = new MultithreadEventLoopGroup(1, NioHandler.newFactory());
EventLoopGroup workerGroup = new MultithreadEventLoopGroup(NioHandler.newFactory());
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new HAProxyServerInitializer());
b.bind(PORT).sync().channel().closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
static class HAProxyServerInitializer extends ChannelInitializer<SocketChannel> {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast(
new LoggingHandler(LogLevel.DEBUG),
new HAProxyMessageDecoder(),
new SimpleChannelInboundHandler<Object>() {
@Override
protected void messageReceived(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof HAProxyMessage) {
System.out.println("proxy message: " + msg);
} else if (msg instanceof ByteBuf) {
System.out.println("bytebuf message: " + ByteBufUtil.prettyHexDump((ByteBuf) msg));
}
}
});
}
}
}