diff --git a/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java b/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java index 0b1050a790..b94c893cc6 100644 --- a/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java +++ b/codec/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java @@ -37,7 +37,20 @@ import java.nio.ByteBuffer; import java.util.concurrent.TimeUnit; import java.util.zip.Checksum; -import static io.netty.handler.codec.compression.Lz4Constants.*; +import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_COMPRESSED; +import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_NON_COMPRESSED; +import static io.netty.handler.codec.compression.Lz4Constants.CHECKSUM_OFFSET; +import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSED_LENGTH_OFFSET; +import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSION_LEVEL_BASE; +import static io.netty.handler.codec.compression.Lz4Constants.DECOMPRESSED_LENGTH_OFFSET; +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_BLOCK_SIZE; +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; +import static io.netty.handler.codec.compression.Lz4Constants.HEADER_LENGTH; +import static io.netty.handler.codec.compression.Lz4Constants.MAGIC_NUMBER; +import static io.netty.handler.codec.compression.Lz4Constants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.Lz4Constants.MIN_BLOCK_SIZE; +import static io.netty.handler.codec.compression.Lz4Constants.TOKEN_OFFSET; +import static io.netty.util.internal.ThrowableUtil.unknownStackTrace; /** * Compresses a {@link ByteBuf} using the LZ4 format. @@ -56,6 +69,9 @@ import static io.netty.handler.codec.compression.Lz4Constants.*; * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ public class Lz4FrameEncoder extends MessageToByteEncoder { + private static final EncoderException ENCODE_FINSHED_EXCEPTION = unknownStackTrace(new EncoderException( + new IllegalStateException("encode finished and not enough space to write remaining data")), + Lz4FrameEncoder.class, "encode"); static final int DEFAULT_MAX_ENCODE_SIZE = Integer.MAX_VALUE; private final int blockSize; @@ -63,12 +79,12 @@ public class Lz4FrameEncoder extends MessageToByteEncoder { /** * Underlying compressor in use. */ - private LZ4Compressor compressor; + private final LZ4Compressor compressor; /** * Underlying checksum calculator in use. */ - private ByteBufChecksum checksum; + private final ByteBufChecksum checksum; /** * Compression level of current LZ4 encoder (depends on {@link #blockSize}). @@ -228,6 +244,10 @@ public class Lz4FrameEncoder extends MessageToByteEncoder { @Override protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { if (finished) { + if (!out.isWritable(in.readableBytes())) { + // out should be EMPTY_BUFFER because we should have allocated enough space above in allocateBuffer. + throw ENCODE_FINSHED_EXCEPTION; + } out.writeBytes(in); return; } @@ -301,33 +321,20 @@ public class Lz4FrameEncoder extends MessageToByteEncoder { } finished = true; - try { - final ByteBuf footer = ctx.alloc().heapBuffer( - compressor.maxCompressedLength(buffer.readableBytes()) + HEADER_LENGTH); - flushBufferedData(footer); + final ByteBuf footer = ctx.alloc().heapBuffer( + compressor.maxCompressedLength(buffer.readableBytes()) + HEADER_LENGTH); + flushBufferedData(footer); - final int idx = footer.writerIndex(); - footer.setLong(idx, MAGIC_NUMBER); - footer.setByte(idx + TOKEN_OFFSET, (byte) (BLOCK_TYPE_NON_COMPRESSED | compressionLevel)); - footer.setInt(idx + COMPRESSED_LENGTH_OFFSET, 0); - footer.setInt(idx + DECOMPRESSED_LENGTH_OFFSET, 0); - footer.setInt(idx + CHECKSUM_OFFSET, 0); + final int idx = footer.writerIndex(); + footer.setLong(idx, MAGIC_NUMBER); + footer.setByte(idx + TOKEN_OFFSET, (byte) (BLOCK_TYPE_NON_COMPRESSED | compressionLevel)); + footer.setInt(idx + COMPRESSED_LENGTH_OFFSET, 0); + footer.setInt(idx + DECOMPRESSED_LENGTH_OFFSET, 0); + footer.setInt(idx + CHECKSUM_OFFSET, 0); - footer.writerIndex(idx + HEADER_LENGTH); + footer.writerIndex(idx + HEADER_LENGTH); - return ctx.writeAndFlush(footer, promise); - } finally { - cleanup(); - } - } - - private void cleanup() { - compressor = null; - checksum = null; - if (buffer != null) { - buffer.release(); - buffer = null; - } + return ctx.writeAndFlush(footer, promise); } /** @@ -408,7 +415,10 @@ public class Lz4FrameEncoder extends MessageToByteEncoder { @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { super.handlerRemoved(ctx); - cleanup(); + if (buffer != null) { + buffer.release(); + buffer = null; + } } final ByteBuf getBackingBuffer() { diff --git a/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java b/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java index 8a7241f617..dabe5a6d95 100644 --- a/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java +++ b/codec/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java @@ -15,30 +15,46 @@ */ package io.netty.handler.codec.compression; -import java.io.InputStream; -import java.util.zip.Checksum; - -import io.netty.util.internal.PlatformDependent; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.EncoderException; import net.jpountz.lz4.LZ4BlockInputStream; import net.jpountz.lz4.LZ4Factory; import net.jpountz.xxhash.XXHashFactory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.zip.Checksum; + import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; import static org.mockito.Mockito.when; public class Lz4FrameEncoderTest extends AbstractEncoderTest { @@ -221,4 +237,72 @@ public class Lz4FrameEncoderTest extends AbstractEncoderTest { Assert.assertTrue(channel.releaseOutbound()); Assert.assertFalse(channel.releaseInbound()); } + + @Test(timeout = 3000) + public void writingAfterClosedChannelDoesNotNPE() throws InterruptedException { + EventLoopGroup group = new NioEventLoopGroup(2); + Channel serverChannel = null; + Channel clientChannel = null; + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference writeFailCauseRef = new AtomicReference(); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + } + }); + + Bootstrap bs = new Bootstrap(); + bs.group(group); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new Lz4FrameEncoder()); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + + final Channel finalClientChannel = clientChannel; + clientChannel.eventLoop().execute(new Runnable() { + @Override + public void run() { + finalClientChannel.close(); + final int size = 27; + ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(size, size); + finalClientChannel.writeAndFlush(buf.writerIndex(buf.writerIndex() + size)) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + try { + writeFailCauseRef.set(future.cause()); + } finally { + latch.countDown(); + } + } + }); + } + }); + latch.await(); + Throwable writeFailCause = writeFailCauseRef.get(); + assertNotNull(writeFailCause); + Throwable writeFailCauseCause = writeFailCause.getCause(); + if (writeFailCauseCause != null) { + assertThat(writeFailCauseCause, is(not(instanceOf(NullPointerException.class)))); + } + } finally { + if (serverChannel != null) { + serverChannel.close(); + } + if (clientChannel != null) { + clientChannel.close(); + } + group.shutdownGracefully(); + } + } }