Ensure "full" ownership of msgs passed to EmbeddedChannel.writeInbound() (#9058)

Motivation

Pipeline handlers are free to "take control" of input buffers if they have singular refcount - in particular to mutate their raw data if non-readonly via discarding of read bytes, etc.

However there are various places (primarily unit tests) where a wrapped byte-array buffer is passed in and the wrapped array is assumed not to change (used after the wrapped buffer is passed to EmbeddedChannel.writeInbound()). This invalid assumption could result in unexpected errors, such as those exposed by #8931.

Modifications

Anywhere that the data passed to writeInbound() might be used again, ensure that either:
- A copy is used rather than wrapping a shared byte array, or
- The buffer is otherwise protected from modification by making it read-only

For the tests, copying is preferred since it still allows the "mutating" optimizations to be exercised.

Results

Avoid possible errors when pipeline assumes it has full control of input buffer.
This commit is contained in:
Nick Hill 2019-05-22 03:08:49 -07:00 committed by Norman Maurer
parent dea4e33c52
commit 2ca526fac6
10 changed files with 39 additions and 37 deletions

View File

@ -40,7 +40,9 @@ import static io.netty.util.internal.ObjectUtil.*;
*/
abstract class DeflateDecoder extends WebSocketExtensionDecoder {
static final byte[] FRAME_TAIL = new byte[] {0x00, 0x00, (byte) 0xff, (byte) 0xff};
static final ByteBuf FRAME_TAIL = Unpooled.unreleasableBuffer(
Unpooled.wrappedBuffer(new byte[] {0x00, 0x00, (byte) 0xff, (byte) 0xff}))
.asReadOnly();
private final boolean noContext;
private final WebSocketExtensionFilter extensionDecoderFilter;
@ -81,7 +83,7 @@ abstract class DeflateDecoder extends WebSocketExtensionDecoder {
boolean readable = msg.content().isReadable();
decoder.writeInbound(msg.content().retain());
if (appendFrameTail(msg)) {
decoder.writeInbound(Unpooled.wrappedBuffer(FRAME_TAIL));
decoder.writeInbound(FRAME_TAIL.duplicate());
}
CompositeByteBuf compositeUncompressedContent = ctx.alloc().compositeBuffer();

View File

@ -114,7 +114,7 @@ abstract class DeflateEncoder extends WebSocketExtensionEncoder {
ByteBuf compressedContent;
if (removeFrameTail(msg)) {
int realLength = fullCompressedContent.readableBytes() - FRAME_TAIL.length;
int realLength = fullCompressedContent.readableBytes() - FRAME_TAIL.readableBytes();
compressedContent = fullCompressedContent.slice(0, realLength);
} else {
compressedContent = fullCompressedContent;

View File

@ -178,7 +178,7 @@ public class HttpContentDecoderTest {
assertThat(o, is(instanceOf(FullHttpResponse.class)));
FullHttpResponse r = (FullHttpResponse) o;
assertEquals(100, r.status().code());
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(GZ_HELLO_WORLD)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD)));
r.release();
assertHasInboundMessages(channel, true);
@ -205,7 +205,7 @@ public class HttpContentDecoderTest {
FullHttpResponse r = (FullHttpResponse) o;
assertEquals(100, r.status().code());
r.release();
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(GZ_HELLO_WORLD)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD)));
assertHasInboundMessages(channel, true);
assertHasOutboundMessages(channel, false);
@ -232,7 +232,7 @@ public class HttpContentDecoderTest {
FullHttpResponse r = (FullHttpResponse) o;
assertEquals(100, r.status().code());
r.release();
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(GZ_HELLO_WORLD)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD)));
assertHasInboundMessages(channel, true);
assertHasOutboundMessages(channel, false);
@ -259,7 +259,7 @@ public class HttpContentDecoderTest {
FullHttpResponse r = (FullHttpResponse) o;
assertEquals(100, r.status().code());
r.release();
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(GZ_HELLO_WORLD)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD)));
assertHasInboundMessages(channel, true);
assertHasOutboundMessages(channel, false);
@ -566,7 +566,7 @@ public class HttpContentDecoderTest {
private static byte[] gzDecompress(byte[] input) {
ZlibDecoder decoder = ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP);
EmbeddedChannel channel = new EmbeddedChannel(decoder);
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(input)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(input)));
assertTrue(channel.finish()); // close the channel to indicate end-of-data
int outputSize = 0;

View File

@ -81,7 +81,7 @@ public class HttpRequestDecoderTest {
private static void testDecodeWholeRequestAtOnce(byte[] content) {
EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder());
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(content)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(content)));
HttpRequest req = channel.readInbound();
assertNotNull(req);
checkHeaders(req.headers());
@ -146,13 +146,13 @@ public class HttpRequestDecoderTest {
}
// if header is done it should produce a HttpRequest
channel.writeInbound(Unpooled.wrappedBuffer(content, a, amount));
channel.writeInbound(Unpooled.copiedBuffer(content, a, amount));
a += amount;
}
for (int i = CONTENT_LENGTH; i > 0; i --) {
// Should produce HttpContent
channel.writeInbound(Unpooled.wrappedBuffer(content, content.length - i, 1));
channel.writeInbound(Unpooled.copiedBuffer(content, content.length - i, 1));
}
HttpRequest req = channel.readInbound();

View File

@ -122,7 +122,7 @@ public class HttpResponseDecoderTest {
for (int i = 0; i < 10; i++) {
assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(data.length) + "\r\n",
CharsetUtil.US_ASCII)));
assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));
assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data)));
HttpContent content = ch.readInbound();
assertEquals(data.length, content.content().readableBytes());
@ -164,7 +164,7 @@ public class HttpResponseDecoderTest {
for (int i = 0; i < 10; i++) {
assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(data.length) + "\r\n",
CharsetUtil.US_ASCII)));
assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data)));
assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data)));
byte[] decodedData = new byte[data.length];
HttpContent content = ch.readInbound();
@ -448,11 +448,11 @@ public class HttpResponseDecoderTest {
// if header is done it should produce a HttpRequest
boolean headerDone = a + amount == headerLength;
assertEquals(headerDone, ch.writeInbound(Unpooled.wrappedBuffer(content, a, amount)));
assertEquals(headerDone, ch.writeInbound(Unpooled.copiedBuffer(content, a, amount)));
a += amount;
}
ch.writeInbound(Unpooled.wrappedBuffer(content, headerLength, content.length - headerLength));
ch.writeInbound(Unpooled.copiedBuffer(content, headerLength, content.length - headerLength));
HttpResponse res = ch.readInbound();
assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
assertThat(res.status(), is(HttpResponseStatus.OK));
@ -483,8 +483,8 @@ public class HttpResponseDecoderTest {
for (int i = 0; i < data.length; i++) {
data[i] = (byte) i;
}
ch.writeInbound(Unpooled.wrappedBuffer(data, 0, data.length / 2));
ch.writeInbound(Unpooled.wrappedBuffer(data, 5, data.length / 2));
ch.writeInbound(Unpooled.copiedBuffer(data, 0, data.length / 2));
ch.writeInbound(Unpooled.copiedBuffer(data, 5, data.length / 2));
HttpResponse res = ch.readInbound();
assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
@ -492,12 +492,12 @@ public class HttpResponseDecoderTest {
HttpContent firstContent = ch.readInbound();
assertThat(firstContent.content().readableBytes(), is(5));
assertEquals(Unpooled.wrappedBuffer(data, 0, 5), firstContent.content());
assertEquals(Unpooled.copiedBuffer(data, 0, 5), firstContent.content());
firstContent.release();
LastHttpContent lastContent = ch.readInbound();
assertEquals(5, lastContent.content().readableBytes());
assertEquals(Unpooled.wrappedBuffer(data, 5, 5), lastContent.content());
assertEquals(Unpooled.copiedBuffer(data, 5, 5), lastContent.content());
lastContent.release();
assertThat(ch.finish(), is(false));
@ -524,15 +524,15 @@ public class HttpResponseDecoderTest {
amount = header.length - a;
}
ch.writeInbound(Unpooled.wrappedBuffer(header, a, amount));
ch.writeInbound(Unpooled.copiedBuffer(header, a, amount));
a += amount;
}
byte[] data = new byte[10];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) i;
}
ch.writeInbound(Unpooled.wrappedBuffer(data, 0, data.length / 2));
ch.writeInbound(Unpooled.wrappedBuffer(data, 5, data.length / 2));
ch.writeInbound(Unpooled.copiedBuffer(data, 0, data.length / 2));
ch.writeInbound(Unpooled.copiedBuffer(data, 5, data.length / 2));
HttpResponse res = ch.readInbound();
assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
@ -589,7 +589,7 @@ public class HttpResponseDecoderTest {
byte[] otherData = {1, 2, 3, 4};
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.wrappedBuffer(data, otherData));
ch.writeInbound(Unpooled.copiedBuffer(data, otherData));
HttpResponse res = ch.readInbound();
assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1));
@ -625,7 +625,7 @@ public class HttpResponseDecoderTest {
EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
ch.writeInbound(Unpooled.wrappedBuffer(data));
ch.writeInbound(Unpooled.copiedBuffer(data));
// Garbage input should generate the 999 Unknown response.
HttpResponse res = ch.readInbound();
@ -636,7 +636,7 @@ public class HttpResponseDecoderTest {
assertThat(ch.readInbound(), is(nullValue()));
// More garbage should not generate anything (i.e. the decoder discards anything beyond this point.)
ch.writeInbound(Unpooled.wrappedBuffer(data));
ch.writeInbound(Unpooled.copiedBuffer(data));
assertThat(ch.readInbound(), is(nullValue()));
// Closing the connection should not generate anything since the protocol has been violated.

View File

@ -57,7 +57,7 @@ public class PerFrameDeflateEncoderTest {
assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame.rsv());
assertTrue(decoderChannel.writeInbound(compressedFrame.content()));
assertTrue(decoderChannel.writeInbound(Unpooled.wrappedBuffer(DeflateDecoder.FRAME_TAIL)));
assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate()));
ByteBuf uncompressedPayload = decoderChannel.readInbound();
assertEquals(300, uncompressedPayload.readableBytes());
@ -135,7 +135,7 @@ public class PerFrameDeflateEncoderTest {
assertTrue(compressedFrame3.isFinalFragment());
assertTrue(decoderChannel.writeInbound(compressedFrame1.content()));
assertTrue(decoderChannel.writeInbound(Unpooled.wrappedBuffer(DeflateDecoder.FRAME_TAIL)));
assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate()));
ByteBuf uncompressedPayload1 = decoderChannel.readInbound();
byte[] finalPayload1 = new byte[100];
uncompressedPayload1.readBytes(finalPayload1);
@ -143,7 +143,7 @@ public class PerFrameDeflateEncoderTest {
uncompressedPayload1.release();
assertTrue(decoderChannel.writeInbound(compressedFrame2.content()));
assertTrue(decoderChannel.writeInbound(Unpooled.wrappedBuffer(DeflateDecoder.FRAME_TAIL)));
assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate()));
ByteBuf uncompressedPayload2 = decoderChannel.readInbound();
byte[] finalPayload2 = new byte[100];
uncompressedPayload2.readBytes(finalPayload2);
@ -151,7 +151,7 @@ public class PerFrameDeflateEncoderTest {
uncompressedPayload2.release();
assertTrue(decoderChannel.writeInbound(compressedFrame3.content()));
assertTrue(decoderChannel.writeInbound(Unpooled.wrappedBuffer(DeflateDecoder.FRAME_TAIL)));
assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate()));
ByteBuf uncompressedPayload3 = decoderChannel.readInbound();
byte[] finalPayload3 = new byte[100];
uncompressedPayload3.readBytes(finalPayload3);

View File

@ -63,7 +63,7 @@ public class PerMessageDeflateEncoderTest {
assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame.rsv());
assertTrue(decoderChannel.writeInbound(compressedFrame.content()));
assertTrue(decoderChannel.writeInbound(Unpooled.wrappedBuffer(DeflateDecoder.FRAME_TAIL)));
assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate()));
ByteBuf uncompressedPayload = decoderChannel.readInbound();
assertEquals(300, uncompressedPayload.readableBytes());
@ -160,7 +160,7 @@ public class PerMessageDeflateEncoderTest {
uncompressedPayload2.release();
assertTrue(decoderChannel.writeInbound(compressedFrame3.content()));
assertTrue(decoderChannel.writeInbound(Unpooled.wrappedBuffer(DeflateDecoder.FRAME_TAIL)));
assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate()));
ByteBuf uncompressedPayload3 = decoderChannel.readInbound();
byte[] finalPayload3 = new byte[100];
uncompressedPayload3.readBytes(finalPayload3);

View File

@ -245,7 +245,7 @@ public class ByteToMessageDecoderTest {
byte[] bytes = new byte[1024];
PlatformDependent.threadLocalRandom().nextBytes(bytes);
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(bytes)));
assertBuffer(Unpooled.wrappedBuffer(bytes), (ByteBuf) channel.readInbound());
assertNull(channel.readInbound());
assertFalse(channel.finish());
@ -277,7 +277,7 @@ public class ByteToMessageDecoderTest {
byte[] bytes = new byte[1024];
PlatformDependent.threadLocalRandom().nextBytes(bytes);
assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes)));
assertTrue(channel.writeInbound(Unpooled.copiedBuffer(bytes)));
assertBuffer(Unpooled.wrappedBuffer(bytes, 0, bytes.length - 1), (ByteBuf) channel.readInbound());
assertNull(channel.readInbound());
assertTrue(channel.finish());

View File

@ -223,7 +223,7 @@ public abstract class ZlibTest {
// Test for https://github.com/netty/netty/issues/2572
private void testDecompressOnly(ZlibWrapper decoderWrapper, byte[] compressed, byte[] data) throws Exception {
EmbeddedChannel chDecoder = new EmbeddedChannel(createDecoder(decoderWrapper));
chDecoder.writeInbound(Unpooled.wrappedBuffer(compressed));
chDecoder.writeInbound(Unpooled.copiedBuffer(compressed));
assertTrue(chDecoder.finish());
ByteBuf decoded = Unpooled.buffer(data.length);
@ -236,7 +236,7 @@ public abstract class ZlibTest {
decoded.writeBytes(buf);
buf.release();
}
assertEquals(Unpooled.wrappedBuffer(data), decoded);
assertEquals(Unpooled.copiedBuffer(data), decoded);
decoded.release();
}

View File

@ -99,13 +99,13 @@ public class HttpRequestDecoderBenchmark extends AbstractMicrobenchmark {
}
// if header is done it should produce a HttpRequest
channel.writeInbound(Unpooled.wrappedBuffer(content, a, amount));
channel.writeInbound(Unpooled.wrappedBuffer(content, a, amount).asReadOnly());
a += amount;
}
for (int i = CONTENT_LENGTH; i > 0; i --) {
// Should produce HttpContent
channel.writeInbound(Unpooled.wrappedBuffer(content, content.length - i, 1));
channel.writeInbound(Unpooled.wrappedBuffer(content, content.length - i, 1).asReadOnly());
}
}
}