Simplified and hardened ObjectDecoder by extending LengthFieldBasedFrameDecoder

This commit is contained in:
Trustin Lee 2010-05-13 13:40:36 +00:00
parent 688ec9d927
commit d19aa4924e

View File

@ -22,7 +22,7 @@ import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBufferInputStream; import org.jboss.netty.buffer.ChannelBufferInputStream;
import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.frame.FrameDecoder; import org.jboss.netty.handler.codec.frame.LengthFieldBasedFrameDecoder;
/** /**
* A decoder which deserializes the received {@link ChannelBuffer}s into Java * A decoder which deserializes the received {@link ChannelBuffer}s into Java
@ -41,9 +41,8 @@ import org.jboss.netty.handler.codec.frame.FrameDecoder;
* @apiviz.landmark * @apiviz.landmark
* @apiviz.has org.jboss.netty.handler.codec.serialization.ObjectDecoderInputStream - - - compatible with * @apiviz.has org.jboss.netty.handler.codec.serialization.ObjectDecoderInputStream - - - compatible with
*/ */
public class ObjectDecoder extends FrameDecoder { public class ObjectDecoder extends LengthFieldBasedFrameDecoder {
private final int maxObjectSize;
private final ClassLoader classLoader; private final ClassLoader classLoader;
/** /**
@ -79,36 +78,20 @@ public class ObjectDecoder extends FrameDecoder {
* of the serialized object * of the serialized object
*/ */
public ObjectDecoder(int maxObjectSize, ClassLoader classLoader) { public ObjectDecoder(int maxObjectSize, ClassLoader classLoader) {
if (maxObjectSize <= 0) { super(maxObjectSize, 0, 4, 0, 4);
throw new IllegalArgumentException("maxObjectSize: " + maxObjectSize);
}
this.maxObjectSize = maxObjectSize;
this.classLoader = classLoader; this.classLoader = classLoader;
} }
@Override @Override
protected Object decode( protected Object decode(
ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception { ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception {
if (buffer.readableBytes() < 4) {
ChannelBuffer frame = (ChannelBuffer) super.decode(ctx, channel, buffer);
if (frame == null) {
return null; return null;
} }
int dataLen = buffer.getInt(buffer.readerIndex());
if (dataLen <= 0) {
throw new StreamCorruptedException("invalid data length: " + dataLen);
}
if (dataLen > maxObjectSize) {
throw new StreamCorruptedException(
"data length too big: " + dataLen + " (max: " + maxObjectSize + ')');
}
if (buffer.readableBytes() < dataLen + 4) {
return null;
}
buffer.skipBytes(4);
return new CompactObjectInputStream( return new CompactObjectInputStream(
new ChannelBufferInputStream(buffer, dataLen), classLoader).readObject(); new ChannelBufferInputStream(frame), classLoader).readObject();
} }
} }