Fixed the haproxy message mem leak issue (#9250)

Motivation:

HAProxyMessage should be released as it contains a list of TLV which hold a ByteBuf, otherwise, it may cause memory leaks.

Modification:

- Let HAProxyMessage extend AbstractReferenceCounted
- Adjust tests.

Result:

Fixes #9201
This commit is contained in:
秦世成 2019-06-24 16:38:58 +08:00 committed by Norman Maurer
parent bf72d6d2d9
commit dd6ba294e0
2 changed files with 148 additions and 51 deletions

View File

@ -19,9 +19,13 @@ import static java.util.Objects.requireNonNull;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol.AddressFamily;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ByteProcessor; import io.netty.util.ByteProcessor;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil; import io.netty.util.NetUtil;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetectorFactory;
import io.netty.util.ResourceLeakTracker;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -30,29 +34,11 @@ import java.util.List;
/** /**
* Message container for decoded HAProxy proxy protocol parameters * Message container for decoded HAProxy proxy protocol parameters
*/ */
public final class HAProxyMessage { public final class HAProxyMessage extends AbstractReferenceCounted {
private static final ResourceLeakDetector<HAProxyMessage> leakDetector =
/** ResourceLeakDetectorFactory.instance().newResourceLeakDetector(HAProxyMessage.class);
* Version 1 proxy protocol message for 'UNKNOWN' proxied protocols. Per spec, when the proxied protocol is
* 'UNKNOWN' we must discard all other header values.
*/
private static final HAProxyMessage V1_UNKNOWN_MSG = new HAProxyMessage(
HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNKNOWN, null, null, 0, 0);
/**
* Version 2 proxy protocol message for 'UNKNOWN' proxied protocols. Per spec, when the proxied protocol is
* 'UNKNOWN' we must discard all other header values.
*/
private static final HAProxyMessage V2_UNKNOWN_MSG = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.UNKNOWN, null, null, 0, 0);
/**
* Version 2 proxy protocol message for local requests. Per spec, we should use an unspecified protocol and family
* for 'LOCAL' commands. Per spec, when the proxied protocol is 'UNKNOWN' we must discard all other header values.
*/
private static final HAProxyMessage V2_LOCAL_MSG = new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.LOCAL, HAProxyProxiedProtocol.UNKNOWN, null, null, 0, 0);
private final ResourceLeakTracker<HAProxyMessage> leak;
private final HAProxyProtocolVersion protocolVersion; private final HAProxyProtocolVersion protocolVersion;
private final HAProxyCommand command; private final HAProxyCommand command;
private final HAProxyProxiedProtocol proxiedProtocol; private final HAProxyProxiedProtocol proxiedProtocol;
@ -107,6 +93,8 @@ public final class HAProxyMessage {
this.sourcePort = sourcePort; this.sourcePort = sourcePort;
this.destinationPort = destinationPort; this.destinationPort = destinationPort;
this.tlvs = Collections.unmodifiableList(tlvs); this.tlvs = Collections.unmodifiableList(tlvs);
leak = leakDetector.track(this);
} }
/** /**
@ -147,7 +135,7 @@ public final class HAProxyMessage {
} }
if (cmd == HAProxyCommand.LOCAL) { if (cmd == HAProxyCommand.LOCAL) {
return V2_LOCAL_MSG; return unknownMsg(HAProxyProtocolVersion.V2, HAProxyCommand.LOCAL);
} }
// Per spec, the 14th byte is the protocol and address family byte // Per spec, the 14th byte is the protocol and address family byte
@ -159,7 +147,7 @@ public final class HAProxyMessage {
} }
if (protAndFam == HAProxyProxiedProtocol.UNKNOWN) { if (protAndFam == HAProxyProxiedProtocol.UNKNOWN) {
return V2_UNKNOWN_MSG; return unknownMsg(HAProxyProtocolVersion.V2, HAProxyCommand.PROXY);
} }
int addressInfoLen = header.readUnsignedShort(); int addressInfoLen = header.readUnsignedShort();
@ -334,7 +322,7 @@ public final class HAProxyMessage {
} }
if (protAndFam == HAProxyProxiedProtocol.UNKNOWN) { if (protAndFam == HAProxyProxiedProtocol.UNKNOWN) {
return V1_UNKNOWN_MSG; return unknownMsg(HAProxyProtocolVersion.V1, HAProxyCommand.PROXY);
} }
if (numParts != 6) { if (numParts != 6) {
@ -346,6 +334,14 @@ public final class HAProxyMessage {
protAndFam, parts[2], parts[3], parts[4], parts[5]); protAndFam, parts[2], parts[3], parts[4], parts[5]);
} }
/**
* Proxy protocol message for 'UNKNOWN' proxied protocols. Per spec, when the proxied protocol is
* 'UNKNOWN' we must discard all other header values.
*/
private static HAProxyMessage unknownMsg(HAProxyProtocolVersion version, HAProxyCommand command) {
return new HAProxyMessage(version, command, HAProxyProxiedProtocol.UNKNOWN, null, null, 0, 0);
}
/** /**
* Convert ip address bytes to string representation * Convert ip address bytes to string representation
* *
@ -355,31 +351,20 @@ public final class HAProxyMessage {
*/ */
private static String ipBytesToString(ByteBuf header, int addressLen) { private static String ipBytesToString(ByteBuf header, int addressLen) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
if (addressLen == 4) { final int ipv4Len = 4;
sb.append(header.readByte() & 0xff); final int ipv6Len = 8;
sb.append('.'); if (addressLen == ipv4Len) {
sb.append(header.readByte() & 0xff); for (int i = 0; i < ipv4Len; i++) {
sb.append('.'); sb.append(header.readByte() & 0xff);
sb.append(header.readByte() & 0xff); sb.append('.');
sb.append('.'); }
sb.append(header.readByte() & 0xff);
} else { } else {
sb.append(Integer.toHexString(header.readUnsignedShort())); for (int i = 0; i < ipv6Len; i++) {
sb.append(':'); sb.append(Integer.toHexString(header.readUnsignedShort()));
sb.append(Integer.toHexString(header.readUnsignedShort())); sb.append(':');
sb.append(':'); }
sb.append(Integer.toHexString(header.readUnsignedShort()));
sb.append(':');
sb.append(Integer.toHexString(header.readUnsignedShort()));
sb.append(':');
sb.append(Integer.toHexString(header.readUnsignedShort()));
sb.append(':');
sb.append(Integer.toHexString(header.readUnsignedShort()));
sb.append(':');
sb.append(Integer.toHexString(header.readUnsignedShort()));
sb.append(':');
sb.append(Integer.toHexString(header.readUnsignedShort()));
} }
sb.setLength(sb.length() - 1);
return sb.toString(); return sb.toString();
} }
@ -512,4 +497,63 @@ public final class HAProxyMessage {
public List<HAProxyTLV> tlvs() { public List<HAProxyTLV> tlvs() {
return tlvs; return tlvs;
} }
@Override
public HAProxyMessage touch() {
tryRecord();
return (HAProxyMessage) super.touch();
}
@Override
public HAProxyMessage touch(Object hint) {
if (leak != null) {
leak.record(hint);
}
return this;
}
@Override
public HAProxyMessage retain() {
tryRecord();
return (HAProxyMessage) super.retain();
}
@Override
public HAProxyMessage retain(int increment) {
tryRecord();
return (HAProxyMessage) super.retain(increment);
}
@Override
public boolean release() {
tryRecord();
return super.release();
}
@Override
public boolean release(int decrement) {
tryRecord();
return super.release(decrement);
}
private void tryRecord() {
if (leak != null) {
leak.record();
}
}
@Override
protected void deallocate() {
try {
for (HAProxyTLV tlv : tlvs) {
tlv.release();
}
} finally {
final ResourceLeakTracker<HAProxyMessage> leak = this.leak;
if (leak != null) {
boolean closed = leak.close(this);
assert closed;
}
}
}
} }

View File

@ -58,6 +58,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(443, msg.destinationPort()); assertEquals(443, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -78,6 +79,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(443, msg.destinationPort()); assertEquals(443, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -98,6 +100,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(0, msg.destinationPort()); assertEquals(0, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test(expected = HAProxyProtocolException.class) @Test(expected = HAProxyProtocolException.class)
@ -264,6 +267,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(443, msg.destinationPort()); assertEquals(443, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -319,6 +323,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(443, msg.destinationPort()); assertEquals(443, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -398,6 +403,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(443, msg.destinationPort()); assertEquals(443, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -476,6 +482,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(0, msg.destinationPort()); assertEquals(0, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -531,6 +538,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(0, msg.destinationPort()); assertEquals(0, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -586,6 +594,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(0, msg.destinationPort()); assertEquals(0, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test @Test
@ -642,9 +651,7 @@ public class HAProxyMessageDecoderTest {
assertTrue(0 < firstTlv.refCnt()); assertTrue(0 < firstTlv.refCnt());
assertTrue(0 < secondTlv.refCnt()); assertTrue(0 < secondTlv.refCnt());
assertTrue(0 < thirdTLV.refCnt()); assertTrue(0 < thirdTLV.refCnt());
assertFalse(thirdTLV.release()); assertTrue(msg.release());
assertFalse(secondTlv.release());
assertTrue(firstTlv.release());
assertEquals(0, firstTlv.refCnt()); assertEquals(0, firstTlv.refCnt());
assertEquals(0, secondTlv.refCnt()); assertEquals(0, secondTlv.refCnt());
assertEquals(0, thirdTLV.refCnt()); assertEquals(0, thirdTLV.refCnt());
@ -653,6 +660,51 @@ public class HAProxyMessageDecoderTest {
assertFalse(ch.finish()); assertFalse(ch.finish());
} }
@Test
public void testReleaseHAProxyMessage() {
ch = new EmbeddedChannel(new HAProxyMessageDecoder());
final byte[] bytes = {
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 35, 127, 0, 0, 1, 127, 0, 0, 1,
-55, -90, 7, 89, 32, 0, 20, 5, 0, 0, 0, 0, 33, 0, 5, 84, 76, 83, 118, 49, 34, 0, 4, 76, 69, 65, 70
};
int startChannels = ch.pipeline().names().size();
assertTrue(ch.writeInbound(copiedBuffer(bytes)));
Object msgObj = ch.readInbound();
assertEquals(startChannels - 1, ch.pipeline().names().size());
HAProxyMessage msg = (HAProxyMessage) msgObj;
final List<HAProxyTLV> tlvs = msg.tlvs();
assertEquals(3, tlvs.size());
assertEquals(1, msg.refCnt());
for (HAProxyTLV tlv : tlvs) {
assertEquals(3, tlv.refCnt());
}
// Retain the haproxy message
msg.retain();
assertEquals(2, msg.refCnt());
for (HAProxyTLV tlv : tlvs) {
assertEquals(3, tlv.refCnt());
}
// Decrease the haproxy message refCnt
msg.release();
assertEquals(1, msg.refCnt());
for (HAProxyTLV tlv : tlvs) {
assertEquals(3, tlv.refCnt());
}
// Release haproxy message, TLVs will be released with it
msg.release();
assertEquals(0, msg.refCnt());
for (HAProxyTLV tlv : tlvs) {
assertEquals(0, tlv.refCnt());
}
}
@Test @Test
public void testV2WithTLV() { public void testV2WithTLV() {
ch = new EmbeddedChannel(new HAProxyMessageDecoder(4)); ch = new EmbeddedChannel(new HAProxyMessageDecoder(4));
@ -738,6 +790,7 @@ public class HAProxyMessageDecoderTest {
assertEquals(0, msg.destinationPort()); assertEquals(0, msg.destinationPort());
assertNull(ch.readInbound()); assertNull(ch.readInbound());
assertFalse(ch.finish()); assertFalse(ch.finish());
assertTrue(msg.release());
} }
@Test(expected = HAProxyProtocolException.class) @Test(expected = HAProxyProtocolException.class)