More robust automatic messageType detection for ChannelInboundMessageHandlerAdapter and MessageToMessageDecoder
This commit is contained in:
parent
38ee575839
commit
fa1b49de98
@ -41,21 +41,14 @@ import io.netty.handler.codec.MessageToMessageDecoder;
|
||||
* so that this handler can intercept HTTP requests after {@link HttpObjectDecoder}
|
||||
* converts {@link ByteBuf}s into HTTP requests.
|
||||
*/
|
||||
public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object> {
|
||||
public abstract class HttpContentDecoder extends MessageToMessageDecoder<HttpObject> {
|
||||
|
||||
private EmbeddedByteChannel decoder;
|
||||
private HttpMessage message;
|
||||
private boolean decodeStarted;
|
||||
|
||||
/**
|
||||
* Creates a new instance.
|
||||
*/
|
||||
protected HttpContentDecoder() {
|
||||
super(HttpObject.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object decode(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
protected Object decode(ChannelHandlerContext ctx, HttpObject msg) throws Exception {
|
||||
if (msg instanceof HttpResponse && ((HttpResponse) msg).getStatus().code() == 100) {
|
||||
// 100-continue response must be passed through.
|
||||
return msg;
|
||||
@ -115,7 +108,7 @@ public abstract class HttpContentDecoder extends MessageToMessageDecoder<Object>
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void freeInboundMessage(Object msg) throws Exception {
|
||||
protected void freeInboundMessage(HttpObject msg) throws Exception {
|
||||
if (decoder == null) {
|
||||
// if the decoder was null we returned the original message so we are not allowed to free it
|
||||
return;
|
||||
|
@ -66,8 +66,6 @@ public class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
|
||||
* a {@link TooLongFrameException} will be raised.
|
||||
*/
|
||||
public HttpObjectAggregator(int maxContentLength) {
|
||||
super(HttpObject.class);
|
||||
|
||||
if (maxContentLength <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"maxContentLength must be a positive integer: " +
|
||||
|
@ -52,8 +52,6 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
|
||||
* a {@link TooLongFrameException} will be raised.
|
||||
*/
|
||||
public SpdyHttpDecoder(int version, int maxContentLength) {
|
||||
super(SpdyDataFrame.class, SpdyControlFrame.class);
|
||||
|
||||
if (version < SpdyConstants.SPDY_MIN_VERSION || version > SpdyConstants.SPDY_MAX_VERSION) {
|
||||
throw new IllegalArgumentException(
|
||||
"unsupported version: " + version);
|
||||
@ -66,6 +64,11 @@ public class SpdyHttpDecoder extends MessageToMessageDecoder<Object> {
|
||||
this.maxContentLength = maxContentLength;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean acceptInboundMessage(Object msg) throws Exception {
|
||||
return msg instanceof SpdyDataFrame || msg instanceof SpdyControlFrame;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object decode(ChannelHandlerContext ctx, Object msg) throws Exception {
|
||||
if (msg instanceof SpdySynStreamFrame) {
|
||||
|
@ -73,9 +73,10 @@ public abstract class MessageToMessageCodec<INBOUND_IN, OUTBOUND_IN>
|
||||
|
||||
private final MessageToMessageDecoder<INBOUND_IN> decoder =
|
||||
new MessageToMessageDecoder<INBOUND_IN>() {
|
||||
|
||||
@Override
|
||||
public boolean isDecodable(Object msg) throws Exception {
|
||||
return MessageToMessageCodec.this.isDecodable(msg);
|
||||
public boolean acceptInboundMessage(Object msg) throws Exception {
|
||||
return isDecodable(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -16,12 +16,9 @@
|
||||
package io.netty.handler.codec;
|
||||
|
||||
import io.netty.buffer.MessageBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelHandlerUtil;
|
||||
import io.netty.channel.ChannelInboundMessageHandler;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.ChannelStateHandlerAdapter;
|
||||
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
|
||||
|
||||
/**
|
||||
* {@link ChannelInboundMessageHandler} which decodes from one message to an other message
|
||||
@ -45,81 +42,25 @@ import io.netty.channel.ChannelStateHandlerAdapter;
|
||||
* </pre>
|
||||
*
|
||||
*/
|
||||
public abstract class MessageToMessageDecoder<I>
|
||||
extends ChannelStateHandlerAdapter implements ChannelInboundMessageHandler<I> {
|
||||
public abstract class MessageToMessageDecoder<I> extends ChannelInboundMessageHandlerAdapter<I> {
|
||||
|
||||
private final Class<?>[] acceptedMsgTypes;
|
||||
protected MessageToMessageDecoder() {
|
||||
super(MessageToMessageDecoder.class, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* The types which will be accepted by the decoder. If a received message is an other type it will be just forwarded
|
||||
* to the next {@link ChannelInboundMessageHandler} in the {@link ChannelPipeline}
|
||||
*/
|
||||
protected MessageToMessageDecoder(Class<?>... acceptedMsgTypes) {
|
||||
this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
|
||||
protected MessageToMessageDecoder(
|
||||
@SuppressWarnings("rawtypes")
|
||||
Class<? extends ChannelInboundMessageHandlerAdapter> parameterizedHandlerType,
|
||||
int messageTypeParamIndex) {
|
||||
super(parameterizedHandlerType, messageTypeParamIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MessageBuf<I> newInboundBuffer(ChannelHandlerContext ctx) throws Exception {
|
||||
return Unpooled.messageBuffer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void inboundBufferUpdated(ChannelHandlerContext ctx)
|
||||
throws Exception {
|
||||
MessageBuf<I> in = ctx.inboundMessageBuffer();
|
||||
MessageBuf<Object> out = ctx.nextInboundMessageBuffer();
|
||||
boolean notify = false;
|
||||
for (;;) {
|
||||
try {
|
||||
Object msg = in.poll();
|
||||
if (msg == null) {
|
||||
break;
|
||||
}
|
||||
if (!isDecodable(msg)) {
|
||||
out.add(msg);
|
||||
notify = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
I imsg = (I) msg;
|
||||
boolean free = true;
|
||||
try {
|
||||
Object omsg = decode(ctx, imsg);
|
||||
if (omsg == null) {
|
||||
// Decoder consumed a message but returned null.
|
||||
// Probably it needs more messages because it's an aggregator.
|
||||
continue;
|
||||
}
|
||||
if (omsg == imsg) {
|
||||
free = false;
|
||||
}
|
||||
if (ChannelHandlerUtil.unfoldAndAdd(ctx, omsg, true)) {
|
||||
notify = true;
|
||||
}
|
||||
} finally {
|
||||
if (free) {
|
||||
freeInboundMessage(imsg);
|
||||
}
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
if (t instanceof CodecException) {
|
||||
ctx.fireExceptionCaught(t);
|
||||
} else {
|
||||
ctx.fireExceptionCaught(new DecoderException(t));
|
||||
}
|
||||
}
|
||||
protected final void messageReceived(ChannelHandlerContext ctx, I msg) throws Exception {
|
||||
Object decoded = decode(ctx, msg);
|
||||
if (decoded != null) {
|
||||
ctx.nextInboundMessageBuffer().add(decoded);
|
||||
}
|
||||
if (notify) {
|
||||
ctx.fireInboundBufferUpdated();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns {@code true} if and only if the specified message can be decoded by this decoder.
|
||||
*/
|
||||
public boolean isDecodable(Object msg) throws Exception {
|
||||
return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -133,18 +74,4 @@ public abstract class MessageToMessageDecoder<I>
|
||||
* @throws Exception is thrown if an error accour
|
||||
*/
|
||||
protected abstract Object decode(ChannelHandlerContext ctx, I msg) throws Exception;
|
||||
|
||||
/**
|
||||
* Is called after a message was processed via {@link #decode(ChannelHandlerContext, Object)} to free
|
||||
* up any resources that is held by the inbound message. You may want to override this if your implementation
|
||||
* just pass-through the input message or need it for later usage.
|
||||
*/
|
||||
protected void freeInboundMessage(I msg) throws Exception {
|
||||
ChannelHandlerUtil.freeMessage(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void freeInboundBuffer(ChannelHandlerContext ctx) throws Exception {
|
||||
ctx.inboundMessageBuffer().free();
|
||||
}
|
||||
}
|
||||
|
@ -53,8 +53,6 @@ public class Base64Decoder extends MessageToMessageDecoder<ByteBuf> {
|
||||
}
|
||||
|
||||
public Base64Decoder(Base64Dialect dialect) {
|
||||
super(ByteBuf.class);
|
||||
|
||||
if (dialect == null) {
|
||||
throw new NullPointerException("dialect");
|
||||
}
|
||||
|
@ -48,10 +48,6 @@ import io.netty.handler.codec.MessageToMessageDecoder;
|
||||
*/
|
||||
public class ByteArrayDecoder extends MessageToMessageDecoder<ByteBuf> {
|
||||
|
||||
public ByteArrayDecoder() {
|
||||
super(ByteBuf.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Object decode(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
|
||||
byte[] array;
|
||||
|
@ -15,6 +15,9 @@
|
||||
*/
|
||||
package io.netty.handler.codec.protobuf;
|
||||
|
||||
import com.google.protobuf.ExtensionRegistry;
|
||||
import com.google.protobuf.Message;
|
||||
import com.google.protobuf.MessageLite;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.ByteBufInputStream;
|
||||
import io.netty.channel.ChannelHandler.Sharable;
|
||||
@ -25,10 +28,6 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
||||
import io.netty.handler.codec.LengthFieldPrepender;
|
||||
import io.netty.handler.codec.MessageToMessageDecoder;
|
||||
|
||||
import com.google.protobuf.ExtensionRegistry;
|
||||
import com.google.protobuf.Message;
|
||||
import com.google.protobuf.MessageLite;
|
||||
|
||||
/**
|
||||
* Decodes a received {@link ByteBuf} into a
|
||||
* <a href="http://code.google.com/p/protobuf/">Google Protocol Buffers</a>
|
||||
@ -74,8 +73,6 @@ public class ProtobufDecoder extends MessageToMessageDecoder<ByteBuf> {
|
||||
}
|
||||
|
||||
public ProtobufDecoder(MessageLite prototype, ExtensionRegistry extensionRegistry) {
|
||||
super(ByteBuf.class);
|
||||
|
||||
if (prototype == null) {
|
||||
throw new NullPointerException("prototype");
|
||||
}
|
||||
|
@ -68,8 +68,6 @@ public class StringDecoder extends MessageToMessageDecoder<ByteBuf> {
|
||||
* Creates a new instance with the specified character set.
|
||||
*/
|
||||
public StringDecoder(Charset charset) {
|
||||
super(ByteBuf.class);
|
||||
|
||||
if (charset == null) {
|
||||
throw new NullPointerException("charset");
|
||||
}
|
||||
|
@ -15,18 +15,19 @@
|
||||
*/
|
||||
package io.netty.handler.codec.bytes;
|
||||
|
||||
import static io.netty.buffer.Unpooled.*;
|
||||
import static org.hamcrest.core.Is.*;
|
||||
import static org.junit.Assert.*;
|
||||
import io.netty.channel.embedded.EmbeddedMessageChannel;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
import static io.netty.buffer.Unpooled.*;
|
||||
import static org.hamcrest.core.Is.*;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
*/
|
||||
@SuppressWarnings("ZeroLengthArrayAllocation")
|
||||
public class ByteArrayDecoderTest {
|
||||
|
||||
private EmbeddedMessageChannel ch;
|
||||
|
@ -40,8 +40,8 @@ public class SctpInboundByteStreamHandler extends ChannelInboundMessageHandlerAd
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSupported(Object msg) throws Exception {
|
||||
if (super.isSupported(msg)) {
|
||||
public boolean acceptInboundMessage(Object msg) throws Exception {
|
||||
if (super.acceptInboundMessage(msg)) {
|
||||
return isDecodable((SctpMessage) msg);
|
||||
}
|
||||
return false;
|
||||
|
@ -23,7 +23,7 @@ import io.netty.handler.codec.MessageToMessageDecoder;
|
||||
public abstract class SctpMessageToMessageDecoder extends MessageToMessageDecoder<SctpMessage> {
|
||||
|
||||
@Override
|
||||
public boolean isDecodable(Object msg) throws Exception {
|
||||
public boolean acceptInboundMessage(Object msg) throws Exception {
|
||||
if (msg instanceof SctpMessage) {
|
||||
SctpMessage sctpMsg = (SctpMessage) msg;
|
||||
if (sctpMsg.isComplete()) {
|
||||
|
@ -18,7 +18,8 @@ package io.netty.channel;
|
||||
import io.netty.buffer.MessageBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
|
||||
@ -50,34 +51,36 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
|
||||
private static final ConcurrentMap<Class<?>, Class<?>> messageTypeMap =
|
||||
new ConcurrentHashMap<Class<?>, Class<?>>();
|
||||
|
||||
private final Class<?> acceptedMsgType = findMessageType();
|
||||
private final Class<?> acceptedMsgType;
|
||||
|
||||
private Class<?> findMessageType() {
|
||||
Class<?> thisClass = getClass();
|
||||
protected ChannelInboundMessageHandlerAdapter() {
|
||||
this(ChannelInboundMessageHandlerAdapter.class, 0);
|
||||
}
|
||||
|
||||
protected ChannelInboundMessageHandlerAdapter(
|
||||
@SuppressWarnings("rawtypes")
|
||||
Class<? extends ChannelInboundMessageHandlerAdapter> parameterizedHandlerType,
|
||||
int messageTypeParamIndex) {
|
||||
acceptedMsgType = findMessageType(parameterizedHandlerType, messageTypeParamIndex);
|
||||
}
|
||||
|
||||
private Class<?> findMessageType(Class<?> parameterizedHandlerType, int messageTypeParamIndex) {
|
||||
final Class<?> thisClass = getClass();
|
||||
Class<?> messageType = messageTypeMap.get(thisClass);
|
||||
if (messageType == null) {
|
||||
for (Method m: getClass().getDeclaredMethods()) {
|
||||
if (!"messageReceived".equals(m.getName())) {
|
||||
continue;
|
||||
}
|
||||
if (m.isSynthetic() || m.isBridge()) {
|
||||
continue;
|
||||
}
|
||||
Class<?>[] p = m.getParameterTypes();
|
||||
if (p.length != 2) {
|
||||
continue;
|
||||
}
|
||||
if (p[0] != ChannelHandlerContext.class) {
|
||||
continue;
|
||||
}
|
||||
Class<?> currentClass = thisClass;
|
||||
for (;;) {
|
||||
if (currentClass.getSuperclass() == parameterizedHandlerType) {
|
||||
Type[] types = ((ParameterizedType) currentClass.getGenericSuperclass()).getActualTypeArguments();
|
||||
if (types.length - 1 < messageTypeParamIndex || !(types[0] instanceof Class)) {
|
||||
throw new IllegalStateException(
|
||||
"cannot determine the inbound message type of " + thisClass.getSimpleName());
|
||||
}
|
||||
|
||||
messageType = p[1];
|
||||
break;
|
||||
}
|
||||
|
||||
if (messageType == null) {
|
||||
throw new IllegalStateException(
|
||||
"cannot determine the inbound message type of " + thisClass.getSimpleName());
|
||||
messageType = (Class<?>) types[0];
|
||||
break;
|
||||
}
|
||||
currentClass = currentClass.getSuperclass();
|
||||
}
|
||||
|
||||
messageTypeMap.put(thisClass, messageType);
|
||||
@ -114,7 +117,7 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
|
||||
}
|
||||
|
||||
try {
|
||||
if (!isSupported(msg)) {
|
||||
if (!acceptInboundMessage(msg)) {
|
||||
out.add(msg);
|
||||
unsupportedFound = true;
|
||||
continue;
|
||||
@ -153,7 +156,7 @@ public abstract class ChannelInboundMessageHandlerAdapter<I>
|
||||
*
|
||||
* @param msg the message
|
||||
*/
|
||||
public boolean isSupported(Object msg) throws Exception {
|
||||
public boolean acceptInboundMessage(Object msg) throws Exception {
|
||||
return acceptedMsgType.isInstance(msg);
|
||||
}
|
||||
|
||||
|
@ -156,9 +156,9 @@ public class DefaultChannelPipelineTest {
|
||||
boolean called;
|
||||
|
||||
@Override
|
||||
public boolean isSupported(Object msg) throws Exception {
|
||||
public boolean acceptInboundMessage(Object msg) throws Exception {
|
||||
called = true;
|
||||
return super.isSupported(msg);
|
||||
return super.acceptInboundMessage(msg);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
Loading…
Reference in New Issue
Block a user