Check if message is supported before cast. See #678

This commit is contained in:
Norman Maurer 2012-10-24 07:03:02 +02:00
parent c43b9b4dd2
commit 985fa97c9b
8 changed files with 79 additions and 40 deletions

View File

@ -21,6 +21,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandler; import io.netty.channel.ChannelInboundByteHandler;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelHandlerUtil;
public abstract class ByteToMessageDecoder<O> public abstract class ByteToMessageDecoder<O>
extends ChannelInboundHandlerAdapter implements ChannelInboundByteHandler { extends ChannelInboundHandlerAdapter implements ChannelInboundByteHandler {
@ -51,7 +52,7 @@ public abstract class ByteToMessageDecoder<O>
} }
try { try {
if (CodecUtil.unfoldAndAdd(ctx, decodeLast(ctx, in), true)) { if (ChannelHandlerUtil.unfoldAndAdd(ctx, decodeLast(ctx, in), true)) {
ctx.fireInboundBufferUpdated(); ctx.fireInboundBufferUpdated();
} }
} catch (Throwable t) { } catch (Throwable t) {
@ -86,7 +87,7 @@ public abstract class ByteToMessageDecoder<O>
} }
} }
if (CodecUtil.unfoldAndAdd(ctx, o, true)) { if (ChannelHandlerUtil.unfoldAndAdd(ctx, o, true)) {
decoded = true; decoded = true;
} else { } else {
break; break;

View File

@ -20,13 +20,14 @@ import io.netty.buffer.MessageBuf;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundMessageHandlerAdapter; import io.netty.channel.ChannelOutboundMessageHandlerAdapter;
import io.netty.channel.ChannelHandlerUtil;
public abstract class MessageToByteEncoder<I> extends ChannelOutboundMessageHandlerAdapter<I> { public abstract class MessageToByteEncoder<I> extends ChannelOutboundMessageHandlerAdapter<I> {
private final Class<?>[] acceptedMsgTypes; private final Class<?>[] acceptedMsgTypes;
protected MessageToByteEncoder(Class<?>... acceptedMsgTypes) { protected MessageToByteEncoder(Class<?>... acceptedMsgTypes) {
this.acceptedMsgTypes = CodecUtil.acceptedMessageTypes(acceptedMsgTypes); this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
} }
@Override @Override
@ -41,7 +42,7 @@ public abstract class MessageToByteEncoder<I> extends ChannelOutboundMessageHand
} }
if (!isEncodable(msg)) { if (!isEncodable(msg)) {
CodecUtil.addToNextOutboundBuffer(ctx, msg); ChannelHandlerUtil.addToNextOutboundBuffer(ctx, msg);
continue; continue;
} }
@ -67,7 +68,7 @@ public abstract class MessageToByteEncoder<I> extends ChannelOutboundMessageHand
* @param msg the message * @param msg the message
*/ */
public boolean isEncodable(Object msg) throws Exception { public boolean isEncodable(Object msg) throws Exception {
return CodecUtil.acceptMessage(acceptedMsgTypes, msg); return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
} }
public abstract void encode(ChannelHandlerContext ctx, I msg, ByteBuf out) throws Exception; public abstract void encode(ChannelHandlerContext ctx, I msg, ByteBuf out) throws Exception;

View File

@ -21,6 +21,7 @@ import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandler; import io.netty.channel.ChannelInboundMessageHandler;
import io.netty.channel.ChannelOutboundMessageHandler; import io.netty.channel.ChannelOutboundMessageHandler;
import io.netty.channel.ChannelHandlerUtil;
public abstract class MessageToMessageCodec<INBOUND_IN, INBOUND_OUT, OUTBOUND_IN, OUTBOUND_OUT> public abstract class MessageToMessageCodec<INBOUND_IN, INBOUND_OUT, OUTBOUND_IN, OUTBOUND_OUT>
extends ChannelHandlerAdapter extends ChannelHandlerAdapter
@ -62,8 +63,8 @@ public abstract class MessageToMessageCodec<INBOUND_IN, INBOUND_OUT, OUTBOUND_IN
protected MessageToMessageCodec( protected MessageToMessageCodec(
Class<?>[] acceptedInboundMsgTypes, Class<?>[] acceptedOutboundMsgTypes) { Class<?>[] acceptedInboundMsgTypes, Class<?>[] acceptedOutboundMsgTypes) {
this.acceptedInboundMsgTypes = CodecUtil.acceptedMessageTypes(acceptedInboundMsgTypes); this.acceptedInboundMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedInboundMsgTypes);
this.acceptedOutboundMsgTypes = CodecUtil.acceptedMessageTypes(acceptedOutboundMsgTypes); this.acceptedOutboundMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedOutboundMsgTypes);
} }
@Override @Override
@ -93,7 +94,7 @@ public abstract class MessageToMessageCodec<INBOUND_IN, INBOUND_OUT, OUTBOUND_IN
* @param msg the message * @param msg the message
*/ */
public boolean isDecodable(Object msg) throws Exception { public boolean isDecodable(Object msg) throws Exception {
return CodecUtil.acceptMessage(acceptedInboundMsgTypes, msg); return ChannelHandlerUtil.acceptMessage(acceptedInboundMsgTypes, msg);
} }
/** /**
@ -102,7 +103,7 @@ public abstract class MessageToMessageCodec<INBOUND_IN, INBOUND_OUT, OUTBOUND_IN
* @param msg the message * @param msg the message
*/ */
public boolean isEncodable(Object msg) throws Exception { public boolean isEncodable(Object msg) throws Exception {
return CodecUtil.acceptMessage(acceptedOutboundMsgTypes, msg); return ChannelHandlerUtil.acceptMessage(acceptedOutboundMsgTypes, msg);
} }
public abstract OUTBOUND_OUT encode(ChannelHandlerContext ctx, OUTBOUND_IN msg) throws Exception; public abstract OUTBOUND_OUT encode(ChannelHandlerContext ctx, OUTBOUND_IN msg) throws Exception;

View File

@ -20,6 +20,7 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInboundMessageHandler; import io.netty.channel.ChannelInboundMessageHandler;
import io.netty.channel.ChannelHandlerUtil;
public abstract class MessageToMessageDecoder<I, O> public abstract class MessageToMessageDecoder<I, O>
extends ChannelInboundHandlerAdapter implements ChannelInboundMessageHandler<I> { extends ChannelInboundHandlerAdapter implements ChannelInboundMessageHandler<I> {
@ -27,7 +28,7 @@ public abstract class MessageToMessageDecoder<I, O>
private final Class<?>[] acceptedMsgTypes; private final Class<?>[] acceptedMsgTypes;
protected MessageToMessageDecoder(Class<?>... acceptedMsgTypes) { protected MessageToMessageDecoder(Class<?>... acceptedMsgTypes) {
this.acceptedMsgTypes = CodecUtil.acceptedMessageTypes(acceptedMsgTypes); this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
} }
@Override @Override
@ -47,7 +48,7 @@ public abstract class MessageToMessageDecoder<I, O>
break; break;
} }
if (!isDecodable(msg)) { if (!isDecodable(msg)) {
CodecUtil.addToNextInboundBuffer(ctx, msg); ChannelHandlerUtil.addToNextInboundBuffer(ctx, msg);
notify = true; notify = true;
continue; continue;
} }
@ -61,7 +62,7 @@ public abstract class MessageToMessageDecoder<I, O>
continue; continue;
} }
if (CodecUtil.unfoldAndAdd(ctx, omsg, true)) { if (ChannelHandlerUtil.unfoldAndAdd(ctx, omsg, true)) {
notify = true; notify = true;
} }
} catch (Throwable t) { } catch (Throwable t) {
@ -83,7 +84,7 @@ public abstract class MessageToMessageDecoder<I, O>
* @param msg the message * @param msg the message
*/ */
public boolean isDecodable(Object msg) throws Exception { public boolean isDecodable(Object msg) throws Exception {
return CodecUtil.acceptMessage(acceptedMsgTypes, msg); return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
} }
public abstract O decode(ChannelHandlerContext ctx, I msg) throws Exception; public abstract O decode(ChannelHandlerContext ctx, I msg) throws Exception;

View File

@ -19,13 +19,14 @@ import io.netty.buffer.MessageBuf;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundMessageHandlerAdapter; import io.netty.channel.ChannelOutboundMessageHandlerAdapter;
import io.netty.channel.ChannelHandlerUtil;
public abstract class MessageToMessageEncoder<I, O> extends ChannelOutboundMessageHandlerAdapter<I> { public abstract class MessageToMessageEncoder<I, O> extends ChannelOutboundMessageHandlerAdapter<I> {
private final Class<?>[] acceptedMsgTypes; private final Class<?>[] acceptedMsgTypes;
protected MessageToMessageEncoder(Class<?>... acceptedMsgTypes) { protected MessageToMessageEncoder(Class<?>... acceptedMsgTypes) {
this.acceptedMsgTypes = CodecUtil.acceptedMessageTypes(acceptedMsgTypes); this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
} }
@Override @Override
@ -39,7 +40,7 @@ public abstract class MessageToMessageEncoder<I, O> extends ChannelOutboundMessa
} }
if (!isEncodable(msg)) { if (!isEncodable(msg)) {
CodecUtil.addToNextOutboundBuffer(ctx, msg); ChannelHandlerUtil.addToNextOutboundBuffer(ctx, msg);
continue; continue;
} }
@ -52,7 +53,7 @@ public abstract class MessageToMessageEncoder<I, O> extends ChannelOutboundMessa
continue; continue;
} }
CodecUtil.unfoldAndAdd(ctx, omsg, false); ChannelHandlerUtil.unfoldAndAdd(ctx, omsg, false);
} catch (Throwable t) { } catch (Throwable t) {
if (t instanceof CodecException) { if (t instanceof CodecException) {
ctx.fireExceptionCaught(t); ctx.fireExceptionCaught(t);
@ -71,7 +72,7 @@ public abstract class MessageToMessageEncoder<I, O> extends ChannelOutboundMessa
* @param msg the message * @param msg the message
*/ */
public boolean isEncodable(Object msg) throws Exception { public boolean isEncodable(Object msg) throws Exception {
return CodecUtil.acceptMessage(acceptedMsgTypes, msg); return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
} }
public abstract O encode(ChannelHandlerContext ctx, I msg) throws Exception; public abstract O encode(ChannelHandlerContext ctx, I msg) throws Exception;

View File

@ -21,6 +21,7 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.util.internal.Signal; import io.netty.util.internal.Signal;
/** /**
@ -368,7 +369,7 @@ public abstract class ReplayingDecoder<O, S> extends ByteToMessageDecoder<O> {
} }
try { try {
if (CodecUtil.unfoldAndAdd(ctx, decodeLast(ctx, replayable), true)) { if (ChannelHandlerUtil.unfoldAndAdd(ctx, decodeLast(ctx, replayable), true)) {
fireInboundBufferUpdated(ctx, in); fireInboundBufferUpdated(ctx, in);
} }
} catch (Signal replay) { } catch (Signal replay) {
@ -432,7 +433,7 @@ public abstract class ReplayingDecoder<O, S> extends ByteToMessageDecoder<O> {
} }
// A successful decode // A successful decode
if (CodecUtil.unfoldAndAdd(ctx, result, true)) { if (ChannelHandlerUtil.unfoldAndAdd(ctx, result, true)) {
decoded = true; decoded = true;
} }
} catch (Throwable t) { } catch (Throwable t) {

View File

@ -13,17 +13,13 @@
* License for the specific language governing permissions and limitations * License for the specific language governing permissions and limitations
* under the License. * under the License.
*/ */
package io.netty.handler.codec; package io.netty.channel;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.NoSuchBufferException;
final class CodecUtil { public final class ChannelHandlerUtil {
static boolean unfoldAndAdd( public static boolean unfoldAndAdd(
ChannelHandlerContext ctx, Object msg, boolean inbound) throws Exception { ChannelHandlerContext ctx, Object msg, boolean inbound) throws Exception {
if (msg == null) { if (msg == null) {
return false; return false;
@ -84,7 +80,7 @@ final class CodecUtil {
private static final Class<?>[] EMPTY_TYPES = new Class<?>[0]; private static final Class<?>[] EMPTY_TYPES = new Class<?>[0];
static Class<?>[] acceptedMessageTypes(Class<?>[] acceptedMsgTypes) { public static Class<?>[] acceptedMessageTypes(Class<?>[] acceptedMsgTypes) {
if (acceptedMsgTypes == null) { if (acceptedMsgTypes == null) {
return EMPTY_TYPES; return EMPTY_TYPES;
} }
@ -103,7 +99,7 @@ final class CodecUtil {
return newAllowedMsgTypes; return newAllowedMsgTypes;
} }
static boolean acceptMessage(Class<?>[] acceptedMsgTypes, Object msg) { public static boolean acceptMessage(Class<?>[] acceptedMsgTypes, Object msg) {
if (acceptedMsgTypes.length == 0) { if (acceptedMsgTypes.length == 0) {
return true; return true;
} }
@ -117,7 +113,7 @@ final class CodecUtil {
return false; return false;
} }
static void addToNextOutboundBuffer(ChannelHandlerContext ctx, Object msg) { public static void addToNextOutboundBuffer(ChannelHandlerContext ctx, Object msg) {
try { try {
ctx.nextOutboundMessageBuffer().add(msg); ctx.nextOutboundMessageBuffer().add(msg);
} catch (NoSuchBufferException e) { } catch (NoSuchBufferException e) {
@ -128,7 +124,7 @@ final class CodecUtil {
} }
} }
static void addToNextInboundBuffer(ChannelHandlerContext ctx, Object msg) { public static void addToNextInboundBuffer(ChannelHandlerContext ctx, Object msg) {
try { try {
ctx.nextInboundMessageBuffer().add(msg); ctx.nextInboundMessageBuffer().add(msg);
} catch (NoSuchBufferException e) { } catch (NoSuchBufferException e) {
@ -139,7 +135,7 @@ final class CodecUtil {
} }
} }
private CodecUtil() { private ChannelHandlerUtil() {
// Unused // Unused
} }
} }

View File

@ -21,33 +21,70 @@ import io.netty.buffer.Unpooled;
public abstract class ChannelInboundMessageHandlerAdapter<I> public abstract class ChannelInboundMessageHandlerAdapter<I>
extends ChannelInboundHandlerAdapter implements ChannelInboundMessageHandler<I> { extends ChannelInboundHandlerAdapter implements ChannelInboundMessageHandler<I> {
private final Class<?>[] acceptedMsgTypes;
protected ChannelInboundMessageHandlerAdapter(Class<?>... acceptedMsgTypes) {
this.acceptedMsgTypes = ChannelHandlerUtil.acceptedMessageTypes(acceptedMsgTypes);
}
@Override @Override
public MessageBuf<I> newInboundBuffer(ChannelHandlerContext ctx) throws Exception { public MessageBuf<I> newInboundBuffer(ChannelHandlerContext ctx) throws Exception {
return Unpooled.messageBuffer(); return Unpooled.messageBuffer();
} }
@SuppressWarnings("unchecked")
@Override @Override
public final void inboundBufferUpdated(ChannelHandlerContext ctx) throws Exception { public final void inboundBufferUpdated(ChannelHandlerContext ctx) throws Exception {
if (!beginMessageReceived(ctx)) { if (!beginMessageReceived(ctx)) {
return; return;
} }
boolean unsupportedFound = false;
try {
MessageBuf<I> in = ctx.inboundMessageBuffer(); MessageBuf<I> in = ctx.inboundMessageBuffer();
for (;;) { for (;;) {
I msg = in.poll(); Object msg = in.poll();
if (msg == null) { if (msg == null) {
break; break;
} }
try { try {
messageReceived(ctx, msg); if (!isSupported(msg)) {
ChannelHandlerUtil.addToNextInboundBuffer(ctx, msg);
unsupportedFound = true;
continue;
}
if (unsupportedFound) {
// the last message were unsupported, but now we received one that is supported.
// So reset the flag and notify the next context
unsupportedFound = false;
ctx.fireInboundBufferUpdated();
}
messageReceived(ctx, (I) msg);
} catch (Throwable t) { } catch (Throwable t) {
ctx.fireExceptionCaught(t); ctx.fireExceptionCaught(t);
} }
} }
} finally {
if (unsupportedFound) {
ctx.fireInboundBufferUpdated();
}
}
endMessageReceived(ctx); endMessageReceived(ctx);
} }
/**
* Returns {@code true} if and only if the specified message can be handled by this handler.
*
* @param msg the message
*/
public boolean isSupported(Object msg) throws Exception {
return ChannelHandlerUtil.acceptMessage(acceptedMsgTypes, msg);
}
/** /**
* Will get notified once {@link #inboundBufferUpdated(ChannelHandlerContext)} was called. * Will get notified once {@link #inboundBufferUpdated(ChannelHandlerContext)} was called.
* *