diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastIPv6Test.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastIPv6Test.java index 26e51ac0e8..5ffa6e72c0 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastIPv6Test.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastIPv6Test.java @@ -24,7 +24,7 @@ import java.net.StandardProtocolFamily; import java.nio.channels.Channel; import java.nio.channels.spi.SelectorProvider; -public class DatagramUnicastIPv6Test extends DatagramUnicastTest { +public class DatagramUnicastIPv6Test extends DatagramUnicastInetTest { @BeforeAll public static void assumeIpv6Supported() { diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastInetTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastInetTest.java new file mode 100644 index 0000000000..94c0606f18 --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastInetTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramPacket; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class DatagramUnicastInetTest extends DatagramUnicastTest { + + @Test + public void testBindWithPortOnly(TestInfo testInfo) throws Throwable { + run(testInfo, DatagramUnicastInetTest::testBindWithPortOnly); + } + + private static void testBindWithPortOnly(Bootstrap sb, Bootstrap cb) throws Throwable { + Channel channel = null; + try { + cb.handler(new ChannelHandlerAdapter() { }); + channel = cb.bind(0).sync().channel(); + } finally { + closeChannel(channel); + } + } + + @Override + protected boolean isConnected(Channel channel) { + return ((DatagramChannel) channel).isConnected(); + } + + @Override + protected Channel setupClientChannel(Bootstrap cb, final byte[] bytes, final CountDownLatch latch, + final AtomicReference errorRef) throws Throwable { + cb.handler(new SimpleChannelInboundHandler() { + + @Override + public void messageReceived(ChannelHandlerContext ctx, DatagramPacket msg) { + try { + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + InetSocketAddress localAddress = (InetSocketAddress) ctx.channel().localAddress(); + if (localAddress.getAddress().isAnyLocalAddress()) { + assertEquals(localAddress.getPort(), msg.recipient().getPort()); + } else { + // Test that the channel's localAddress is equal to the message's recipient + assertEquals(localAddress, msg.recipient()); + } + } finally { + latch.countDown(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + return cb.bind(newSocketAddress()).sync().channel(); + } + + @Override + protected Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final SocketAddress sender, + final CountDownLatch latch, final AtomicReference errorRef, + final boolean echo) throws Throwable { + sb.handler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + + @Override + public void messageReceived(ChannelHandlerContext ctx, DatagramPacket msg) { + try { + if (sender == null) { + assertNotNull(msg.sender()); + } else { + InetSocketAddress senderAddress = (InetSocketAddress) sender; + if (senderAddress.getAddress().isAnyLocalAddress()) { + assertEquals(senderAddress.getPort(), msg.sender().getPort()); + } else { + assertEquals(sender, msg.sender()); + } + } + + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + // Test that the channel's localAddress is equal to the message's recipient + assertEquals(ctx.channel().localAddress(), msg.recipient()); + + if (echo) { + ctx.writeAndFlush(new DatagramPacket(buf.retainedDuplicate(), msg.sender())); + } + } finally { + latch.countDown(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + } + }); + return sb.bind(newSocketAddress()).sync().channel(); + } + + @Override + protected boolean supportDisconnect() { + return true; + } + + @Override + protected ChannelFuture write(Channel cc, ByteBuf buf, SocketAddress remote, WrapType wrapType) { + switch (wrapType) { + case DUP: + return cc.write(new DatagramPacket(buf.retainedDuplicate(), (InetSocketAddress) remote)); + case SLICE: + return cc.write(new DatagramPacket(buf.retainedSlice(), (InetSocketAddress) remote)); + case READ_ONLY: + return cc.write(new DatagramPacket(buf.retain().asReadOnly(), (InetSocketAddress) remote)); + case NONE: + return cc.write(new DatagramPacket(buf.retain(), (InetSocketAddress) remote)); + default: + throw new Error("unknown wrap type: " + wrapType); + } + } +} diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java index 52d04a68ec..bf905cbbe6 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java @@ -21,13 +21,9 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.channel.socket.DatagramChannel; -import io.netty.channel.socket.DatagramPacket; import io.netty.util.NetUtil; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; @@ -43,35 +39,19 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -public class DatagramUnicastTest extends AbstractDatagramTest { +public abstract class DatagramUnicastTest extends AbstractDatagramTest { private static final byte[] BYTES = {0, 1, 2, 3}; - private enum WrapType { + protected enum WrapType { NONE, DUP, SLICE, READ_ONLY } - @Test - public void testBindWithPortOnly(TestInfo testInfo) throws Throwable { - run(testInfo, this::testBindWithPortOnly); - } - - public void testBindWithPortOnly(Bootstrap sb, Bootstrap cb) throws Throwable { - Channel channel = null; - try { - cb.handler(new ChannelHandlerAdapter() { }); - channel = cb.bind(0).sync().channel(); - } finally { - closeChannel(channel); - } - } - @Test public void testSimpleSendDirectByteBuf(TestInfo testInfo) throws Throwable { run(testInfo, this::testSimpleSendDirectByteBuf); @@ -181,7 +161,7 @@ public class DatagramUnicastTest extends AbstractDatagramTest { try { cb.handler(new SimpleChannelInboundHandler() { @Override - public void messageReceived(ChannelHandlerContext ctx, Object msgs) throws Exception { + public void messageReceived(ChannelHandlerContext ctx, Object msgs) { // Nothing will be sent. } }); @@ -197,11 +177,13 @@ public class DatagramUnicastTest extends AbstractDatagramTest { } final CountDownLatch latch = new CountDownLatch(count); - AtomicReference errorRef = new AtomicReference(); + AtomicReference errorRef = new AtomicReference<>(); sc = setupServerChannel(sb, bytes, sender, latch, errorRef, false); - InetSocketAddress addr = sendToAddress((InetSocketAddress) sc.localAddress()); - List futures = new ArrayList(count); + SocketAddress localAddr = sc.localAddress(); + SocketAddress addr = localAddr instanceof InetSocketAddress ? + sendToAddress((InetSocketAddress) localAddr) : localAddr; + List futures = new ArrayList<>(count); for (int i = 0; i < count; i++) { futures.add(write(cc, buf, addr, wrapType)); } @@ -227,20 +209,6 @@ public class DatagramUnicastTest extends AbstractDatagramTest { } } - private static ChannelFuture write(Channel cc, ByteBuf buf, InetSocketAddress remote, WrapType wrapType) { - switch (wrapType) { - case DUP: - return cc.write(new DatagramPacket(buf.retainedDuplicate(), remote)); - case SLICE: - return cc.write(new DatagramPacket(buf.retainedSlice(), remote)); - case READ_ONLY: - return cc.write(new DatagramPacket(buf.retain().asReadOnly(), remote)); - case NONE: - return cc.write(new DatagramPacket(buf.retain(), remote)); - default: - throw new Error("unknown wrap type: " + wrapType); - } - } private void testSimpleSendWithConnect(Bootstrap sb, Bootstrap cb, ByteBuf buf, final byte[] bytes, int count) throws Throwable { try { @@ -254,47 +222,22 @@ public class DatagramUnicastTest extends AbstractDatagramTest { private void testSimpleSendWithConnect0(Bootstrap sb, Bootstrap cb, ByteBuf buf, final byte[] bytes, int count, WrapType wrapType) throws Throwable { - final CountDownLatch clientLatch = new CountDownLatch(count); - final AtomicReference clientErrorRef = new AtomicReference(); - cb.handler(new SimpleChannelInboundHandler() { - @Override - public void messageReceived(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception { - try { - ByteBuf buf = msg.content(); - assertEquals(bytes.length, buf.readableBytes()); - for (int i = 0; i < bytes.length; i++) { - assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); - } - - InetSocketAddress localAddress = (InetSocketAddress) ctx.channel().localAddress(); - if (localAddress.getAddress().isAnyLocalAddress()) { - assertEquals(localAddress.getPort(), msg.recipient().getPort()); - } else { - // Test that the channel's localAddress is equal to the message's recipient - assertEquals(localAddress, msg.recipient()); - } - } finally { - clientLatch.countDown(); - } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - clientErrorRef.compareAndSet(null, cause); - } - }); - Channel sc = null; - DatagramChannel cc = null; + Channel cc = null; try { final CountDownLatch latch = new CountDownLatch(count); - final AtomicReference errorRef = new AtomicReference(); - cc = (DatagramChannel) cb.bind(newSocketAddress()).sync().channel(); + final AtomicReference errorRef = new AtomicReference<>(); + final CountDownLatch clientLatch = new CountDownLatch(count); + final AtomicReference clientErrorRef = new AtomicReference<>(); + cc = setupClientChannel(cb, bytes, clientLatch, clientErrorRef); sc = setupServerChannel(sb, bytes, cc.localAddress(), latch, errorRef, true); - cc.connect(sendToAddress((InetSocketAddress) sc.localAddress())).syncUninterruptibly(); + SocketAddress localAddr = sc.localAddress(); + SocketAddress addr = localAddr instanceof InetSocketAddress ? + sendToAddress((InetSocketAddress) localAddr) : localAddr; + cc.connect(addr).syncUninterruptibly(); - List futures = new ArrayList(); + List futures = new ArrayList<>(); for (int i = 0; i < count; i++) { futures.add(write(cc, buf, wrapType)); } @@ -318,21 +261,23 @@ public class DatagramUnicastTest extends AbstractDatagramTest { } fail(); } - assertTrue(cc.isConnected()); + assertTrue(isConnected(cc)); assertNotNull(cc.localAddress()); assertNotNull(cc.remoteAddress()); - // Test what happens when we call disconnect() - cc.disconnect().syncUninterruptibly(); - assertFalse(cc.isConnected()); - assertNotNull(cc.localAddress()); - assertNull(cc.remoteAddress()); + if (supportDisconnect()) { + // Test what happens when we call disconnect() + cc.disconnect().syncUninterruptibly(); + assertFalse(isConnected(cc)); + assertNotNull(cc.localAddress()); + assertNull(cc.remoteAddress()); - ChannelFuture future = cc.writeAndFlush( - buf.retain().duplicate()).awaitUninterruptibly(); - assertTrue(future.cause() instanceof NotYetConnectedException, - "NotYetConnectedException expected, got: " + future.cause()); + ChannelFuture future = cc.writeAndFlush( + buf.retain().duplicate()).awaitUninterruptibly(); + assertTrue(future.cause() instanceof NotYetConnectedException, + "NotYetConnectedException expected, got: " + future.cause()); + } } finally { // release as we used buf.retain() before buf.release(); @@ -357,57 +302,20 @@ public class DatagramUnicastTest extends AbstractDatagramTest { } } - @SuppressWarnings("deprecation") - private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final SocketAddress sender, - final CountDownLatch latch, final AtomicReference errorRef, - final boolean echo) - throws Throwable { - sb.handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addLast(new SimpleChannelInboundHandler() { - @Override - public void messageReceived(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception { - try { - if (sender == null) { - assertNotNull(msg.sender()); - } else { - InetSocketAddress senderAddress = (InetSocketAddress) sender; - if (senderAddress.getAddress().isAnyLocalAddress()) { - assertEquals(senderAddress.getPort(), msg.sender().getPort()); - } else { - assertEquals(sender, msg.sender()); - } - } + protected abstract boolean isConnected(Channel channel); - ByteBuf buf = msg.content(); - assertEquals(bytes.length, buf.readableBytes()); - for (int i = 0; i < bytes.length; i++) { - assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); - } + protected abstract Channel setupClientChannel(Bootstrap cb, byte[] bytes, CountDownLatch latch, + AtomicReference errorRef) throws Throwable; - // Test that the channel's localAddress is equal to the message's recipient - assertEquals(ctx.channel().localAddress(), msg.recipient()); + protected abstract Channel setupServerChannel(Bootstrap sb, byte[] bytes, SocketAddress sender, + CountDownLatch latch, AtomicReference errorRef, + boolean echo) throws Throwable; - if (echo) { - ctx.writeAndFlush(new DatagramPacket(buf.retainedDuplicate(), msg.sender())); - } - } finally { - latch.countDown(); - } - } + protected abstract boolean supportDisconnect(); - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - errorRef.compareAndSet(null, cause); - } - }); - } - }); - return sb.bind(newSocketAddress()).sync().channel(); - } + protected abstract ChannelFuture write(Channel cc, ByteBuf buf, SocketAddress remote, WrapType wrapType); - private static void closeChannel(Channel channel) throws Exception { + protected static void closeChannel(Channel channel) throws Exception { if (channel != null) { channel.close().sync(); } diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDomainDatagramChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDomainDatagramChannel.java new file mode 100644 index 0000000000..7f290a50dc --- /dev/null +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDomainDatagramChannel.java @@ -0,0 +1,384 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufConvertible; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultAddressedEnvelope; +import io.netty.channel.EventLoop; +import io.netty.channel.unix.DomainDatagramChannel; +import io.netty.channel.unix.DomainDatagramChannelConfig; +import io.netty.channel.unix.DomainDatagramPacket; +import io.netty.channel.unix.DomainDatagramSocketAddress; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.channel.unix.IovArray; +import io.netty.channel.unix.PeerCredentials; +import io.netty.channel.unix.UnixChannelUtil; +import io.netty.util.CharsetUtil; +import io.netty.util.UncheckedBooleanSupplier; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; + +import static io.netty.channel.epoll.LinuxSocket.newSocketDomainDgram; + +@UnstableApi +public final class EpollDomainDatagramChannel extends AbstractEpollChannel implements DomainDatagramChannel { + + private static final ChannelMetadata METADATA = new ChannelMetadata(true); + + private static final String EXPECTED_TYPES = + " (expected: " + + StringUtil.simpleClassName(DomainDatagramPacket.class) + ", " + + StringUtil.simpleClassName(AddressedEnvelope.class) + '<' + + StringUtil.simpleClassName(ByteBuf.class) + ", " + + StringUtil.simpleClassName(DomainSocketAddress.class) + ">, " + + StringUtil.simpleClassName(ByteBuf.class) + ')'; + + private volatile boolean connected; + private volatile DomainSocketAddress local; + private volatile DomainSocketAddress remote; + + private final EpollDomainDatagramChannelConfig config; + + public EpollDomainDatagramChannel(EventLoop eventLoop) { + this(eventLoop, newSocketDomainDgram(), false); + } + + public EpollDomainDatagramChannel(EventLoop eventLoop, int fd) { + this(eventLoop, new LinuxSocket(fd), true); + } + + private EpollDomainDatagramChannel(EventLoop eventLoop, LinuxSocket socket, boolean active) { + super(null, eventLoop, socket, active); + config = new EpollDomainDatagramChannelConfig(this); + } + + @Override + public EpollDomainDatagramChannelConfig config() { + return config; + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + super.doBind(localAddress); + local = (DomainSocketAddress) localAddress; + active = true; + } + + @Override + protected void doClose() throws Exception { + super.doClose(); + connected = active = false; + local = null; + remote = null; + } + + @Override + protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + if (super.doConnect(remoteAddress, localAddress)) { + if (localAddress != null) { + local = (DomainSocketAddress) localAddress; + } + remote = (DomainSocketAddress) remoteAddress; + connected = true; + return true; + } + return false; + } + + @Override + protected void doDisconnect() throws Exception { + doClose(); + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + int maxMessagesPerWrite = maxMessagesPerWrite(); + while (maxMessagesPerWrite > 0) { + Object msg = in.current(); + if (msg == null) { + break; + } + + try { + boolean done = false; + for (int i = config().getWriteSpinCount(); i > 0; --i) { + if (doWriteMessage(msg)) { + done = true; + break; + } + } + + if (done) { + in.remove(); + maxMessagesPerWrite--; + } else { + break; + } + } catch (IOException e) { + maxMessagesPerWrite--; + + // Continue on write error as a DatagramChannel can write to multiple remote peers + // + // See https://github.com/netty/netty/issues/2665 + in.remove(e); + } + } + + if (in.isEmpty()) { + // Did write all messages. + clearFlag(Native.EPOLLOUT); + } else { + // Did not write all messages. + setFlag(Native.EPOLLOUT); + } + } + + private boolean doWriteMessage(Object msg) throws Exception { + final ByteBuf data; + DomainSocketAddress remoteAddress; + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope envelope = + (AddressedEnvelope) msg; + data = envelope.content(); + remoteAddress = envelope.recipient(); + } else { + data = ((ByteBufConvertible) msg).asByteBuf(); + remoteAddress = null; + } + + final int dataLen = data.readableBytes(); + if (dataLen == 0) { + return true; + } + + final long writtenBytes; + if (data.hasMemoryAddress()) { + long memoryAddress = data.memoryAddress(); + if (remoteAddress == null) { + writtenBytes = socket.writeAddress(memoryAddress, data.readerIndex(), data.writerIndex()); + } else { + writtenBytes = socket.sendToAddressDomainSocket(memoryAddress, data.readerIndex(), data.writerIndex(), + remoteAddress.path().getBytes(CharsetUtil.UTF_8)); + } + } else if (data.nioBufferCount() > 1) { + IovArray array = registration().cleanIovArray(); + array.add(data, data.readerIndex(), data.readableBytes()); + int cnt = array.count(); + assert cnt != 0; + + if (remoteAddress == null) { + writtenBytes = socket.writevAddresses(array.memoryAddress(0), cnt); + } else { + writtenBytes = socket.sendToAddressesDomainSocket(array.memoryAddress(0), cnt, + remoteAddress.path().getBytes(CharsetUtil.UTF_8)); + } + } else { + ByteBuffer nioData = data.internalNioBuffer(data.readerIndex(), data.readableBytes()); + if (remoteAddress == null) { + writtenBytes = socket.write(nioData, nioData.position(), nioData.limit()); + } else { + writtenBytes = socket.sendToDomainSocket(nioData, nioData.position(), nioData.limit(), + remoteAddress.path().getBytes(CharsetUtil.UTF_8)); + } + } + + return writtenBytes > 0; + } + + @Override + protected Object filterOutboundMessage(Object msg) { + if (msg instanceof DomainDatagramPacket) { + DomainDatagramPacket packet = (DomainDatagramPacket) msg; + ByteBuf content = packet.content(); + return UnixChannelUtil.isBufferCopyNeededForWrite(content) ? + new DomainDatagramPacket(newDirectBuffer(packet, content), packet.recipient()) : msg; + } + + if (msg instanceof ByteBufConvertible) { + ByteBuf buf = ((ByteBufConvertible) msg).asByteBuf(); + return UnixChannelUtil.isBufferCopyNeededForWrite(buf) ? newDirectBuffer(buf) : buf; + } + + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope e = (AddressedEnvelope) msg; + if (e.content() instanceof ByteBufConvertible && + (e.recipient() == null || e.recipient() instanceof DomainSocketAddress)) { + + ByteBuf content = ((ByteBufConvertible) e.content()).asByteBuf(); + return UnixChannelUtil.isBufferCopyNeededForWrite(content) ? + new DefaultAddressedEnvelope<>( + newDirectBuffer(e, content), (DomainSocketAddress) e.recipient()) : e; + } + } + + throw new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); + } + + @Override + public boolean isActive() { + return socket.isOpen() && (config.getActiveOnOpen() && isRegistered() || active); + } + + @Override + public boolean isConnected() { + return connected; + } + + @Override + public DomainSocketAddress localAddress() { + return (DomainSocketAddress) super.localAddress(); + } + + @Override + protected DomainSocketAddress localAddress0() { + return local; + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + protected AbstractEpollUnsafe newUnsafe() { + return new EpollDomainDatagramChannelUnsafe(); + } + + /** + * Returns the unix credentials (uid, gid, pid) of the peer + * SO_PEERCRED + */ + public PeerCredentials peerCredentials() throws IOException { + return socket.getPeerCredentials(); + } + + @Override + public DomainSocketAddress remoteAddress() { + return (DomainSocketAddress) super.remoteAddress(); + } + + @Override + protected DomainSocketAddress remoteAddress0() { + return remote; + } + + final class EpollDomainDatagramChannelUnsafe extends AbstractEpollUnsafe { + + @Override + void epollInReady() { + assert eventLoop().inEventLoop(); + final DomainDatagramChannelConfig config = config(); + if (shouldBreakEpollInReady(config)) { + clearEpollIn0(); + return; + } + final EpollRecvByteAllocatorHandle allocHandle = recvBufAllocHandle(); + + final ChannelPipeline pipeline = pipeline(); + final ByteBufAllocator allocator = config.getAllocator(); + allocHandle.reset(config); + epollInBefore(); + + Throwable exception = null; + try { + ByteBuf byteBuf = null; + try { + boolean connected = isConnected(); + do { + byteBuf = allocHandle.allocate(allocator); + allocHandle.attemptedBytesRead(byteBuf.writableBytes()); + + final DomainDatagramPacket packet; + if (connected) { + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + if (allocHandle.lastBytesRead() <= 0) { + // nothing was read, release the buffer. + byteBuf.release(); + break; + } + packet = new DomainDatagramPacket(byteBuf, (DomainSocketAddress) localAddress(), + (DomainSocketAddress) remoteAddress()); + } else { + final DomainDatagramSocketAddress remoteAddress; + if (byteBuf.hasMemoryAddress()) { + // has a memory address so use optimized call + remoteAddress = socket.recvFromAddressDomainSocket(byteBuf.memoryAddress(), + byteBuf.writerIndex(), byteBuf.capacity()); + } else { + ByteBuffer nioData = byteBuf.internalNioBuffer( + byteBuf.writerIndex(), byteBuf.writableBytes()); + remoteAddress = + socket.recvFromDomainSocket(nioData, nioData.position(), nioData.limit()); + } + + if (remoteAddress == null) { + allocHandle.lastBytesRead(-1); + byteBuf.release(); + break; + } + DomainSocketAddress localAddress = remoteAddress.localAddress(); + if (localAddress == null) { + localAddress = (DomainSocketAddress) localAddress(); + } + allocHandle.lastBytesRead(remoteAddress.receivedAmount()); + byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead()); + + packet = new DomainDatagramPacket(byteBuf, localAddress, remoteAddress); + } + + allocHandle.incMessagesRead(1); + + readPending = false; + pipeline.fireChannelRead(packet); + + byteBuf = null; + + // We use the TRUE_SUPPLIER as it is also ok to read less then what we did try to read (as long + // as we read anything). + } while (allocHandle.continueReading(UncheckedBooleanSupplier.TRUE_SUPPLIER)); + } catch (Throwable t) { + if (byteBuf != null) { + byteBuf.release(); + } + exception = t; + } + + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + + if (exception != null) { + pipeline.fireExceptionCaught(exception); + } + readIfIsAutoRead(); + } finally { + epollInFinally(config); + } + } + } +} diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDomainDatagramChannelConfig.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDomainDatagramChannelConfig.java new file mode 100644 index 0000000000..fd31542070 --- /dev/null +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDomainDatagramChannelConfig.java @@ -0,0 +1,172 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.FixedRecvByteBufAllocator; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.unix.DomainDatagramChannelConfig; +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; +import java.util.Map; + +import static io.netty.channel.ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION; +import static io.netty.channel.ChannelOption.SO_SNDBUF; + +@UnstableApi +public final class EpollDomainDatagramChannelConfig extends EpollChannelConfig implements DomainDatagramChannelConfig { + + private static final RecvByteBufAllocator DEFAULT_RCVBUF_ALLOCATOR = new FixedRecvByteBufAllocator(2048); + + private boolean activeOnOpen; + + EpollDomainDatagramChannelConfig(EpollDomainDatagramChannel channel) { + super(channel); + setRecvByteBufAllocator(DEFAULT_RCVBUF_ALLOCATOR); + } + + @Override + @SuppressWarnings("deprecation") + public Map, Object> getOptions() { + return getOptions( + super.getOptions(), + DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION, SO_SNDBUF); + } + + @Override + @SuppressWarnings({"unchecked", "deprecation"}) + public T getOption(ChannelOption option) { + if (option == DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) { + return (T) Boolean.valueOf(activeOnOpen); + } + if (option == SO_SNDBUF) { + return (T) Integer.valueOf(getSendBufferSize()); + } + return super.getOption(option); + } + + @Override + @SuppressWarnings("deprecation") + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) { + setActiveOnOpen((Boolean) value); + } else if (option == SO_SNDBUF) { + setSendBufferSize((Integer) value); + } else { + return super.setOption(option, value); + } + + return true; + } + + private void setActiveOnOpen(boolean activeOnOpen) { + if (channel.isRegistered()) { + throw new IllegalStateException("Can only changed before channel was registered"); + } + this.activeOnOpen = activeOnOpen; + } + + boolean getActiveOnOpen() { + return activeOnOpen; + } + + @Override + public EpollDomainDatagramChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public EpollDomainDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setMaxMessagesPerWrite(int maxMessagesPerWrite) { + super.setMaxMessagesPerWrite(maxMessagesPerWrite); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setSendBufferSize(int sendBufferSize) { + try { + ((EpollDomainDatagramChannel) channel).socket.setSendBufferSize(sendBufferSize); + return this; + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public int getSendBufferSize() { + try { + return ((EpollDomainDatagramChannel) channel).socket.getSendBufferSize(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public EpollDomainDatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public EpollDomainDatagramChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } +} diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java index 73fd631150..2c42e344d6 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/LinuxSocket.java @@ -345,6 +345,10 @@ final class LinuxSocket extends Socket { return new LinuxSocket(newSocketDomain0()); } + public static LinuxSocket newSocketDomainDgram() { + return new LinuxSocket(newSocketDomainDgram0()); + } + private static InetAddress unsafeInetAddrByName(String inetName) { try { return InetAddress.getByName(inetName); diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramUnicastTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramUnicastTest.java index c4f097e82d..19835a9a20 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramUnicastTest.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDatagramUnicastTest.java @@ -27,7 +27,7 @@ import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.DatagramPacket; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.testsuite.transport.TestsuitePermutation; -import io.netty.testsuite.transport.socket.DatagramUnicastTest; +import io.netty.testsuite.transport.socket.DatagramUnicastInetTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; @@ -41,7 +41,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; -public class EpollDatagramUnicastTest extends DatagramUnicastTest { +public class EpollDatagramUnicastTest extends DatagramUnicastInetTest { @Override protected List> newFactories() { return EpollSocketTestPermutation.INSTANCE.datagram(InternetProtocolFamily.IPv4); diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainDatagramPathTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainDatagramPathTest.java new file mode 100644 index 0000000000..5a043c032f --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainDatagramPathTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.unix.DomainDatagramPacket; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.AbstractClientSocketTest; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.io.FileNotFoundException; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +class EpollDomainDatagramPathTest extends AbstractClientSocketTest { + + @Test + void testConnectPathDoesNotExist(TestInfo testInfo) throws Throwable { + run(testInfo, bootstrap -> { + try { + bootstrap.handler(new ChannelHandlerAdapter() { }) + .connect(EpollSocketTestPermutation.newSocketAddress()).sync().channel(); + fail("Expected FileNotFoundException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof FileNotFoundException); + } + }); + } + + @Test + void testWriteReceiverPathDoesNotExist(TestInfo testInfo) throws Throwable { + run(testInfo, bootstrap -> { + try { + Channel ch = bootstrap.handler(new ChannelHandlerAdapter() { }) + .bind(EpollSocketTestPermutation.newSocketAddress()).sync().channel(); + ch.writeAndFlush(new DomainDatagramPacket( + Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII), + EpollSocketTestPermutation.newSocketAddress())).sync(); + fail("Expected FileNotFoundException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof FileNotFoundException); + } + }); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.domainDatagramSocket(); + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainDatagramUnicastTest.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainDatagramUnicastTest.java new file mode 100644 index 0000000000..d404dc203e --- /dev/null +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollDomainDatagramUnicastTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.epoll; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.unix.DomainDatagramChannel; +import io.netty.channel.unix.DomainDatagramPacket; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.DatagramUnicastTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class EpollDomainDatagramUnicastTest extends DatagramUnicastTest { + + @Test + void testBind(TestInfo testInfo) throws Throwable { + run(testInfo, (bootstrap, bootstrap2) -> testBind(bootstrap2)); + } + + private void testBind(Bootstrap cb) throws Throwable { + Channel channel = null; + try { + channel = cb.handler(new ChannelHandlerAdapter() { }) + .bind(newSocketAddress()).sync().channel(); + assertThat(channel.localAddress()).isNotNull() + .isInstanceOf(DomainSocketAddress.class); + } finally { + closeChannel(channel); + } + } + + @Override + protected boolean supportDisconnect() { + return false; + } + + @Override + protected boolean isConnected(Channel channel) { + return ((DomainDatagramChannel) channel).isConnected(); + } + + @Override + protected List> newFactories() { + return EpollSocketTestPermutation.INSTANCE.domainDatagram(); + } + + @Override + protected SocketAddress newSocketAddress() { + return EpollSocketTestPermutation.newSocketAddress(); + } + + @Override + protected Channel setupClientChannel(Bootstrap cb, final byte[] bytes, final CountDownLatch latch, + final AtomicReference errorRef) throws Throwable { + cb.handler(new SimpleChannelInboundHandler() { + + @Override + public void messageReceived(ChannelHandlerContext ctx, DomainDatagramPacket msg) { + try { + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + assertEquals(ctx.channel().localAddress(), msg.recipient()); + } finally { + latch.countDown(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + return cb.bind(newSocketAddress()).sync().channel(); + } + + @Override + protected Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final SocketAddress sender, + final CountDownLatch latch, final AtomicReference errorRef, + final boolean echo) throws Throwable { + sb.handler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + + @Override + public void messageReceived(ChannelHandlerContext ctx, DomainDatagramPacket msg) { + try { + if (sender == null) { + assertNotNull(msg.sender()); + } else { + assertEquals(sender, msg.sender()); + } + + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + assertEquals(ctx.channel().localAddress(), msg.recipient()); + + if (echo) { + ctx.writeAndFlush(new DomainDatagramPacket(buf.retainedDuplicate(), msg.sender())); + } + } finally { + latch.countDown(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + } + }); + return sb.bind(newSocketAddress()).sync().channel(); + } + + @Override + protected ChannelFuture write(Channel cc, ByteBuf buf, SocketAddress remote, WrapType wrapType) { + switch (wrapType) { + case DUP: + return cc.write(new DomainDatagramPacket(buf.retainedDuplicate(), (DomainSocketAddress) remote)); + case SLICE: + return cc.write(new DomainDatagramPacket(buf.retainedSlice(), (DomainSocketAddress) remote)); + case READ_ONLY: + return cc.write(new DomainDatagramPacket(buf.retain().asReadOnly(), (DomainSocketAddress) remote)); + case NONE: + return cc.write(new DomainDatagramPacket(buf.retain(), (DomainSocketAddress) remote)); + default: + throw new Error("unknown wrap type: " + wrapType); + } + } +} diff --git a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java index cfaa2aaa73..9ff7fff700 100644 --- a/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java +++ b/transport-native-epoll/src/test/java/io/netty/channel/epoll/EpollSocketTestPermutation.java @@ -91,7 +91,7 @@ class EpollSocketTestPermutation extends SocketTestPermutation { @Override public List> clientSocket() { - List> toReturn = new ArrayList>(); + List> toReturn = new ArrayList<>(); toReturn.add(() -> new Bootstrap().group(EPOLL_WORKER_GROUP).channel(EpollSocketChannel.class)); toReturn.add(() -> new Bootstrap().group(nioWorkerGroup).channel(NioSocketChannel.class)); return toReturn; @@ -187,4 +187,14 @@ class EpollSocketTestPermutation extends SocketTestPermutation { public static DomainSocketAddress newSocketAddress() { return UnixTestUtils.newSocketAddress(); } + + public List> domainDatagram() { + return combo(domainDatagramSocket(), domainDatagramSocket()); + } + + public List> domainDatagramSocket() { + return Collections.singletonList( + () -> new Bootstrap().group(EPOLL_WORKER_GROUP).channel(EpollDomainDatagramChannel.class) + ); + } } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueDatagramChannel.java new file mode 100644 index 0000000000..44de714817 --- /dev/null +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueDatagramChannel.java @@ -0,0 +1,76 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.EventLoop; + +import java.io.IOException; + +abstract class AbstractKQueueDatagramChannel extends AbstractKQueueChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(true); + + AbstractKQueueDatagramChannel(Channel parent, EventLoop eventLoop, BsdSocket fd, boolean active) { + super(parent, eventLoop, fd, active); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + protected abstract boolean doWriteMessage(Object msg) throws Exception; + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + int maxMessagesPerWrite = maxMessagesPerWrite(); + while (maxMessagesPerWrite > 0) { + Object msg = in.current(); + if (msg == null) { + break; + } + + try { + boolean done = false; + for (int i = config().getWriteSpinCount(); i > 0; --i) { + if (doWriteMessage(msg)) { + done = true; + break; + } + } + + if (done) { + in.remove(); + maxMessagesPerWrite--; + } else { + break; + } + } catch (IOException e) { + maxMessagesPerWrite--; + + // Continue on write error as a DatagramChannel can write to multiple remote peers + // + // See https://github.com/netty/netty/issues/2665 + in.remove(e); + } + } + + // Whether all messages were written or not. + writeFilter(!in.isEmpty()); + } +} diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/BsdSocket.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/BsdSocket.java index c4cd6c8a04..4065882735 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/BsdSocket.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/BsdSocket.java @@ -92,6 +92,10 @@ final class BsdSocket extends Socket { return new BsdSocket(newSocketDomain0()); } + public static BsdSocket newSocketDomainDgram() { + return new BsdSocket(newSocketDomainDgram0()); + } + private static native long sendFile(int socketFd, DefaultFileRegion src, long baseOffset, long offset, long length) throws IOException; diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index c45973123e..116d35723f 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -50,8 +50,7 @@ import static io.netty.channel.kqueue.BsdSocket.newSocketDgram; import static java.util.Objects.requireNonNull; @UnstableApi -public final class KQueueDatagramChannel extends AbstractKQueueChannel implements DatagramChannel { - private static final ChannelMetadata METADATA = new ChannelMetadata(true); +public final class KQueueDatagramChannel extends AbstractKQueueDatagramChannel implements DatagramChannel { private static final String EXPECTED_TYPES = " (expected: " + StringUtil.simpleClassName(DatagramPacket.class) + ", " + StringUtil.simpleClassName(AddressedEnvelope.class) + '<' + @@ -86,11 +85,6 @@ public final class KQueueDatagramChannel extends AbstractKQueueChannel implement return (InetSocketAddress) super.localAddress(); } - @Override - public ChannelMetadata metadata() { - return METADATA; - } - @Override @SuppressWarnings("deprecation") public boolean isActive() { @@ -246,44 +240,7 @@ public final class KQueueDatagramChannel extends AbstractKQueueChannel implement } @Override - protected void doWrite(ChannelOutboundBuffer in) throws Exception { - int maxMessagesPerWrite = maxMessagesPerWrite(); - while (maxMessagesPerWrite > 0) { - Object msg = in.current(); - if (msg == null) { - break; - } - - try { - boolean done = false; - for (int i = config().getWriteSpinCount(); i > 0; --i) { - if (doWriteMessage(msg)) { - done = true; - break; - } - } - - if (done) { - in.remove(); - maxMessagesPerWrite --; - } else { - break; - } - } catch (IOException e) { - maxMessagesPerWrite --; - - // Continue on write error as a DatagramChannel can write to multiple remote peers - // - // See https://github.com/netty/netty/issues/2665 - in.remove(e); - } - } - - // Whether all messages were written or not. - writeFilter(!in.isEmpty()); - } - - private boolean doWriteMessage(Object msg) throws Exception { + protected boolean doWriteMessage(Object msg) throws Exception { final ByteBuf data; InetSocketAddress remoteAddress; if (msg instanceof AddressedEnvelope) { diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDomainDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDomainDatagramChannel.java new file mode 100644 index 0000000000..8fa50f8534 --- /dev/null +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDomainDatagramChannel.java @@ -0,0 +1,332 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufConvertible; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultAddressedEnvelope; +import io.netty.channel.EventLoop; +import io.netty.channel.unix.DomainDatagramChannel; +import io.netty.channel.unix.DomainDatagramChannelConfig; +import io.netty.channel.unix.DomainDatagramPacket; +import io.netty.channel.unix.DomainDatagramSocketAddress; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.channel.unix.IovArray; +import io.netty.channel.unix.PeerCredentials; +import io.netty.channel.unix.UnixChannelUtil; +import io.netty.util.CharsetUtil; +import io.netty.util.UncheckedBooleanSupplier; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; + +import static io.netty.channel.kqueue.BsdSocket.newSocketDomainDgram; + +@UnstableApi +public final class KQueueDomainDatagramChannel extends AbstractKQueueDatagramChannel implements DomainDatagramChannel { + + private static final String EXPECTED_TYPES = + " (expected: " + + StringUtil.simpleClassName(DomainDatagramPacket.class) + ", " + + StringUtil.simpleClassName(AddressedEnvelope.class) + '<' + + StringUtil.simpleClassName(ByteBuf.class) + ", " + + StringUtil.simpleClassName(DomainSocketAddress.class) + ">, " + + StringUtil.simpleClassName(ByteBuf.class) + ')'; + + private volatile boolean connected; + private volatile DomainSocketAddress local; + private volatile DomainSocketAddress remote; + + private final KQueueDomainDatagramChannelConfig config; + + public KQueueDomainDatagramChannel(EventLoop eventLoop) { + this(eventLoop, newSocketDomainDgram(), false); + } + + public KQueueDomainDatagramChannel(EventLoop eventLoop, int fd) { + this(eventLoop, new BsdSocket(fd), true); + } + + private KQueueDomainDatagramChannel(EventLoop eventLoop, BsdSocket socket, boolean active) { + super(null, eventLoop, socket, active); + config = new KQueueDomainDatagramChannelConfig(this); + } + + @Override + public KQueueDomainDatagramChannelConfig config() { + return config; + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + super.doBind(localAddress); + local = (DomainSocketAddress) localAddress; + active = true; + } + + @Override + protected void doClose() throws Exception { + super.doClose(); + connected = active = false; + local = null; + remote = null; + } + + @Override + protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + if (super.doConnect(remoteAddress, localAddress)) { + if (localAddress != null) { + local = (DomainSocketAddress) localAddress; + } + remote = (DomainSocketAddress) remoteAddress; + connected = true; + return true; + } + return false; + } + + @Override + protected void doDisconnect() throws Exception { + doClose(); + } + + @Override + protected boolean doWriteMessage(Object msg) throws Exception { + final ByteBuf data; + DomainSocketAddress remoteAddress; + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope envelope = + (AddressedEnvelope) msg; + data = envelope.content(); + remoteAddress = envelope.recipient(); + } else { + data = ((ByteBufConvertible) msg).asByteBuf(); + remoteAddress = null; + } + + final int dataLen = data.readableBytes(); + if (dataLen == 0) { + return true; + } + + final long writtenBytes; + if (data.hasMemoryAddress()) { + long memoryAddress = data.memoryAddress(); + if (remoteAddress == null) { + writtenBytes = socket.writeAddress(memoryAddress, data.readerIndex(), data.writerIndex()); + } else { + writtenBytes = socket.sendToAddressDomainSocket(memoryAddress, data.readerIndex(), data.writerIndex(), + remoteAddress.path().getBytes(CharsetUtil.UTF_8)); + } + } else if (data.nioBufferCount() > 1) { + IovArray array = registration().cleanArray(); + array.add(data, data.readerIndex(), data.readableBytes()); + int cnt = array.count(); + assert cnt != 0; + + if (remoteAddress == null) { + writtenBytes = socket.writevAddresses(array.memoryAddress(0), cnt); + } else { + writtenBytes = socket.sendToAddressesDomainSocket(array.memoryAddress(0), cnt, + remoteAddress.path().getBytes(CharsetUtil.UTF_8)); + } + } else { + ByteBuffer nioData = data.internalNioBuffer(data.readerIndex(), data.readableBytes()); + if (remoteAddress == null) { + writtenBytes = socket.write(nioData, nioData.position(), nioData.limit()); + } else { + writtenBytes = socket.sendToDomainSocket(nioData, nioData.position(), nioData.limit(), + remoteAddress.path().getBytes(CharsetUtil.UTF_8)); + } + } + + return writtenBytes > 0; + } + + @Override + protected Object filterOutboundMessage(Object msg) { + if (msg instanceof DomainDatagramPacket) { + DomainDatagramPacket packet = (DomainDatagramPacket) msg; + ByteBuf content = packet.content(); + return UnixChannelUtil.isBufferCopyNeededForWrite(content) ? + new DomainDatagramPacket(newDirectBuffer(packet, content), packet.recipient()) : msg; + } + + if (msg instanceof ByteBufConvertible) { + ByteBuf buf = ((ByteBufConvertible) msg).asByteBuf(); + return UnixChannelUtil.isBufferCopyNeededForWrite(buf) ? newDirectBuffer(buf) : buf; + } + + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope e = (AddressedEnvelope) msg; + if (e.content() instanceof ByteBufConvertible && + (e.recipient() == null || e.recipient() instanceof DomainSocketAddress)) { + + ByteBuf content = ((ByteBufConvertible) e.content()).asByteBuf(); + return UnixChannelUtil.isBufferCopyNeededForWrite(content) ? + new DefaultAddressedEnvelope<>( + newDirectBuffer(e, content), (DomainSocketAddress) e.recipient()) : e; + } + } + + throw new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); + } + + @Override + public boolean isActive() { + return socket.isOpen() && (config.getActiveOnOpen() && isRegistered() || active); + } + + @Override + public boolean isConnected() { + return connected; + } + + @Override + public DomainSocketAddress localAddress() { + return (DomainSocketAddress) super.localAddress(); + } + + @Override + protected DomainSocketAddress localAddress0() { + return local; + } + + @Override + protected AbstractKQueueUnsafe newUnsafe() { + return new KQueueDomainDatagramChannelUnsafe(); + } + + /** + * Returns the unix credentials (uid, gid, pid) of the peer + * SO_PEERCRED + */ + public PeerCredentials peerCredentials() throws IOException { + return socket.getPeerCredentials(); + } + + @Override + public DomainSocketAddress remoteAddress() { + return (DomainSocketAddress) super.remoteAddress(); + } + + @Override + protected DomainSocketAddress remoteAddress0() { + return remote; + } + + final class KQueueDomainDatagramChannelUnsafe extends AbstractKQueueUnsafe { + + @Override + void readReady(KQueueRecvByteAllocatorHandle allocHandle) { + assert eventLoop().inEventLoop(); + final DomainDatagramChannelConfig config = config(); + if (shouldBreakReadReady(config)) { + clearReadFilter0(); + return; + } + final ChannelPipeline pipeline = pipeline(); + final ByteBufAllocator allocator = config.getAllocator(); + allocHandle.reset(config); + readReadyBefore(); + + Throwable exception = null; + try { + ByteBuf byteBuf = null; + try { + boolean connected = isConnected(); + do { + byteBuf = allocHandle.allocate(allocator); + allocHandle.attemptedBytesRead(byteBuf.writableBytes()); + + final DomainDatagramPacket packet; + if (connected) { + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + if (allocHandle.lastBytesRead() <= 0) { + // nothing was read, release the buffer. + byteBuf.release(); + break; + } + packet = new DomainDatagramPacket(byteBuf, (DomainSocketAddress) localAddress(), + (DomainSocketAddress) remoteAddress()); + } else { + final DomainDatagramSocketAddress remoteAddress; + if (byteBuf.hasMemoryAddress()) { + // has a memory address so use optimized call + remoteAddress = socket.recvFromAddressDomainSocket(byteBuf.memoryAddress(), + byteBuf.writerIndex(), byteBuf.capacity()); + } else { + ByteBuffer nioData = byteBuf.internalNioBuffer( + byteBuf.writerIndex(), byteBuf.writableBytes()); + remoteAddress = + socket.recvFromDomainSocket(nioData, nioData.position(), nioData.limit()); + } + + if (remoteAddress == null) { + allocHandle.lastBytesRead(-1); + byteBuf.release(); + break; + } + DomainSocketAddress localAddress = remoteAddress.localAddress(); + if (localAddress == null) { + localAddress = (DomainSocketAddress) localAddress(); + } + allocHandle.lastBytesRead(remoteAddress.receivedAmount()); + byteBuf.writerIndex(byteBuf.writerIndex() + allocHandle.lastBytesRead()); + + packet = new DomainDatagramPacket(byteBuf, localAddress, remoteAddress); + } + + allocHandle.incMessagesRead(1); + + readPending = false; + pipeline.fireChannelRead(packet); + + byteBuf = null; + + // We use the TRUE_SUPPLIER as it is also ok to read less then what we did try to read (as long + // as we read anything). + } while (allocHandle.continueReading(UncheckedBooleanSupplier.TRUE_SUPPLIER)); + } catch (Throwable t) { + if (byteBuf != null) { + byteBuf.release(); + } + exception = t; + } + + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + + if (exception != null) { + pipeline.fireExceptionCaught(exception); + } else { + readIfIsAutoRead(); + } + } finally { + readReadyFinally(config); + } + } + } +} diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDomainDatagramChannelConfig.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDomainDatagramChannelConfig.java new file mode 100644 index 0000000000..55d73e92f7 --- /dev/null +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDomainDatagramChannelConfig.java @@ -0,0 +1,178 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.FixedRecvByteBufAllocator; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.unix.DomainDatagramChannelConfig; +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; +import java.util.Map; + +import static io.netty.channel.ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION; +import static io.netty.channel.ChannelOption.SO_SNDBUF; + +@UnstableApi +public final class KQueueDomainDatagramChannelConfig + extends KQueueChannelConfig implements DomainDatagramChannelConfig { + + private static final RecvByteBufAllocator DEFAULT_RCVBUF_ALLOCATOR = new FixedRecvByteBufAllocator(2048); + + private boolean activeOnOpen; + + KQueueDomainDatagramChannelConfig(KQueueDomainDatagramChannel channel) { + super(channel); + setRecvByteBufAllocator(DEFAULT_RCVBUF_ALLOCATOR); + } + + @Override + @SuppressWarnings("deprecation") + public Map, Object> getOptions() { + return getOptions(super.getOptions(), + DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION, SO_SNDBUF); + } + + @Override + @SuppressWarnings({"unchecked", "deprecation"}) + public T getOption(ChannelOption option) { + if (option == DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) { + return (T) Boolean.valueOf(activeOnOpen); + } + if (option == SO_SNDBUF) { + return (T) Integer.valueOf(getSendBufferSize()); + } + return super.getOption(option); + } + + @Override + @SuppressWarnings("deprecation") + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) { + setActiveOnOpen((Boolean) value); + } else if (option == SO_SNDBUF) { + setSendBufferSize((Integer) value); + } else { + return super.setOption(option, value); + } + + return true; + } + + private void setActiveOnOpen(boolean activeOnOpen) { + if (channel.isRegistered()) { + throw new IllegalStateException("Can only changed before channel was registered"); + } + this.activeOnOpen = activeOnOpen; + } + + boolean getActiveOnOpen() { + return activeOnOpen; + } + + @Override + public KQueueDomainDatagramChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public KQueueDomainDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setMaxMessagesPerWrite(int maxMessagesPerWrite) { + super.setMaxMessagesPerWrite(maxMessagesPerWrite); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setRcvAllocTransportProvidesGuess(boolean transportProvidesGuess) { + super.setRcvAllocTransportProvidesGuess(transportProvidesGuess); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setSendBufferSize(int sendBufferSize) { + try { + ((KQueueDomainDatagramChannel) channel).socket.setSendBufferSize(sendBufferSize); + return this; + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public int getSendBufferSize() { + try { + return ((KQueueDomainDatagramChannel) channel).socket.getSendBufferSize(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public KQueueDomainDatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public KQueueDomainDatagramChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramUnicastTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramUnicastTest.java index bca86d8272..48004910d3 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramUnicastTest.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDatagramUnicastTest.java @@ -18,11 +18,11 @@ package io.netty.channel.kqueue; import io.netty.bootstrap.Bootstrap; import io.netty.channel.socket.InternetProtocolFamily; import io.netty.testsuite.transport.TestsuitePermutation; -import io.netty.testsuite.transport.socket.DatagramUnicastTest; +import io.netty.testsuite.transport.socket.DatagramUnicastInetTest; import java.util.List; -public class KQueueDatagramUnicastTest extends DatagramUnicastTest { +public class KQueueDatagramUnicastTest extends DatagramUnicastInetTest { @Override protected List> newFactories() { return KQueueSocketTestPermutation.INSTANCE.datagram(InternetProtocolFamily.IPv4); diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainDatagramPathTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainDatagramPathTest.java new file mode 100644 index 0000000000..3c5a2a00b8 --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainDatagramPathTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.unix.DomainDatagramPacket; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.AbstractClientSocketTest; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.io.FileNotFoundException; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +class KQueueDomainDatagramPathTest extends AbstractClientSocketTest { + + @Test + void testConnectPathDoesNotExist(TestInfo testInfo) throws Throwable { + run(testInfo, bootstrap -> { + try { + bootstrap.handler(new ChannelHandlerAdapter() { }) + .connect(KQueueSocketTestPermutation.newSocketAddress()).sync().channel(); + fail("Expected FileNotFoundException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof FileNotFoundException); + } + }); + } + + @Test + void testWriteReceiverPathDoesNotExist(TestInfo testInfo) throws Throwable { + run(testInfo, bootstrap -> { + try { + Channel ch = bootstrap.handler(new ChannelHandlerAdapter() { }) + .bind(KQueueSocketTestPermutation.newSocketAddress()).sync().channel(); + ch.writeAndFlush(new DomainDatagramPacket( + Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII), + KQueueSocketTestPermutation.newSocketAddress())).sync(); + fail("Expected FileNotFoundException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof FileNotFoundException); + } + }); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.domainDatagramSocket(); + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainDatagramUnicastTest.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainDatagramUnicastTest.java new file mode 100644 index 0000000000..5f416e840f --- /dev/null +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueDomainDatagramUnicastTest.java @@ -0,0 +1,170 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.kqueue; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.unix.DomainDatagramChannel; +import io.netty.channel.unix.DomainDatagramPacket; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.testsuite.transport.TestsuitePermutation; +import io.netty.testsuite.transport.socket.DatagramUnicastTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class KQueueDomainDatagramUnicastTest extends DatagramUnicastTest { + + @Test + void testBind(TestInfo testInfo) throws Throwable { + run(testInfo, (bootstrap, bootstrap2) -> testBind(bootstrap2)); + } + + private void testBind(Bootstrap cb) throws Throwable { + Channel channel = null; + try { + channel = cb.handler(new ChannelHandlerAdapter() { }) + .bind(newSocketAddress()).sync().channel(); + assertThat(channel.localAddress()).isNotNull() + .isInstanceOf(DomainSocketAddress.class); + } finally { + closeChannel(channel); + } + } + + @Override + protected boolean supportDisconnect() { + return false; + } + + @Override + protected boolean isConnected(Channel channel) { + return ((DomainDatagramChannel) channel).isConnected(); + } + + @Override + protected List> newFactories() { + return KQueueSocketTestPermutation.INSTANCE.domainDatagram(); + } + + @Override + protected SocketAddress newSocketAddress() { + return KQueueSocketTestPermutation.newSocketAddress(); + } + + @Override + protected Channel setupClientChannel(Bootstrap cb, final byte[] bytes, final CountDownLatch latch, + final AtomicReference errorRef) throws Throwable { + cb.handler(new SimpleChannelInboundHandler() { + + @Override + public void messageReceived(ChannelHandlerContext ctx, DomainDatagramPacket msg) { + try { + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + assertEquals(ctx.channel().localAddress(), msg.recipient()); + } finally { + latch.countDown(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + return cb.bind(newSocketAddress()).sync().channel(); + } + + @Override + protected Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final SocketAddress sender, + final CountDownLatch latch, final AtomicReference errorRef, + final boolean echo) throws Throwable { + sb.handler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + + @Override + public void messageReceived(ChannelHandlerContext ctx, DomainDatagramPacket msg) { + try { + if (sender == null) { + assertNotNull(msg.sender()); + } else { + assertEquals(sender, msg.sender()); + } + + ByteBuf buf = msg.content(); + assertEquals(bytes.length, buf.readableBytes()); + for (int i = 0; i < bytes.length; i++) { + assertEquals(bytes[i], buf.getByte(buf.readerIndex() + i)); + } + + assertEquals(ctx.channel().localAddress(), msg.recipient()); + + if (echo) { + ctx.writeAndFlush(new DomainDatagramPacket(buf.retainedDuplicate(), msg.sender())); + } + } finally { + latch.countDown(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + errorRef.compareAndSet(null, cause); + } + }); + } + }); + return sb.bind(newSocketAddress()).sync().channel(); + } + + @Override + protected ChannelFuture write(Channel cc, ByteBuf buf, SocketAddress remote, WrapType wrapType) { + switch (wrapType) { + case DUP: + return cc.write(new DomainDatagramPacket(buf.retainedDuplicate(), (DomainSocketAddress) remote)); + case SLICE: + return cc.write(new DomainDatagramPacket(buf.retainedSlice(), (DomainSocketAddress) remote)); + case READ_ONLY: + return cc.write(new DomainDatagramPacket(buf.retain().asReadOnly(), (DomainSocketAddress) remote)); + case NONE: + return cc.write(new DomainDatagramPacket(buf.retain(), (DomainSocketAddress) remote)); + default: + throw new Error("unknown wrap type: " + wrapType); + } + } +} diff --git a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java index 70ffe017e6..5b98659b60 100644 --- a/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java +++ b/transport-native-kqueue/src/test/java/io/netty/channel/kqueue/KQueueSocketTestPermutation.java @@ -32,8 +32,6 @@ import io.netty.testsuite.transport.TestsuitePermutation; import io.netty.testsuite.transport.TestsuitePermutation.BootstrapFactory; import io.netty.testsuite.transport.socket.SocketTestPermutation; import io.netty.util.concurrent.DefaultThreadFactory; -import io.netty.util.internal.logging.InternalLogger; -import io.netty.util.internal.logging.InternalLoggerFactory; import java.util.ArrayList; import java.util.Arrays; @@ -135,4 +133,15 @@ class KQueueSocketTestPermutation extends SocketTestPermutation { public static DomainSocketAddress newSocketAddress() { return UnixTestUtils.newSocketAddress(); } + + public List> domainDatagram() { + return combo(domainDatagramSocket(), domainDatagramSocket()); + } + + public List> domainDatagramSocket() { + return Collections.singletonList( + () -> new Bootstrap().group(KQUEUE_WORKER_GROUP).channel(KQueueDomainDatagramChannel.class) + ); + } + } diff --git a/transport-native-unix-common/src/main/c/netty_unix_socket.c b/transport-native-unix-common/src/main/c/netty_unix_socket.c index 03e9da2131..2524b41220 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_socket.c +++ b/transport-native-unix-common/src/main/c/netty_unix_socket.c @@ -45,7 +45,9 @@ #endif static jclass datagramSocketAddressClass = NULL; +static jclass domainDatagramSocketAddressClass = NULL; static jmethodID datagramSocketAddrMethodId = NULL; +static jmethodID domainDatagramSocketAddrMethodId = NULL; static jmethodID inetSocketAddrMethodId = NULL; static jclass inetSocketAddressClass = NULL; static int socketType = AF_INET; @@ -123,6 +125,22 @@ static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_st return obj; } +static jobject createDomainDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, int len, jobject local) { + struct sockaddr_un* s = (struct sockaddr_un*) addr; + int pathLength = strlen(s->sun_path); + jbyteArray pathBytes = (*env)->NewByteArray(env, pathLength); + if (pathBytes == NULL) { + return NULL; + } + + (*env)->SetByteArrayRegion(env, pathBytes, 0, pathLength, (jbyte*) &s->sun_path); + jobject obj = (*env)->NewObject(env, domainDatagramSocketAddressClass, domainDatagramSocketAddrMethodId, pathBytes, len, local); + if ((*env)->ExceptionCheck(env) == JNI_TRUE) { + return NULL; + } + return obj; +} + static jsize addressLength(const struct sockaddr_storage* addr) { int len = netty_unix_socket_ipAddressLength(addr); if (len == 4) { @@ -334,6 +352,35 @@ static jint _sendTo(JNIEnv* env, jint fd, jboolean ipv6, void* buffer, jint pos, return (jint) res; } +static jint _sendToDomainSocket(JNIEnv* env, jint fd, void* buffer, jint pos, jint limit, jbyteArray socketPath) { + struct sockaddr_un addr; + jint socket_path_len; + + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + + jbyte* socket_path = (*env)->GetByteArrayElements(env, socketPath, 0); + socket_path_len = (*env)->GetArrayLength(env, socketPath); + if (socket_path_len > sizeof(addr.sun_path)) { + socket_path_len = sizeof(addr.sun_path); + } + memcpy(addr.sun_path, socket_path, socket_path_len); + + ssize_t res; + int err; + do { + res = sendto(fd, buffer + pos, (size_t) (limit - pos), 0, (struct sockaddr*) &addr, _UNIX_ADDR_LENGTH(socket_path_len)); + // keep on writing if it was interrupted + } while (res == -1 && ((err = errno) == EINTR)); + + (*env)->ReleaseByteArrayElements(env, socketPath, socket_path, 0); + + if (res < 0) { + return -err; + } + return (jint) res; +} + static jobject _recvFrom(JNIEnv* env, jint fd, void* buffer, jint pos, jint limit) { struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); @@ -365,6 +412,33 @@ static jobject _recvFrom(JNIEnv* env, jint fd, void* buffer, jint pos, jint limi return createDatagramSocketAddress(env, &addr, res, NULL); } +static jobject _recvFromDomainSocket(JNIEnv* env, jint fd, void* buffer, jint pos, jint limit) { + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + ssize_t res; + int err; + + do { + res = recvfrom(fd, buffer + pos, (size_t) (limit - pos), 0, (struct sockaddr*) &addr, &addrlen); + // Keep on reading if it was interrupted + } while (res == -1 && ((err = errno) == EINTR)); + + if (res < 0) { + if (err == EAGAIN || err == EWOULDBLOCK) { + // Nothing left to read + return NULL; + } + if (err == EBADF) { + netty_unix_errors_throwClosedChannelException(env); + return NULL; + } + netty_unix_errors_throwIOExceptionErrorNo(env, "_recvFromDomainSocket() failed: ", err); + return NULL; + } + + return createDomainDatagramSocketAddress(env, &addr, res, NULL); +} + void netty_unix_socket_getOptionHandleError(JNIEnv* env, int err) { netty_unix_socket_optionHandleError(env, err, "getsockopt() failed: "); } @@ -577,6 +651,14 @@ static jint netty_unix_socket_newSocketDomainFd(JNIEnv* env, jclass clazz) { return fd; } +static jint netty_unix_socket_newSocketDomainDgramFd(JNIEnv* env, jclass clazz) { + int fd = nettyNonBlockingSocket(PF_UNIX, SOCK_DGRAM, 0); + if (fd == -1) { + return -errno; + } + return fd; +} + static jint netty_unix_socket_sendTo(JNIEnv* env, jclass clazz, jint fd, jboolean ipv6, jobject jbuffer, jint pos, jint limit, jbyteArray address, jint scopeId, jint port, jint flags) { // We check that GetDirectBufferAddress will not return NULL in OnLoad return _sendTo(env, fd, ipv6, (*env)->GetDirectBufferAddress(env, jbuffer), pos, limit, address, scopeId, port, flags); @@ -612,6 +694,50 @@ static jint netty_unix_socket_sendToAddresses(JNIEnv* env, jclass clazz, jint fd return (jint) res; } +static jint netty_unix_socket_sendToDomainSocket(JNIEnv* env, jclass clazz, jint fd, jobject jbuffer, jint pos, jint limit, jbyteArray socketPath) { + // We check that GetDirectBufferAddress will not return NULL in OnLoad + return _sendToDomainSocket(env, fd, (*env)->GetDirectBufferAddress(env, jbuffer), pos, limit, socketPath); +} + +static jint netty_unix_socket_sendToAddressDomainSocket(JNIEnv* env, jclass clazz, jint fd, jlong memoryAddress, jint pos, jint limit, jbyteArray socketPath) { + return _sendToDomainSocket(env, fd, (void *) (intptr_t) memoryAddress, pos, limit, socketPath); +} + +static jint netty_unix_socket_sendToAddressesDomainSocket(JNIEnv* env, jclass clazz, jint fd, jlong memoryAddress, jint length, jbyteArray socketPath) { + struct sockaddr_un addr; + jint socket_path_len; + + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + + jbyte* socket_path = (*env)->GetByteArrayElements(env, socketPath, 0); + socket_path_len = (*env)->GetArrayLength(env, socketPath); + if (socket_path_len > sizeof(addr.sun_path)) { + socket_path_len = sizeof(addr.sun_path); + } + memcpy(addr.sun_path, socket_path, socket_path_len); + + struct msghdr m = { 0 }; + m.msg_name = (void*) &addr; + m.msg_namelen = sizeof(struct sockaddr_un); + m.msg_iov = (struct iovec*) (intptr_t) memoryAddress; + m.msg_iovlen = length; + + ssize_t res; + int err; + do { + res = sendmsg(fd, &m, 0); + // keep on writing if it was interrupted + } while (res == -1 && ((err = errno) == EINTR)); + + (*env)->ReleaseByteArrayElements(env, socketPath, socket_path, 0); + + if (res < 0) { + return -err; + } + return (jint) res; +} + static jobject netty_unix_socket_recvFrom(JNIEnv* env, jclass clazz, jint fd, jobject jbuffer, jint pos, jint limit) { // We check that GetDirectBufferAddress will not return NULL in OnLoad return _recvFrom(env, fd, (*env)->GetDirectBufferAddress(env, jbuffer), pos, limit); @@ -621,6 +747,15 @@ static jobject netty_unix_socket_recvFromAddress(JNIEnv* env, jclass clazz, jint return _recvFrom(env, fd, (void *) (intptr_t) address, pos, limit); } +static jobject netty_unix_socket_recvFromDomainSocket(JNIEnv* env, jclass clazz, jint fd, jobject jbuffer, jint pos, jint limit) { + // We check that GetDirectBufferAddress will not return NULL in OnLoad + return _recvFromDomainSocket(env, fd, (*env)->GetDirectBufferAddress(env, jbuffer), pos, limit); +} + +static jobject netty_unix_socket_recvFromAddressDomainSocket(JNIEnv* env, jclass clazz, jint fd, jlong address, jint pos, jint limit) { + return _recvFromDomainSocket(env, fd, (void *) (intptr_t) address, pos, limit); +} + static jint netty_unix_socket_bindDomainSocket(JNIEnv* env, jclass clazz, jint fd, jbyteArray socketPath) { struct sockaddr_un addr; @@ -939,11 +1074,17 @@ static const JNINativeMethod fixed_method_table[] = { { "newSocketDgramFd", "(Z)I", (void *) netty_unix_socket_newSocketDgramFd }, { "newSocketStreamFd", "(Z)I", (void *) netty_unix_socket_newSocketStreamFd }, { "newSocketDomainFd", "()I", (void *) netty_unix_socket_newSocketDomainFd }, + { "newSocketDomainDgramFd", "()I", (void *) netty_unix_socket_newSocketDomainDgramFd }, { "sendTo", "(IZLjava/nio/ByteBuffer;II[BIII)I", (void *) netty_unix_socket_sendTo }, { "sendToAddress", "(IZJII[BIII)I", (void *) netty_unix_socket_sendToAddress }, { "sendToAddresses", "(IZJI[BIII)I", (void *) netty_unix_socket_sendToAddresses }, + { "sendToDomainSocket", "(ILjava/nio/ByteBuffer;II[B)I", (void *) netty_unix_socket_sendToDomainSocket }, + { "sendToAddressDomainSocket", "(IJII[B)I", (void *) netty_unix_socket_sendToAddressDomainSocket }, + { "sendToAddressesDomainSocket", "(IJI[B)I", (void *) netty_unix_socket_sendToAddressesDomainSocket }, // "recvFrom" has a dynamic signature // "recvFromAddress" has a dynamic signature + // "recvFromDomainSocket" has a dynamic signature + // "recvFromAddressDomainSocket" has a dynamic signature { "recvFd", "(I)I", (void *) netty_unix_socket_recvFd }, { "sendFd", "(II)I", (void *) netty_unix_socket_sendFd }, { "bindDomainSocket", "(I[B)I", (void *) netty_unix_socket_bindDomainSocket }, @@ -975,7 +1116,8 @@ static const JNINativeMethod fixed_method_table[] = { static const jint fixed_method_table_size = sizeof(fixed_method_table) / sizeof(fixed_method_table[0]); static jint dynamicMethodsTableSize() { - return fixed_method_table_size + 2; + // 4 is for the dynamic method signatures. + return fixed_method_table_size + 4; } static JNINativeMethod* createDynamicMethodsTable(const char* packagePrefix) { @@ -1002,6 +1144,20 @@ static JNINativeMethod* createDynamicMethodsTable(const char* packagePrefix) { dynamicMethod->fnPtr = (void *) netty_unix_socket_recvFromAddress; netty_jni_util_free_dynamic_name(&dynamicTypeName); + ++dynamicMethod; + NETTY_JNI_UTIL_PREPEND(packagePrefix, "io/netty/channel/unix/DomainDatagramSocketAddress;", dynamicTypeName, error); + NETTY_JNI_UTIL_PREPEND("(ILjava/nio/ByteBuffer;II)L", dynamicTypeName, dynamicMethod->signature, error); + dynamicMethod->name = "recvFromDomainSocket"; + dynamicMethod->fnPtr = (void *) netty_unix_socket_recvFromDomainSocket; + netty_jni_util_free_dynamic_name(&dynamicTypeName); + + ++dynamicMethod; + NETTY_JNI_UTIL_PREPEND(packagePrefix, "io/netty/channel/unix/DomainDatagramSocketAddress;", dynamicTypeName, error); + NETTY_JNI_UTIL_PREPEND("(IJII)L", dynamicTypeName, dynamicMethod->signature, error); + dynamicMethod->name = "recvFromAddressDomainSocket"; + dynamicMethod->fnPtr = (void *) netty_unix_socket_recvFromAddressDomainSocket; + netty_jni_util_free_dynamic_name(&dynamicTypeName); + return dynamicMethods; error: free(dynamicTypeName); @@ -1038,6 +1194,14 @@ jint netty_unix_socket_JNI_OnLoad(JNIEnv* env, const char* packagePrefix) { netty_jni_util_free_dynamic_name(&nettyClassName); NETTY_JNI_UTIL_GET_METHOD(env, datagramSocketAddressClass, datagramSocketAddrMethodId, "", parameters, done); + NETTY_JNI_UTIL_PREPEND(packagePrefix, "io/netty/channel/unix/DomainDatagramSocketAddress", nettyClassName, done); + NETTY_JNI_UTIL_LOAD_CLASS(env, domainDatagramSocketAddressClass, nettyClassName, done); + + char parameters1[1024] = {0}; + snprintf(parameters1, sizeof(parameters1), "([BIL%s;)V", nettyClassName); + netty_jni_util_free_dynamic_name(&nettyClassName); + NETTY_JNI_UTIL_GET_METHOD(env, domainDatagramSocketAddressClass, domainDatagramSocketAddrMethodId, "", parameters1, done); + NETTY_JNI_UTIL_LOAD_CLASS(env, inetSocketAddressClass, "java/net/InetSocketAddress", done); NETTY_JNI_UTIL_GET_METHOD(env, inetSocketAddressClass, inetSocketAddrMethodId, "", "(Ljava/lang/String;I)V", done); @@ -1063,6 +1227,7 @@ done: void netty_unix_socket_JNI_OnUnLoad(JNIEnv* env, const char* packagePrefix) { NETTY_JNI_UTIL_UNLOAD_CLASS(env, datagramSocketAddressClass); + NETTY_JNI_UTIL_UNLOAD_CLASS(env, domainDatagramSocketAddressClass); NETTY_JNI_UTIL_UNLOAD_CLASS(env, inetSocketAddressClass); netty_jni_util_unregister_natives(env, packagePrefix, SOCKET_CLASSNAME); diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java new file mode 100644 index 0000000000..a26ef95d23 --- /dev/null +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java @@ -0,0 +1,39 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.Channel; + +/** + * A {@link UnixChannel} that supports communication via + * UNIX domain datagram sockets. + */ +public interface DomainDatagramChannel extends UnixChannel, Channel { + + @Override + DomainDatagramChannelConfig config(); + + /** + * Return {@code true} if the {@link DomainDatagramChannel} is connected to the remote peer. + */ + boolean isConnected(); + + @Override + DomainSocketAddress localAddress(); + + @Override + DomainSocketAddress remoteAddress(); +} diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java new file mode 100644 index 0000000000..68b1a97de7 --- /dev/null +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java @@ -0,0 +1,80 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +/** + * A {@link ChannelConfig} for a {@link DomainDatagramChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link ChannelConfig}, + * {@link DomainDatagramChannelConfig} allows the following options in the option map: + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#SO_SNDBUF}{@link #setSendBufferSize(int)}
+ */ +public interface DomainDatagramChannelConfig extends ChannelConfig { + + @Override + DomainDatagramChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + DomainDatagramChannelConfig setAutoClose(boolean autoClose); + + @Override + DomainDatagramChannelConfig setAutoRead(boolean autoRead); + + @Override + DomainDatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + @Deprecated + DomainDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + DomainDatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + DomainDatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + /** + * Sets the {@link java.net.StandardSocketOptions#SO_SNDBUF} option. + */ + DomainDatagramChannelConfig setSendBufferSize(int sendBufferSize); + + /** + * Gets the {@link java.net.StandardSocketOptions#SO_SNDBUF} option. + */ + int getSendBufferSize(); + + @Override + DomainDatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + + @Override + DomainDatagramChannelConfig setWriteSpinCount(int writeSpinCount); +} diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java new file mode 100644 index 0000000000..39a1cd3570 --- /dev/null +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java @@ -0,0 +1,86 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.DefaultAddressedEnvelope; + +/** + * The message container that is used for {@link DomainDatagramChannel} to communicate with the remote peer. + */ +public final class DomainDatagramPacket + extends DefaultAddressedEnvelope implements ByteBufHolder { + + /** + * Create a new instance with the specified packet {@code data} and {@code recipient} address. + */ + public DomainDatagramPacket(ByteBuf data, DomainSocketAddress recipient) { + super(data, recipient); + } + + /** + * Create a new instance with the specified packet {@code data}, {@code recipient} address, and {@code sender} + * address. + */ + public DomainDatagramPacket(ByteBuf data, DomainSocketAddress recipient, DomainSocketAddress sender) { + super(data, recipient, sender); + } + + @Override + public DomainDatagramPacket copy() { + return replace(content().copy()); + } + + @Override + public DomainDatagramPacket duplicate() { + return replace(content().duplicate()); + } + + @Override + public DomainDatagramPacket replace(ByteBuf content) { + return new DomainDatagramPacket(content, recipient(), sender()); + } + + @Override + public DomainDatagramPacket retain() { + super.retain(); + return this; + } + + @Override + public DomainDatagramPacket retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public DomainDatagramPacket retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public DomainDatagramPacket touch() { + super.touch(); + return this; + } + + @Override + public DomainDatagramPacket touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java new file mode 100644 index 0000000000..b67c6701e1 --- /dev/null +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java @@ -0,0 +1,48 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.util.CharsetUtil; + +/** + * Act as special {@link DomainSocketAddress} to be able to easily pass all needed data from JNI without the need + * to create more objects then needed. + *

+ * Internal usage only! + */ +public final class DomainDatagramSocketAddress extends DomainSocketAddress { + + private static final long serialVersionUID = -5925732678737768223L; + + private final DomainDatagramSocketAddress localAddress; + // holds the amount of received bytes + private final int receivedAmount; + + public DomainDatagramSocketAddress(byte[] socketPath, int receivedAmount, + DomainDatagramSocketAddress localAddress) { + super(new String(socketPath, CharsetUtil.UTF_8)); + this.localAddress = localAddress; + this.receivedAmount = receivedAmount; + } + + public DomainDatagramSocketAddress localAddress() { + return localAddress; + } + + public int receivedAmount() { + return receivedAmount; + } +} diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainSocketAddress.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainSocketAddress.java index 4acb772897..a3767f4135 100644 --- a/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainSocketAddress.java +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/DomainSocketAddress.java @@ -24,7 +24,7 @@ import java.net.SocketAddress; * A address for a * Unix Domain Socket. */ -public final class DomainSocketAddress extends SocketAddress { +public class DomainSocketAddress extends SocketAddress { private static final long serialVersionUID = -6934618000832236893L; private final String socketPath; diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/Socket.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/Socket.java index 9ed98630a4..0452845803 100644 --- a/transport-native-unix-common/src/main/java/io/netty/channel/unix/Socket.java +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/Socket.java @@ -146,6 +146,14 @@ public class Socket extends FileDescriptor { return ioResult("sendTo", res); } + public final int sendToDomainSocket(ByteBuffer buf, int pos, int limit, byte[] path) throws IOException { + int res = sendToDomainSocket(fd, buf, pos, limit, path); + if (res >= 0) { + return res; + } + return ioResult("sendToDomainSocket", res); + } + public final int sendToAddress(long memoryAddress, int pos, int limit, InetAddress addr, int port) throws IOException { return sendToAddress(memoryAddress, pos, limit, addr, port, false); @@ -182,6 +190,14 @@ public class Socket extends FileDescriptor { return ioResult("sendToAddress", res); } + public final int sendToAddressDomainSocket(long memoryAddress, int pos, int limit, byte[] path) throws IOException { + int res = sendToAddressDomainSocket(fd, memoryAddress, pos, limit, path); + if (res >= 0) { + return res; + } + return ioResult("sendToAddressDomainSocket", res); + } + public final int sendToAddresses(long memoryAddress, int length, InetAddress addr, int port) throws IOException { return sendToAddresses(memoryAddress, length, addr, port, false); } @@ -217,6 +233,14 @@ public class Socket extends FileDescriptor { return ioResult("sendToAddresses", res); } + public final int sendToAddressesDomainSocket(long memoryAddress, int length, byte[] path) throws IOException { + int res = sendToAddressesDomainSocket(fd, memoryAddress, length, path); + if (res >= 0) { + return res; + } + return ioResult("sendToAddressesDomainSocket", res); + } + public final DatagramSocketAddress recvFrom(ByteBuffer buf, int pos, int limit) throws IOException { return recvFrom(fd, buf, pos, limit); } @@ -225,6 +249,16 @@ public class Socket extends FileDescriptor { return recvFromAddress(fd, memoryAddress, pos, limit); } + public final DomainDatagramSocketAddress recvFromDomainSocket(ByteBuffer buf, int pos, int limit) + throws IOException { + return recvFromDomainSocket(fd, buf, pos, limit); + } + + public final DomainDatagramSocketAddress recvFromAddressDomainSocket(long memoryAddress, int pos, int limit) + throws IOException { + return recvFromAddressDomainSocket(fd, memoryAddress, pos, limit); + } + public final int recvFd() throws IOException { int res = recvFd(fd); if (res > 0) { @@ -441,6 +475,10 @@ public class Socket extends FileDescriptor { return new Socket(newSocketDomain0()); } + public static Socket newSocketDomainDgram() { + return new Socket(newSocketDomainDgram0()); + } + public static void initialize() { if (INITIALIZED.compareAndSet(false, true)) { initialize(NetUtil.isIpV4StackPreferred()); @@ -479,6 +517,14 @@ public class Socket extends FileDescriptor { return res; } + protected static int newSocketDomainDgram0() { + int res = newSocketDomainDgramFd(); + if (res < 0) { + throw new ChannelException(newIOException("newSocketDomainDgram", res)); + } + return res; + } + private static native int shutdown(int fd, boolean read, boolean write); private static native int connect(int fd, boolean ipv6, byte[] address, int scopeId, int port); private static native int connectDomainSocket(int fd, byte[] path); @@ -504,10 +550,18 @@ public class Socket extends FileDescriptor { int fd, boolean ipv6, long memoryAddress, int length, byte[] address, int scopeId, int port, int flags); + private static native int sendToDomainSocket(int fd, ByteBuffer buf, int pos, int limit, byte[] path); + private static native int sendToAddressDomainSocket(int fd, long memoryAddress, int pos, int limit, byte[] path); + private static native int sendToAddressesDomainSocket(int fd, long memoryAddress, int length, byte[] path); + private static native DatagramSocketAddress recvFrom( int fd, ByteBuffer buf, int pos, int limit) throws IOException; private static native DatagramSocketAddress recvFromAddress( int fd, long memoryAddress, int pos, int limit) throws IOException; + private static native DomainDatagramSocketAddress recvFromDomainSocket( + int fd, ByteBuffer buf, int pos, int limit) throws IOException; + private static native DomainDatagramSocketAddress recvFromAddressDomainSocket( + int fd, long memoryAddress, int pos, int limit) throws IOException; private static native int recvFd(int fd); private static native int sendFd(int socketFd, int fd); private static native int msgFastopen(); @@ -515,6 +569,7 @@ public class Socket extends FileDescriptor { private static native int newSocketStreamFd(boolean ipv6); private static native int newSocketDgramFd(boolean ipv6); private static native int newSocketDomainFd(); + private static native int newSocketDomainDgramFd(); private static native int isReuseAddress(int fd) throws IOException; private static native int isReusePort(int fd) throws IOException; diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/Unix.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/Unix.java index 168622bd90..c020997bab 100644 --- a/transport-native-unix-common/src/main/java/io/netty/channel/unix/Unix.java +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/Unix.java @@ -43,7 +43,7 @@ public final class Unix { IOException.class, PortUnreachableException.class, // netty_unix_socket - DatagramSocketAddress.class, InetSocketAddress.class + DatagramSocketAddress.class, DomainDatagramSocketAddress.class, InetSocketAddress.class ); }