More robust automatic messageType detection for ChannelInboundMessageHandlerAdapter and MessageToMessageDecoder

This commit is contained in:
Trustin Lee 2013-02-08 15:44:41 +09:00
parent 38ee575839
commit fa1b49de98
14 changed files with 70 additions and 155 deletions

View File

@ -41,21 +41,14 @@ import io.netty.handler.codec.MessageToMessageDecoder;
* so that this handler can intercept HTTP requests after {@link HttpObjectDecoder}
* converts {@link ByteBuf}s into HTTP requests.
*/
public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object> {
public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObject> {
private EmbeddedByteChannel decoder;
private HttpMessage message;
private boolean decodeStarted;
/**
* Creates a new instance.
*/
protected HttpContentDecoder() {
super(HttpObject.class);
}
@Override
protected Object decode(ChannelHandlerContext ctx, Object msg) throws Exception {
protected Object decode(ChannelHandlerContext ctx, HttpObject msg) throws Exception {
if (msg instanceof HttpResponse && ((HttpResponse) msg).getStatus().code() == 100) {
// 100-continue response must be passed through.
return msg;
@ -115,7 +108,7 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object>
}
@Override
protected void freeInboundMessage(Object msg) throws Exception {
protected void freeInboundMessage(HttpObject msg) throws Exception {
if (decoder == null) {
// if the decoder was null we returned the original message so we are not allowed to free it
return;

View File

@ -66,8 +66,6 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
* a {@link TooLongFrameException} will be raised.
*/
public HttpObjectAggregator(int maxContentLength) {
super(HttpObject.class);
if (maxContentLength <= 0) {
throw new IllegalArgumentException(
"maxContentLength must be a positive integer: " +

View File

@ -52,8 +52,6 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
* a {@link TooLongFrameException} will be raised.
*/
public SpdyHttpDecoder(int version, int maxContentLength) {
super(SpdyDataFrame.class, SpdyControlFrame.class);
if (version < SpdyConstants.SPDY_MIN_VERSION || version > SpdyConstants.SPDY_MAX_VERSION) {
throw new IllegalArgumentException(
"unsupported version: " + version);
@ -66,6 +64,11 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
this.maxContentLength = maxContentLength;
}
@Override
public boolean acceptInboundMessage(Object msg) throws Exception {
return msg instanceof SpdyDataFrame || msg instanceof SpdyControlFrame;
}
@Override
public Object decode(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof SpdySynStreamFrame) {

View File

@ -73,9 +73,10 @@ public abstract class MessageToMessageCodec<INBOUND_IN, OUTBOUND_IN>
private final MessageToMessageDecoder<INBOUND_IN> decoder =
new MessageToMessageDecoder<INBOUND_IN>() {
@Override
public boolean isDecodable(Object msg) throws Exception {
return MessageToMessageCodec.this.isDecodable(msg);
public boolean acceptInboundMessage(Object msg) throws Exception {
return isDecodable(msg);
}
@Override

View File

@ -16,12 +16,9 @@
package io.netty.handler.codec;
import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.channel.ChannelInboundMessageHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelStateHandlerAdapter;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
/**
* {@link ChannelInboundMessageHandler} which decodes from one message to an other message
@ -45,81 +42,25 @@ import io.netty.channel.ChannelStateHandlerAdapter;
* </pre>
*
*/
public abstract class MessageToMessageDecoder<I>
extends ChannelStateHandlerAdapter implements ChannelInboundMessageHandler<I> {
public abstract class MessageToMessageDecoder<I> extends ChannelInboundMessageHandlerAdapter<I> {
private final Class<?>[] acceptedMsgTypes;
protected MessageToMessageDecoder() {
super(MessageToMessageDecoder.class, 0);
}
/**
* The types which will be accepted by the decoder. If a received message is an other type it will be just forwarded
* to the next {@link ChannelInboundMessageHandler} in the {@link ChannelPipeline}
*/
protected MessageToMessageDecoder(Class<?>... acceptedMsgTypes) {
this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
protected MessageToMessageDecoder(
@SuppressWarnings("rawtypes")
Class<? extends ChannelInboundMessageHandlerAdapter> parameterizedHandlerType,
int messageTypeParamIndex) {
super(parameterizedHandlerType, messageTypeParamIndex);
}
@Override
public MessageBuf<I> newInboundBuffer(ChannelHandlerContext ctx) throws Exception {
return Unpooled.messageBuffer();
protected final void messageReceived(ChannelHandlerContext ctx, I msg) throws Exception {
Object decoded = decode(ctx, msg);
if (decoded != null) {
ctx.nextInboundMessageBuffer().add(decoded);
}
@Override
public void inboundBufferUpdated(ChannelHandlerContext ctx)
throws Exception {
MessageBuf<I> in = ctx.inboundMessageBuffer();
MessageBuf<Object> out = ctx.nextInboundMessageBuffer();
boolean notify = false;
for (;;) {
try {
Object msg = in.poll();
if (msg == null) {
break;
}
if (!isDecodable(msg)) {
out.add(msg);
notify = true;
continue;
}
@SuppressWarnings("unchecked")
I imsg = (I) msg;
boolean free = true;
try {
Object omsg = decode(ctx, imsg);
if (omsg == null) {
// Decoder consumed a message but returned null.
// Probably it needs more messages because it's an aggregator.
continue;
}
if (omsg == imsg) {
free = false;
}
if (ChannelHandlerUtil.unfoldAndAdd(ctx, omsg, true)) {
notify = true;
}
} finally {
if (free) {
freeInboundMessage(imsg);
}
}
} catch (Throwable t) {
if (t instanceof CodecException) {
ctx.fireExceptionCaught(t);
} else {
ctx.fireExceptionCaught(new DecoderException(t));
}
}
}
if (notify) {
ctx.fireInboundBufferUpdated();
}
}
/**
* Returns {@code true} if and only if the specified message can be decoded by this decoder.
*/
public boolean isDecodable(Object msg) throws Exception {
return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
}
/**
@ -133,18 +74,4 @@ public abstract class MessageToMessageDecoder<I>
* @throws Exception is thrown if an error accour
*/
protected abstract Object decode(ChannelHandlerContext ctx, I msg) throws Exception;
/**
* Is called after a message was processed via {@link #decode(ChannelHandlerContext, Object)} to free
* up any resources that is held by the inbound message. You may want to override this if your implementation
* just pass-through the input message or need it for later usage.
*/
protected void freeInboundMessage(I msg) throws Exception {
ChannelHandlerUtil.freeMessage(msg);
}
@Override
public void freeInboundBuffer(ChannelHandlerContext ctx) throws Exception {
ctx.inboundMessageBuffer().free();
}
}

View File

@ -53,8 +53,6 @@ public class Base64Decoder extends MessageToMessageDecoder<ByteBuf> {
}
public Base64Decoder(Base64Dialect dialect) {
super(ByteBuf.class);
if (dialect == null) {
throw new NullPointerException("dialect");
}

View File

@ -48,10 +48,6 @@ import io.netty.handler.codec.MessageToMessageDecoder;
*/
public class ByteArrayDecoder extends MessageToMessageDecoder<ByteBuf> {
public ByteArrayDecoder() {
super(ByteBuf.class);
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
byte[] array;

View File

@ -15,6 +15,9 @@
*/
package io.netty.handler.codec.protobuf;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandler.Sharable;
@ -25,10 +28,6 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.MessageToMessageDecoder;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
/**
* Decodes a received {@link ByteBuf} into a
* <a href="http://code.google.com/p/protobuf/">Google Protocol Buffers</a>
@ -74,8 +73,6 @@ public class ProtobufDecoder extends MessageToMessageDecoder<ByteBuf> {
}
public ProtobufDecoder(MessageLite prototype, ExtensionRegistry extensionRegistry) {
super(ByteBuf.class);
if (prototype == null) {
throw new NullPointerException("prototype");
}

View File

@ -68,8 +68,6 @@ public class StringDecoder extends MessageToMessageDecoder<ByteBuf> {
* Creates a new instance with the specified character set.
*/
public StringDecoder(Charset charset) {
super(ByteBuf.class);
if (charset == null) {
throw new NullPointerException("charset");
}

View File

@ -15,18 +15,19 @@
*/
package io.netty.handler.codec.bytes;
import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.core.Is.*;
import static org.junit.Assert.*;
import io.netty.channel.embedded.EmbeddedMessageChannel;
import java.util.Random;
import org.junit.Before;
import org.junit.Test;
import java.util.Random;
import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.core.Is.*;
import static org.junit.Assert.*;
/**
*/
@SuppressWarnings("ZeroLengthArrayAllocation")
public class ByteArrayDecoderTest {
private EmbeddedMessageChannel ch;

View File

@ -40,8 +40,8 @@ public class SctpInboundByteStreamHandler extends ChannelInboundMessageHandlerAd
}
@Override
public boolean isSupported(Object msg) throws Exception {
if (super.isSupported(msg)) {
public boolean acceptInboundMessage(Object msg) throws Exception {
if (super.acceptInboundMessage(msg)) {
return isDecodable((SctpMessage) msg);
}
return false;

View File

@ -23,7 +23,7 @@ import io.netty.handler.codec.MessageToMessageDecoder;
public abstract class SctpMessageToMessageDecoder extends MessageToMessageDecoder<SctpMessage> {
@Override
public boolean isDecodable(Object msg) throws Exception {
public boolean acceptInboundMessage(Object msg) throws Exception {
if (msg instanceof SctpMessage) {
SctpMessage sctpMsg = (SctpMessage) msg;
if (sctpMsg.isComplete()) {

View File

@ -18,7 +18,8 @@ package io.netty.channel;
import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@ -50,36 +51,38 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
private static final ConcurrentMap<Class<?>, Class<?>> messageTypeMap =
new ConcurrentHashMap<Class<?>, Class<?>>();
private final Class<?> acceptedMsgType = findMessageType();
private final Class<?> acceptedMsgType;
private Class<?> findMessageType() {
Class<?> thisClass = getClass();
protected ChannelInboundMessageHandlerAdapter() {
this(ChannelInboundMessageHandlerAdapter.class, 0);
}
protected ChannelInboundMessageHandlerAdapter(
@SuppressWarnings("rawtypes")
Class<? extends ChannelInboundMessageHandlerAdapter> parameterizedHandlerType,
int messageTypeParamIndex) {
acceptedMsgType = findMessageType(parameterizedHandlerType, messageTypeParamIndex);
}
private Class<?> findMessageType(Class<?> parameterizedHandlerType, int messageTypeParamIndex) {
final Class<?> thisClass = getClass();
Class<?> messageType = messageTypeMap.get(thisClass);
if (messageType == null) {
for (Method m: getClass().getDeclaredMethods()) {
if (!"messageReceived".equals(m.getName())) {
continue;
}
if (m.isSynthetic() || m.isBridge()) {
continue;
}
Class<?>[] p = m.getParameterTypes();
if (p.length != 2) {
continue;
}
if (p[0] != ChannelHandlerContext.class) {
continue;
}
messageType = p[1];
break;
}
if (messageType == null) {
Class<?> currentClass = thisClass;
for (;;) {
if (currentClass.getSuperclass() == parameterizedHandlerType) {
Type[] types = ((ParameterizedType) currentClass.getGenericSuperclass()).getActualTypeArguments();
if (types.length - 1 < messageTypeParamIndex || !(types[0] instanceof Class)) {
throw new IllegalStateException(
"cannot determine the inbound message type of " + thisClass.getSimpleName());
}
messageType = (Class<?>) types[0];
break;
}
currentClass = currentClass.getSuperclass();
}
messageTypeMap.put(thisClass, messageType);
}
@ -114,7 +117,7 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
}
try {
if (!isSupported(msg)) {
if (!acceptInboundMessage(msg)) {
out.add(msg);
unsupportedFound = true;
continue;
@ -153,7 +156,7 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
*
* @param msg the message
*/
public boolean isSupported(Object msg) throws Exception {
public boolean acceptInboundMessage(Object msg) throws Exception {
return acceptedMsgType.isInstance(msg);
}

View File

@ -156,9 +156,9 @@ public class DefaultChannelPipelineTest {
boolean called;
@Override
public boolean isSupported(Object msg) throws Exception {
public boolean acceptInboundMessage(Object msg) throws Exception {
called = true;
return super.isSupported(msg);
return super.acceptInboundMessage(msg);
}
@Override