fix: Load functional patches correctly

This commit is contained in:
oSumAtrIX 2024-04-27 02:16:41 +02:00
parent c5f02e8c28
commit 4088609e2a
No known key found for this signature in database
GPG Key ID: A9B3094ACDB604B4
6 changed files with 144 additions and 70 deletions

View File

@ -1,4 +1,4 @@
@file:Suppress("MemberVisibilityCanBePrivate")
@file:Suppress("MemberVisibilityCanBePrivate", "unused")
package app.revanced.patcher.patch
@ -565,7 +565,7 @@ class PatchResult internal constructor(val patch: Patch<*>, val exception: Patch
/**
* A loader for [Patch].
* Loads patches from JAR or DEX files declared as public fields or properties.
* Loads patches from JAR or DEX files declared as public static fields or returned by public static methods.
* Patches with no name are not loaded.
*
* @param patchesFiles A set of JAR or DEX files to load the patches from.
@ -576,30 +576,7 @@ sealed class PatchLoader private constructor(
patchesFiles: Set<File>,
private val getBinaryClassNames: (patchesFile: File) -> List<String>,
private val classLoader: ClassLoader,
) : PatchSet by mutableSetOf() {
init {
@Suppress("UNCHECKED_CAST", "LeakingThis")
val thisSet = this as MutableSet<Patch<*>>
patchesFiles.asSequence().flatMap(getBinaryClassNames).map {
classLoader.loadClass(it)
}.flatMap {
// Get all patches from the class declared as public fields.
val patchFields = it.fields.filter { field ->
Patch::class.java.isAssignableFrom(field.type)
}.map { field -> field.get(null) as Patch<*> }
// Get all patches from the class declared as public methods.
val patchMethods = it.methods.filter { method ->
Patch::class.java.isAssignableFrom(method.returnType)
}.map { method -> method.invoke(null) as Patch<*> }
patchFields + patchMethods
}.filter {
it.name != null
}.toList().let(thisSet::addAll)
}
) : PatchSet by classLoader.loadPatches(patchesFiles.flatMap(getBinaryClassNames)) {
/**
* A [PatchLoader] for JAR files.
*
@ -640,10 +617,43 @@ sealed class PatchLoader private constructor(
PatchLoader::class.java.classLoader,
),
)
// Companion object required for unit tests.
private companion object {
/**
* Loads named patches declared as public static fields or returned by public static methods from classes.
*
* @param binaryClassNames The binary class name of the classes to load the patches from.
*
* @return The loaded patches.
*/
private fun ClassLoader.loadPatches(binaryClassNames: List<String>) = binaryClassNames.asSequence().map {
loadClass(it)
}.flatMap {
val isPatch = { cls: Class<*> -> Patch::class.java.isAssignableFrom(cls) }
val patchesFromFields = it.fields.filter { field ->
isPatch(field.type) && field.canAccess(null)
}.map { field ->
field.get(null) as Patch<*>
}
val patchesFromMethods = it.methods.filter { method ->
isPatch(method.returnType) && method.canAccess(null)
}.map { method ->
method.invoke(null) as Patch<*>
}
patchesFromFields + patchesFromMethods
}.filter {
it.name != null
}.toSet()
}
}
/**
* Loads patches from JAR files.
* Loads patches from JAR files declared as public static fields or returned by public static methods.
* Patches with no name are not loaded.
*
* @param patchesFiles The JAR files to load the patches from.
*
@ -653,7 +663,8 @@ fun loadPatchesFromJar(patchesFiles: Set<File>): PatchSet =
PatchLoader.Jar(patchesFiles)
/**
* Loads patches from DEX files.
* Loads patches from DEX files declared as public static fields or returned by public static methods.
* Patches with no name are not loaded.
*
* @param patchesFiles The DEX files to load the patches from.
*

View File

@ -20,7 +20,7 @@ import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertTrue
object PatcherTest {
internal object PatcherTest {
private lateinit var patcher: Patcher
@BeforeEach

View File

@ -0,0 +1,66 @@
@file:Suppress("unused")
package app.revanced.patcher.patch
import org.junit.jupiter.api.Test
import kotlin.reflect.KFunction
import kotlin.reflect.full.*
import kotlin.reflect.jvm.javaField
import kotlin.reflect.jvm.javaMethod
import kotlin.test.assertEquals
// region Test patches.
val publicUnnamedPatch = bytecodePatch {
}
val publicPatch = bytecodePatch("Public") {
}
private val privateUnnamedPatch = bytecodePatch {
}
private val privatePatch = bytecodePatch("Private") {
}
fun publicUnnamedPatchFunction() = publicUnnamedPatch
fun publicNamedPatchFunction() = bytecodePatch("Public") { }
private fun privateUnnamedPatchFunction() = privateUnnamedPatch
private fun privateNamedPatchFunction() = privatePatch
// endregion
internal object PatchLoaderTest {
private const val LOAD_PATCHES_FUNCTION_NAME = "loadPatches"
private val TEST_PATCHES_CLASS = ::publicPatch.javaField!!.declaringClass.name
private val TEST_PATCHES_CLASS_LOADER = ::publicPatch.javaClass.classLoader
@Test
fun `loads patches correctly`() {
// Get instance of private PatchLoader.Companion class.
val patchLoaderCompanionObject = PatchLoader::class.java.declaredFields.first {
it.type == PatchLoader::class.companionObject!!.javaObjectType
}.apply {
isAccessible = true
}.get(null)
// Get private PatchLoader.Companion.loadPatches function from PatchLoader.Companion.
@Suppress("UNCHECKED_CAST")
val loadPatchesFunction = patchLoaderCompanionObject::class.declaredFunctions.first {
it.name == LOAD_PATCHES_FUNCTION_NAME
}.apply {
javaMethod!!.isAccessible = true
} as KFunction<Set<Patch<*>>>
// Call private PatchLoader.Companion.loadPatches function.
val patches = loadPatchesFunction.call(
patchLoaderCompanionObject,
TEST_PATCHES_CLASS_LOADER,
listOf(TEST_PATCHES_CLASS),
)
assertEquals(
2,
patches.size,
"Expected 2 patches to be loaded, " +
"because there's only two named patches declared as a public static field or method.",
)
}
}

View File

@ -4,7 +4,7 @@ import app.revanced.patcher.fingerprint.methodFingerprint
import kotlin.test.Test
import kotlin.test.assertEquals
object PatchTest {
internal object PatchTest {
@Test
fun `can create patch with name`() {
val patch = bytecodePatch(name = "Test") {}

View File

@ -7,7 +7,7 @@ import kotlin.test.Test
import kotlin.test.assertNull
import kotlin.test.assertTrue
internal class PatchOptionsTest {
internal object PatchOptionsTest {
private val optionsTestPatch = bytecodePatch {
booleanPatchOption("bool", true)

View File

@ -18,23 +18,24 @@ import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class InlineSmaliCompilerTest {
internal object InlineSmaliCompilerTest {
@Test
fun `compiler should output valid instruction`() {
fun `outputs valid instruction`() {
val want = BuilderInstruction21c(Opcode.CONST_STRING, 0, ImmutableStringReference("Test")) as BuilderInstruction
val have = "const-string v0, \"Test\"".toInstruction()
instructionEquals(want, have)
assertInstructionsEqual(want, have)
}
@Test
fun `compiler should support branching with own branches`() {
fun `supports branching with own branches`() {
val method = createMethod()
val insnAmount = 8
val insnIndex = insnAmount - 2
val targetIndex = insnIndex - 1
val instructionCount = 8
val instructionIndex = instructionCount - 2
val targetIndex = instructionIndex - 1
method.addInstructions(
arrayOfNulls<String>(insnAmount).also {
arrayOfNulls<String>(instructionCount).also {
Arrays.fill(it, "const/4 v0, 0x0")
}.joinToString("\n"),
)
@ -47,14 +48,15 @@ internal class InlineSmaliCompilerTest {
""",
)
val insn = method.getInstruction<BuilderInstruction21t>(insnIndex)
assertEquals(targetIndex, insn.target.location.index)
val instruction = method.getInstruction<BuilderInstruction21t>(instructionIndex)
assertEquals(targetIndex, instruction.target.location.index)
}
@Test
fun `compiler should support branching to outside branches`() {
fun `supports branching to outside branches`() {
val method = createMethod()
val insnIndex = 3
val instructionIndex = 3
val labelIndex = 1
method.addInstructions(
@ -76,35 +78,30 @@ internal class InlineSmaliCompilerTest {
ExternalLabel("test", method.getInstruction(1)),
)
val insn = method.getInstruction<BuilderInstruction21t>(insnIndex)
assertTrue(insn.target.isPlaced, "Label was not placed")
assertEquals(labelIndex, insn.target.location.index)
val instruction = method.getInstruction<BuilderInstruction21t>(instructionIndex)
assertTrue(instruction.target.isPlaced, "Label was not placed")
assertEquals(labelIndex, instruction.target.location.index)
}
companion object {
private fun createMethod(
name: String = "dummy",
returnType: String = "V",
accessFlags: Int = AccessFlags.STATIC.value,
registerCount: Int = 1,
) = ImmutableMethod(
"Ldummy;",
name,
emptyList(), // parameters
returnType,
accessFlags,
emptySet(),
emptySet(),
MutableMethodImplementation(registerCount),
).toMutable()
private fun createMethod(
name: String = "dummy",
returnType: String = "V",
accessFlags: Int = AccessFlags.STATIC.value,
registerCount: Int = 1,
) = ImmutableMethod(
"Ldummy;",
name,
emptyList(), // parameters
returnType,
accessFlags,
emptySet(),
emptySet(),
MutableMethodImplementation(registerCount),
).toMutable()
private fun instructionEquals(
want: BuilderInstruction,
have: BuilderInstruction,
) {
assertEquals(want.opcode, have.opcode)
assertEquals(want.format, have.format)
assertEquals(want.codeUnits, have.codeUnits)
}
private fun assertInstructionsEqual(want: BuilderInstruction, have: BuilderInstruction) {
assertEquals(want.opcode, have.opcode)
assertEquals(want.format, have.format)
assertEquals(want.codeUnits, have.codeUnits)
}
}