IovArray#add return value resulted in more ByteBufs being added during iteration

Motivation:
IovArray implements MessageProcessor, and the processMessage method will continue to be called during iteration until it returns true. A recent commit b215794de3 changed the return value to only return true if any component of a CompositeByteBuf was added as a result of the method call. However this results in the iteration continuing, and potentially subsequent smaller buffers maybe added, which will result in out of order writes and generally corrupts data.

Modifications:
- IovArray#add should return false so that the MessageProcessor#processMessage will stop iterating.

Result:
Native transports which use IovArray will not corrupt data during gathering writes of CompositeByteBuf objects.
This commit is contained in:
Scott Mitchell 2018-01-04 08:04:32 -08:00 committed by GitHub
parent ab9f0a0fda
commit 33ddb83dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 176 additions and 8 deletions

View File

@ -21,14 +21,18 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;
import org.junit.Test;
import java.io.IOException;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
@ -133,6 +137,147 @@ public class CompositeBufferGatheringWriteTest extends AbstractSocketTest {
}
}
@Test(timeout = 10000)
public void testCompositeBufferPartialWriteDoesNotCorruptData() throws Throwable {
run();
}
protected void compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ChannelConfig config,
int soSndBuf) {
}
public void testCompositeBufferPartialWriteDoesNotCorruptData(ServerBootstrap sb, Bootstrap cb) throws Throwable {
// The scenario is the following:
// Limit SO_SNDBUF so that a single buffer can be written, and part of a CompositeByteBuf at the same time.
// We then write the single buffer, the CompositeByteBuf, and another single buffer and verify the data is not
// corrupted when we read it on the other side.
Channel serverChannel = null;
Channel clientChannel = null;
try {
Random r = new Random();
final int soSndBuf = 1024;
ByteBufAllocator alloc = ByteBufAllocator.DEFAULT;
final ByteBuf expectedContent = alloc.buffer(soSndBuf * 2);
expectedContent.writeBytes(newRandomBytes(expectedContent.writableBytes(), r));
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Object> clientReceived = new AtomicReference<Object>();
sb.childOption(ChannelOption.SO_SNDBUF, soSndBuf)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ctx.channel().config(),
soSndBuf);
// First single write
int offset = soSndBuf - 100;
ctx.write(expectedContent.retainedSlice(expectedContent.readerIndex(), offset));
// Build and write CompositeByteBuf
CompositeByteBuf compositeByteBuf = ctx.alloc().compositeBuffer();
compositeByteBuf.addComponent(true,
expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 50));
offset += 50;
compositeByteBuf.addComponent(true,
expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 200));
offset += 200;
ctx.write(compositeByteBuf);
// Write a single buffer that is smaller than the second component of the CompositeByteBuf
// above but small enough to fit in the remaining space allowed by the soSndBuf amount.
ctx.write(expectedContent.retainedSlice(expectedContent.readerIndex() + offset, 50));
offset += 50;
// Write the remainder of the content
ctx.writeAndFlush(expectedContent.retainedSlice(expectedContent.readerIndex() + offset,
expectedContent.readableBytes() - expectedContent.readerIndex() - offset))
.addListener(ChannelFutureListener.CLOSE);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// IOException is fine as it will also close the channel and may just be a connection reset.
if (!(cause instanceof IOException)) {
clientReceived.set(cause);
latch.countDown();
}
}
});
}
});
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private ByteBuf aggregator;
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
aggregator = ctx.alloc().buffer(expectedContent.readableBytes());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
if (msg instanceof ByteBuf) {
aggregator.writeBytes((ByteBuf) msg);
}
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
// IOException is fine as it will also close the channel and may just be a connection reset.
if (!(cause instanceof IOException)) {
clientReceived.set(cause);
latch.countDown();
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (clientReceived.compareAndSet(null, aggregator)) {
try {
assertEquals(expectedContent.readableBytes(), aggregator.readableBytes());
} catch (Throwable cause) {
aggregator.release();
aggregator = null;
clientReceived.set(cause);
} finally {
latch.countDown();
}
}
}
});
}
});
serverChannel = sb.bind().syncUninterruptibly().channel();
clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
latch.await();
Object received = clientReceived.get();
if (received instanceof ByteBuf) {
ByteBuf actual = (ByteBuf) received;
assertEquals(expectedContent, actual);
expectedContent.release();
actual.release();
} else {
expectedContent.release();
throw (Throwable) received;
}
} finally {
if (clientChannel != null) {
clientChannel.close().sync();
}
if (serverChannel != null) {
serverChannel.close().sync();
}
}
}
private static ByteBuf newCompositeBuffer(ByteBufAllocator alloc) {
CompositeByteBuf compositeByteBuf = alloc.compositeBuffer();
compositeByteBuf.addComponent(true, alloc.directBuffer(4).writeInt(100));
@ -141,4 +286,10 @@ public class CompositeBufferGatheringWriteTest extends AbstractSocketTest {
assertEquals(EXPECTED_BYTES, compositeByteBuf.readableBytes());
return compositeByteBuf;
}
private static byte[] newRandomBytes(int size, Random r) {
byte[] bytes = new byte[size];
r.nextBytes(bytes);
return bytes;
}
}

View File

@ -17,6 +17,7 @@ package io.netty.channel.epoll;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelConfig;
import io.netty.testsuite.transport.TestsuitePermutation;
import io.netty.testsuite.transport.socket.CompositeBufferGatheringWriteTest;
@ -27,4 +28,12 @@ public class EpollCompositeBufferGatheringWriteTest extends CompositeBufferGathe
protected List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> newFactories() {
return EpollSocketTestPermutation.INSTANCE.socket();
}
@Override
protected void compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ChannelConfig config,
int soSndBuf) {
if (config instanceof EpollChannelConfig) {
((EpollChannelConfig) config).setMaxBytesPerGatheringWrite(soSndBuf);
}
}
}

View File

@ -17,6 +17,7 @@ package io.netty.channel.kqueue;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelConfig;
import io.netty.testsuite.transport.TestsuitePermutation;
import io.netty.testsuite.transport.socket.CompositeBufferGatheringWriteTest;
@ -27,4 +28,12 @@ public class KQueueCompositeBufferGatheringWriteTest extends CompositeBufferGath
protected List<TestsuitePermutation.BootstrapComboFactory<ServerBootstrap, Bootstrap>> newFactories() {
return KQueueSocketTestPermutation.INSTANCE.socket();
}
@Override
protected void compositeBufferPartialWriteDoesNotCorruptDataInitServerConfig(ChannelConfig config,
int soSndBuf) {
if (config instanceof KQueueChannelConfig) {
((KQueueChannelConfig) config).setMaxBytesPerGatheringWrite(soSndBuf);
}
}
}

View File

@ -16,6 +16,7 @@
package io.netty.channel.unix;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelOutboundBuffer.MessageProcessor;
import io.netty.util.internal.PlatformDependent;
@ -80,8 +81,11 @@ public final class IovArray implements MessageProcessor {
}
/**
* Try to add the given {@link ByteBuf}. Returns {@code true} on success,
* {@code false} otherwise.
* Add a {@link ByteBuf} to this {@link IovArray}.
* @param buf The {@link ByteBuf} to add.
* @return {@code true} if the entire {@link ByteBuf} has been added to this {@link IovArray}. Note in the event
* that {@link ByteBuf} is a {@link CompositeByteBuf} {@code false} may be returned even if some of the components
* have been added.
*/
public boolean add(ByteBuf buf) {
if (count == IOV_MAX) {
@ -95,7 +99,7 @@ public final class IovArray implements MessageProcessor {
for (ByteBuffer nioBuffer : buffers) {
final int len = nioBuffer.remaining();
if (len != 0 && (!add(directBufferAddress(nioBuffer), nioBuffer.position(), len) || count == IOV_MAX)) {
break;
return false;
}
}
return true;
@ -103,11 +107,6 @@ public final class IovArray implements MessageProcessor {
}
private boolean add(long addr, int offset, int len) {
if (len == 0) {
// No need to add an empty buffer.
return true;
}
final long baseOffset = memoryAddress(count);
final long lengthOffset = baseOffset + ADDRESS_SIZE;