More robust automatic messageType detection for ChannelInboundMessageHandlerAdapter and MessageToMessageDecoder

This commit is contained in:
Trustin Lee 2013-02-08 15:44:41 +09:00
parent 38ee575839
commit fa1b49de98
14 changed files with 70 additions and 155 deletions

View File

@ -41,21 +41,14 @@ import io.netty.handler.codec.MessageToMessageDecoder;
* so that this handler can intercept HTTP requests after {@link HttpObjectDecoder} * so that this handler can intercept HTTP requests after {@link HttpObjectDecoder}
* converts {@link ByteBuf}s into HTTP requests. * converts {@link ByteBuf}s into HTTP requests.
*/ */
public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object> { public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObject> {
private EmbeddedByteChannel decoder; private EmbeddedByteChannel decoder;
private HttpMessage message; private HttpMessage message;
private boolean decodeStarted; private boolean decodeStarted;
/**
* Creates a new instance.
*/
protected HttpContentDecoder() {
super(HttpObject.class);
}
@Override @Override
protected Object decode(ChannelHandlerContext ctx, Object msg) throws Exception { protected Object decode(ChannelHandlerContext ctx, HttpObject msg) throws Exception {
if (msg instanceof HttpResponse && ((HttpResponse) msg).getStatus().code() == 100) { if (msg instanceof HttpResponse && ((HttpResponse) msg).getStatus().code() == 100) {
// 100-continue response must be passed through. // 100-continue response must be passed through.
return msg; return msg;
@ -115,7 +108,7 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object>
} }
@Override @Override
protected void freeInboundMessage(Object msg) throws Exception { protected void freeInboundMessage(HttpObject msg) throws Exception {
if (decoder == null) { if (decoder == null) {
// if the decoder was null we returned the original message so we are not allowed to free it // if the decoder was null we returned the original message so we are not allowed to free it
return; return;

View File

@ -66,8 +66,6 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
* a {@link TooLongFrameException} will be raised. * a {@link TooLongFrameException} will be raised.
*/ */
public HttpObjectAggregator(int maxContentLength) { public HttpObjectAggregator(int maxContentLength) {
super(HttpObject.class);
if (maxContentLength <= 0) { if (maxContentLength <= 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"maxContentLength must be a positive integer: " + "maxContentLength must be a positive integer: " +

View File

@ -52,8 +52,6 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
* a {@link TooLongFrameException} will be raised. * a {@link TooLongFrameException} will be raised.
*/ */
public SpdyHttpDecoder(int version, int maxContentLength) { public SpdyHttpDecoder(int version, int maxContentLength) {
super(SpdyDataFrame.class, SpdyControlFrame.class);
if (version < SpdyConstants.SPDY_MIN_VERSION || version > SpdyConstants.SPDY_MAX_VERSION) { if (version < SpdyConstants.SPDY_MIN_VERSION || version > SpdyConstants.SPDY_MAX_VERSION) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"unsupported version: " + version); "unsupported version: " + version);
@ -66,6 +64,11 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
this.maxContentLength = maxContentLength; this.maxContentLength = maxContentLength;
} }
@Override
public boolean acceptInboundMessage(Object msg) throws Exception {
return msg instanceof SpdyDataFrame || msg instanceof SpdyControlFrame;
}
@Override @Override
public Object decode(ChannelHandlerContext ctx, Object msg) throws Exception { public Object decode(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof SpdySynStreamFrame) { if (msg instanceof SpdySynStreamFrame) {

View File

@ -73,9 +73,10 @@ public abstract class MessageToMessageCodec<INBOUND_IN, OUTBOUND_IN>
private final MessageToMessageDecoder<INBOUND_IN> decoder = private final MessageToMessageDecoder<INBOUND_IN> decoder =
new MessageToMessageDecoder<INBOUND_IN>() { new MessageToMessageDecoder<INBOUND_IN>() {
@Override @Override
public boolean isDecodable(Object msg) throws Exception { public boolean acceptInboundMessage(Object msg) throws Exception {
return MessageToMessageCodec.this.isDecodable(msg); return isDecodable(msg);
} }
@Override @Override

View File

@ -16,12 +16,9 @@
package io.netty.handler.codec; package io.netty.handler.codec;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.channel.ChannelInboundMessageHandler; import io.netty.channel.ChannelInboundMessageHandler;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.ChannelStateHandlerAdapter;
/** /**
* {@link ChannelInboundMessageHandler} which decodes from one message to an other message * {@link ChannelInboundMessageHandler} which decodes from one message to an other message
@ -45,81 +42,25 @@ import io.netty.channel.ChannelStateHandlerAdapter;
* </pre> * </pre>
* *
*/ */
public abstract class MessageToMessageDecoder<I> public abstract class MessageToMessageDecoder<I> extends ChannelInboundMessageHandlerAdapter<I> {
extends ChannelStateHandlerAdapter implements ChannelInboundMessageHandler<I> {
private final Class<?>[] acceptedMsgTypes; protected MessageToMessageDecoder() {
super(MessageToMessageDecoder.class, 0);
}
/** protected MessageToMessageDecoder(
* The types which will be accepted by the decoder. If a received message is an other type it will be just forwarded @SuppressWarnings("rawtypes")
* to the next {@link ChannelInboundMessageHandler} in the {@link ChannelPipeline} Class<? extends ChannelInboundMessageHandlerAdapter> parameterizedHandlerType,
*/ int messageTypeParamIndex) {
protected MessageToMessageDecoder(Class<?>... acceptedMsgTypes) { super(parameterizedHandlerType, messageTypeParamIndex);
this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
} }
@Override @Override
public MessageBuf<I> newInboundBuffer(ChannelHandlerContext ctx) throws Exception { protected final void messageReceived(ChannelHandlerContext ctx, I msg) throws Exception {
return Unpooled.messageBuffer(); Object decoded = decode(ctx, msg);
if (decoded != null) {
ctx.nextInboundMessageBuffer().add(decoded);
} }
@Override
public void inboundBufferUpdated(ChannelHandlerContext ctx)
throws Exception {
MessageBuf<I> in = ctx.inboundMessageBuffer();
MessageBuf<Object> out = ctx.nextInboundMessageBuffer();
boolean notify = false;
for (;;) {
try {
Object msg = in.poll();
if (msg == null) {
break;
}
if (!isDecodable(msg)) {
out.add(msg);
notify = true;
continue;
}
@SuppressWarnings("unchecked")
I imsg = (I) msg;
boolean free = true;
try {
Object omsg = decode(ctx, imsg);
if (omsg == null) {
// Decoder consumed a message but returned null.
// Probably it needs more messages because it's an aggregator.
continue;
}
if (omsg == imsg) {
free = false;
}
if (ChannelHandlerUtil.unfoldAndAdd(ctx, omsg, true)) {
notify = true;
}
} finally {
if (free) {
freeInboundMessage(imsg);
}
}
} catch (Throwable t) {
if (t instanceof CodecException) {
ctx.fireExceptionCaught(t);
} else {
ctx.fireExceptionCaught(new DecoderException(t));
}
}
}
if (notify) {
ctx.fireInboundBufferUpdated();
}
}
/**
* Returns {@code true} if and only if the specified message can be decoded by this decoder.
*/
public boolean isDecodable(Object msg) throws Exception {
return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
} }
/** /**
@ -133,18 +74,4 @@ public abstract class MessageToMessageDecoder<I>
* @throws Exception is thrown if an error accour * @throws Exception is thrown if an error accour
*/ */
protected abstract Object decode(ChannelHandlerContext ctx, I msg) throws Exception; protected abstract Object decode(ChannelHandlerContext ctx, I msg) throws Exception;
/**
* Is called after a message was processed via {@link #decode(ChannelHandlerContext, Object)} to free
* up any resources that is held by the inbound message. You may want to override this if your implementation
* just pass-through the input message or need it for later usage.
*/
protected void freeInboundMessage(I msg) throws Exception {
ChannelHandlerUtil.freeMessage(msg);
}
@Override
public void freeInboundBuffer(ChannelHandlerContext ctx) throws Exception {
ctx.inboundMessageBuffer().free();
}
} }

View File

@ -53,8 +53,6 @@ public class Base64Decoder extends MessageToMessageDecoder<ByteBuf> {
} }
public Base64Decoder(Base64Dialect dialect) { public Base64Decoder(Base64Dialect dialect) {
super(ByteBuf.class);
if (dialect == null) { if (dialect == null) {
throw new NullPointerException("dialect"); throw new NullPointerException("dialect");
} }

View File

@ -48,10 +48,6 @@ import io.netty.handler.codec.MessageToMessageDecoder;
*/ */
public class ByteArrayDecoder extends MessageToMessageDecoder<ByteBuf> { public class ByteArrayDecoder extends MessageToMessageDecoder<ByteBuf> {
public ByteArrayDecoder() {
super(ByteBuf.class);
}
@Override @Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { protected Object decode(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
byte[] array; byte[] array;

View File

@ -15,6 +15,9 @@
*/ */
package io.netty.handler.codec.protobuf; package io.netty.handler.codec.protobuf;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandler.Sharable;
@ -25,10 +28,6 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageDecoder;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
/** /**
* Decodes a received {@link ByteBuf} into a * Decodes a received {@link ByteBuf} into a
* <a href="http://code.google.com/p/protobuf/">Google Protocol Buffers</a> * <a href="http://code.google.com/p/protobuf/">Google Protocol Buffers</a>
@ -74,8 +73,6 @@ public class ProtobufDecoder extends MessageToMessageDecoder<ByteBuf> {
} }
public ProtobufDecoder(MessageLite prototype, ExtensionRegistry extensionRegistry) { public ProtobufDecoder(MessageLite prototype, ExtensionRegistry extensionRegistry) {
super(ByteBuf.class);
if (prototype == null) { if (prototype == null) {
throw new NullPointerException("prototype"); throw new NullPointerException("prototype");
} }

View File

@ -68,8 +68,6 @@ public class StringDecoder extends MessageToMessageDecoder<ByteBuf> {
* Creates a new instance with the specified character set. * Creates a new instance with the specified character set.
*/ */
public StringDecoder(Charset charset) { public StringDecoder(Charset charset) {
super(ByteBuf.class);
if (charset == null) { if (charset == null) {
throw new NullPointerException("charset"); throw new NullPointerException("charset");
} }

View File

@ -15,18 +15,19 @@
*/ */
package io.netty.handler.codec.bytes; package io.netty.handler.codec.bytes;
import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.core.Is.*;
import static org.junit.Assert.*;
import io.netty.channel.embedded.EmbeddedMessageChannel; import io.netty.channel.embedded.EmbeddedMessageChannel;
import java.util.Random;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.util.Random;
import static io.netty.buffer.Unpooled.*;
import static org.hamcrest.core.Is.*;
import static org.junit.Assert.*;
/** /**
*/ */
@SuppressWarnings("ZeroLengthArrayAllocation")
public class ByteArrayDecoderTest { public class ByteArrayDecoderTest {
private EmbeddedMessageChannel ch; private EmbeddedMessageChannel ch;

View File

@ -40,8 +40,8 @@ public class SctpInboundByteStreamHandler extends ChannelInboundMessageHandlerAd
} }
@Override @Override
public boolean isSupported(Object msg) throws Exception { public boolean acceptInboundMessage(Object msg) throws Exception {
if (super.isSupported(msg)) { if (super.acceptInboundMessage(msg)) {
return isDecodable((SctpMessage) msg); return isDecodable((SctpMessage) msg);
} }
return false; return false;

View File

@ -23,7 +23,7 @@ import io.netty.handler.codec.MessageToMessageDecoder;
public abstract class SctpMessageToMessageDecoder extends MessageToMessageDecoder<SctpMessage> { public abstract class SctpMessageToMessageDecoder extends MessageToMessageDecoder<SctpMessage> {
@Override @Override
public boolean isDecodable(Object msg) throws Exception { public boolean acceptInboundMessage(Object msg) throws Exception {
if (msg instanceof SctpMessage) { if (msg instanceof SctpMessage) {
SctpMessage sctpMsg = (SctpMessage) msg; SctpMessage sctpMsg = (SctpMessage) msg;
if (sctpMsg.isComplete()) { if (sctpMsg.isComplete()) {

View File

@ -18,7 +18,8 @@ package io.netty.channel;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentMap;
@ -50,36 +51,38 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
private static final ConcurrentMap<Class<?>, Class<?>> messageTypeMap = private static final ConcurrentMap<Class<?>, Class<?>> messageTypeMap =
new ConcurrentHashMap<Class<?>, Class<?>>(); new ConcurrentHashMap<Class<?>, Class<?>>();
private final Class<?> acceptedMsgType = findMessageType(); private final Class<?> acceptedMsgType;
private Class<?> findMessageType() { protected ChannelInboundMessageHandlerAdapter() {
Class<?> thisClass = getClass(); this(ChannelInboundMessageHandlerAdapter.class, 0);
}
protected ChannelInboundMessageHandlerAdapter(
@SuppressWarnings("rawtypes")
Class<? extends ChannelInboundMessageHandlerAdapter> parameterizedHandlerType,
int messageTypeParamIndex) {
acceptedMsgType = findMessageType(parameterizedHandlerType, messageTypeParamIndex);
}
private Class<?> findMessageType(Class<?> parameterizedHandlerType, int messageTypeParamIndex) {
final Class<?> thisClass = getClass();
Class<?> messageType = messageTypeMap.get(thisClass); Class<?> messageType = messageTypeMap.get(thisClass);
if (messageType == null) { if (messageType == null) {
for (Method m: getClass().getDeclaredMethods()) { Class<?> currentClass = thisClass;
if (!"messageReceived".equals(m.getName())) { for (;;) {
continue; if (currentClass.getSuperclass() == parameterizedHandlerType) {
} Type[] types = ((ParameterizedType) currentClass.getGenericSuperclass()).getActualTypeArguments();
if (m.isSynthetic() || m.isBridge()) { if (types.length - 1 < messageTypeParamIndex || !(types[0] instanceof Class)) {
continue;
}
Class<?>[] p = m.getParameterTypes();
if (p.length != 2) {
continue;
}
if (p[0] != ChannelHandlerContext.class) {
continue;
}
messageType = p[1];
break;
}
if (messageType == null) {
throw new IllegalStateException( throw new IllegalStateException(
"cannot determine the inbound message type of " + thisClass.getSimpleName()); "cannot determine the inbound message type of " + thisClass.getSimpleName());
} }
messageType = (Class<?>) types[0];
break;
}
currentClass = currentClass.getSuperclass();
}
messageTypeMap.put(thisClass, messageType); messageTypeMap.put(thisClass, messageType);
} }
@ -114,7 +117,7 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
} }
try { try {
if (!isSupported(msg)) { if (!acceptInboundMessage(msg)) {
out.add(msg); out.add(msg);
unsupportedFound = true; unsupportedFound = true;
continue; continue;
@ -153,7 +156,7 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
* *
* @param msg the message * @param msg the message
*/ */
public boolean isSupported(Object msg) throws Exception { public boolean acceptInboundMessage(Object msg) throws Exception {
return acceptedMsgType.isInstance(msg); return acceptedMsgType.isInstance(msg);
} }

View File

@ -156,9 +156,9 @@ public class DefaultChannelPipelineTest {
boolean called; boolean called;
@Override @Override
public boolean isSupported(Object msg) throws Exception { public boolean acceptInboundMessage(Object msg) throws Exception {
called = true; called = true;
return super.isSupported(msg); return super.acceptInboundMessage(msg);
} }
@Override @Override