fix: Load functional patches correctly
This commit is contained in:
parent
c5f02e8c28
commit
4088609e2a
|
@ -1,4 +1,4 @@
|
|||
@file:Suppress("MemberVisibilityCanBePrivate")
|
||||
@file:Suppress("MemberVisibilityCanBePrivate", "unused")
|
||||
|
||||
package app.revanced.patcher.patch
|
||||
|
||||
|
@ -565,7 +565,7 @@ class PatchResult internal constructor(val patch: Patch<*>, val exception: Patch
|
|||
|
||||
/**
|
||||
* A loader for [Patch].
|
||||
* Loads patches from JAR or DEX files declared as public fields or properties.
|
||||
* Loads patches from JAR or DEX files declared as public static fields or returned by public static methods.
|
||||
* Patches with no name are not loaded.
|
||||
*
|
||||
* @param patchesFiles A set of JAR or DEX files to load the patches from.
|
||||
|
@ -576,30 +576,7 @@ sealed class PatchLoader private constructor(
|
|||
patchesFiles: Set<File>,
|
||||
private val getBinaryClassNames: (patchesFile: File) -> List<String>,
|
||||
private val classLoader: ClassLoader,
|
||||
) : PatchSet by mutableSetOf() {
|
||||
init {
|
||||
@Suppress("UNCHECKED_CAST", "LeakingThis")
|
||||
val thisSet = this as MutableSet<Patch<*>>
|
||||
|
||||
patchesFiles.asSequence().flatMap(getBinaryClassNames).map {
|
||||
classLoader.loadClass(it)
|
||||
}.flatMap {
|
||||
// Get all patches from the class declared as public fields.
|
||||
val patchFields = it.fields.filter { field ->
|
||||
Patch::class.java.isAssignableFrom(field.type)
|
||||
}.map { field -> field.get(null) as Patch<*> }
|
||||
|
||||
// Get all patches from the class declared as public methods.
|
||||
val patchMethods = it.methods.filter { method ->
|
||||
Patch::class.java.isAssignableFrom(method.returnType)
|
||||
}.map { method -> method.invoke(null) as Patch<*> }
|
||||
|
||||
patchFields + patchMethods
|
||||
}.filter {
|
||||
it.name != null
|
||||
}.toList().let(thisSet::addAll)
|
||||
}
|
||||
|
||||
) : PatchSet by classLoader.loadPatches(patchesFiles.flatMap(getBinaryClassNames)) {
|
||||
/**
|
||||
* A [PatchLoader] for JAR files.
|
||||
*
|
||||
|
@ -640,10 +617,43 @@ sealed class PatchLoader private constructor(
|
|||
PatchLoader::class.java.classLoader,
|
||||
),
|
||||
)
|
||||
|
||||
// Companion object required for unit tests.
|
||||
private companion object {
|
||||
/**
|
||||
* Loads named patches declared as public static fields or returned by public static methods from classes.
|
||||
*
|
||||
* @param binaryClassNames The binary class name of the classes to load the patches from.
|
||||
*
|
||||
* @return The loaded patches.
|
||||
*/
|
||||
private fun ClassLoader.loadPatches(binaryClassNames: List<String>) = binaryClassNames.asSequence().map {
|
||||
loadClass(it)
|
||||
}.flatMap {
|
||||
val isPatch = { cls: Class<*> -> Patch::class.java.isAssignableFrom(cls) }
|
||||
|
||||
val patchesFromFields = it.fields.filter { field ->
|
||||
isPatch(field.type) && field.canAccess(null)
|
||||
}.map { field ->
|
||||
field.get(null) as Patch<*>
|
||||
}
|
||||
|
||||
val patchesFromMethods = it.methods.filter { method ->
|
||||
isPatch(method.returnType) && method.canAccess(null)
|
||||
}.map { method ->
|
||||
method.invoke(null) as Patch<*>
|
||||
}
|
||||
|
||||
patchesFromFields + patchesFromMethods
|
||||
}.filter {
|
||||
it.name != null
|
||||
}.toSet()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads patches from JAR files.
|
||||
* Loads patches from JAR files declared as public static fields or returned by public static methods.
|
||||
* Patches with no name are not loaded.
|
||||
*
|
||||
* @param patchesFiles The JAR files to load the patches from.
|
||||
*
|
||||
|
@ -653,7 +663,8 @@ fun loadPatchesFromJar(patchesFiles: Set<File>): PatchSet =
|
|||
PatchLoader.Jar(patchesFiles)
|
||||
|
||||
/**
|
||||
* Loads patches from DEX files.
|
||||
* Loads patches from DEX files declared as public static fields or returned by public static methods.
|
||||
* Patches with no name are not loaded.
|
||||
*
|
||||
* @param patchesFiles The DEX files to load the patches from.
|
||||
*
|
||||
|
|
|
@ -20,7 +20,7 @@ import kotlin.test.assertEquals
|
|||
import kotlin.test.assertNull
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
object PatcherTest {
|
||||
internal object PatcherTest {
|
||||
private lateinit var patcher: Patcher
|
||||
|
||||
@BeforeEach
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
@file:Suppress("unused")
|
||||
|
||||
package app.revanced.patcher.patch
|
||||
|
||||
import org.junit.jupiter.api.Test
|
||||
import kotlin.reflect.KFunction
|
||||
import kotlin.reflect.full.*
|
||||
import kotlin.reflect.jvm.javaField
|
||||
import kotlin.reflect.jvm.javaMethod
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
// region Test patches.
|
||||
|
||||
val publicUnnamedPatch = bytecodePatch {
|
||||
}
|
||||
val publicPatch = bytecodePatch("Public") {
|
||||
}
|
||||
private val privateUnnamedPatch = bytecodePatch {
|
||||
}
|
||||
private val privatePatch = bytecodePatch("Private") {
|
||||
}
|
||||
|
||||
fun publicUnnamedPatchFunction() = publicUnnamedPatch
|
||||
fun publicNamedPatchFunction() = bytecodePatch("Public") { }
|
||||
private fun privateUnnamedPatchFunction() = privateUnnamedPatch
|
||||
private fun privateNamedPatchFunction() = privatePatch
|
||||
|
||||
// endregion
|
||||
|
||||
internal object PatchLoaderTest {
|
||||
private const val LOAD_PATCHES_FUNCTION_NAME = "loadPatches"
|
||||
private val TEST_PATCHES_CLASS = ::publicPatch.javaField!!.declaringClass.name
|
||||
private val TEST_PATCHES_CLASS_LOADER = ::publicPatch.javaClass.classLoader
|
||||
|
||||
@Test
|
||||
fun `loads patches correctly`() {
|
||||
// Get instance of private PatchLoader.Companion class.
|
||||
val patchLoaderCompanionObject = PatchLoader::class.java.declaredFields.first {
|
||||
it.type == PatchLoader::class.companionObject!!.javaObjectType
|
||||
}.apply {
|
||||
isAccessible = true
|
||||
}.get(null)
|
||||
|
||||
// Get private PatchLoader.Companion.loadPatches function from PatchLoader.Companion.
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val loadPatchesFunction = patchLoaderCompanionObject::class.declaredFunctions.first {
|
||||
it.name == LOAD_PATCHES_FUNCTION_NAME
|
||||
}.apply {
|
||||
javaMethod!!.isAccessible = true
|
||||
} as KFunction<Set<Patch<*>>>
|
||||
|
||||
// Call private PatchLoader.Companion.loadPatches function.
|
||||
val patches = loadPatchesFunction.call(
|
||||
patchLoaderCompanionObject,
|
||||
TEST_PATCHES_CLASS_LOADER,
|
||||
listOf(TEST_PATCHES_CLASS),
|
||||
)
|
||||
|
||||
assertEquals(
|
||||
2,
|
||||
patches.size,
|
||||
"Expected 2 patches to be loaded, " +
|
||||
"because there's only two named patches declared as a public static field or method.",
|
||||
)
|
||||
}
|
||||
}
|
|
@ -4,7 +4,7 @@ import app.revanced.patcher.fingerprint.methodFingerprint
|
|||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
object PatchTest {
|
||||
internal object PatchTest {
|
||||
@Test
|
||||
fun `can create patch with name`() {
|
||||
val patch = bytecodePatch(name = "Test") {}
|
||||
|
|
|
@ -7,7 +7,7 @@ import kotlin.test.Test
|
|||
import kotlin.test.assertNull
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class PatchOptionsTest {
|
||||
internal object PatchOptionsTest {
|
||||
private val optionsTestPatch = bytecodePatch {
|
||||
booleanPatchOption("bool", true)
|
||||
|
||||
|
|
|
@ -18,23 +18,24 @@ import kotlin.test.Test
|
|||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class InlineSmaliCompilerTest {
|
||||
internal object InlineSmaliCompilerTest {
|
||||
@Test
|
||||
fun `compiler should output valid instruction`() {
|
||||
fun `outputs valid instruction`() {
|
||||
val want = BuilderInstruction21c(Opcode.CONST_STRING, 0, ImmutableStringReference("Test")) as BuilderInstruction
|
||||
val have = "const-string v0, \"Test\"".toInstruction()
|
||||
instructionEquals(want, have)
|
||||
|
||||
assertInstructionsEqual(want, have)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compiler should support branching with own branches`() {
|
||||
fun `supports branching with own branches`() {
|
||||
val method = createMethod()
|
||||
val insnAmount = 8
|
||||
val insnIndex = insnAmount - 2
|
||||
val targetIndex = insnIndex - 1
|
||||
val instructionCount = 8
|
||||
val instructionIndex = instructionCount - 2
|
||||
val targetIndex = instructionIndex - 1
|
||||
|
||||
method.addInstructions(
|
||||
arrayOfNulls<String>(insnAmount).also {
|
||||
arrayOfNulls<String>(instructionCount).also {
|
||||
Arrays.fill(it, "const/4 v0, 0x0")
|
||||
}.joinToString("\n"),
|
||||
)
|
||||
|
@ -47,14 +48,15 @@ internal class InlineSmaliCompilerTest {
|
|||
""",
|
||||
)
|
||||
|
||||
val insn = method.getInstruction<BuilderInstruction21t>(insnIndex)
|
||||
assertEquals(targetIndex, insn.target.location.index)
|
||||
val instruction = method.getInstruction<BuilderInstruction21t>(instructionIndex)
|
||||
|
||||
assertEquals(targetIndex, instruction.target.location.index)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compiler should support branching to outside branches`() {
|
||||
fun `supports branching to outside branches`() {
|
||||
val method = createMethod()
|
||||
val insnIndex = 3
|
||||
val instructionIndex = 3
|
||||
val labelIndex = 1
|
||||
|
||||
method.addInstructions(
|
||||
|
@ -76,35 +78,30 @@ internal class InlineSmaliCompilerTest {
|
|||
ExternalLabel("test", method.getInstruction(1)),
|
||||
)
|
||||
|
||||
val insn = method.getInstruction<BuilderInstruction21t>(insnIndex)
|
||||
assertTrue(insn.target.isPlaced, "Label was not placed")
|
||||
assertEquals(labelIndex, insn.target.location.index)
|
||||
val instruction = method.getInstruction<BuilderInstruction21t>(instructionIndex)
|
||||
assertTrue(instruction.target.isPlaced, "Label was not placed")
|
||||
assertEquals(labelIndex, instruction.target.location.index)
|
||||
}
|
||||
|
||||
companion object {
|
||||
private fun createMethod(
|
||||
name: String = "dummy",
|
||||
returnType: String = "V",
|
||||
accessFlags: Int = AccessFlags.STATIC.value,
|
||||
registerCount: Int = 1,
|
||||
) = ImmutableMethod(
|
||||
"Ldummy;",
|
||||
name,
|
||||
emptyList(), // parameters
|
||||
returnType,
|
||||
accessFlags,
|
||||
emptySet(),
|
||||
emptySet(),
|
||||
MutableMethodImplementation(registerCount),
|
||||
).toMutable()
|
||||
private fun createMethod(
|
||||
name: String = "dummy",
|
||||
returnType: String = "V",
|
||||
accessFlags: Int = AccessFlags.STATIC.value,
|
||||
registerCount: Int = 1,
|
||||
) = ImmutableMethod(
|
||||
"Ldummy;",
|
||||
name,
|
||||
emptyList(), // parameters
|
||||
returnType,
|
||||
accessFlags,
|
||||
emptySet(),
|
||||
emptySet(),
|
||||
MutableMethodImplementation(registerCount),
|
||||
).toMutable()
|
||||
|
||||
private fun instructionEquals(
|
||||
want: BuilderInstruction,
|
||||
have: BuilderInstruction,
|
||||
) {
|
||||
assertEquals(want.opcode, have.opcode)
|
||||
assertEquals(want.format, have.format)
|
||||
assertEquals(want.codeUnits, have.codeUnits)
|
||||
}
|
||||
private fun assertInstructionsEqual(want: BuilderInstruction, have: BuilderInstruction) {
|
||||
assertEquals(want.opcode, have.opcode)
|
||||
assertEquals(want.format, have.format)
|
||||
assertEquals(want.codeUnits, have.codeUnits)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue