diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java index 381df73309..716379fc67 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java @@ -20,6 +20,11 @@ import org.junit.BeforeClass; import static org.junit.Assume.assumeTrue; public class JdkOpenSslEngineInteroptTest extends SSLEngineTest { + + public JdkOpenSslEngineInteroptTest(BufferType type) { + super(type); + } + @BeforeClass public static void checkOpenSsl() { assumeTrue(OpenSsl.isAvailable()); diff --git a/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java index 9a5723043b..5856fee1c2 100644 --- a/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java @@ -39,6 +39,10 @@ public class JdkSslEngineTest extends SSLEngineTest { private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; private static final String APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE = "my-protocol-FOO"; + public JdkSslEngineTest(BufferType type) { + super(type); + } + @Test public void testNpn() throws Exception { try { diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 2ef0a6c6c2..9911738312 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -38,6 +38,10 @@ public class OpenSslEngineTest extends SSLEngineTest { private static final String PREFERRED_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http2"; private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; + public OpenSslEngineTest(BufferType type) { + super(type); + } + @BeforeClass public static void checkOpenSsl() { assumeTrue(OpenSsl.isAvailable()); @@ -78,7 +82,7 @@ public class OpenSslEngineTest extends SSLEngineTest { new String[]{PROTOCOL_SSL_V2_HELLO, PROTOCOL_TLS_V1_2}); } @Test - public void testWrapHeapBuffersNoWritePendingError() throws Exception { + public void testWrapBuffersNoWritePendingError() throws Exception { clientSslCtx = SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .sslProvider(sslClientProvider()) @@ -94,9 +98,11 @@ public class OpenSslEngineTest extends SSLEngineTest { serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); handshake(clientEngine, serverEngine); - ByteBuffer src = ByteBuffer.allocate(1024 * 10); - ThreadLocalRandom.current().nextBytes(src.array()); - ByteBuffer dst = ByteBuffer.allocate(1); + ByteBuffer src = allocateBuffer(1024 * 10); + byte[] data = new byte[src.capacity()]; + ThreadLocalRandom.current().nextBytes(data); + src.put(data).flip(); + ByteBuffer dst = allocateBuffer(1); // Try to wrap multiple times so we are more likely to hit the issue. for (int i = 0; i < 100; i++) { src.position(0); diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java index 9954f692b2..7cba76ffc5 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java @@ -21,6 +21,10 @@ import static org.junit.Assume.assumeTrue; public class OpenSslJdkSslEngineInteroptTest extends SSLEngineTest { + public OpenSslJdkSslEngineInteroptTest(BufferType type) { + super(type); + } + @BeforeClass public static void checkOpenSsl() { assumeTrue(OpenSsl.isAvailable()); diff --git a/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java index 40b3193b76..6d38940cbf 100644 --- a/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java @@ -20,6 +20,11 @@ import io.netty.util.ReferenceCountUtil; import javax.net.ssl.SSLEngine; public class ReferenceCountedOpenSslEngineTest extends OpenSslEngineTest { + + public ReferenceCountedOpenSslEngineTest(BufferType type) { + super(type); + } + @Override protected SslProvider sslClientProvider() { return SslProvider.OPENSSL_REFCNT; diff --git a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java index 6dc5ca7f40..1e288b644f 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -19,6 +19,7 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; @@ -40,9 +41,12 @@ import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Promise; import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ThreadLocalRandom; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -55,6 +59,8 @@ import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.security.cert.Certificate; import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; @@ -77,6 +83,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.verify; +@RunWith(Parameterized.class) public abstract class SSLEngineTest { private static final String X509_CERT_PEM = @@ -186,6 +193,205 @@ public abstract class SSLEngineTest { } } + enum BufferType { + Direct, + Heap, + Mixed + } + + @Parameterized.Parameters(name = "{index}: bufferType = {0}") + public static Collection data() { + List params = new ArrayList(); + for (BufferType type: BufferType.values()) { + params.add(type); + } + return params; + } + + private final BufferType type; + + protected SSLEngineTest(BufferType type) { + this.type = type; + } + + protected ByteBuffer allocateBuffer(int len) { + switch (type) { + case Direct: + return ByteBuffer.allocateDirect(len); + case Heap: + return ByteBuffer.allocate(len); + case Mixed: + return ThreadLocalRandom.current().nextBoolean() ? + ByteBuffer.allocateDirect(len) : ByteBuffer.allocate(len); + default: + throw new Error(); + } + } + + private static final class TestByteBufAllocator implements ByteBufAllocator { + + private final ByteBufAllocator allocator; + private final BufferType type; + + TestByteBufAllocator(ByteBufAllocator allocator, BufferType type) { + this.allocator = allocator; + this.type = type; + } + + @Override + public ByteBuf buffer() { + switch (type) { + case Direct: + return allocator.directBuffer(); + case Heap: + return allocator.heapBuffer(); + case Mixed: + return ThreadLocalRandom.current().nextBoolean() ? + allocator.directBuffer() : allocator.heapBuffer(); + default: + throw new Error(); + } + } + + @Override + public ByteBuf buffer(int initialCapacity) { + switch (type) { + case Direct: + return allocator.directBuffer(initialCapacity); + case Heap: + return allocator.heapBuffer(initialCapacity); + case Mixed: + return ThreadLocalRandom.current().nextBoolean() ? + allocator.directBuffer(initialCapacity) : allocator.heapBuffer(initialCapacity); + default: + throw new Error(); + } + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + switch (type) { + case Direct: + return allocator.directBuffer(initialCapacity, maxCapacity); + case Heap: + return allocator.heapBuffer(initialCapacity, maxCapacity); + case Mixed: + return ThreadLocalRandom.current().nextBoolean() ? + allocator.directBuffer(initialCapacity, maxCapacity) : + allocator.heapBuffer(initialCapacity, maxCapacity); + default: + throw new Error(); + } + } + + @Override + public ByteBuf ioBuffer() { + return allocator.ioBuffer(); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return allocator.ioBuffer(initialCapacity); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return allocator.ioBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf heapBuffer() { + return allocator.heapBuffer(); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return allocator.heapBuffer(initialCapacity); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return allocator.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf directBuffer() { + return allocator.directBuffer(); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return allocator.directBuffer(initialCapacity); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return allocator.directBuffer(initialCapacity, maxCapacity); + } + + @Override + public CompositeByteBuf compositeBuffer() { + switch (type) { + case Direct: + return allocator.compositeDirectBuffer(); + case Heap: + return allocator.compositeHeapBuffer(); + case Mixed: + return ThreadLocalRandom.current().nextBoolean() ? + allocator.compositeDirectBuffer() : + allocator.compositeHeapBuffer(); + default: + throw new Error(); + } + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + switch (type) { + case Direct: + return allocator.compositeDirectBuffer(maxNumComponents); + case Heap: + return allocator.compositeHeapBuffer(maxNumComponents); + case Mixed: + return ThreadLocalRandom.current().nextBoolean() ? + allocator.compositeDirectBuffer(maxNumComponents) : + allocator.compositeHeapBuffer(maxNumComponents); + default: + throw new Error(); + } + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return allocator.compositeHeapBuffer(); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return allocator.compositeHeapBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return allocator.compositeDirectBuffer(); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return allocator.compositeDirectBuffer(maxNumComponents); + } + + @Override + public boolean isDirectBufferPooled() { + return allocator.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return allocator.calculateNewCapacity(minNewCapacity, maxCapacity); + } + } + @Before public void setup() { MockitoAnnotations.initMocks(this); @@ -340,6 +546,8 @@ public abstract class SSLEngineTest { sb.childHandler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ChannelPipeline p = ch.pipeline(); SSLEngine engine = serverSslCtx.newEngine(ch.alloc()); engine.setUseClientMode(false); @@ -398,6 +606,8 @@ public abstract class SSLEngineTest { cb.handler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ChannelPipeline p = ch.pipeline(); p.addLast(clientSslCtx.newHandler(ch.alloc())); p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); @@ -550,6 +760,8 @@ public abstract class SSLEngineTest { .childHandler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ChannelPipeline p = ch.pipeline(); p.addLast(serverSslCtx.newHandler(ch.alloc())); p.addLast(new ChannelInboundHandlerAdapter() { @@ -600,6 +812,8 @@ public abstract class SSLEngineTest { .handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ChannelPipeline p = ch.pipeline(); SslHandler sslHandler = clientSslCtx.newHandler(ch.alloc()); // The renegotiate is not expected to succeed, so we should stop trying in a timely manner so @@ -678,19 +892,19 @@ public abstract class SSLEngineTest { } } - protected static void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { - ByteBuffer cTOs = ByteBuffer.allocateDirect(clientEngine.getSession().getPacketBufferSize()); - ByteBuffer sTOc = ByteBuffer.allocateDirect(serverEngine.getSession().getPacketBufferSize()); + protected void handshake(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { + ByteBuffer cTOs = allocateBuffer(clientEngine.getSession().getPacketBufferSize()); + ByteBuffer sTOc = allocateBuffer(serverEngine.getSession().getPacketBufferSize()); - ByteBuffer serverAppReadBuffer = ByteBuffer.allocateDirect( + ByteBuffer serverAppReadBuffer = allocateBuffer( serverEngine.getSession().getApplicationBufferSize()); - ByteBuffer clientAppReadBuffer = ByteBuffer.allocateDirect( + ByteBuffer clientAppReadBuffer = allocateBuffer( clientEngine.getSession().getApplicationBufferSize()); clientEngine.beginHandshake(); serverEngine.beginHandshake(); - ByteBuffer empty = ByteBuffer.allocate(0); + ByteBuffer empty = allocateBuffer(0); SSLEngineResult clientResult; SSLEngineResult serverResult; @@ -844,6 +1058,8 @@ public abstract class SSLEngineTest { sb.childHandler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ChannelPipeline p = ch.pipeline(); p.addLast(serverSslCtx.newHandler(ch.alloc())); p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); @@ -867,6 +1083,8 @@ public abstract class SSLEngineTest { cb.handler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ChannelPipeline p = ch.pipeline(); p.addLast(clientSslCtx.newHandler(ch.alloc())); p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); @@ -907,6 +1125,8 @@ public abstract class SSLEngineTest { serverChannel = sb.childHandler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ch.pipeline().addFirst(serverSslCtx.newHandler(ch.alloc())); ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { @Override @@ -955,9 +1175,14 @@ public abstract class SSLEngineTest { cb = new Bootstrap(); cb.group(new NioEventLoopGroup()); cb.channel(NioSocketChannel.class); - clientChannel = cb.handler( - new SslHandler(clientSslCtx.newEngine(ByteBufAllocator.DEFAULT))) - .connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + clientChannel = cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + ch.pipeline().addLast(new SslHandler(clientSslCtx.newEngine(ch.alloc()))); + } + + }).connect(serverChannel.localAddress()).syncUninterruptibly().channel(); promise.syncUninterruptibly(); } @@ -982,9 +1207,9 @@ public abstract class SSLEngineTest { byte[] bytes = "Hello World".getBytes(CharsetUtil.US_ASCII); try { - ByteBuffer plainClientOut = ByteBuffer.allocate(client.getSession().getApplicationBufferSize()); - ByteBuffer encryptedClientToServer = ByteBuffer.allocate(server.getSession().getPacketBufferSize() * 2); - ByteBuffer plainServerIn = ByteBuffer.allocate(server.getSession().getApplicationBufferSize()); + ByteBuffer plainClientOut = allocateBuffer(client.getSession().getApplicationBufferSize()); + ByteBuffer encryptedClientToServer = allocateBuffer(server.getSession().getPacketBufferSize() * 2); + ByteBuffer plainServerIn = allocateBuffer(server.getSession().getApplicationBufferSize()); handshake(client, server); @@ -1016,7 +1241,7 @@ public abstract class SSLEngineTest { // try with too small output buffer first (to check BUFFER_OVERFLOW case) int remaining = encryptedClientToServer.remaining(); - ByteBuffer small = ByteBuffer.allocate(3); + ByteBuffer small = allocateBuffer(3); result = server.unwrap(encryptedClientToServer, small); assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); assertEquals(remaining, encryptedClientToServer.remaining()); @@ -1061,7 +1286,7 @@ public abstract class SSLEngineTest { try { // Allocate an buffer that is bigger then the max plain record size. - ByteBuffer plainServerOut = ByteBuffer.allocate(server.getSession().getApplicationBufferSize() * 2); + ByteBuffer plainServerOut = allocateBuffer(server.getSession().getApplicationBufferSize() * 2); handshake(client, server); @@ -1069,7 +1294,7 @@ public abstract class SSLEngineTest { plainServerOut.position(plainServerOut.capacity()); plainServerOut.flip(); - ByteBuffer encryptedServerToClient = ByteBuffer.allocate(server.getSession().getPacketBufferSize()); + ByteBuffer encryptedServerToClient = allocateBuffer(server.getSession().getPacketBufferSize()); int encryptedServerToClientPos = encryptedServerToClient.position(); int plainServerOutPos = plainServerOut.position(); @@ -1093,9 +1318,9 @@ public abstract class SSLEngineTest { SSLEngine client = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); try { - ByteBuffer src = ByteBuffer.allocate(client.getSession().getApplicationBufferSize()); - ByteBuffer dst = ByteBuffer.allocate(client.getSession().getPacketBufferSize()); - ByteBuffer empty = ByteBuffer.allocateDirect(0); + ByteBuffer src = allocateBuffer(client.getSession().getApplicationBufferSize()); + ByteBuffer dst = allocateBuffer(client.getSession().getPacketBufferSize()); + ByteBuffer empty = allocateBuffer(0); SSLEngineResult clientResult = client.wrap(empty, dst); assertEquals(SSLEngineResult.Status.OK, clientResult.getStatus()); @@ -1159,9 +1384,9 @@ public abstract class SSLEngineTest { } } - private static void testBeginHandshakeCloseOutbound(SSLEngine engine) throws SSLException { - ByteBuffer dst = ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize()); - ByteBuffer empty = ByteBuffer.allocateDirect(0); + private void testBeginHandshakeCloseOutbound(SSLEngine engine) throws SSLException { + ByteBuffer dst = allocateBuffer(engine.getSession().getPacketBufferSize()); + ByteBuffer empty = allocateBuffer(0); engine.beginHandshake(); engine.closeOutbound(); @@ -1233,12 +1458,12 @@ public abstract class SSLEngineTest { SSLEngine server = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); try { - ByteBuffer plainClientOut = ByteBuffer.allocate(client.getSession().getApplicationBufferSize()); - ByteBuffer plainServerOut = ByteBuffer.allocate(server.getSession().getApplicationBufferSize()); + ByteBuffer plainClientOut = allocateBuffer(client.getSession().getApplicationBufferSize()); + ByteBuffer plainServerOut = allocateBuffer(server.getSession().getApplicationBufferSize()); - ByteBuffer encryptedClientToServer = ByteBuffer.allocate(client.getSession().getPacketBufferSize()); - ByteBuffer encryptedServerToClient = ByteBuffer.allocate(server.getSession().getPacketBufferSize()); - ByteBuffer empty = ByteBuffer.allocate(0); + ByteBuffer encryptedClientToServer = allocateBuffer(client.getSession().getPacketBufferSize()); + ByteBuffer encryptedServerToClient = allocateBuffer(server.getSession().getPacketBufferSize()); + ByteBuffer empty = allocateBuffer(0); handshake(client, server); @@ -1376,14 +1601,14 @@ public abstract class SSLEngineTest { try { // Choose buffer size small enough that we can put multiple buffers into one buffer and pass it into the // unwrap call without exceed MAX_ENCRYPTED_PACKET_LENGTH. - ByteBuffer plainClientOut = ByteBuffer.allocate(1024); - ByteBuffer plainServerOut = ByteBuffer.allocate(server.getSession().getApplicationBufferSize()); + ByteBuffer plainClientOut = allocateBuffer(1024); + ByteBuffer plainServerOut = allocateBuffer(server.getSession().getApplicationBufferSize()); - ByteBuffer encClientToServer = ByteBuffer.allocate(client.getSession().getPacketBufferSize()); + ByteBuffer encClientToServer = allocateBuffer(client.getSession().getPacketBufferSize()); int positionOffset = 1; // We need to be able to hold 2 records + positionOffset - ByteBuffer combinedEncClientToServer = ByteBuffer.allocate( + ByteBuffer combinedEncClientToServer = allocateBuffer( encClientToServer.capacity() * 2 + positionOffset); combinedEncClientToServer.position(positionOffset);