Fix bugs in ZlibDecoder / Port the factorial example

- Fixed IndexOutOfBoundsException in ZlibDecoder
- Fixed a bug where ZlibDecoder raises an exception when a connection
  is closed
This commit is contained in:
Trustin Lee 2012-05-29 00:06:00 -07:00
parent e00f303e4f
commit 81e8c49931
9 changed files with 170 additions and 224 deletions

View File

@ -52,11 +52,9 @@ public class ZlibDecoder extends StreamToStreamDecoder {
throw new NullPointerException("wrapper");
}
synchronized (z) {
int resultCode = z.inflateInit(ZlibUtil.convertWrapperType(wrapper));
if (resultCode != JZlib.Z_OK) {
ZlibUtil.fail(z, "initialization failure", resultCode);
}
int resultCode = z.inflateInit(ZlibUtil.convertWrapperType(wrapper));
if (resultCode != JZlib.Z_OK) {
ZlibUtil.fail(z, "initialization failure", resultCode);
}
}
@ -73,12 +71,10 @@ public class ZlibDecoder extends StreamToStreamDecoder {
}
this.dictionary = dictionary;
synchronized (z) {
int resultCode;
resultCode = z.inflateInit(JZlib.W_ZLIB);
if (resultCode != JZlib.Z_OK) {
ZlibUtil.fail(z, "initialization failure", resultCode);
}
int resultCode;
resultCode = z.inflateInit(JZlib.W_ZLIB);
if (resultCode != JZlib.Z_OK) {
ZlibUtil.fail(z, "initialization failure", resultCode);
}
}
@ -95,30 +91,34 @@ public class ZlibDecoder extends StreamToStreamDecoder {
ChannelInboundHandlerContext<Byte> ctx,
ChannelBuffer in, ChannelBuffer out) throws Exception {
synchronized (z) {
if (!in.readable()) {
return;
}
try {
// Configure input.
int inputLength = in.readableBytes();
boolean inHasArray = in.hasArray();
z.avail_in = inputLength;
if (inHasArray) {
z.next_in = in.array();
z.next_in_index = in.arrayOffset() + in.readerIndex();
} else {
byte[] array = new byte[inputLength];
in.readBytes(array);
z.next_in = array;
z.next_in_index = 0;
}
int oldNextInIndex = z.next_in_index;
// Configure output.
int maxOutputLength = inputLength << 1;
boolean outHasArray = out.hasArray();
if (!outHasArray) {
z.next_out = new byte[maxOutputLength];
}
try {
// Configure input.
int inputLength = in.readableBytes();
boolean inHasArray = in.hasArray();
z.avail_in = inputLength;
if (inHasArray) {
z.next_in = in.array();
z.next_in_index = in.arrayOffset() + in.readerIndex();
} else {
byte[] array = new byte[inputLength];
in.readBytes(array);
z.next_in = array;
z.next_in_index = 0;
}
int oldNextInIndex = z.next_in_index;
// Configure output.
int maxOutputLength = inputLength << 1;
boolean outHasArray = out.hasArray();
if (!outHasArray) {
z.next_out = new byte[maxOutputLength];
}
loop: for (;;) {
z.avail_out = maxOutputLength;
if (outHasArray) {
@ -131,14 +131,7 @@ public class ZlibDecoder extends StreamToStreamDecoder {
int oldNextOutIndex = z.next_out_index;
// Decompress 'in' into 'out'
int resultCode;
try {
resultCode = z.inflate(JZlib.Z_SYNC_FLUSH);
} finally {
if (inHasArray) {
in.skipBytes(z.next_in_index - oldNextInIndex);
}
}
int resultCode = z.inflate(JZlib.Z_SYNC_FLUSH);
int outputLength = z.next_out_index - oldNextOutIndex;
if (outputLength > 0) {
if (outHasArray) {
@ -175,13 +168,17 @@ public class ZlibDecoder extends StreamToStreamDecoder {
}
}
} finally {
// Deference the external references explicitly to tell the VM that
// the allocated byte arrays are temporary so that the call stack
// can be utilized.
// I'm not sure if the modern VMs do this optimization though.
z.next_in = null;
z.next_out = null;
if (inHasArray) {
in.skipBytes(z.next_in_index - oldNextInIndex);
}
}
} finally {
// Deference the external references explicitly to tell the VM that
// the allocated byte arrays are temporary so that the call stack
// can be utilized.
// I'm not sure if the modern VMs do this optimization though.
z.next_in = null;
z.next_out = null;
}
}
}

View File

@ -15,13 +15,12 @@
*/
package io.netty.example.factorial;
import java.math.BigInteger;
import io.netty.buffer.ChannelBuffer;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.FrameDecoder;
import io.netty.handler.codec.StreamToMessageDecoder;
import java.math.BigInteger;
/**
* Decodes the binary representation of a {@link BigInteger} prepended
@ -29,36 +28,35 @@ import io.netty.handler.codec.FrameDecoder;
* {@link BigInteger} instance. For example, { 'F', 0, 0, 0, 1, 42 } will be
* decoded into new BigInteger("42").
*/
public class BigIntegerDecoder extends FrameDecoder {
public class BigIntegerDecoder extends StreamToMessageDecoder<BigInteger> {
@Override
protected Object decode(
ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception {
public BigInteger decode(ChannelInboundHandlerContext<Byte> ctx, ChannelBuffer in) {
// Wait until the length prefix is available.
if (buffer.readableBytes() < 5) {
if (in.readableBytes() < 5) {
return null;
}
buffer.markReaderIndex();
in.markReaderIndex();
// Check the magic number.
int magicNumber = buffer.readUnsignedByte();
int magicNumber = in.readUnsignedByte();
if (magicNumber != 'F') {
buffer.resetReaderIndex();
in.resetReaderIndex();
throw new CorruptedFrameException(
"Invalid magic number: " + magicNumber);
}
// Wait until the whole data is available.
int dataLength = buffer.readInt();
if (buffer.readableBytes() < dataLength) {
buffer.resetReaderIndex();
int dataLength = in.readInt();
if (in.readableBytes() < dataLength) {
in.resetReaderIndex();
return null;
}
// Convert the received data into a new BigInteger.
byte[] decoded = new byte[dataLength];
buffer.readBytes(decoded);
in.readBytes(decoded);
return new BigInteger(decoded);
}

View File

@ -15,13 +15,10 @@
*/
package io.netty.example.factorial;
import java.net.InetSocketAddress;
import java.util.concurrent.Executors;
import io.netty.bootstrap.ClientBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.socket.nio.NioClientSocketChannelFactory;
import io.netty.channel.socket.nio.NioEventLoop;
import io.netty.channel.socket.nio.NioSocketChannel;
/**
* Sends a sequence of integers to a {@link FactorialServer} to calculate
@ -39,32 +36,28 @@ public class FactorialClient {
this.count = count;
}
public void run() {
// Configure the client.
ClientBootstrap bootstrap = new ClientBootstrap(
new NioClientSocketChannelFactory(
Executors.newCachedThreadPool()));
public void run() throws Exception {
ChannelBootstrap b = new ChannelBootstrap();
try {
b.eventLoop(new NioEventLoop())
.channel(new NioSocketChannel())
.remoteAddress(host, port)
.initializer(new FactorialClientInitializer(count));
// Set up the event pipeline factory.
bootstrap.setPipelineFactory(new FactorialClientPipelineFactory(count));
// Make a new connection.
ChannelFuture f = b.connect().sync();
// Make a new connection.
ChannelFuture connectFuture =
bootstrap.connect(new InetSocketAddress(host, port));
// Get the handler instance to retrieve the answer.
FactorialClientHandler handler =
(FactorialClientHandler) f.channel().pipeline().last();
// Wait until the connection is made successfully.
Channel channel = connectFuture.awaitUninterruptibly().channel();
// Print out the answer.
System.err.format(
"Factorial of %,d is: %,d", count, handler.getFactorial());
// Get the handler instance to retrieve the answer.
FactorialClientHandler handler =
(FactorialClientHandler) channel.pipeline().last();
// Print out the answer.
System.err.format(
"Factorial of %,d is: %,d", count, handler.getFactorial());
// Shut down all thread pools to exit.
bootstrap.releaseExternalResources();
} finally {
b.shutdown();
}
}
public static void main(String[] args) throws Exception {

View File

@ -15,22 +15,18 @@
*/
package io.netty.example.factorial;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import java.math.BigInteger;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import io.netty.channel.Channel;
import io.netty.channel.ChannelEvent;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelStateEvent;
import io.netty.channel.ExceptionEvent;
import io.netty.channel.MessageEvent;
import io.netty.channel.SimpleChannelUpstreamHandler;
/**
* Handler for a client-side channel. This handler maintains stateful
* information which is specific to a certain channel using member variables.
@ -38,12 +34,13 @@ import io.netty.channel.SimpleChannelUpstreamHandler;
* to create a new handler instance whenever you create a new channel and insert
* this handler to avoid a race condition.
*/
public class FactorialClientHandler extends SimpleChannelUpstreamHandler {
public class FactorialClientHandler extends ChannelInboundMessageHandlerAdapter<BigInteger> {
private static final Logger logger = Logger.getLogger(
FactorialClientHandler.class.getName());
// Stateful properties
private ChannelInboundHandlerContext<BigInteger> ctx;
private Queue<Object> out;
private int i = 1;
private int receivedMessages;
private final int count;
@ -69,34 +66,23 @@ public class FactorialClientHandler extends SimpleChannelUpstreamHandler {
}
@Override
public void handleUpstream(
ChannelHandlerContext ctx, ChannelEvent e) throws Exception {
if (e instanceof ChannelStateEvent) {
logger.info(e.toString());
}
super.handleUpstream(ctx, e);
public void channelActive(ChannelInboundHandlerContext<BigInteger> ctx) {
this.ctx = ctx;
out = ctx.out().messageBuffer();
sendNumbers();
}
@Override
public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) {
sendNumbers(e);
}
@Override
public void channelInterestChanged(ChannelHandlerContext ctx, ChannelStateEvent e) {
sendNumbers(e);
}
@Override
public void messageReceived(
ChannelHandlerContext ctx, final MessageEvent e) {
ChannelInboundHandlerContext<BigInteger> ctx, final BigInteger msg) {
receivedMessages ++;
if (receivedMessages == count) {
// Offer the answer after closing the connection.
e.channel().close().addListener(new ChannelFutureListener() {
ctx.channel().close().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
boolean offered = answer.offer((BigInteger) e.getMessage());
boolean offered = answer.offer(msg);
assert offered;
}
});
@ -105,23 +91,38 @@ public class FactorialClientHandler extends SimpleChannelUpstreamHandler {
@Override
public void exceptionCaught(
ChannelHandlerContext ctx, ExceptionEvent e) {
ChannelInboundHandlerContext<BigInteger> ctx, Throwable cause) throws Exception {
logger.log(
Level.WARNING,
"Unexpected exception from downstream.",
e.cause());
e.channel().close();
"Unexpected exception from downstream.", cause);
ctx.close();
}
private void sendNumbers(ChannelStateEvent e) {
Channel channel = e.channel();
while (channel.isWritable()) {
private void sendNumbers() {
// Do not send more than 4096 numbers.
boolean finished = false;
while (out.size() < 4096) {
if (i <= count) {
channel.write(Integer.valueOf(i));
out.add(Integer.valueOf(i));
i ++;
} else {
finished = true;
break;
}
}
ChannelFuture f = ctx.flush();
if (!finished) {
f.addListener(SEND_NUMBERS);
}
}
private final ChannelFutureListener SEND_NUMBERS = new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
sendNumbers();
}
}
};
}

View File

@ -15,10 +15,9 @@
*/
package io.netty.example.factorial;
import static io.netty.channel.Channels.*;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPipelineFactory;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.compression.ZlibDecoder;
import io.netty.handler.codec.compression.ZlibEncoder;
import io.netty.handler.codec.compression.ZlibWrapper;
@ -26,18 +25,17 @@ import io.netty.handler.codec.compression.ZlibWrapper;
/**
* Creates a newly configured {@link ChannelPipeline} for a client-side channel.
*/
public class FactorialClientPipelineFactory implements
ChannelPipelineFactory {
public class FactorialClientInitializer extends ChannelInitializer<SocketChannel> {
private final int count;
public FactorialClientPipelineFactory(int count) {
public FactorialClientInitializer(int count) {
this.count = count;
}
@Override
public ChannelPipeline getPipeline() throws Exception {
ChannelPipeline pipeline = pipeline();
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
// Enable stream compression (you can remove these two if unnecessary)
pipeline.addLast("deflater", new ZlibEncoder(ZlibWrapper.GZIP));
@ -49,7 +47,5 @@ public class FactorialClientPipelineFactory implements
// and then business logic.
pipeline.addLast("handler", new FactorialClientHandler(count));
return pipeline;
}
}

View File

@ -15,11 +15,9 @@
*/
package io.netty.example.factorial;
import java.net.InetSocketAddress;
import java.util.concurrent.Executors;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.socket.nio.NioServerSocketChannelFactory;
import io.netty.channel.ServerChannelBootstrap;
import io.netty.channel.socket.nio.NioEventLoop;
import io.netty.channel.socket.nio.NioServerSocketChannel;
/**
* Receives a sequence of integers from a {@link FactorialClient} to calculate
@ -33,17 +31,18 @@ public class FactorialServer {
this.port = port;
}
public void run() {
// Configure the server.
ServerBootstrap bootstrap = new ServerBootstrap(
new NioServerSocketChannelFactory(
Executors.newCachedThreadPool()));
public void run() throws Exception {
ServerChannelBootstrap b = new ServerChannelBootstrap();
try {
b.eventLoop(new NioEventLoop(), new NioEventLoop())
.channel(new NioServerSocketChannel())
.localAddress(port)
.childInitializer(new FactorialServerInitializer());
// Set up the event pipeline factory.
bootstrap.setPipelineFactory(new FactorialServerPipelineFactory());
// Bind and start to accept incoming connections.
bootstrap.bind(new InetSocketAddress(port));
b.bind().sync().channel().closeFuture().sync();
} finally {
b.shutdown();
}
}
public static void main(String[] args) throws Exception {

View File

@ -15,18 +15,14 @@
*/
package io.netty.example.factorial;
import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import java.math.BigInteger;
import java.util.Formatter;
import java.util.logging.Level;
import java.util.logging.Logger;
import io.netty.channel.ChannelEvent;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelStateEvent;
import io.netty.channel.ExceptionEvent;
import io.netty.channel.MessageEvent;
import io.netty.channel.SimpleChannelUpstreamHandler;
/**
* Handler for a server-side channel. This handler maintains stateful
* information which is specific to a certain channel using member variables.
@ -34,54 +30,36 @@ import io.netty.channel.SimpleChannelUpstreamHandler;
* to create a new handler instance whenever you create a new channel and insert
* this handler to avoid a race condition.
*/
public class FactorialServerHandler extends SimpleChannelUpstreamHandler {
public class FactorialServerHandler extends ChannelInboundMessageHandlerAdapter<BigInteger> {
private static final Logger logger = Logger.getLogger(
FactorialServerHandler.class.getName());
// Stateful properties.
private int lastMultiplier = 1;
private BigInteger factorial = new BigInteger(new byte[] { 1 });
@Override
public void handleUpstream(
ChannelHandlerContext ctx, ChannelEvent e) throws Exception {
if (e instanceof ChannelStateEvent) {
logger.info(e.toString());
}
super.handleUpstream(ctx, e);
}
private BigInteger lastMultiplier = new BigInteger("1");
private BigInteger factorial = new BigInteger("1");
@Override
public void messageReceived(
ChannelHandlerContext ctx, MessageEvent e) {
ChannelInboundHandlerContext<BigInteger> ctx, BigInteger msg) throws Exception {
// Calculate the cumulative factorial and send it to the client.
BigInteger number;
if (e.getMessage() instanceof BigInteger) {
number = (BigInteger) e.getMessage();
} else {
number = new BigInteger(e.getMessage().toString());
}
lastMultiplier = number.intValue();
factorial = factorial.multiply(number);
e.channel().write(factorial);
lastMultiplier = msg;
factorial = factorial.multiply(msg);
ctx.write(factorial);
}
@Override
public void channelDisconnected(ChannelHandlerContext ctx,
ChannelStateEvent e) throws Exception {
public void channelInactive(
ChannelInboundHandlerContext<BigInteger> ctx) throws Exception {
logger.info(new Formatter().format(
"Factorial of %,d is: %,d", lastMultiplier, factorial).toString());
}
@Override
public void exceptionCaught(
ChannelHandlerContext ctx, ExceptionEvent e) {
ChannelInboundHandlerContext<BigInteger> ctx, Throwable cause) throws Exception {
logger.log(
Level.WARNING,
"Unexpected exception from downstream.",
e.cause());
e.channel().close();
"Unexpected exception from downstream.", cause);
ctx.close();
}
}

View File

@ -15,10 +15,9 @@
*/
package io.netty.example.factorial;
import static io.netty.channel.Channels.*;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPipelineFactory;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.compression.ZlibDecoder;
import io.netty.handler.codec.compression.ZlibEncoder;
import io.netty.handler.codec.compression.ZlibWrapper;
@ -26,12 +25,10 @@ import io.netty.handler.codec.compression.ZlibWrapper;
/**
* Creates a newly configured {@link ChannelPipeline} for a server-side channel.
*/
public class FactorialServerPipelineFactory implements
ChannelPipelineFactory {
public class FactorialServerInitializer extends ChannelInitializer<SocketChannel> {
@Override
public ChannelPipeline getPipeline() throws Exception {
ChannelPipeline pipeline = pipeline();
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
// Enable stream compression (you can remove these two if unnecessary)
pipeline.addLast("deflater", new ZlibEncoder(ZlibWrapper.GZIP));
@ -45,7 +42,5 @@ public class FactorialServerPipelineFactory implements
// Please note we create a handler for every new channel
// because it has stateful properties.
pipeline.addLast("handler", new FactorialServerHandler());
return pipeline;
}
}

View File

@ -15,29 +15,22 @@
*/
package io.netty.example.factorial;
import java.math.BigInteger;
import io.netty.buffer.ChannelBuffer;
import io.netty.buffer.ChannelBuffers;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.oneone.OneToOneEncoder;
import io.netty.channel.ChannelOutboundHandlerContext;
import io.netty.handler.codec.MessageToStreamEncoder;
import java.math.BigInteger;
/**
* Encodes a {@link Number} into the binary representation prepended with
* a magic number ('F' or 0x46) and a 32-bit length prefix. For example, 42
* will be encoded to { 'F', 0, 0, 0, 1, 42 }.
*/
public class NumberEncoder extends OneToOneEncoder {
public class NumberEncoder extends MessageToStreamEncoder<Number> {
@Override
protected Object encode(
ChannelHandlerContext ctx, Channel channel, Object msg) throws Exception {
if (!(msg instanceof Number)) {
// Ignore what this encoder can't encode.
return msg;
}
public void encode(
ChannelOutboundHandlerContext<Number> ctx, Number msg, ChannelBuffer out) throws Exception {
// Convert to a BigInteger first for easier implementation.
BigInteger v;
if (msg instanceof BigInteger) {
@ -50,13 +43,9 @@ public class NumberEncoder extends OneToOneEncoder {
byte[] data = v.toByteArray();
int dataLength = data.length;
// Construct a message.
ChannelBuffer buf = ChannelBuffers.dynamicBuffer();
buf.writeByte((byte) 'F'); // magic number
buf.writeInt(dataLength); // data length
buf.writeBytes(data); // data
// Return the constructed message.
return buf;
// Write a message.
out.writeByte((byte) 'F'); // magic number
out.writeInt(dataLength); // data length
out.writeBytes(data); // data
}
}