[#6023] Ensure LastInboundHandler correctly handle different events / data

Motivation:

LastInboundHandler maintains 2 queues which may contain the same data and tries to match these up when you read elements out of it. Because of this it can happen that you remove an element only out of one queue and so double free stuff later.

Modifications:

Just use one "queue" to store things.

Result:

Not possible to only remove things from one queue and so get into trouble later when release everything that sits in the handler.
This commit is contained in:
Norman Maurer 2016-11-18 11:39:21 +00:00
parent 0b3122d8ff
commit e5070b7266
2 changed files with 48 additions and 36 deletions

View File

@ -154,6 +154,7 @@ public class Http2FrameCodecTest {
assertEquals(expected, inboundData); assertEquals(expected, inboundData);
assertEquals(1, inboundData.refCnt()); assertEquals(1, inboundData.refCnt());
expected.release(); expected.release();
inboundData.release();
assertNull(inboundHandler.readInbound()); assertNull(inboundHandler.readInbound());
inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, false).streamId(stream.id())); inboundHandler.writeOutbound(new DefaultHttp2HeadersFrame(response, false).streamId(stream.id()));
@ -198,25 +199,25 @@ public class Http2FrameCodecTest {
assertNotNull(stream); assertNotNull(stream);
assertEquals(State.OPEN, stream.state()); assertEquals(State.OPEN, stream.state());
Http2StreamActiveEvent activeEvent = inboundHandler.readInboundMessagesAndEvents(); Http2StreamActiveEvent activeEvent = inboundHandler.readInboundMessageOrUserEvent();
assertNotNull(activeEvent); assertNotNull(activeEvent);
assertEquals(stream.id(), activeEvent.streamId()); assertEquals(stream.id(), activeEvent.streamId());
Http2HeadersFrame expectedHeaders = new DefaultHttp2HeadersFrame(request, false, 31).streamId(stream.id()); Http2HeadersFrame expectedHeaders = new DefaultHttp2HeadersFrame(request, false, 31).streamId(stream.id());
Http2HeadersFrame actualHeaders = inboundHandler.readInboundMessagesAndEvents(); Http2HeadersFrame actualHeaders = inboundHandler.readInboundMessageOrUserEvent();
assertEquals(expectedHeaders, actualHeaders); assertEquals(expectedHeaders, actualHeaders);
frameListener.onRstStreamRead(http2HandlerCtx, 3, Http2Error.NO_ERROR.code()); frameListener.onRstStreamRead(http2HandlerCtx, 3, Http2Error.NO_ERROR.code());
Http2ResetFrame expectedRst = new DefaultHttp2ResetFrame(Http2Error.NO_ERROR).streamId(stream.id()); Http2ResetFrame expectedRst = new DefaultHttp2ResetFrame(Http2Error.NO_ERROR).streamId(stream.id());
Http2ResetFrame actualRst = inboundHandler.readInboundMessagesAndEvents(); Http2ResetFrame actualRst = inboundHandler.readInboundMessageOrUserEvent();
assertEquals(expectedRst, actualRst); assertEquals(expectedRst, actualRst);
Http2StreamClosedEvent closedEvent = inboundHandler.readInboundMessagesAndEvents(); Http2StreamClosedEvent closedEvent = inboundHandler.readInboundMessageOrUserEvent();
assertNotNull(closedEvent); assertNotNull(closedEvent);
assertEquals(stream.id(), closedEvent.streamId()); assertEquals(stream.id(), closedEvent.streamId());
assertNull(inboundHandler.readInboundMessagesAndEvents()); assertNull(inboundHandler.readInboundMessageOrUserEvent());
} }
@Test @Test
@ -370,14 +371,14 @@ public class Http2FrameCodecTest {
Http2Stream stream = framingCodec.connectionHandler().connection().stream(3); Http2Stream stream = framingCodec.connectionHandler().connection().stream(3);
assertNotNull(stream); assertNotNull(stream);
Http2StreamActiveEvent activeEvent = inboundHandler.readInboundMessagesAndEvents(); Http2StreamActiveEvent activeEvent = inboundHandler.readInboundMessageOrUserEvent();
assertNotNull(activeEvent); assertNotNull(activeEvent);
assertEquals(stream.id(), activeEvent.streamId()); assertEquals(stream.id(), activeEvent.streamId());
StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo");
framingCodec.connectionHandler().onError(http2HandlerCtx, streamEx); framingCodec.connectionHandler().onError(http2HandlerCtx, streamEx);
Http2HeadersFrame headersFrame = inboundHandler.readInboundMessagesAndEvents(); Http2HeadersFrame headersFrame = inboundHandler.readInboundMessageOrUserEvent();
assertNotNull(headersFrame); assertNotNull(headersFrame);
try { try {
@ -387,11 +388,11 @@ public class Http2FrameCodecTest {
assertEquals(streamEx, e); assertEquals(streamEx, e);
} }
Http2StreamClosedEvent closedEvent = inboundHandler.readInboundMessagesAndEvents(); Http2StreamClosedEvent closedEvent = inboundHandler.readInboundMessageOrUserEvent();
assertNotNull(closedEvent); assertNotNull(closedEvent);
assertEquals(stream.id(), closedEvent.streamId()); assertEquals(stream.id(), closedEvent.streamId());
assertNull(inboundHandler.readInboundMessagesAndEvents()); assertNull(inboundHandler.readInboundMessageOrUserEvent());
} }
@Test @Test

View File

@ -23,8 +23,8 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import java.util.ArrayDeque; import java.util.ArrayList;
import java.util.Queue; import java.util.List;
import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.LockSupport;
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
@ -33,9 +33,7 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS;
* Channel handler that allows to easily access inbound messages. * Channel handler that allows to easily access inbound messages.
*/ */
public class LastInboundHandler extends ChannelDuplexHandler { public class LastInboundHandler extends ChannelDuplexHandler {
private final Queue<Object> inboundMessages = new ArrayDeque<Object>(); private final List<Object> queue = new ArrayList<Object>();
private final Queue<Object> userEvents = new ArrayDeque<Object>();
private final Queue<Object> inboundMessagesAndUserEvents = new ArrayDeque<Object>();
private Throwable lastException; private Throwable lastException;
private ChannelHandlerContext ctx; private ChannelHandlerContext ctx;
private boolean channelActive; private boolean channelActive;
@ -69,14 +67,12 @@ public class LastInboundHandler extends ChannelDuplexHandler {
@Override @Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
inboundMessages.add(msg); queue.add(msg);
inboundMessagesAndUserEvents.add(msg);
} }
@Override @Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
userEvents.add(evt); queue.add(new UserEvent(evt));
inboundMessagesAndUserEvents.add(evt);
} }
@Override @Override
@ -99,11 +95,15 @@ public class LastInboundHandler extends ChannelDuplexHandler {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T readInbound() { public <T> T readInbound() {
T message = (T) inboundMessages.poll(); for (int i = 0; i < queue.size(); i++) {
if (message == inboundMessagesAndUserEvents.peek()) { Object o = queue.get(i);
inboundMessagesAndUserEvents.poll(); if (!(o instanceof UserEvent)) {
queue.remove(i);
return (T) o;
}
} }
return message;
return null;
} }
public <T> T blockingReadInbound() { public <T> T blockingReadInbound() {
@ -116,27 +116,30 @@ public class LastInboundHandler extends ChannelDuplexHandler {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T readUserEvent() { public <T> T readUserEvent() {
T message = (T) userEvents.poll(); for (int i = 0; i < queue.size(); i++) {
if (message == inboundMessagesAndUserEvents.peek()) { Object o = queue.get(i);
inboundMessagesAndUserEvents.poll(); if (o instanceof UserEvent) {
queue.remove(i);
return (T) ((UserEvent) o).evt;
}
} }
return message;
return null;
} }
/** /**
* Useful to test order of events and messages. * Useful to test order of events and messages.
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T readInboundMessagesAndEvents() { public <T> T readInboundMessageOrUserEvent() {
T message = (T) inboundMessagesAndUserEvents.poll(); if (queue.isEmpty()) {
return null;
if (message == inboundMessages.peek()) {
inboundMessages.poll();
} else if (message == userEvents.peek()) {
userEvents.poll();
} }
Object o = queue.remove(0);
return message; if (o instanceof UserEvent) {
return (T) ((UserEvent) o).evt;
}
return (T) o;
} }
public void writeOutbound(Object... msgs) throws Exception { public void writeOutbound(Object... msgs) throws Exception {
@ -153,7 +156,7 @@ public class LastInboundHandler extends ChannelDuplexHandler {
public void finishAndReleaseAll() throws Exception { public void finishAndReleaseAll() throws Exception {
checkException(); checkException();
Object o; Object o;
while ((o = readInboundMessagesAndEvents()) != null) { while ((o = readInboundMessageOrUserEvent()) != null) {
ReferenceCountUtil.release(o); ReferenceCountUtil.release(o);
} }
} }
@ -161,4 +164,12 @@ public class LastInboundHandler extends ChannelDuplexHandler {
public Channel channel() { public Channel channel() {
return ctx.channel(); return ctx.channel();
} }
private static final class UserEvent {
private final Object evt;
UserEvent(Object evt) {
this.evt = evt;
}
}
} }