Bring back ChannelGroup.find(id)

This commit is contained in:
Trustin Lee 2013-11-18 15:59:44 +09:00
parent 2235873537
commit 786fdbd6e0
2 changed files with 55 additions and 30 deletions

View File

@ -21,6 +21,7 @@ import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ServerChannel; import io.netty.channel.ServerChannel;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
@ -99,6 +100,13 @@ public interface ChannelGroup extends Set<Channel>, Comparable<ChannelGroup> {
*/ */
String name(); String name();
/**
* Returns the {@link Channel} which has the specified {@link ChannelId}.
*
* @return the matching {@link Channel} if found. {@code null} otherwise.
*/
Channel find(ChannelId id);
/** /**
* Writes the specified {@code message} to all {@link Channel}s in this * Writes the specified {@code message} to all {@link Channel}s in this
* group. If the specified {@code message} is an instance of * group. If the specified {@code message} is an instance of

View File

@ -20,10 +20,11 @@ import io.netty.buffer.ByteBufHolder;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelId;
import io.netty.channel.ServerChannel; import io.netty.channel.ServerChannel;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.EventExecutor;
import io.netty.util.internal.ConcurrentSet; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import java.util.AbstractSet; import java.util.AbstractSet;
@ -32,6 +33,7 @@ import java.util.Collection;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
/** /**
@ -42,8 +44,8 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
private static final AtomicInteger nextId = new AtomicInteger(); private static final AtomicInteger nextId = new AtomicInteger();
private final String name; private final String name;
private final EventExecutor executor; private final EventExecutor executor;
private final ConcurrentSet<Channel> serverChannels = new ConcurrentSet<Channel>(); private final ConcurrentMap<ChannelId, Channel> serverChannels = PlatformDependent.newConcurrentHashMap();
private final ConcurrentSet<Channel> nonServerChannels = new ConcurrentSet<Channel>(); private final ConcurrentMap<ChannelId, Channel> nonServerChannels = PlatformDependent.newConcurrentHashMap();
private final ChannelFutureListener remover = new ChannelFutureListener() { private final ChannelFutureListener remover = new ChannelFutureListener() {
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
@ -77,6 +79,16 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
return name; return name;
} }
@Override
public Channel find(ChannelId id) {
Channel c = nonServerChannels.get(id);
if (c != null) {
return c;
} else {
return serverChannels.get(id);
}
}
@Override @Override
public boolean isEmpty() { public boolean isEmpty() {
return nonServerChannels.isEmpty() && serverChannels.isEmpty(); return nonServerChannels.isEmpty() && serverChannels.isEmpty();
@ -92,9 +104,9 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
if (o instanceof Channel) { if (o instanceof Channel) {
Channel c = (Channel) o; Channel c = (Channel) o;
if (o instanceof ServerChannel) { if (o instanceof ServerChannel) {
return serverChannels.contains(c); return serverChannels.containsValue(c);
} else { } else {
return nonServerChannels.contains(c); return nonServerChannels.containsValue(c);
} }
} else { } else {
return false; return false;
@ -103,10 +115,10 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
@Override @Override
public boolean add(Channel channel) { public boolean add(Channel channel) {
ConcurrentSet<Channel> set = ConcurrentMap<ChannelId, Channel> map =
channel instanceof ServerChannel? serverChannels : nonServerChannels; channel instanceof ServerChannel? serverChannels : nonServerChannels;
boolean added = set.add(channel); boolean added = map.putIfAbsent(channel.id(), channel) == null;
if (added) { if (added) {
channel.closeFuture().addListener(remover); channel.closeFuture().addListener(remover);
} }
@ -115,17 +127,22 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
@Override @Override
public boolean remove(Object o) { public boolean remove(Object o) {
if (!(o instanceof Channel)) { Channel c = null;
return false; if (o instanceof ChannelId) {
c = nonServerChannels.remove(o);
if (c == null) {
c = serverChannels.remove(o);
} }
boolean removed; } else if (o instanceof Channel) {
Channel c = (Channel) o; c = (Channel) o;
if (c instanceof ServerChannel) { if (c instanceof ServerChannel) {
removed = serverChannels.remove(c); c = serverChannels.remove(c.id());
} else { } else {
removed = nonServerChannels.remove(c); c = nonServerChannels.remove(c.id());
} }
if (!removed) { }
if (c == null) {
return false; return false;
} }
@ -142,23 +159,23 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
@Override @Override
public Iterator<Channel> iterator() { public Iterator<Channel> iterator() {
return new CombinedIterator<Channel>( return new CombinedIterator<Channel>(
serverChannels.iterator(), serverChannels.values().iterator(),
nonServerChannels.iterator()); nonServerChannels.values().iterator());
} }
@Override @Override
public Object[] toArray() { public Object[] toArray() {
Collection<Channel> channels = new ArrayList<Channel>(size()); Collection<Channel> channels = new ArrayList<Channel>(size());
channels.addAll(serverChannels); channels.addAll(serverChannels.values());
channels.addAll(nonServerChannels); channels.addAll(nonServerChannels.values());
return channels.toArray(); return channels.toArray();
} }
@Override @Override
public <T> T[] toArray(T[] a) { public <T> T[] toArray(T[] a) {
Collection<Channel> channels = new ArrayList<Channel>(size()); Collection<Channel> channels = new ArrayList<Channel>(size());
channels.addAll(serverChannels); channels.addAll(serverChannels.values());
channels.addAll(nonServerChannels); channels.addAll(nonServerChannels.values());
return channels.toArray(a); return channels.toArray(a);
} }
@ -199,7 +216,7 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
} }
Map<Channel, ChannelFuture> futures = new LinkedHashMap<Channel, ChannelFuture>(size()); Map<Channel, ChannelFuture> futures = new LinkedHashMap<Channel, ChannelFuture>(size());
for (Channel c: nonServerChannels) { for (Channel c: nonServerChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
futures.put(c, c.write(safeDuplicate(message))); futures.put(c, c.write(safeDuplicate(message)));
} }
@ -228,12 +245,12 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
Map<Channel, ChannelFuture> futures = Map<Channel, ChannelFuture> futures =
new LinkedHashMap<Channel, ChannelFuture>(size()); new LinkedHashMap<Channel, ChannelFuture>(size());
for (Channel c: serverChannels) { for (Channel c: serverChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
futures.put(c, c.disconnect()); futures.put(c, c.disconnect());
} }
} }
for (Channel c: nonServerChannels) { for (Channel c: nonServerChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
futures.put(c, c.disconnect()); futures.put(c, c.disconnect());
} }
@ -251,12 +268,12 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
Map<Channel, ChannelFuture> futures = Map<Channel, ChannelFuture> futures =
new LinkedHashMap<Channel, ChannelFuture>(size()); new LinkedHashMap<Channel, ChannelFuture>(size());
for (Channel c: serverChannels) { for (Channel c: serverChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
futures.put(c, c.close()); futures.put(c, c.close());
} }
} }
for (Channel c: nonServerChannels) { for (Channel c: nonServerChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
futures.put(c, c.close()); futures.put(c, c.close());
} }
@ -267,7 +284,7 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
@Override @Override
public ChannelGroup flush(ChannelMatcher matcher) { public ChannelGroup flush(ChannelMatcher matcher) {
for (Channel c: nonServerChannels) { for (Channel c: nonServerChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
c.flush(); c.flush();
} }
@ -283,7 +300,7 @@ public class DefaultChannelGroup extends AbstractSet<Channel> implements Channel
Map<Channel, ChannelFuture> futures = new LinkedHashMap<Channel, ChannelFuture>(size()); Map<Channel, ChannelFuture> futures = new LinkedHashMap<Channel, ChannelFuture>(size());
for (Channel c: nonServerChannels) { for (Channel c: nonServerChannels.values()) {
if (matcher.matches(c)) { if (matcher.matches(c)) {
futures.put(c, c.writeAndFlush(safeDuplicate(message))); futures.put(c, c.writeAndFlush(safeDuplicate(message)));
} }