diff --git a/codec/src/main/java/io/netty/handler/codec/embedder/AbstractCodecEmbedder.java b/codec/src/main/java/io/netty/handler/codec/embedder/AbstractCodecEmbedder.java index 0608789da6..665f3a4d9e 100644 --- a/codec/src/main/java/io/netty/handler/codec/embedder/AbstractCodecEmbedder.java +++ b/codec/src/main/java/io/netty/handler/codec/embedder/AbstractCodecEmbedder.java @@ -15,12 +15,17 @@ */ package io.netty.handler.codec.embedder; +import io.netty.buffer.ChannelBuffer; import io.netty.channel.Channel; import io.netty.channel.ChannelBufferHolder; import io.netty.channel.ChannelBufferHolders; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoop; import io.netty.handler.codec.CodecException; @@ -45,8 +50,62 @@ abstract class AbstractCodecEmbedder implements CodecEmbedder { * handlers. */ protected AbstractCodecEmbedder(ChannelHandler... handlers) { - channel.pipeline().addLast(handlers); - channel.pipeline().addLast(new LastHandler()); + if (handlers == null) { + throw new NullPointerException("handlers"); + } + + int inboundType = 0; // 0 - unknown, 1 - stream, 2 - message + int outboundType = 0; + int nHandlers = 0; + ChannelPipeline p = channel.pipeline(); + for (ChannelHandler h: handlers) { + if (h == null) { + break; + } + nHandlers ++; + + p.addLast(h); + ChannelHandlerContext ctx = p.context(h); + if (inboundType == 0) { + if (ctx.canHandleInbound()) { + ChannelInboundHandlerContext inCtx = (ChannelInboundHandlerContext) ctx; + if (inCtx.inbound().hasByteBuffer()) { + inboundType = 1; + } else { + inboundType = 2; + } + } + } + if (ctx.canHandleOutbound()) { + ChannelOutboundHandlerContext outCtx = (ChannelOutboundHandlerContext) ctx; + if (outCtx.outbound().hasByteBuffer()) { + outboundType = 1; + } else { + outboundType = 2; + } + } + } + + if (nHandlers == 0) { + throw new IllegalArgumentException("handlers is empty."); + } + + if (inboundType == 0 && outboundType == 0) { + throw new IllegalArgumentException("handlers does not provide any buffers."); + } + + p.addFirst(StreamToChannelBufferEncoder.INSTANCE); + + if (inboundType == 1) { + p.addFirst(ChannelBufferToStreamDecoder.INSTANCE); + } + + if (outboundType == 1) { + p.addLast(ChannelBufferToStreamEncoder.INSTANCE); + } + + p.addLast(new LastHandler()); + loop.register(channel); } @@ -166,4 +225,85 @@ abstract class AbstractCodecEmbedder implements CodecEmbedder { productQueue.add(cause); } } + + @ChannelHandler.Sharable + private static final class StreamToChannelBufferEncoder extends ChannelOutboundHandlerAdapter { + + static final StreamToChannelBufferEncoder INSTANCE = new StreamToChannelBufferEncoder(); + + @Override + public ChannelBufferHolder newOutboundBuffer( + ChannelOutboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.byteBuffer(); + } + + @Override + public void flush(ChannelOutboundHandlerContext ctx, ChannelFuture future) throws Exception { + ChannelBuffer in = ctx.outbound().byteBuffer(); + if (in.readable()) { + ctx.nextOutboundMessageBuffer().add(in.readBytes(in.readableBytes())); + } + ctx.flush(future); + } + } + + @ChannelHandler.Sharable + private static final class ChannelBufferToStreamDecoder extends ChannelInboundHandlerAdapter { + + static final ChannelBufferToStreamDecoder INSTANCE = new ChannelBufferToStreamDecoder(); + + @Override + public ChannelBufferHolder newInboundBuffer( + ChannelInboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.messageBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelInboundHandlerContext ctx) throws Exception { + Queue in = ctx.inbound().messageBuffer(); + for (;;) { + Object msg = in.poll(); + if (msg == null) { + break; + } + if (msg instanceof ChannelBuffer) { + ChannelBuffer buf = (ChannelBuffer) msg; + ctx.nextInboundByteBuffer().writeBytes(buf, buf.readerIndex(), buf.readableBytes()); + } else { + ctx.nextInboundMessageBuffer().add(msg); + } + } + ctx.fireInboundBufferUpdated(); + } + } + + @ChannelHandler.Sharable + private static final class ChannelBufferToStreamEncoder extends ChannelOutboundHandlerAdapter { + + static final ChannelBufferToStreamEncoder INSTANCE = new ChannelBufferToStreamEncoder(); + + @Override + public ChannelBufferHolder newOutboundBuffer( + ChannelOutboundHandlerContext ctx) throws Exception { + return ChannelBufferHolders.messageBuffer(); + } + + @Override + public void flush(ChannelOutboundHandlerContext ctx, ChannelFuture future) throws Exception { + Queue in = ctx.outbound().messageBuffer(); + for (;;) { + Object msg = in.poll(); + if (msg == null) { + break; + } + if (msg instanceof ChannelBuffer) { + ChannelBuffer buf = (ChannelBuffer) msg; + ctx.nextOutboundByteBuffer().writeBytes(buf, buf.readerIndex(), buf.readableBytes()); + } else { + ctx.nextOutboundMessageBuffer().add(msg); + } + } + ctx.flush(future); + } + } } diff --git a/codec/src/main/java/io/netty/handler/codec/embedder/EmbeddedChannel.java b/codec/src/main/java/io/netty/handler/codec/embedder/EmbeddedChannel.java index 89c0d21a0e..9a91dd1ce2 100644 --- a/codec/src/main/java/io/netty/handler/codec/embedder/EmbeddedChannel.java +++ b/codec/src/main/java/io/netty/handler/codec/embedder/EmbeddedChannel.java @@ -15,12 +15,12 @@ */ package io.netty.handler.codec.embedder; -import io.netty.buffer.ChannelBuffer; import io.netty.channel.AbstractChannel; import io.netty.channel.ChannelBufferHolder; import io.netty.channel.ChannelBufferHolders; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelType; import io.netty.channel.DefaultChannelConfig; import io.netty.channel.EventLoop; @@ -36,10 +36,15 @@ class EmbeddedChannel extends AbstractChannel { private int state; // 0 = OPEN, 1 = ACTIVE, 2 = CLOSED EmbeddedChannel(Queue productQueue) { - super(null, null, ChannelBufferHolders.catchAllBuffer()); + super(null, null, ChannelBufferHolders.messageBuffer()); this.productQueue = productQueue; } + @Override + public ChannelType type() { + return ChannelType.MESSAGE; + } + @Override public ChannelConfig config() { return config; @@ -97,13 +102,6 @@ class EmbeddedChannel extends AbstractChannel { @Override protected void doFlush(ChannelBufferHolder buf) throws Exception { - ChannelBuffer byteBuf = buf.byteBuffer(); - int byteBufLen = byteBuf.readableBytes(); - if (byteBufLen > 0) { - productQueue.add(byteBuf.readBytes(byteBufLen)); - byteBuf.clear(); - } - Queue msgBuf = buf.messageBuffer(); if (!msgBuf.isEmpty()) { productQueue.addAll(msgBuf);