Make sure we handle outbound messages of type ByteBuf special

This commit is contained in:
Norman Maurer 2013-03-08 14:05:40 +01:00 committed by Trustin Lee
parent 32efba34d8
commit 806e9b1f8c
12 changed files with 133 additions and 36 deletions

View File

@ -18,10 +18,11 @@ package io.netty.handler.codec.base64;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.channel.ChannelOutboundMessageHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.MessageToMessageEncoder;
/**
* Encodes a {@link ByteBuf} into a Base64-encoded {@link ByteBuf}.
@ -38,7 +39,7 @@ import io.netty.handler.codec.MessageToMessageEncoder;
* </pre>
*/
@Sharable
public class Base64Encoder extends MessageToMessageEncoder<ByteBuf> {
public class Base64Encoder extends ChannelOutboundMessageHandlerAdapter<ByteBuf> {
private final boolean breakLines;
private final Base64Dialect dialect;
@ -61,8 +62,9 @@ public class Base64Encoder extends MessageToMessageEncoder<ByteBuf> {
}
@Override
protected Object encode(ChannelHandlerContext ctx,
public void flush(ChannelHandlerContext ctx,
ByteBuf msg) throws Exception {
return Base64.encode(msg, msg.readerIndex(), msg.readableBytes(), breakLines, dialect);
ByteBuf buf = Base64.encode(msg, msg.readerIndex(), msg.readableBytes(), breakLines, dialect);
ChannelHandlerUtil.addToNextOutboundBuffer(ctx, buf);
}
}

View File

@ -15,7 +15,6 @@
*/
package io.netty.handler.codec.bytes;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
@ -50,22 +49,13 @@ import io.netty.handler.codec.LengthFieldPrepender;
*/
public class ByteArrayEncoder extends ChannelOutboundMessageHandlerAdapter<byte[]> {
private final BufType nextBufferType;
public ByteArrayEncoder(BufType nextBufferType) {
if (nextBufferType == null) {
throw new NullPointerException("nextBufferType");
}
this.nextBufferType = nextBufferType;
}
@Override
public void flush(ChannelHandlerContext ctx, byte[] msg) throws Exception {
if (msg.length == 0) {
return;
}
switch (nextBufferType) {
switch (ctx.nextOutboundBufferType()) {
case BYTE:
ctx.nextOutboundByteBuffer().writeBytes(msg);
break;

View File

@ -15,7 +15,6 @@
*/
package io.netty.handler.codec.bytes;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedMessageChannel;
import org.junit.Before;
@ -35,7 +34,7 @@ public class ByteArrayEncoderTest {
@Before
public void setUp() {
ch = new EmbeddedMessageChannel(new ByteArrayEncoder(BufType.MESSAGE));
ch = new EmbeddedMessageChannel(new ByteArrayEncoder());
}
@Test

View File

@ -24,6 +24,7 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.channel.ChannelOutboundMessageHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
@ -308,7 +309,7 @@ public class ChunkedWriteHandler
});
}
} else {
ctx.nextOutboundMessageBuffer().add(currentEvent);
ChannelHandlerUtil.addToNextOutboundBuffer(ctx, currentEvent);
this.currentEvent = null;
}

View File

@ -17,6 +17,7 @@ package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.BufUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
@ -25,8 +26,13 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOperationHandlerAdapter;
import io.netty.channel.ChannelOutboundMessageHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.ByteToByteEncoder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.testsuite.util.BogusSslContextFactory;
import org.junit.Test;
@ -54,6 +60,20 @@ public class SocketSslEchoTest extends AbstractSocketTest {
}
public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, false);
}
@Test
public void testSslEchoWithChunkHandler() throws Throwable {
run();
}
public void testSslEchoWithChunkHandler(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, true);
}
private void testSslEcho0(ServerBootstrap sb, Bootstrap cb, final boolean chunkWriteHandler) throws Throwable {
final EchoHandler sh = new EchoHandler(true);
final EchoHandler ch = new EchoHandler(false);
@ -66,6 +86,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
@Override
public void initChannel(SocketChannel sch) throws Exception {
sch.pipeline().addFirst("ssl", new SslHandler(sse));
if (chunkWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", sh);
}
});
@ -74,6 +97,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
@Override
public void initChannel(SocketChannel sch) throws Exception {
sch.pipeline().addFirst("ssl", new SslHandler(cse));
if (chunkWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", ch);
}
});
@ -81,15 +107,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel();
ChannelFuture hf = cc.pipeline().get(SslHandler.class).handshake();
final ChannelFuture firstByteWriteFuture =
cc.write(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
cc.write(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
final AtomicBoolean firstByteWriteFutureDone = new AtomicBoolean();
hf.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
firstByteWriteFutureDone.set(firstByteWriteFuture.isDone());
}
});
hf.sync();
assertFalse(firstByteWriteFutureDone.get());

View File

@ -16,6 +16,7 @@
package io.netty.channel;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf;
import io.netty.util.Attribute;
@ -225,25 +226,35 @@ public interface ChannelHandlerContext
<T> MessageBuf<T> outboundMessageBuffer();
/**
* Return the {@link ByteBuf} of the next {@link ChannelHandlerContext}.
* Return the {@link ByteBuf} of the next {@link ChannelInboundByteHandler} in the pipeline.
*/
ByteBuf nextInboundByteBuffer();
/**
* Return the {@link MessageBuf} of the next {@link ChannelHandlerContext}.
* Return the {@link MessageBuf} of the next {@link ChannelInboundMessageHandler} in the pipeline.
*/
MessageBuf<Object> nextInboundMessageBuffer();
/**
* Return the {@link ByteBuf} of the next {@link ChannelHandlerContext}.
* Return the {@link ByteBuf} of the next {@link ChannelOutboundByteHandler} in the pipeline.
*/
ByteBuf nextOutboundByteBuffer();
/**
* Return the {@link MessageBuf} of the next {@link ChannelHandlerContext}.
* Return the {@link MessageBuf} of the next {@link ChannelOutboundMessageHandler} in the pipeline.
*/
MessageBuf<Object> nextOutboundMessageBuffer();
/**
* Return the {@link BufType} of the next {@link ChannelInboundHandler} in the pipeline.
*/
BufType nextInboundBufferType();
/**
* Return the {@link BufType} of the next {@link ChannelOutboundHandler} in the pipeline.
*/
BufType nextOutboundBufferType();
@Override
ChannelHandlerContext fireChannelRegistered();

View File

@ -16,6 +16,7 @@
package io.netty.channel;
import io.netty.buffer.BufType;
import io.netty.buffer.BufUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf;
@ -84,8 +85,6 @@ public final class ChannelHandlerUtil {
SingleOutboundMessageHandler<T> handler) throws Exception {
MessageBuf<Object> in = ctx.outboundMessageBuffer();
MessageBuf<Object> out = null;
final int inSize = in.size();
if (inSize == 0) {
ctx.flush(promise);
@ -102,10 +101,7 @@ public final class ChannelHandlerUtil {
}
if (!handler.acceptOutboundMessage(msg)) {
if (out == null) {
out = ctx.nextOutboundMessageBuffer();
}
out.add(msg);
addToNextOutboundBuffer(ctx, msg);
processed ++;
continue;
}
@ -205,6 +201,35 @@ public final class ChannelHandlerUtil {
throw new IllegalStateException();
}
}
/**
* Add the msg to the next outbound buffer in the {@link ChannelPipeline}. This takes special care of
* msgs that are of type {@link ByteBuf}.
*/
public static boolean addToNextOutboundBuffer(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof ByteBuf) {
if (ctx.nextOutboundBufferType() == BufType.BYTE) {
ctx.nextOutboundByteBuffer().writeBytes((ByteBuf) msg);
return true;
}
}
return ctx.nextOutboundMessageBuffer().add(msg);
}
/**
* Add the msg to the next inbound buffer in the {@link ChannelPipeline}. This takes special care of
* msgs that are of type {@link ByteBuf}.
*/
public static boolean addToNextInboundBuffer(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof ByteBuf) {
if (ctx.nextInboundBufferType() == BufType.BYTE) {
ctx.nextInboundByteBuffer().writeBytes((ByteBuf) msg);
return true;
}
}
return ctx.nextInboundMessageBuffer().add(msg);
}
private ChannelHandlerUtil() { }
public interface SingleInboundMessageHandler<T> {

View File

@ -15,10 +15,15 @@
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf;
/**
* Special {@link ChannelInboundHandler} which store the inbound data in a {@link MessageBuf} for futher processing.
*
* If your {@link ChannelOutboundMessageHandler} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextInboundBuffer(ChannelHandlerContext, Object)}.
*/
public interface ChannelInboundMessageHandler<I> extends ChannelInboundHandler {

View File

@ -15,6 +15,7 @@
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerUtil.SingleInboundMessageHandler;
@ -41,6 +42,10 @@ import io.netty.util.internal.TypeParameterMatcher;
* }
* </pre>
*
* If your {@link ChannelInboundMessageHandlerAdapter} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextInboundBuffer(ChannelHandlerContext, Object)}.
*
* @param <I> The type of the messages to handle
*/
public abstract class ChannelInboundMessageHandlerAdapter<I>

View File

@ -15,12 +15,17 @@
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf;
/**
* ChannelOutboundHandler implementation which operates on messages of a specific type
* by pass them in a {@link MessageBuf} and consume then from there.
*
* If your {@link ChannelOutboundMessageHandler} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextOutboundBuffer(ChannelHandlerContext, Object)}.
*
* @param <I> the message type
*/
public interface ChannelOutboundMessageHandler<I> extends ChannelOutboundHandler {

View File

@ -15,6 +15,7 @@
*/
package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerUtil.SingleOutboundMessageHandler;
@ -24,6 +25,10 @@ import io.netty.util.internal.TypeParameterMatcher;
/**
* Abstract base class which handles messages of a specific type.
*
* If your {@link ChannelOutboundMessageHandlerAdapter} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextOutboundBuffer(ChannelHandlerContext, Object)}.
*
* @param <I> The type of the messages to handle
*/
public abstract class ChannelOutboundMessageHandlerAdapter<I>

View File

@ -16,6 +16,7 @@
package io.netty.channel;
import io.netty.buffer.Buf;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.MessageBuf;
@ -1622,6 +1623,34 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements
return ctx;
}
@Override
public BufType nextInboundBufferType() {
DefaultChannelHandlerContext ctx = this;
do {
ctx = ctx.next;
} while (!(ctx.handler() instanceof ChannelInboundHandler));
if (ctx.handler() instanceof ChannelInboundByteHandler) {
return BufType.BYTE;
} else {
return BufType.MESSAGE;
}
}
@Override
public BufType nextOutboundBufferType() {
DefaultChannelHandlerContext ctx = this;
do {
ctx = ctx.prev;
} while (!(ctx.handler() instanceof ChannelOutboundHandler));
if (ctx.handler() instanceof ChannelOutboundByteHandler) {
return BufType.BYTE;
} else {
return BufType.MESSAGE;
}
}
private DefaultChannelHandlerContext findContextOutbound() {
DefaultChannelHandlerContext ctx = this;
do {