From 26595471fb58524dde69b9ecc5d6a31060a5983d Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Mon, 7 Jan 2013 08:44:16 +0100 Subject: [PATCH] Call Freeable.free() if a Freeable message reaches the end of the ChannelPipeline to guard against resource leakage --- .../netty/channel/DefaultChannelPipeline.java | 76 ++++++++++++---- .../channel/DefaultChannelPipelineTest.java | 89 ++++++++++++++++++- 2 files changed, 147 insertions(+), 18 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java index d428e6fd2d..ff18485383 100755 --- a/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -17,6 +17,7 @@ package io.netty.channel; import io.netty.buffer.Buf; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Freeable; import io.netty.buffer.MessageBuf; import io.netty.buffer.Unpooled; import io.netty.logging.InternalLogger; @@ -48,6 +49,8 @@ final class DefaultChannelPipeline implements ChannelPipeline { final DefaultChannelHandlerContext head; private volatile DefaultChannelHandlerContext tail; + private final DefaultChannelHandlerContext tailCtx; + private final Map name2ctx = new HashMap(4); private boolean firedChannelActive; @@ -56,6 +59,8 @@ final class DefaultChannelPipeline implements ChannelPipeline { final Map childExecutors = new IdentityHashMap(); + private static final TailHandler TAIL_HANDLER = new TailHandler(); + public DefaultChannelPipeline(Channel channel) { if (channel == null) { throw new NullPointerException("channel"); @@ -63,9 +68,12 @@ final class DefaultChannelPipeline implements ChannelPipeline { this.channel = channel; HeadHandler headHandler = new HeadHandler(); + tailCtx = new DefaultChannelHandlerContext( + this, null, null, null, generateName(TAIL_HANDLER), TAIL_HANDLER); head = new DefaultChannelHandlerContext( - this, null, null, null, generateName(headHandler), headHandler); - tail = head; + this, null, null, tailCtx, generateName(headHandler), headHandler); + tailCtx.prev = head; + tail = tailCtx; unsafe = channel.unsafe(); } @@ -119,10 +127,12 @@ final class DefaultChannelPipeline implements ChannelPipeline { if (nextCtx != null) { nextCtx.prev = newCtx; } - head.next = newCtx; - if (tail == head) { + if (head.next == tailCtx) { tail = newCtx; + newCtx.next = tailCtx; + tailCtx.prev = newCtx; } + head.next = newCtx; name2ctx.put(name, newCtx); @@ -143,8 +153,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { checkDuplicateName(name); oldTail = tail; - newTail = new DefaultChannelHandlerContext(this, group, oldTail, null, name, handler); - + newTail = new DefaultChannelHandlerContext(this, group, null, null, name, handler); if (!newTail.channel().isRegistered() || newTail.executor().inEventLoop()) { addLast0(name, oldTail, newTail); return this; @@ -171,7 +180,21 @@ final class DefaultChannelPipeline implements ChannelPipeline { final String name, DefaultChannelHandlerContext oldTail, DefaultChannelHandlerContext newTail) { callBeforeAdd(newTail); - oldTail.next = newTail; + DefaultChannelHandlerContext prev = oldTail.prev; + if (oldTail == tailCtx) { + // This is the first handler added + tailCtx.prev = newTail; + newTail.next = tailCtx; + prev.next = newTail; + newTail.prev = prev; + } else { + oldTail.next = newTail; + newTail.prev = oldTail; + + prev.next = oldTail; + oldTail.prev = prev; + } + tail = newTail; name2ctx.put(name, newTail); @@ -361,12 +384,15 @@ final class DefaultChannelPipeline implements ChannelPipeline { Future future; synchronized (this) { + if (ctx == tailCtx) { + throw new NoSuchElementException(); + } if (head == tail) { return null; } else if (ctx == head) { throw new Error(); // Should never happen. } else if (ctx == tail) { - if (head == tail) { + if (tail == tailCtx) { throw new NoSuchElementException(); } @@ -425,7 +451,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { @Override public ChannelHandler removeFirst() { - if (head == tail) { + if (head.next == tailCtx) { throw new NoSuchElementException(); } return remove(head.next).handler(); @@ -436,7 +462,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { final DefaultChannelHandlerContext oldTail; synchronized (this) { - if (head == tail) { + if (tail == tailCtx) { throw new NoSuchElementException(); } oldTail = tail; @@ -464,7 +490,9 @@ final class DefaultChannelPipeline implements ChannelPipeline { private void removeLast0(DefaultChannelHandlerContext oldTail) { callBeforeRemove(oldTail); - oldTail.prev.next = null; + tailCtx.prev = oldTail.prev; + oldTail.prev.next = tailCtx; + tail = oldTail.prev; name2ctx.remove(oldTail.name()); @@ -493,10 +521,13 @@ final class DefaultChannelPipeline implements ChannelPipeline { final DefaultChannelHandlerContext ctx, final String newName, ChannelHandler newHandler) { Future future; synchronized (this) { + if (ctx == tailCtx) { + throw new NoSuchElementException(); + } if (ctx == head) { throw new IllegalArgumentException(); } else if (ctx == tail) { - if (head == tail) { + if (tail == tailCtx) { throw new NoSuchElementException(); } final DefaultChannelHandlerContext oldTail = tail; @@ -688,7 +719,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { @Override public ChannelHandler last() { DefaultChannelHandlerContext last = tail; - if (last == head || last == null) { + if (last == tailCtx || last == null) { return null; } return last.handler(); @@ -743,6 +774,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { DefaultChannelHandlerContext ctx = head.next; for (;;) { + if (ctx == null) { return null; } @@ -791,7 +823,7 @@ final class DefaultChannelPipeline implements ChannelPipeline { Map map = new LinkedHashMap(); DefaultChannelHandlerContext ctx = head.next; for (;;) { - if (ctx == null) { + if (ctx == null || ctx == tailCtx) { return map; } map.put(ctx.name(), ctx.handler()); @@ -1331,7 +1363,6 @@ final class DefaultChannelPipeline implements ChannelPipeline { ctx = ctx.prev; } - if (executor.inEventLoop()) { write0(ctx, message, promise, msgBuf); return promise; @@ -1483,6 +1514,21 @@ final class DefaultChannelPipeline implements ChannelPipeline { } } + private static final class TailHandler extends ChannelInboundMessageHandlerAdapter { + public TailHandler() { + super(Freeable.class); + } + + @Override + protected void messageReceived(ChannelHandlerContext ctx, Freeable msg) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Freeable reached end-of-pipeline, call " + msg + ".free() to" + + " guard against resource leakage!"); + } + msg.free(); + } + } + private final class HeadHandler implements ChannelOutboundHandler { @Override public Buf newOutboundBuffer(ChannelHandlerContext ctx) throws Exception { diff --git a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java index c66976308f..6ee7cb8ae8 100644 --- a/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java +++ b/transport/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -15,13 +15,89 @@ */ package io.netty.channel; + +import io.netty.buffer.Freeable; import io.netty.channel.ChannelHandler.Sharable; import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalEventLoopGroup; import org.junit.Test; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + import static org.junit.Assert.*; public class DefaultChannelPipelineTest { + @Test + public void testFreeCalled() throws InterruptedException{ + final CountDownLatch free = new CountDownLatch(1); + + Freeable holder = new Freeable() { + @Override + public void free() { + free.countDown(); + } + + @Override + public boolean isFreed() { + return free.getCount() == 0; + } + }; + LocalChannel channel = new LocalChannel(); + LocalEventLoopGroup group = new LocalEventLoopGroup(); + group.register(channel).awaitUninterruptibly(); + DefaultChannelPipeline pipeline = new DefaultChannelPipeline(channel); + + StringInboundHandler handler = new StringInboundHandler(); + pipeline.addLast(handler); + pipeline.fireChannelActive(); + pipeline.inboundMessageBuffer().add(holder); + pipeline.fireInboundBufferUpdated(); + + assertTrue(free.await(10, TimeUnit.SECONDS)); + assertTrue(handler.called); + } + + private static final class StringInboundHandler extends ChannelInboundMessageHandlerAdapter { + boolean called; + + public StringInboundHandler() { + super(String.class); + } + + @Override + public boolean isSupported(Object msg) throws Exception { + called = true; + return super.isSupported(msg); + } + + @Override + protected void messageReceived(ChannelHandlerContext ctx, String msg) throws Exception { + fail(); + } + } + + + @Test + public void testRemoveChannelHandler() { + DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel()); + + ChannelHandler handler1 = newHandler(); + ChannelHandler handler2 = newHandler(); + ChannelHandler handler3 = newHandler(); + + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + pipeline.addLast("handler3", handler3); + assertSame(pipeline.get("handler1"), handler1); + assertSame(pipeline.get("handler2"), handler2); + assertSame(pipeline.get("handler3"), handler3); + + pipeline.remove(handler1); + pipeline.remove(handler2); + pipeline.remove(handler3); + } + @Test public void testReplaceChannelHandler() { DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel()); @@ -107,8 +183,11 @@ public class DefaultChannelPipelineTest { while (ctx != null) { int i = toInt(ctx.name()); int j = next(ctx); - - assertTrue(i < j); + if (j != -1) { + assertTrue(i < j); + } else { + assertNull(ctx.next.next); + } ctx = ctx.next; } @@ -125,7 +204,11 @@ public class DefaultChannelPipelineTest { } private static int toInt(String name) { - return Integer.parseInt(name); + try { + return Integer.parseInt(name); + } catch (NumberFormatException e) { + return -1; + } } private static void verifyContextNumber(DefaultChannelPipeline pipeline, int expectedNumber) {