diff --git a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java index ef8d07b6e4..aad9422bcb 100644 --- a/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java +++ b/handler/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -62,6 +62,8 @@ import static io.netty.handler.ssl.OpenSsl.memoryAddress; import static io.netty.util.internal.EmptyArrays.EMPTY_CERTIFICATES; import static io.netty.util.internal.EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Math.max; +import static java.lang.Math.min; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; @@ -159,6 +161,8 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = MAX_ENCRYPTED_PACKET_LENGTH - MAX_PLAINTEXT_LENGTH; + private static final int MAX_ENCRYPTION_OVERHEAD_DIFF = Integer.MAX_VALUE - MAX_ENCRYPTION_OVERHEAD_LENGTH; + private static final AtomicIntegerFieldUpdater DESTROYED_UPDATER = AtomicIntegerFieldUpdater.newUpdater(ReferenceCountedOpenSslEngine.class, "destroyed"); @@ -440,7 +444,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } else { final int pos = dst.position(); final int limit = dst.limit(); - final int len = Math.min(MAX_ENCRYPTED_PACKET_LENGTH, limit - pos); + final int len = min(MAX_ENCRYPTED_PACKET_LENGTH, limit - pos); final ByteBuf buf = alloc.directBuffer(len); try { final long addr = memoryAddress(buf); @@ -593,7 +597,30 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } } - if (dst.remaining() < MAX_ENCRYPTED_PACKET_LENGTH) { + int endOffset = offset + length; + int srcsLen = 0; + + for (int i = offset; i < endOffset; ++i) { + final ByteBuffer src = srcs[i]; + if (src == null) { + throw new IllegalArgumentException("srcs[" + i + "] is null"); + } + if (srcsLen == MAX_PLAINTEXT_LENGTH) { + continue; + } + + srcsLen += src.remaining(); + if (srcsLen > MAX_PLAINTEXT_LENGTH || srcsLen < 0) { + // If srcLen > MAX_PLAINTEXT_LENGTH or secLen < 0 just set it to MAX_PLAINTEXT_LENGTH. + // This also help us to guard against overflow. + // We not break out here as we still need to check for null entries in srcs[]. + srcsLen = MAX_PLAINTEXT_LENGTH; + } + } + + int maxEncryptedLen = calculateOutNetBufSize(srcsLen); + + if (dst.remaining() < maxEncryptedLen) { // Can not hold the maximum packet so we need to tell the caller to use a bigger destination // buffer. return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); @@ -601,18 +628,14 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc // There was no pending data in the network BIO -- encrypt any application data int bytesProduced = 0; int bytesConsumed = 0; - int endOffset = offset + length; loop: for (int i = offset; i < endOffset; ++i) { final ByteBuffer src = srcs[i]; - if (src == null) { - throw new IllegalArgumentException("srcs[" + i + "] is null"); - } while (src.hasRemaining()) { final SSLEngineResult pendingNetResult; // Write plaintext application data to the SSL engine int result = writePlaintextData( - src, Math.min(src.remaining(), MAX_PLAINTEXT_LENGTH - bytesConsumed)); + src, min(src.remaining(), MAX_PLAINTEXT_LENGTH - bytesConsumed)); if (result > 0) { bytesConsumed += result; @@ -816,7 +839,7 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc } // Write more encrypted data into the BIO. Ensure we only read one packet at a time as // stated in the SSLEngine javadocs. - int written = writeEncryptedData(src, Math.min(packetLengthRemaining, src.remaining())); + int written = writeEncryptedData(src, min(packetLengthRemaining, src.remaining())); if (written > 0) { packetLengthRemaining -= written; if (packetLengthRemaining == 0) { @@ -1640,6 +1663,11 @@ public class ReferenceCountedOpenSslEngine extends SSLEngine implements Referenc return destroyed != 0; } + static int calculateOutNetBufSize(int pendingBytes) { + return min(MAX_ENCRYPTED_PACKET_LENGTH, MAX_ENCRYPTION_OVERHEAD_LENGTH + + min(MAX_ENCRYPTION_OVERHEAD_DIFF, pendingBytes)); + } + private final class OpenSslSession implements SSLSession, ApplicationProtocolAccessor { private final OpenSslSessionContext sessionContext; diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 199c7da51b..13b7ba17d4 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -209,6 +209,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH out.writerIndex(writerIndex + result.bytesProduced()); return result; } + + @Override + int calculateOutNetBufSize(SslHandler handler, int pendingBytes) { + return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes); + } }, JDK(false, MERGE_CUMULATOR) { @Override @@ -220,6 +225,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH out.writerIndex(writerIndex + result.bytesProduced()); return result; } + + @Override + int calculateOutNetBufSize(SslHandler handler, int pendingBytes) { + return handler.maxPacketBufferSize; + } }; static SslEngineType forEngine(SSLEngine engine) { @@ -237,6 +247,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) throws SSLException; + abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes); + // BEGIN Platform-dependent flags /** @@ -339,7 +351,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH throw new NullPointerException("delegatedTaskExecutor"); } this.engine = engine; - this.engineType = SslEngineType.forEngine(engine); + engineType = SslEngineType.forEngine(engine); this.delegatedTaskExecutor = delegatedTaskExecutor; this.startTls = startTls; maxPacketBufferSize = engine.getSession().getPacketBufferSize(); @@ -572,7 +584,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf buf = (ByteBuf) msg; if (out == null) { - out = allocateOutNetBuf(ctx); + out = allocateOutNetBuf(ctx, buf.readableBytes()); } SSLEngineResult result = wrap(alloc, engine, buf, out); @@ -653,7 +665,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH // See https://github.com/netty/netty/issues/5860 while (!ctx.isRemoved()) { if (out == null) { - out = allocateOutNetBuf(ctx); + out = allocateOutNetBuf(ctx, 0); } SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); @@ -1556,8 +1568,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt * the specified amount of pending bytes. */ - private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx) { - return allocate(ctx, maxPacketBufferSize); + private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes) { + return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes)); } private final class LazyChannelPromise extends DefaultPromise { diff --git a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java index 9911738312..986b20a37b 100644 --- a/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -26,12 +26,16 @@ import org.junit.BeforeClass; import org.junit.Test; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; public class OpenSslEngineTest extends SSLEngineTest { @@ -115,6 +119,146 @@ public class OpenSslEngineTest extends SSLEngineTest { } } + @Test + public void testOnlySmallBufferNeededForWrap() throws Exception { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + handshake(clientEngine, serverEngine); + + // Allocate a buffer which is small enough and set the limit to the capacity to mark its whole content + // as readable. + int srcLen = 1024; + ByteBuffer src = allocateBuffer(srcLen); + + ByteBuffer dstTooSmall = allocateBuffer( + src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH - 1); + ByteBuffer dst = allocateBuffer( + src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + + // Check that we fail to wrap if the dst buffers capacity is not at least + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH + SSLEngineResult result = clientEngine.wrap(src, dstTooSmall); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + assertEquals(src.remaining(), src.capacity()); + assertEquals(dst.remaining(), dst.capacity()); + + // Check that we can wrap with a dst buffer that has the capacity of + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH + result = clientEngine.wrap(src, dst); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(srcLen, result.bytesConsumed()); + assertEquals(0, src.remaining()); + assertTrue(result.bytesProduced() > srcLen); + assertEquals(src.capacity() - result.bytesConsumed(), src.remaining()); + assertEquals(dst.capacity() - result.bytesProduced(), dst.remaining()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @Test + public void testNeededDstCapacityIsCorrectlyCalculated() throws Exception { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + handshake(clientEngine, serverEngine); + + ByteBuffer src = allocateBuffer(1024); + ByteBuffer src2 = src.duplicate(); + + ByteBuffer dst = allocateBuffer(src.capacity() + + ReferenceCountedOpenSslEngine.MAX_ENCRYPTION_OVERHEAD_LENGTH); + + SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(0, src.position()); + assertEquals(0, src2.position()); + assertEquals(0, dst.position()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @Test + public void testSrcsLenOverFlowCorrectlyHandled() throws Exception { + clientSslCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build(); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build(); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + serverEngine = serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + handshake(clientEngine, serverEngine); + + ByteBuffer src = allocateBuffer(1024); + List srcList = new ArrayList(); + long srcsLen = 0; + long maxLen = ((long) Integer.MAX_VALUE) * 2; + + while (srcsLen < maxLen) { + ByteBuffer dup = src.duplicate(); + srcList.add(dup); + srcsLen += dup.capacity(); + } + + ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[srcList.size()]); + + ByteBuffer dst = allocateBuffer(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH - 1); + + SSLEngineResult result = clientEngine.wrap(srcs, dst); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + + for (ByteBuffer buffer : srcs) { + assertEquals(0, buffer.position()); + } + assertEquals(0, dst.position()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @Test + public void testCalculateOutNetBufSizeOverflow() { + assertEquals(ReferenceCountedOpenSslEngine.MAX_ENCRYPTED_PACKET_LENGTH, + ReferenceCountedOpenSslEngine.calculateOutNetBufSize(Integer.MAX_VALUE)); + } + @Override protected SslProvider sslClientProvider() { return SslProvider.OPENSSL;