NullPointerException in Lz4FrameEncoder

Motivation:
Lz4FrameEncoder maintains internal state, but the life cycle of the buffer is not consistently managed. The buffer is allocated in handlerAdded but freed in close, but the buffer can still be used until handlerRemoved is called.

Modifications:
- Move the cleanup of the buffer from close to handlerRemoved
- Explicitly throw an EncoderException from Lz4FrameEncoder if the encode operation has finished and there isn't enough space to write data

Result:
No more NPE in Lz4FrameEncoder on the buffer.
This commit is contained in:
Scott Mitchell 2017-06-19 11:01:31 -07:00
parent 7460d90a67
commit 14ea69cdc1
2 changed files with 130 additions and 36 deletions

View File

@ -37,7 +37,20 @@ import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.zip.Checksum; import java.util.zip.Checksum;
import static io.netty.handler.codec.compression.Lz4Constants.*; import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_COMPRESSED;
import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_NON_COMPRESSED;
import static io.netty.handler.codec.compression.Lz4Constants.CHECKSUM_OFFSET;
import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSED_LENGTH_OFFSET;
import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSION_LEVEL_BASE;
import static io.netty.handler.codec.compression.Lz4Constants.DECOMPRESSED_LENGTH_OFFSET;
import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_BLOCK_SIZE;
import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED;
import static io.netty.handler.codec.compression.Lz4Constants.HEADER_LENGTH;
import static io.netty.handler.codec.compression.Lz4Constants.MAGIC_NUMBER;
import static io.netty.handler.codec.compression.Lz4Constants.MAX_BLOCK_SIZE;
import static io.netty.handler.codec.compression.Lz4Constants.MIN_BLOCK_SIZE;
import static io.netty.handler.codec.compression.Lz4Constants.TOKEN_OFFSET;
import static io.netty.util.internal.ThrowableUtil.unknownStackTrace;
/** /**
* Compresses a {@link ByteBuf} using the LZ4 format. * Compresses a {@link ByteBuf} using the LZ4 format.
@ -56,6 +69,9 @@ import static io.netty.handler.codec.compression.Lz4Constants.*;
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
*/ */
public class Lz4FrameEncoder extends MessageToByteEncoder<ByteBuf> { public class Lz4FrameEncoder extends MessageToByteEncoder<ByteBuf> {
private static final EncoderException ENCODE_FINSHED_EXCEPTION = unknownStackTrace(new EncoderException(
new IllegalStateException("encode finished and not enough space to write remaining data")),
Lz4FrameEncoder.class, "encode");
static final int DEFAULT_MAX_ENCODE_SIZE = Integer.MAX_VALUE; static final int DEFAULT_MAX_ENCODE_SIZE = Integer.MAX_VALUE;
private final int blockSize; private final int blockSize;
@ -63,12 +79,12 @@ public class Lz4FrameEncoder extends MessageToByteEncoder<ByteBuf> {
/** /**
* Underlying compressor in use. * Underlying compressor in use.
*/ */
private LZ4Compressor compressor; private final LZ4Compressor compressor;
/** /**
* Underlying checksum calculator in use. * Underlying checksum calculator in use.
*/ */
private ByteBufChecksum checksum; private final ByteBufChecksum checksum;
/** /**
* Compression level of current LZ4 encoder (depends on {@link #blockSize}). * Compression level of current LZ4 encoder (depends on {@link #blockSize}).
@ -228,6 +244,10 @@ public class Lz4FrameEncoder extends MessageToByteEncoder<ByteBuf> {
@Override @Override
protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception {
if (finished) { if (finished) {
if (!out.isWritable(in.readableBytes())) {
// out should be EMPTY_BUFFER because we should have allocated enough space above in allocateBuffer.
throw ENCODE_FINSHED_EXCEPTION;
}
out.writeBytes(in); out.writeBytes(in);
return; return;
} }
@ -301,33 +321,20 @@ public class Lz4FrameEncoder extends MessageToByteEncoder<ByteBuf> {
} }
finished = true; finished = true;
try { final ByteBuf footer = ctx.alloc().heapBuffer(
final ByteBuf footer = ctx.alloc().heapBuffer( compressor.maxCompressedLength(buffer.readableBytes()) + HEADER_LENGTH);
compressor.maxCompressedLength(buffer.readableBytes()) + HEADER_LENGTH); flushBufferedData(footer);
flushBufferedData(footer);
final int idx = footer.writerIndex(); final int idx = footer.writerIndex();
footer.setLong(idx, MAGIC_NUMBER); footer.setLong(idx, MAGIC_NUMBER);
footer.setByte(idx + TOKEN_OFFSET, (byte) (BLOCK_TYPE_NON_COMPRESSED | compressionLevel)); footer.setByte(idx + TOKEN_OFFSET, (byte) (BLOCK_TYPE_NON_COMPRESSED | compressionLevel));
footer.setInt(idx + COMPRESSED_LENGTH_OFFSET, 0); footer.setInt(idx + COMPRESSED_LENGTH_OFFSET, 0);
footer.setInt(idx + DECOMPRESSED_LENGTH_OFFSET, 0); footer.setInt(idx + DECOMPRESSED_LENGTH_OFFSET, 0);
footer.setInt(idx + CHECKSUM_OFFSET, 0); footer.setInt(idx + CHECKSUM_OFFSET, 0);
footer.writerIndex(idx + HEADER_LENGTH); footer.writerIndex(idx + HEADER_LENGTH);
return ctx.writeAndFlush(footer, promise); return ctx.writeAndFlush(footer, promise);
} finally {
cleanup();
}
}
private void cleanup() {
compressor = null;
checksum = null;
if (buffer != null) {
buffer.release();
buffer = null;
}
} }
/** /**
@ -408,7 +415,10 @@ public class Lz4FrameEncoder extends MessageToByteEncoder<ByteBuf> {
@Override @Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
super.handlerRemoved(ctx); super.handlerRemoved(ctx);
cleanup(); if (buffer != null) {
buffer.release();
buffer = null;
}
} }
final ByteBuf getBackingBuffer() { final ByteBuf getBackingBuffer() {

View File

@ -15,30 +15,46 @@
*/ */
package io.netty.handler.codec.compression; package io.netty.handler.codec.compression;
import java.io.InputStream; import io.netty.bootstrap.Bootstrap;
import java.util.zip.Checksum; import io.netty.bootstrap.ServerBootstrap;
import io.netty.util.internal.PlatformDependent;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.EncoderException; import io.netty.handler.codec.EncoderException;
import net.jpountz.lz4.LZ4BlockInputStream; import net.jpountz.lz4.LZ4BlockInputStream;
import net.jpountz.lz4.LZ4Factory; import net.jpountz.lz4.LZ4Factory;
import net.jpountz.xxhash.XXHashFactory; import net.jpountz.xxhash.XXHashFactory;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.zip.Checksum;
import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class Lz4FrameEncoderTest extends AbstractEncoderTest { public class Lz4FrameEncoderTest extends AbstractEncoderTest {
@ -221,4 +237,72 @@ public class Lz4FrameEncoderTest extends AbstractEncoderTest {
Assert.assertTrue(channel.releaseOutbound()); Assert.assertTrue(channel.releaseOutbound());
Assert.assertFalse(channel.releaseInbound()); Assert.assertFalse(channel.releaseInbound());
} }
@Test(timeout = 3000)
public void writingAfterClosedChannelDoesNotNPE() throws InterruptedException {
EventLoopGroup group = new NioEventLoopGroup(2);
Channel serverChannel = null;
Channel clientChannel = null;
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Throwable> writeFailCauseRef = new AtomicReference<Throwable>();
try {
ServerBootstrap sb = new ServerBootstrap();
sb.group(group);
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
}
});
Bootstrap bs = new Bootstrap();
bs.group(group);
bs.channel(NioSocketChannel.class);
bs.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new Lz4FrameEncoder());
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).syncUninterruptibly().channel();
clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
final Channel finalClientChannel = clientChannel;
clientChannel.eventLoop().execute(new Runnable() {
@Override
public void run() {
finalClientChannel.close();
final int size = 27;
ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(size, size);
finalClientChannel.writeAndFlush(buf.writerIndex(buf.writerIndex() + size))
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
try {
writeFailCauseRef.set(future.cause());
} finally {
latch.countDown();
}
}
});
}
});
latch.await();
Throwable writeFailCause = writeFailCauseRef.get();
assertNotNull(writeFailCause);
Throwable writeFailCauseCause = writeFailCause.getCause();
if (writeFailCauseCause != null) {
assertThat(writeFailCauseCause, is(not(instanceOf(NullPointerException.class))));
}
} finally {
if (serverChannel != null) {
serverChannel.close();
}
if (clientChannel != null) {
clientChannel.close();
}
group.shutdownGracefully();
}
}
} }