Make sure we handle outbound messages of type ByteBuf special

This commit is contained in:
Norman Maurer 2013-03-08 14:05:40 +01:00 committed by Trustin Lee
parent 32efba34d8
commit 806e9b1f8c
12 changed files with 133 additions and 36 deletions

View File

@ -18,10 +18,11 @@ package io.netty.handler.codec.base64;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.channel.ChannelOutboundMessageHandlerAdapter;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DelimiterBasedFrameDecoder; import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters; import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.MessageToMessageEncoder;
/** /**
* Encodes a {@link ByteBuf} into a Base64-encoded {@link ByteBuf}. * Encodes a {@link ByteBuf} into a Base64-encoded {@link ByteBuf}.
@ -38,7 +39,7 @@ import io.netty.handler.codec.MessageToMessageEncoder;
* </pre> * </pre>
*/ */
@Sharable @Sharable
public class Base64Encoder extends MessageToMessageEncoder<ByteBuf> { public class Base64Encoder extends ChannelOutboundMessageHandlerAdapter<ByteBuf> {
private final boolean breakLines; private final boolean breakLines;
private final Base64Dialect dialect; private final Base64Dialect dialect;
@ -61,8 +62,9 @@ public class Base64Encoder extends MessageToMessageEncoder<ByteBuf> {
} }
@Override @Override
protected Object encode(ChannelHandlerContext ctx, public void flush(ChannelHandlerContext ctx,
ByteBuf msg) throws Exception { ByteBuf msg) throws Exception {
return Base64.encode(msg, msg.readerIndex(), msg.readableBytes(), breakLines, dialect); ByteBuf buf = Base64.encode(msg, msg.readerIndex(), msg.readableBytes(), breakLines, dialect);
ChannelHandlerUtil.addToNextOutboundBuffer(ctx, buf);
} }
} }

View File

@ -15,7 +15,6 @@
*/ */
package io.netty.handler.codec.bytes; package io.netty.handler.codec.bytes;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
@ -50,22 +49,13 @@ import io.netty.handler.codec.LengthFieldPrepender;
*/ */
public class ByteArrayEncoder extends ChannelOutboundMessageHandlerAdapter<byte[]> { public class ByteArrayEncoder extends ChannelOutboundMessageHandlerAdapter<byte[]> {
private final BufType nextBufferType;
public ByteArrayEncoder(BufType nextBufferType) {
if (nextBufferType == null) {
throw new NullPointerException("nextBufferType");
}
this.nextBufferType = nextBufferType;
}
@Override @Override
public void flush(ChannelHandlerContext ctx, byte[] msg) throws Exception { public void flush(ChannelHandlerContext ctx, byte[] msg) throws Exception {
if (msg.length == 0) { if (msg.length == 0) {
return; return;
} }
switch (nextBufferType) { switch (ctx.nextOutboundBufferType()) {
case BYTE: case BYTE:
ctx.nextOutboundByteBuffer().writeBytes(msg); ctx.nextOutboundByteBuffer().writeBytes(msg);
break; break;

View File

@ -15,7 +15,6 @@
*/ */
package io.netty.handler.codec.bytes; package io.netty.handler.codec.bytes;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.embedded.EmbeddedMessageChannel; import io.netty.channel.embedded.EmbeddedMessageChannel;
import org.junit.Before; import org.junit.Before;
@ -35,7 +34,7 @@ public class ByteArrayEncoderTest {
@Before @Before
public void setUp() { public void setUp() {
ch = new EmbeddedMessageChannel(new ByteArrayEncoder(BufType.MESSAGE)); ch = new EmbeddedMessageChannel(new ByteArrayEncoder());
} }
@Test @Test

View File

@ -24,6 +24,7 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelHandlerUtil;
import io.netty.channel.ChannelOutboundMessageHandler; import io.netty.channel.ChannelOutboundMessageHandler;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
@ -308,7 +309,7 @@ public class ChunkedWriteHandler
}); });
} }
} else { } else {
ctx.nextOutboundMessageBuffer().add(currentEvent); ChannelHandlerUtil.addToNextOutboundBuffer(ctx, currentEvent);
this.currentEvent = null; this.currentEvent = null;
} }

View File

@ -17,6 +17,7 @@ package io.netty.testsuite.transport.socket;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.BufUtil;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
@ -25,8 +26,13 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter; import io.netty.channel.ChannelInboundByteHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOperationHandlerAdapter;
import io.netty.channel.ChannelOutboundMessageHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.ByteToByteEncoder;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.testsuite.util.BogusSslContextFactory; import io.netty.testsuite.util.BogusSslContextFactory;
import org.junit.Test; import org.junit.Test;
@ -54,6 +60,20 @@ public class SocketSslEchoTest extends AbstractSocketTest {
} }
public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable { public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, false);
}
@Test
public void testSslEchoWithChunkHandler() throws Throwable {
run();
}
public void testSslEchoWithChunkHandler(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, true);
}
private void testSslEcho0(ServerBootstrap sb, Bootstrap cb, final boolean chunkWriteHandler) throws Throwable {
final EchoHandler sh = new EchoHandler(true); final EchoHandler sh = new EchoHandler(true);
final EchoHandler ch = new EchoHandler(false); final EchoHandler ch = new EchoHandler(false);
@ -66,6 +86,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
@Override @Override
public void initChannel(SocketChannel sch) throws Exception { public void initChannel(SocketChannel sch) throws Exception {
sch.pipeline().addFirst("ssl", new SslHandler(sse)); sch.pipeline().addFirst("ssl", new SslHandler(sse));
if (chunkWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", sh); sch.pipeline().addLast("handler", sh);
} }
}); });
@ -74,6 +97,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
@Override @Override
public void initChannel(SocketChannel sch) throws Exception { public void initChannel(SocketChannel sch) throws Exception {
sch.pipeline().addFirst("ssl", new SslHandler(cse)); sch.pipeline().addFirst("ssl", new SslHandler(cse));
if (chunkWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler());
}
sch.pipeline().addLast("handler", ch); sch.pipeline().addLast("handler", ch);
} }
}); });
@ -81,15 +107,9 @@ public class SocketSslEchoTest extends AbstractSocketTest {
Channel sc = sb.bind().sync().channel(); Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel(); Channel cc = cb.connect().sync().channel();
ChannelFuture hf = cc.pipeline().get(SslHandler.class).handshake(); ChannelFuture hf = cc.pipeline().get(SslHandler.class).handshake();
final ChannelFuture firstByteWriteFuture = cc.write(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
cc.write(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
final AtomicBoolean firstByteWriteFutureDone = new AtomicBoolean(); final AtomicBoolean firstByteWriteFutureDone = new AtomicBoolean();
hf.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
firstByteWriteFutureDone.set(firstByteWriteFuture.isDone());
}
});
hf.sync(); hf.sync();
assertFalse(firstByteWriteFutureDone.get()); assertFalse(firstByteWriteFutureDone.get());

View File

@ -16,6 +16,7 @@
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
import io.netty.util.Attribute; import io.netty.util.Attribute;
@ -225,25 +226,35 @@ public interface ChannelHandlerContext
<T> MessageBuf<T> outboundMessageBuffer(); <T> MessageBuf<T> outboundMessageBuffer();
/** /**
* Return the {@link ByteBuf} of the next {@link ChannelHandlerContext}. * Return the {@link ByteBuf} of the next {@link ChannelInboundByteHandler} in the pipeline.
*/ */
ByteBuf nextInboundByteBuffer(); ByteBuf nextInboundByteBuffer();
/** /**
* Return the {@link MessageBuf} of the next {@link ChannelHandlerContext}. * Return the {@link MessageBuf} of the next {@link ChannelInboundMessageHandler} in the pipeline.
*/ */
MessageBuf<Object> nextInboundMessageBuffer(); MessageBuf<Object> nextInboundMessageBuffer();
/** /**
* Return the {@link ByteBuf} of the next {@link ChannelHandlerContext}. * Return the {@link ByteBuf} of the next {@link ChannelOutboundByteHandler} in the pipeline.
*/ */
ByteBuf nextOutboundByteBuffer(); ByteBuf nextOutboundByteBuffer();
/** /**
* Return the {@link MessageBuf} of the next {@link ChannelHandlerContext}. * Return the {@link MessageBuf} of the next {@link ChannelOutboundMessageHandler} in the pipeline.
*/ */
MessageBuf<Object> nextOutboundMessageBuffer(); MessageBuf<Object> nextOutboundMessageBuffer();
/**
* Return the {@link BufType} of the next {@link ChannelInboundHandler} in the pipeline.
*/
BufType nextInboundBufferType();
/**
* Return the {@link BufType} of the next {@link ChannelOutboundHandler} in the pipeline.
*/
BufType nextOutboundBufferType();
@Override @Override
ChannelHandlerContext fireChannelRegistered(); ChannelHandlerContext fireChannelRegistered();

View File

@ -16,6 +16,7 @@
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.BufType;
import io.netty.buffer.BufUtil; import io.netty.buffer.BufUtil;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
@ -84,8 +85,6 @@ public final class ChannelHandlerUtil {
SingleOutboundMessageHandler<T> handler) throws Exception { SingleOutboundMessageHandler<T> handler) throws Exception {
MessageBuf<Object> in = ctx.outboundMessageBuffer(); MessageBuf<Object> in = ctx.outboundMessageBuffer();
MessageBuf<Object> out = null;
final int inSize = in.size(); final int inSize = in.size();
if (inSize == 0) { if (inSize == 0) {
ctx.flush(promise); ctx.flush(promise);
@ -102,10 +101,7 @@ public final class ChannelHandlerUtil {
} }
if (!handler.acceptOutboundMessage(msg)) { if (!handler.acceptOutboundMessage(msg)) {
if (out == null) { addToNextOutboundBuffer(ctx, msg);
out = ctx.nextOutboundMessageBuffer();
}
out.add(msg);
processed ++; processed ++;
continue; continue;
} }
@ -205,6 +201,35 @@ public final class ChannelHandlerUtil {
throw new IllegalStateException(); throw new IllegalStateException();
} }
} }
/**
* Add the msg to the next outbound buffer in the {@link ChannelPipeline}. This takes special care of
* msgs that are of type {@link ByteBuf}.
*/
public static boolean addToNextOutboundBuffer(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof ByteBuf) {
if (ctx.nextOutboundBufferType() == BufType.BYTE) {
ctx.nextOutboundByteBuffer().writeBytes((ByteBuf) msg);
return true;
}
}
return ctx.nextOutboundMessageBuffer().add(msg);
}
/**
* Add the msg to the next inbound buffer in the {@link ChannelPipeline}. This takes special care of
* msgs that are of type {@link ByteBuf}.
*/
public static boolean addToNextInboundBuffer(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof ByteBuf) {
if (ctx.nextInboundBufferType() == BufType.BYTE) {
ctx.nextInboundByteBuffer().writeBytes((ByteBuf) msg);
return true;
}
}
return ctx.nextInboundMessageBuffer().add(msg);
}
private ChannelHandlerUtil() { } private ChannelHandlerUtil() { }
public interface SingleInboundMessageHandler<T> { public interface SingleInboundMessageHandler<T> {

View File

@ -15,10 +15,15 @@
*/ */
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
/** /**
* Special {@link ChannelInboundHandler} which store the inbound data in a {@link MessageBuf} for futher processing. * Special {@link ChannelInboundHandler} which store the inbound data in a {@link MessageBuf} for futher processing.
*
* If your {@link ChannelOutboundMessageHandler} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextInboundBuffer(ChannelHandlerContext, Object)}.
*/ */
public interface ChannelInboundMessageHandler<I> extends ChannelInboundHandler { public interface ChannelInboundMessageHandler<I> extends ChannelInboundHandler {

View File

@ -15,6 +15,7 @@
*/ */
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerUtil.SingleInboundMessageHandler; import io.netty.channel.ChannelHandlerUtil.SingleInboundMessageHandler;
@ -41,6 +42,10 @@ import io.netty.util.internal.TypeParameterMatcher;
* } * }
* </pre> * </pre>
* *
* If your {@link ChannelInboundMessageHandlerAdapter} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextInboundBuffer(ChannelHandlerContext, Object)}.
*
* @param <I> The type of the messages to handle * @param <I> The type of the messages to handle
*/ */
public abstract class ChannelInboundMessageHandlerAdapter<I> public abstract class ChannelInboundMessageHandlerAdapter<I>

View File

@ -15,12 +15,17 @@
*/ */
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
/** /**
* ChannelOutboundHandler implementation which operates on messages of a specific type * ChannelOutboundHandler implementation which operates on messages of a specific type
* by pass them in a {@link MessageBuf} and consume then from there. * by pass them in a {@link MessageBuf} and consume then from there.
* *
* If your {@link ChannelOutboundMessageHandler} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextOutboundBuffer(ChannelHandlerContext, Object)}.
*
* @param <I> the message type * @param <I> the message type
*/ */
public interface ChannelOutboundMessageHandler<I> extends ChannelOutboundHandler { public interface ChannelOutboundMessageHandler<I> extends ChannelOutboundHandler {

View File

@ -15,6 +15,7 @@
*/ */
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerUtil.SingleOutboundMessageHandler; import io.netty.channel.ChannelHandlerUtil.SingleOutboundMessageHandler;
@ -24,6 +25,10 @@ import io.netty.util.internal.TypeParameterMatcher;
/** /**
* Abstract base class which handles messages of a specific type. * Abstract base class which handles messages of a specific type.
* *
* If your {@link ChannelOutboundMessageHandlerAdapter} handles messages of type {@link ByteBuf} or {@link Object}
* and you want to add a {@link ByteBuf} to the next buffer in the {@link ChannelPipeline} use
* {@link ChannelHandlerUtil#addToNextOutboundBuffer(ChannelHandlerContext, Object)}.
*
* @param <I> The type of the messages to handle * @param <I> The type of the messages to handle
*/ */
public abstract class ChannelOutboundMessageHandlerAdapter<I> public abstract class ChannelOutboundMessageHandlerAdapter<I>

View File

@ -16,6 +16,7 @@
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.Buf; import io.netty.buffer.Buf;
import io.netty.buffer.BufType;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
@ -1622,6 +1623,34 @@ final class DefaultChannelHandlerContext extends DefaultAttributeMap implements
return ctx; return ctx;
} }
@Override
public BufType nextInboundBufferType() {
DefaultChannelHandlerContext ctx = this;
do {
ctx = ctx.next;
} while (!(ctx.handler() instanceof ChannelInboundHandler));
if (ctx.handler() instanceof ChannelInboundByteHandler) {
return BufType.BYTE;
} else {
return BufType.MESSAGE;
}
}
@Override
public BufType nextOutboundBufferType() {
DefaultChannelHandlerContext ctx = this;
do {
ctx = ctx.prev;
} while (!(ctx.handler() instanceof ChannelOutboundHandler));
if (ctx.handler() instanceof ChannelOutboundByteHandler) {
return BufType.BYTE;
} else {
return BufType.MESSAGE;
}
}
private DefaultChannelHandlerContext findContextOutbound() { private DefaultChannelHandlerContext findContextOutbound() {
DefaultChannelHandlerContext ctx = this; DefaultChannelHandlerContext ctx = this;
do { do {