From 786fdbd6e05d3cd3bd8e43f4f47cc5d0f1f9da9a Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Mon, 18 Nov 2013 15:59:44 +0900 Subject: [PATCH] Bring back ChannelGroup.find(id) --- .../io/netty/channel/group/ChannelGroup.java | 8 ++ .../channel/group/DefaultChannelGroup.java | 77 +++++++++++-------- 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/group/ChannelGroup.java b/transport/src/main/java/io/netty/channel/group/ChannelGroup.java index e5b70c260d..0ec17b03f7 100644 --- a/transport/src/main/java/io/netty/channel/group/ChannelGroup.java +++ b/transport/src/main/java/io/netty/channel/group/ChannelGroup.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBufHolder; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelId; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ServerChannel; import io.netty.util.CharsetUtil; @@ -99,6 +100,13 @@ public interface ChannelGroup extends Set, Comparable { */ String name(); + /** + * Returns the {@link Channel} which has the specified {@link ChannelId}. + * + * @return the matching {@link Channel} if found. {@code null} otherwise. + */ + Channel find(ChannelId id); + /** * Writes the specified {@code message} to all {@link Channel}s in this * group. If the specified {@code message} is an instance of diff --git a/transport/src/main/java/io/netty/channel/group/DefaultChannelGroup.java b/transport/src/main/java/io/netty/channel/group/DefaultChannelGroup.java index 3bd75fd64e..51afe51254 100644 --- a/transport/src/main/java/io/netty/channel/group/DefaultChannelGroup.java +++ b/transport/src/main/java/io/netty/channel/group/DefaultChannelGroup.java @@ -20,10 +20,11 @@ import io.netty.buffer.ByteBufHolder; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelId; import io.netty.channel.ServerChannel; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; -import io.netty.util.internal.ConcurrentSet; +import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import java.util.AbstractSet; @@ -32,6 +33,7 @@ import java.util.Collection; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; /** @@ -42,8 +44,8 @@ public class DefaultChannelGroup extends AbstractSet implements Channel private static final AtomicInteger nextId = new AtomicInteger(); private final String name; private final EventExecutor executor; - private final ConcurrentSet serverChannels = new ConcurrentSet(); - private final ConcurrentSet nonServerChannels = new ConcurrentSet(); + private final ConcurrentMap serverChannels = PlatformDependent.newConcurrentHashMap(); + private final ConcurrentMap nonServerChannels = PlatformDependent.newConcurrentHashMap(); private final ChannelFutureListener remover = new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { @@ -77,6 +79,16 @@ public class DefaultChannelGroup extends AbstractSet implements Channel return name; } + @Override + public Channel find(ChannelId id) { + Channel c = nonServerChannels.get(id); + if (c != null) { + return c; + } else { + return serverChannels.get(id); + } + } + @Override public boolean isEmpty() { return nonServerChannels.isEmpty() && serverChannels.isEmpty(); @@ -92,9 +104,9 @@ public class DefaultChannelGroup extends AbstractSet implements Channel if (o instanceof Channel) { Channel c = (Channel) o; if (o instanceof ServerChannel) { - return serverChannels.contains(c); + return serverChannels.containsValue(c); } else { - return nonServerChannels.contains(c); + return nonServerChannels.containsValue(c); } } else { return false; @@ -103,10 +115,10 @@ public class DefaultChannelGroup extends AbstractSet implements Channel @Override public boolean add(Channel channel) { - ConcurrentSet set = + ConcurrentMap map = channel instanceof ServerChannel? serverChannels : nonServerChannels; - boolean added = set.add(channel); + boolean added = map.putIfAbsent(channel.id(), channel) == null; if (added) { channel.closeFuture().addListener(remover); } @@ -115,17 +127,22 @@ public class DefaultChannelGroup extends AbstractSet implements Channel @Override public boolean remove(Object o) { - if (!(o instanceof Channel)) { - return false; + Channel c = null; + if (o instanceof ChannelId) { + c = nonServerChannels.remove(o); + if (c == null) { + c = serverChannels.remove(o); + } + } else if (o instanceof Channel) { + c = (Channel) o; + if (c instanceof ServerChannel) { + c = serverChannels.remove(c.id()); + } else { + c = nonServerChannels.remove(c.id()); + } } - boolean removed; - Channel c = (Channel) o; - if (c instanceof ServerChannel) { - removed = serverChannels.remove(c); - } else { - removed = nonServerChannels.remove(c); - } - if (!removed) { + + if (c == null) { return false; } @@ -142,23 +159,23 @@ public class DefaultChannelGroup extends AbstractSet implements Channel @Override public Iterator iterator() { return new CombinedIterator( - serverChannels.iterator(), - nonServerChannels.iterator()); + serverChannels.values().iterator(), + nonServerChannels.values().iterator()); } @Override public Object[] toArray() { Collection channels = new ArrayList(size()); - channels.addAll(serverChannels); - channels.addAll(nonServerChannels); + channels.addAll(serverChannels.values()); + channels.addAll(nonServerChannels.values()); return channels.toArray(); } @Override public T[] toArray(T[] a) { Collection channels = new ArrayList(size()); - channels.addAll(serverChannels); - channels.addAll(nonServerChannels); + channels.addAll(serverChannels.values()); + channels.addAll(nonServerChannels.values()); return channels.toArray(a); } @@ -199,7 +216,7 @@ public class DefaultChannelGroup extends AbstractSet implements Channel } Map futures = new LinkedHashMap(size()); - for (Channel c: nonServerChannels) { + for (Channel c: nonServerChannels.values()) { if (matcher.matches(c)) { futures.put(c, c.write(safeDuplicate(message))); } @@ -228,12 +245,12 @@ public class DefaultChannelGroup extends AbstractSet implements Channel Map futures = new LinkedHashMap(size()); - for (Channel c: serverChannels) { + for (Channel c: serverChannels.values()) { if (matcher.matches(c)) { futures.put(c, c.disconnect()); } } - for (Channel c: nonServerChannels) { + for (Channel c: nonServerChannels.values()) { if (matcher.matches(c)) { futures.put(c, c.disconnect()); } @@ -251,12 +268,12 @@ public class DefaultChannelGroup extends AbstractSet implements Channel Map futures = new LinkedHashMap(size()); - for (Channel c: serverChannels) { + for (Channel c: serverChannels.values()) { if (matcher.matches(c)) { futures.put(c, c.close()); } } - for (Channel c: nonServerChannels) { + for (Channel c: nonServerChannels.values()) { if (matcher.matches(c)) { futures.put(c, c.close()); } @@ -267,7 +284,7 @@ public class DefaultChannelGroup extends AbstractSet implements Channel @Override public ChannelGroup flush(ChannelMatcher matcher) { - for (Channel c: nonServerChannels) { + for (Channel c: nonServerChannels.values()) { if (matcher.matches(c)) { c.flush(); } @@ -283,7 +300,7 @@ public class DefaultChannelGroup extends AbstractSet implements Channel Map futures = new LinkedHashMap(size()); - for (Channel c: nonServerChannels) { + for (Channel c: nonServerChannels.values()) { if (matcher.matches(c)) { futures.put(c, c.writeAndFlush(safeDuplicate(message))); }