Call Freeable.free() if a Freeable message reaches the end of the ChannelPipeline to guard against resource leakage

This commit is contained in:
Norman Maurer 2013-01-07 08:44:16 +01:00
parent cf2fbf7883
commit 26595471fb
2 changed files with 147 additions and 18 deletions

View File

@ -17,6 +17,7 @@ package io.netty.channel;
import io.netty.buffer.Buf;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Freeable;
import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled;
import io.netty.logging.InternalLogger;
@ -48,6 +49,8 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final DefaultChannelHandlerContext head;
private volatile DefaultChannelHandlerContext tail;
private final DefaultChannelHandlerContext tailCtx;
private final Map<String, DefaultChannelHandlerContext> name2ctx =
new HashMap<String, DefaultChannelHandlerContext>(4);
private boolean firedChannelActive;
@ -56,6 +59,8 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final Map<EventExecutorGroup, EventExecutor> childExecutors =
new IdentityHashMap<EventExecutorGroup, EventExecutor>();
private static final TailHandler TAIL_HANDLER = new TailHandler();
public DefaultChannelPipeline(Channel channel) {
if (channel == null) {
throw new NullPointerException("channel");
@ -63,9 +68,12 @@ final class DefaultChannelPipeline implements ChannelPipeline {
this.channel = channel;
HeadHandler headHandler = new HeadHandler();
tailCtx = new DefaultChannelHandlerContext(
this, null, null, null, generateName(TAIL_HANDLER), TAIL_HANDLER);
head = new DefaultChannelHandlerContext(
this, null, null, null, generateName(headHandler), headHandler);
tail = head;
this, null, null, tailCtx, generateName(headHandler), headHandler);
tailCtx.prev = head;
tail = tailCtx;
unsafe = channel.unsafe();
}
@ -119,10 +127,12 @@ final class DefaultChannelPipeline implements ChannelPipeline {
if (nextCtx != null) {
nextCtx.prev = newCtx;
}
head.next = newCtx;
if (tail == head) {
if (head.next == tailCtx) {
tail = newCtx;
newCtx.next = tailCtx;
tailCtx.prev = newCtx;
}
head.next = newCtx;
name2ctx.put(name, newCtx);
@ -143,8 +153,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
checkDuplicateName(name);
oldTail = tail;
newTail = new DefaultChannelHandlerContext(this, group, oldTail, null, name, handler);
newTail = new DefaultChannelHandlerContext(this, group, null, null, name, handler);
if (!newTail.channel().isRegistered() || newTail.executor().inEventLoop()) {
addLast0(name, oldTail, newTail);
return this;
@ -171,7 +180,21 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final String name, DefaultChannelHandlerContext oldTail, DefaultChannelHandlerContext newTail) {
callBeforeAdd(newTail);
DefaultChannelHandlerContext prev = oldTail.prev;
if (oldTail == tailCtx) {
// This is the first handler added
tailCtx.prev = newTail;
newTail.next = tailCtx;
prev.next = newTail;
newTail.prev = prev;
} else {
oldTail.next = newTail;
newTail.prev = oldTail;
prev.next = oldTail;
oldTail.prev = prev;
}
tail = newTail;
name2ctx.put(name, newTail);
@ -361,12 +384,15 @@ final class DefaultChannelPipeline implements ChannelPipeline {
Future<?> future;
synchronized (this) {
if (ctx == tailCtx) {
throw new NoSuchElementException();
}
if (head == tail) {
return null;
} else if (ctx == head) {
throw new Error(); // Should never happen.
} else if (ctx == tail) {
if (head == tail) {
if (tail == tailCtx) {
throw new NoSuchElementException();
}
@ -425,7 +451,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
@Override
public ChannelHandler removeFirst() {
if (head == tail) {
if (head.next == tailCtx) {
throw new NoSuchElementException();
}
return remove(head.next).handler();
@ -436,7 +462,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final DefaultChannelHandlerContext oldTail;
synchronized (this) {
if (head == tail) {
if (tail == tailCtx) {
throw new NoSuchElementException();
}
oldTail = tail;
@ -464,7 +490,9 @@ final class DefaultChannelPipeline implements ChannelPipeline {
private void removeLast0(DefaultChannelHandlerContext oldTail) {
callBeforeRemove(oldTail);
oldTail.prev.next = null;
tailCtx.prev = oldTail.prev;
oldTail.prev.next = tailCtx;
tail = oldTail.prev;
name2ctx.remove(oldTail.name());
@ -493,10 +521,13 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final DefaultChannelHandlerContext ctx, final String newName, ChannelHandler newHandler) {
Future<?> future;
synchronized (this) {
if (ctx == tailCtx) {
throw new NoSuchElementException();
}
if (ctx == head) {
throw new IllegalArgumentException();
} else if (ctx == tail) {
if (head == tail) {
if (tail == tailCtx) {
throw new NoSuchElementException();
}
final DefaultChannelHandlerContext oldTail = tail;
@ -688,7 +719,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
@Override
public ChannelHandler last() {
DefaultChannelHandlerContext last = tail;
if (last == head || last == null) {
if (last == tailCtx || last == null) {
return null;
}
return last.handler();
@ -743,6 +774,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
DefaultChannelHandlerContext ctx = head.next;
for (;;) {
if (ctx == null) {
return null;
}
@ -791,7 +823,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
Map<String, ChannelHandler> map = new LinkedHashMap<String, ChannelHandler>();
DefaultChannelHandlerContext ctx = head.next;
for (;;) {
if (ctx == null) {
if (ctx == null || ctx == tailCtx) {
return map;
}
map.put(ctx.name(), ctx.handler());
@ -1331,7 +1363,6 @@ final class DefaultChannelPipeline implements ChannelPipeline {
ctx = ctx.prev;
}
if (executor.inEventLoop()) {
write0(ctx, message, promise, msgBuf);
return promise;
@ -1483,6 +1514,21 @@ final class DefaultChannelPipeline implements ChannelPipeline {
}
}
private static final class TailHandler extends ChannelInboundMessageHandlerAdapter<Freeable> {
public TailHandler() {
super(Freeable.class);
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Freeable msg) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Freeable reached end-of-pipeline, call " + msg + ".free() to" +
" guard against resource leakage!");
}
msg.free();
}
}
private final class HeadHandler implements ChannelOutboundHandler {
@Override
public Buf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception {

View File

@ -15,13 +15,89 @@
*/
package io.netty.channel;
import io.netty.buffer.Freeable;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalEventLoopGroup;
import org.junit.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.*;
public class DefaultChannelPipelineTest {
@Test
public void testFreeCalled() throws InterruptedException{
final CountDownLatch free = new CountDownLatch(1);
Freeable holder = new Freeable() {
@Override
public void free() {
free.countDown();
}
@Override
public boolean isFreed() {
return free.getCount() == 0;
}
};
LocalChannel channel = new LocalChannel();
LocalEventLoopGroup group = new LocalEventLoopGroup();
group.register(channel).awaitUninterruptibly();
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel);
StringInboundHandler handler = new StringInboundHandler();
pipeline.addLast(handler);
pipeline.fireChannelActive();
pipeline.inboundMessageBuffer().add(holder);
pipeline.fireInboundBufferUpdated();
assertTrue(free.await(10, TimeUnit.SECONDS));
assertTrue(handler.called);
}
private static final class StringInboundHandler extends ChannelInboundMessageHandlerAdapter<String> {
boolean called;
public StringInboundHandler() {
super(String.class);
}
@Override
public boolean isSupported(Object msg) throws Exception {
called = true;
return super.isSupported(msg);
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
fail();
}
}
@Test
public void testRemoveChannelHandler() {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel());
ChannelHandler handler1 = newHandler();
ChannelHandler handler2 = newHandler();
ChannelHandler handler3 = newHandler();
pipeline.addLast("handler1", handler1);
pipeline.addLast("handler2", handler2);
pipeline.addLast("handler3", handler3);
assertSame(pipeline.get("handler1"), handler1);
assertSame(pipeline.get("handler2"), handler2);
assertSame(pipeline.get("handler3"), handler3);
pipeline.remove(handler1);
pipeline.remove(handler2);
pipeline.remove(handler3);
}
@Test
public void testReplaceChannelHandler() {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel());
@ -107,8 +183,11 @@ public class DefaultChannelPipelineTest {
while (ctx != null) {
int i = toInt(ctx.name());
int j = next(ctx);
if (j != -1) {
assertTrue(i < j);
} else {
assertNull(ctx.next.next);
}
ctx = ctx.next;
}
@ -125,7 +204,11 @@ public class DefaultChannelPipelineTest {
}
private static int toInt(String name) {
try {
return Integer.parseInt(name);
} catch (NumberFormatException e) {
return -1;
}
}
private static void verifyContextNumber(DefaultChannelPipeline pipeline, int expectedNumber) {