Call Freeable.free() if a Freeable message reaches the end of the ChannelPipeline to guard against resource leakage

This commit is contained in:
Norman Maurer 2013-01-07 08:44:16 +01:00
parent cf2fbf7883
commit 26595471fb
2 changed files with 147 additions and 18 deletions

View File

@ -17,6 +17,7 @@ package io.netty.channel;
import io.netty.buffer.Buf; import io.netty.buffer.Buf;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Freeable;
import io.netty.buffer.MessageBuf; import io.netty.buffer.MessageBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.logging.InternalLogger; import io.netty.logging.InternalLogger;
@ -48,6 +49,8 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final DefaultChannelHandlerContext head; final DefaultChannelHandlerContext head;
private volatile DefaultChannelHandlerContext tail; private volatile DefaultChannelHandlerContext tail;
private final DefaultChannelHandlerContext tailCtx;
private final Map<String, DefaultChannelHandlerContext> name2ctx = private final Map<String, DefaultChannelHandlerContext> name2ctx =
new HashMap<String, DefaultChannelHandlerContext>(4); new HashMap<String, DefaultChannelHandlerContext>(4);
private boolean firedChannelActive; private boolean firedChannelActive;
@ -56,6 +59,8 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final Map<EventExecutorGroup, EventExecutor> childExecutors = final Map<EventExecutorGroup, EventExecutor> childExecutors =
new IdentityHashMap<EventExecutorGroup, EventExecutor>(); new IdentityHashMap<EventExecutorGroup, EventExecutor>();
private static final TailHandler TAIL_HANDLER = new TailHandler();
public DefaultChannelPipeline(Channel channel) { public DefaultChannelPipeline(Channel channel) {
if (channel == null) { if (channel == null) {
throw new NullPointerException("channel"); throw new NullPointerException("channel");
@ -63,9 +68,12 @@ final class DefaultChannelPipeline implements ChannelPipeline {
this.channel = channel; this.channel = channel;
HeadHandler headHandler = new HeadHandler(); HeadHandler headHandler = new HeadHandler();
tailCtx = new DefaultChannelHandlerContext(
this, null, null, null, generateName(TAIL_HANDLER), TAIL_HANDLER);
head = new DefaultChannelHandlerContext( head = new DefaultChannelHandlerContext(
this, null, null, null, generateName(headHandler), headHandler); this, null, null, tailCtx, generateName(headHandler), headHandler);
tail = head; tailCtx.prev = head;
tail = tailCtx;
unsafe = channel.unsafe(); unsafe = channel.unsafe();
} }
@ -119,10 +127,12 @@ final class DefaultChannelPipeline implements ChannelPipeline {
if (nextCtx != null) { if (nextCtx != null) {
nextCtx.prev = newCtx; nextCtx.prev = newCtx;
} }
head.next = newCtx; if (head.next == tailCtx) {
if (tail == head) {
tail = newCtx; tail = newCtx;
newCtx.next = tailCtx;
tailCtx.prev = newCtx;
} }
head.next = newCtx;
name2ctx.put(name, newCtx); name2ctx.put(name, newCtx);
@ -143,8 +153,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
checkDuplicateName(name); checkDuplicateName(name);
oldTail = tail; oldTail = tail;
newTail = new DefaultChannelHandlerContext(this, group, oldTail, null, name, handler); newTail = new DefaultChannelHandlerContext(this, group, null, null, name, handler);
if (!newTail.channel().isRegistered() || newTail.executor().inEventLoop()) { if (!newTail.channel().isRegistered() || newTail.executor().inEventLoop()) {
addLast0(name, oldTail, newTail); addLast0(name, oldTail, newTail);
return this; return this;
@ -171,7 +180,21 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final String name, DefaultChannelHandlerContext oldTail, DefaultChannelHandlerContext newTail) { final String name, DefaultChannelHandlerContext oldTail, DefaultChannelHandlerContext newTail) {
callBeforeAdd(newTail); callBeforeAdd(newTail);
oldTail.next = newTail; DefaultChannelHandlerContext prev = oldTail.prev;
if (oldTail == tailCtx) {
// This is the first handler added
tailCtx.prev = newTail;
newTail.next = tailCtx;
prev.next = newTail;
newTail.prev = prev;
} else {
oldTail.next = newTail;
newTail.prev = oldTail;
prev.next = oldTail;
oldTail.prev = prev;
}
tail = newTail; tail = newTail;
name2ctx.put(name, newTail); name2ctx.put(name, newTail);
@ -361,12 +384,15 @@ final class DefaultChannelPipeline implements ChannelPipeline {
Future<?> future; Future<?> future;
synchronized (this) { synchronized (this) {
if (ctx == tailCtx) {
throw new NoSuchElementException();
}
if (head == tail) { if (head == tail) {
return null; return null;
} else if (ctx == head) { } else if (ctx == head) {
throw new Error(); // Should never happen. throw new Error(); // Should never happen.
} else if (ctx == tail) { } else if (ctx == tail) {
if (head == tail) { if (tail == tailCtx) {
throw new NoSuchElementException(); throw new NoSuchElementException();
} }
@ -425,7 +451,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
@Override @Override
public ChannelHandler removeFirst() { public ChannelHandler removeFirst() {
if (head == tail) { if (head.next == tailCtx) {
throw new NoSuchElementException(); throw new NoSuchElementException();
} }
return remove(head.next).handler(); return remove(head.next).handler();
@ -436,7 +462,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final DefaultChannelHandlerContext oldTail; final DefaultChannelHandlerContext oldTail;
synchronized (this) { synchronized (this) {
if (head == tail) { if (tail == tailCtx) {
throw new NoSuchElementException(); throw new NoSuchElementException();
} }
oldTail = tail; oldTail = tail;
@ -464,7 +490,9 @@ final class DefaultChannelPipeline implements ChannelPipeline {
private void removeLast0(DefaultChannelHandlerContext oldTail) { private void removeLast0(DefaultChannelHandlerContext oldTail) {
callBeforeRemove(oldTail); callBeforeRemove(oldTail);
oldTail.prev.next = null; tailCtx.prev = oldTail.prev;
oldTail.prev.next = tailCtx;
tail = oldTail.prev; tail = oldTail.prev;
name2ctx.remove(oldTail.name()); name2ctx.remove(oldTail.name());
@ -493,10 +521,13 @@ final class DefaultChannelPipeline implements ChannelPipeline {
final DefaultChannelHandlerContext ctx, final String newName, ChannelHandler newHandler) { final DefaultChannelHandlerContext ctx, final String newName, ChannelHandler newHandler) {
Future<?> future; Future<?> future;
synchronized (this) { synchronized (this) {
if (ctx == tailCtx) {
throw new NoSuchElementException();
}
if (ctx == head) { if (ctx == head) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
} else if (ctx == tail) { } else if (ctx == tail) {
if (head == tail) { if (tail == tailCtx) {
throw new NoSuchElementException(); throw new NoSuchElementException();
} }
final DefaultChannelHandlerContext oldTail = tail; final DefaultChannelHandlerContext oldTail = tail;
@ -688,7 +719,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
@Override @Override
public ChannelHandler last() { public ChannelHandler last() {
DefaultChannelHandlerContext last = tail; DefaultChannelHandlerContext last = tail;
if (last == head || last == null) { if (last == tailCtx || last == null) {
return null; return null;
} }
return last.handler(); return last.handler();
@ -743,6 +774,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
DefaultChannelHandlerContext ctx = head.next; DefaultChannelHandlerContext ctx = head.next;
for (;;) { for (;;) {
if (ctx == null) { if (ctx == null) {
return null; return null;
} }
@ -791,7 +823,7 @@ final class DefaultChannelPipeline implements ChannelPipeline {
Map<String, ChannelHandler> map = new LinkedHashMap<String, ChannelHandler>(); Map<String, ChannelHandler> map = new LinkedHashMap<String, ChannelHandler>();
DefaultChannelHandlerContext ctx = head.next; DefaultChannelHandlerContext ctx = head.next;
for (;;) { for (;;) {
if (ctx == null) { if (ctx == null || ctx == tailCtx) {
return map; return map;
} }
map.put(ctx.name(), ctx.handler()); map.put(ctx.name(), ctx.handler());
@ -1331,7 +1363,6 @@ final class DefaultChannelPipeline implements ChannelPipeline {
ctx = ctx.prev; ctx = ctx.prev;
} }
if (executor.inEventLoop()) { if (executor.inEventLoop()) {
write0(ctx, message, promise, msgBuf); write0(ctx, message, promise, msgBuf);
return promise; return promise;
@ -1483,6 +1514,21 @@ final class DefaultChannelPipeline implements ChannelPipeline {
} }
} }
private static final class TailHandler extends ChannelInboundMessageHandlerAdapter<Freeable> {
public TailHandler() {
super(Freeable.class);
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, Freeable msg) throws Exception {
if (logger.isWarnEnabled()) {
logger.warn("Freeable reached end-of-pipeline, call " + msg + ".free() to" +
" guard against resource leakage!");
}
msg.free();
}
}
private final class HeadHandler implements ChannelOutboundHandler { private final class HeadHandler implements ChannelOutboundHandler {
@Override @Override
public Buf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception { public Buf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception {

View File

@ -15,13 +15,89 @@
*/ */
package io.netty.channel; package io.netty.channel;
import io.netty.buffer.Freeable;
import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalEventLoopGroup;
import org.junit.Test; import org.junit.Test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.*; import static org.junit.Assert.*;
public class DefaultChannelPipelineTest { public class DefaultChannelPipelineTest {
@Test
public void testFreeCalled() throws InterruptedException{
final CountDownLatch free = new CountDownLatch(1);
Freeable holder = new Freeable() {
@Override
public void free() {
free.countDown();
}
@Override
public boolean isFreed() {
return free.getCount() == 0;
}
};
LocalChannel channel = new LocalChannel();
LocalEventLoopGroup group = new LocalEventLoopGroup();
group.register(channel).awaitUninterruptibly();
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel);
StringInboundHandler handler = new StringInboundHandler();
pipeline.addLast(handler);
pipeline.fireChannelActive();
pipeline.inboundMessageBuffer().add(holder);
pipeline.fireInboundBufferUpdated();
assertTrue(free.await(10, TimeUnit.SECONDS));
assertTrue(handler.called);
}
private static final class StringInboundHandler extends ChannelInboundMessageHandlerAdapter<String> {
boolean called;
public StringInboundHandler() {
super(String.class);
}
@Override
public boolean isSupported(Object msg) throws Exception {
called = true;
return super.isSupported(msg);
}
@Override
protected void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception {
fail();
}
}
@Test
public void testRemoveChannelHandler() {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel());
ChannelHandler handler1 = newHandler();
ChannelHandler handler2 = newHandler();
ChannelHandler handler3 = newHandler();
pipeline.addLast("handler1", handler1);
pipeline.addLast("handler2", handler2);
pipeline.addLast("handler3", handler3);
assertSame(pipeline.get("handler1"), handler1);
assertSame(pipeline.get("handler2"), handler2);
assertSame(pipeline.get("handler3"), handler3);
pipeline.remove(handler1);
pipeline.remove(handler2);
pipeline.remove(handler3);
}
@Test @Test
public void testReplaceChannelHandler() { public void testReplaceChannelHandler() {
DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel()); DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel());
@ -107,8 +183,11 @@ public class DefaultChannelPipelineTest {
while (ctx != null) { while (ctx != null) {
int i = toInt(ctx.name()); int i = toInt(ctx.name());
int j = next(ctx); int j = next(ctx);
if (j != -1) {
assertTrue(i < j); assertTrue(i < j);
} else {
assertNull(ctx.next.next);
}
ctx = ctx.next; ctx = ctx.next;
} }
@ -125,7 +204,11 @@ public class DefaultChannelPipelineTest {
} }
private static int toInt(String name) { private static int toInt(String name) {
return Integer.parseInt(name); try {
return Integer.parseInt(name);
} catch (NumberFormatException e) {
return -1;
}
} }
private static void verifyContextNumber(DefaultChannelPipeline pipeline, int expectedNumber) { private static void verifyContextNumber(DefaultChannelPipeline pipeline, int expectedNumber) {