From 72f9f502bbb0ab5445c280aabb7e759e74eaff16 Mon Sep 17 00:00:00 2001 From: norman Date: Mon, 2 Apr 2012 11:07:11 +0200 Subject: [PATCH] Add support for UDP multicast in NIO. See #216 Add some javadocs. See #216 Use the correct key to lookup MembershipKey. See #216 --- .../socket/nio/NioDatagramChannel.java | 160 +++++++++++++++++- 1 file changed, 151 insertions(+), 9 deletions(-) diff --git a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java index f916f7f088..d199d135df 100644 --- a/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java +++ b/transport/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java @@ -22,13 +22,21 @@ import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelSink; import io.netty.channel.socket.DatagramChannelConfig; +import io.netty.util.internal.DetectionUtil; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.NetworkInterface; import java.net.SocketAddress; +import java.net.SocketException; import java.nio.channels.DatagramChannel; +import java.nio.channels.MembershipKey; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; /** * Provides an NIO based {@link io.netty.channel.socket.DatagramChannel}. @@ -39,7 +47,7 @@ public final class NioDatagramChannel extends AbstractNioChannel implements io.n * The {@link DatagramChannelConfig}. */ private final NioDatagramChannelConfig config; - + private Map> memberships; static NioDatagramChannel create(ChannelFactory factory, ChannelPipeline pipeline, ChannelSink sink, NioDatagramWorker worker) { @@ -99,28 +107,162 @@ public final class NioDatagramChannel extends AbstractNioChannel implements io.n @Override public void joinGroup(InetAddress multicastAddress) { - throw new UnsupportedOperationException(); + try { + joinGroup(multicastAddress, NetworkInterface.getByInetAddress(getLocalAddress().getAddress()), null); + } catch (SocketException e) { + throw new ChannelException(e); + } } @Override - public void joinGroup(InetSocketAddress multicastAddress, - NetworkInterface networkInterface) { - throw new UnsupportedOperationException(); + public void joinGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + joinGroup(multicastAddress.getAddress(), networkInterface, null); } + /** + * Joins the specified multicast group at the specified interface using the specified source. + */ + public void joinGroup(InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source) { + if (DetectionUtil.javaVersion() < 7) { + throw new UnsupportedOperationException(); + } else { + if (multicastAddress == null) { + throw new NullPointerException("multicastAddress"); + } + + if (networkInterface == null) { + throw new NullPointerException("networkInterface"); + } + + try { + MembershipKey key = getJdkChannel().getChannel().join(multicastAddress, networkInterface); + synchronized (this) { + if (memberships == null) { + memberships = new HashMap>(); + + } + List keys = memberships.get(multicastAddress); + if (keys == null) { + keys = new ArrayList(); + memberships.put(multicastAddress, keys); + } + + keys.add(key); + } + } catch (IOException e) { + throw new ChannelException(e); + } + } + } + @Override public void leaveGroup(InetAddress multicastAddress) { - throw new UnsupportedOperationException(); + try { + leaveGroup(multicastAddress, NetworkInterface.getByInetAddress(getLocalAddress().getAddress()), null); + } catch (SocketException e) { + throw new ChannelException(e); + } + } @Override public void leaveGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface) { - throw new UnsupportedOperationException(); + leaveGroup(multicastAddress.getAddress(), networkInterface, null); } - - + /** + * Leave the specified multicast group at the specified interface using the specified source. + */ + public void leaveGroup(InetAddress multicastAddress, + NetworkInterface networkInterface, InetAddress source) { + if (DetectionUtil.javaVersion() < 7) { + throw new UnsupportedOperationException(); + } else { + if (multicastAddress == null) { + throw new NullPointerException("multicastAddress"); + } + + if (networkInterface == null) { + throw new NullPointerException("networkInterface"); + } + + synchronized (this) { + if (memberships != null) { + List keys = memberships.get(multicastAddress); + if (keys != null) { + Iterator keyIt = keys.iterator(); + + while(keyIt.hasNext()) { + MembershipKey key = keyIt.next(); + if (networkInterface.equals(key.networkInterface())) { + if (source == null && key.sourceAddress() == null || (source != null && source.equals(key.sourceAddress()))) { + key.drop(); + keyIt.remove(); + } + + } + } + if (keys.isEmpty()) { + memberships.remove(multicastAddress); + } + } + } + } + + + } + } + + /** + * Block the given sourceToBlock address for the given multicastAddress on the given networkInterface + * + */ + public void block(InetAddress multicastAddress, + NetworkInterface networkInterface, InetAddress sourceToBlock) { + if (DetectionUtil.javaVersion() < 7) { + throw new UnsupportedOperationException(); + } else { + if (multicastAddress == null) { + throw new NullPointerException("multicastAddress"); + } + if (sourceToBlock == null) { + throw new NullPointerException("sourceToBlock"); + } + + if (networkInterface == null) { + throw new NullPointerException("networkInterface"); + } + synchronized (this) { + if (memberships != null) { + List keys = memberships.get(multicastAddress); + for (MembershipKey key: keys) { + if (networkInterface.equals(key.networkInterface())) { + try { + key.block(sourceToBlock); + } catch (IOException e) { + throw new ChannelException(e); + } + } + } + } + } + + + } + } + /** + * Block the given sourceToBlock address for the given multicastAddress + * + */ + public void block(InetAddress multicastAddress, InetAddress sourceToBlock) { + try { + block(multicastAddress, NetworkInterface.getByInetAddress(getLocalAddress().getAddress()), sourceToBlock); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + @Override public ChannelFuture write(Object message, SocketAddress remoteAddress) { if (remoteAddress == null || remoteAddress.equals(getRemoteAddress())) {