Correctly handle ChannelInputShutdownEvent in ReplayingDecoder

Motivation:

b112673554 added ChannelInputShutdownEvent support to ByteToMessageDecoder but missed updating the code for ReplayingDecoder. This has the effect:

- If a ChannelInputShutdownEvent is fired ByteToMessageDecoder (the super-class of ReplayingDecoder) will call the channelInputClosed(...) method which will pass the incorrect buffer to the decode method of ReplayingDecoder.

Modifications:

Share more code between ByteToMessageDEcoder and ReplayingDecoder and so also support ChannelInputShutdownEvent correctly in ReplayingDecoder

Result:

ChannelInputShutdownEvent is corrrectly handle in ReplayingDecoder as well.
This commit is contained in:
Norman Maurer 2016-04-11 09:15:30 +02:00
parent 527c07e41b
commit f94edd2e92
3 changed files with 60 additions and 32 deletions

View File

@ -323,12 +323,7 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
private void channelInputClosed(ChannelHandlerContext ctx, boolean callChannelInactive) throws Exception { private void channelInputClosed(ChannelHandlerContext ctx, boolean callChannelInactive) throws Exception {
RecyclableArrayList out = RecyclableArrayList.newInstance(); RecyclableArrayList out = RecyclableArrayList.newInstance();
try { try {
if (cumulation != null) { channelInputClosed(ctx, out);
callDecode(ctx, cumulation, out);
decodeLast(ctx, cumulation, out);
} else {
decodeLast(ctx, Unpooled.EMPTY_BUFFER, out);
}
} catch (DecoderException e) { } catch (DecoderException e) {
throw e; throw e;
} catch (Exception e) { } catch (Exception e) {
@ -355,6 +350,19 @@ public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter
} }
} }
/**
* Called when the input of the channel was closed which may be because it changed to inactive or because of
* {@link ChannelInputShutdownEvent}.
*/
void channelInputClosed(ChannelHandlerContext ctx, List<Object> out) throws Exception {
if (cumulation != null) {
callDecode(ctx, cumulation, out);
decodeLast(ctx, cumulation, out);
} else {
decodeLast(ctx, Unpooled.EMPTY_BUFFER, out);
}
}
/** /**
* Called once data should be decoded from the given {@link ByteBuf}. This method will call * Called once data should be decoded from the given {@link ByteBuf}. This method will call
* {@link #decode(ChannelHandlerContext, ByteBuf, List)} as long as decoding should take place. * {@link #decode(ChannelHandlerContext, ByteBuf, List)} as long as decoding should take place.

View File

@ -16,6 +16,7 @@
package io.netty.handler.codec; package io.netty.handler.codec;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
@ -322,37 +323,19 @@ public abstract class ReplayingDecoder<S> extends ByteToMessageDecoder {
} }
@Override @Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception { final void channelInputClosed(ChannelHandlerContext ctx, List<Object> out) throws Exception {
RecyclableArrayList out = RecyclableArrayList.newInstance();
try { try {
replayable.terminate(); replayable.terminate();
callDecode(ctx, internalBuffer(), out); if (cumulation != null) {
decodeLast(ctx, replayable, out); callDecode(ctx, internalBuffer(), out);
decodeLast(ctx, replayable, out);
} else {
replayable.setCumulation(Unpooled.EMPTY_BUFFER);
decodeLast(ctx, replayable, out);
}
} catch (Signal replay) { } catch (Signal replay) {
// Ignore // Ignore
replay.expect(REPLAY); replay.expect(REPLAY);
} catch (DecoderException e) {
throw e;
} catch (Exception e) {
throw new DecoderException(e);
} finally {
try {
if (cumulation != null) {
cumulation.release();
cumulation = null;
}
int size = out.size();
if (size > 0) {
fireChannelRead(ctx, out, size);
// Something was read, call fireChannelReadComplete()
ctx.fireChannelReadComplete();
}
ctx.fireChannelInactive();
} finally {
// recycle in all cases
out.recycle();
}
} }
} }

View File

@ -20,11 +20,13 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.ChannelInputShutdownEvent;
import org.junit.Test; import org.junit.Test;
import java.util.List; import java.util.List;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicReference;
import static io.netty.util.ReferenceCountUtil.releaseLater; import static io.netty.util.ReferenceCountUtil.releaseLater;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -225,4 +227,39 @@ public class ReplayingDecoderTest {
assertEquals(3, (int) queue.take()); assertEquals(3, (int) queue.take());
assertTrue(queue.isEmpty()); assertTrue(queue.isEmpty());
} }
@Test
public void testChannelInputShutdownEvent() {
final AtomicReference<Error> error = new AtomicReference<Error>();
EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder<Integer>(0) {
private boolean decoded;
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
if (!(in instanceof ReplayingDecoderByteBuf)) {
error.set(new AssertionError("in must be of type " + ReplayingDecoderByteBuf.class
+ " but was " + in.getClass()));
return;
}
if (!decoded) {
decoded = true;
in.readByte();
state(1);
} else {
// This will throw an ReplayingError
in.skipBytes(Integer.MAX_VALUE);
}
}
});
assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(new byte[] {0, 1})));
channel.pipeline().fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE);
assertFalse(channel.finishAndReleaseAll());
Error err = error.get();
if (err != null) {
throw err;
}
}
} }