diff --git a/app/proguard-rules.pro b/app/proguard-rules.pro index e7e963c9d..48bd18808 100644 --- a/app/proguard-rules.pro +++ b/app/proguard-rules.pro @@ -25,9 +25,7 @@ # Snet -keepclassmembers class com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper { *; } -keep,allowobfuscation interface com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback --keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { - void onResponse(java.lang.String); -} +-keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { *; } # Stub -keep class com.topjohnwu.magisk.core.App { (java.lang.Object); } diff --git a/app/src/main/java/com/topjohnwu/magisk/core/Const.kt b/app/src/main/java/com/topjohnwu/magisk/core/Const.kt index 97d8b484f..fd2fbd142 100644 --- a/app/src/main/java/com/topjohnwu/magisk/core/Const.kt +++ b/app/src/main/java/com/topjohnwu/magisk/core/Const.kt @@ -29,7 +29,7 @@ object Const { const val MAGISK_LOG = "/cache/magisk.log" // Versions - const val SNET_EXT_VER = 16 + const val SNET_EXT_VER = 17 const val SNET_REVISION = "22.0" const val BOOTCTL_REVISION = "22.0" diff --git a/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/CheckSafetyNetEvent.kt b/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/CheckSafetyNetEvent.kt index d3a977842..969bece1e 100644 --- a/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/CheckSafetyNetEvent.kt +++ b/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/CheckSafetyNetEvent.kt @@ -31,10 +31,9 @@ import java.io.ByteArrayInputStream import java.io.File import java.io.IOException import java.lang.reflect.InvocationHandler -import java.security.GeneralSecurityException +import java.lang.reflect.Proxy import java.security.SecureRandom import java.security.Signature -import java.security.cert.X509Certificate class CheckSafetyNetEvent( private val callback: (SafetyNetResult) -> Unit = {} @@ -42,19 +41,17 @@ class CheckSafetyNetEvent( private val svc get() = ServiceLocator.networkService - private lateinit var apk: File - private lateinit var dex: File + private lateinit var jar: File private lateinit var nonce: ByteArray override fun invoke(context: Context) { - apk = File("${context.filesDir.parent}/snet", "snet.jar") - dex = File(apk.parent, "snet.dex") + jar = File("${context.filesDir.parent}/snet", "snet.jar") scope.launch(Dispatchers.IO) { attest(context) { // Download and retry - Shell.sh("rm -rf " + apk.parent).exec() - apk.parentFile?.mkdir() + Shell.sh("rm -rf " + jar.parent).exec() + jar.parentFile?.mkdir() withContext(Dispatchers.Main) { showDialog(context) } @@ -65,25 +62,24 @@ class CheckSafetyNetEvent( private suspend fun attest(context: Context, onError: suspend (Exception) -> Unit) { val helper: SafetyNetHelper try { - val loader = createClassLoader(apk) + val loader = createClassLoader(jar) // Scan through the dex and find our helper class var clazz: Class<*>? = null loop@for (dex in loader.getDexFiles()) { for (name in dex.entries()) { - if (name.startsWith("x.")) { - val cls = loader.loadClass(name) - if (InvocationHandler::class.java.isAssignableFrom(cls)) { - clazz = cls - break@loop - } + val cls = loader.loadClass(name) + if (InvocationHandler::class.java.isAssignableFrom(cls)) { + clazz = cls + break@loop } } } - clazz ?: throw Exception("Cannot find SafetyNetHelper class") + clazz ?: throw Exception("Cannot find SafetyNetHelper implementation") - helper = clazz.getMethod("get", Class::class.java, Context::class.java, Any::class.java) - .invoke(null, SafetyNetHelper::class.java, context, this) as SafetyNetHelper + helper = Proxy.newProxyInstance( + loader, arrayOf(SafetyNetHelper::class.java), + clazz.newInstance() as InvocationHandler) as SafetyNetHelper if (helper.version != Const.SNET_EXT_VER) throw Exception("snet extension version mismatch") @@ -95,7 +91,7 @@ class CheckSafetyNetEvent( val random = SecureRandom() nonce = ByteArray(24) random.nextBytes(nonce) - helper.attest(nonce) + helper.attest(context, nonce, this) } // All of these fields are whitelisted @@ -114,7 +110,7 @@ class CheckSafetyNetEvent( } } try { - svc.fetchSafetynet().byteStream().writeTo(apk) + svc.fetchSafetynet().byteStream().writeTo(jar) attest(context, abort) } catch (e: IOException) { abort(e) @@ -147,7 +143,7 @@ class CheckSafetyNetEvent( Base64.decode(this, Base64.URL_SAFE) } - private fun String.parseJws(): SafetyNetResponse? { + private fun String.parseJws(): SafetyNetResponse { val jws = split('.') val secondDot = lastIndexOf('.') val rawHeader = String(jws[0].decode()) @@ -156,7 +152,8 @@ class CheckSafetyNetEvent( val signedBytes = substring(0, secondDot).toByteArray() val moshi = Moshi.Builder().build() - val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader) ?: return null + val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader) + ?: error("Invalid JWS header") val alg = when (header.algorithm) { "RS256" -> "SHA256withRSA" @@ -165,41 +162,30 @@ class CheckSafetyNetEvent( signature = ASN1Primitive.fromByteArray(signature).getEncoded(ASN1Encoding.DER) "SHA256withECDSA" } - else -> return null + else -> error("Unsupported algorithm: ${header.algorithm}") } // Verify signature - val certB64 = header.certificates?.first() ?: return null - val certDer = certB64.decode() - val bis = ByteArrayInputStream(certDer) - val cert: X509Certificate - try { - cert = CryptoUtils.readCertificate(bis) - val verifier = Signature.getInstance(alg) - verifier.initVerify(cert.publicKey) - verifier.update(signedBytes) - if (!verifier.verify(signature)) - return null - } catch (e: GeneralSecurityException) { - Timber.e(e) - return null - } + val certB64 = header.certificates?.first() ?: error("Cannot find certificate in JWS") + val bis = ByteArrayInputStream(certB64.decode()) + val cert = CryptoUtils.readCertificate(bis) + val verifier = Signature.getInstance(alg) + verifier.initVerify(cert.publicKey) + verifier.update(signedBytes) + if (!verifier.verify(signature)) + error("Signature mismatch") // Verify hostname - val hostNameVerifier = JsseDefaultHostnameAuthorizer(setOf()) - try { - if (!hostNameVerifier.verify("attest.android.com", cert)) - return null - } catch (e: IOException) { - Timber.e(e) - return null - } + val hostnameVerifier = JsseDefaultHostnameAuthorizer(setOf()) + if (!hostnameVerifier.verify("attest.android.com", cert)) + error("Hostname mismatch") - val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload) ?: return null + val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload) + ?: error("Invalid SafetyNet response") // Verify results if (!response.nonce.decode().contentEquals(nonce)) - return null + error("nonce mismatch") return response } @@ -207,7 +193,10 @@ class CheckSafetyNetEvent( override fun onResponse(response: String?) { if (response != null) { scope.launch(Dispatchers.Default) { - val res = response.parseJws() + val res = runCatching { response.parseJws() }.getOrElse { + Timber.e(it) + INVALID_RESPONSE + } withContext(Dispatchers.Main) { callback(SafetyNetResult(res)) } @@ -231,3 +220,6 @@ data class SafetyNetResponse( val basicIntegrity: Boolean, val evaluationType: String = "" ) + +// Special instance to indicate invalid SafetyNet response +val INVALID_RESPONSE = SafetyNetResponse("", ctsProfileMatch = false, basicIntegrity = false) diff --git a/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetyNetHelper.kt b/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetyNetHelper.kt index 013f45dd5..92ef7c4ec 100644 --- a/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetyNetHelper.kt +++ b/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetyNetHelper.kt @@ -1,10 +1,12 @@ package com.topjohnwu.magisk.ui.safetynet +import android.content.Context + interface SafetyNetHelper { val version: Int - fun attest(nonce: ByteArray) + fun attest(context: Context, nonce: ByteArray, callback: Callback) interface Callback { fun onResponse(response: String?) diff --git a/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetynetViewModel.kt b/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetynetViewModel.kt index 59131a661..1a0b5dcf0 100644 --- a/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetynetViewModel.kt +++ b/app/src/main/java/com/topjohnwu/magisk/ui/safetynet/SafetynetViewModel.kt @@ -6,7 +6,7 @@ import com.topjohnwu.magisk.R import com.topjohnwu.magisk.arch.BaseViewModel import com.topjohnwu.magisk.utils.set -data class SafetyNetResult( +class SafetyNetResult( val response: SafetyNetResponse? = null, val dismiss: Boolean = false ) @@ -42,37 +42,43 @@ class SafetynetViewModel : BaseViewModel() { init { cachedResult?.also { - handleResponse(SafetyNetResult(it)) + handleResult(SafetyNetResult(it)) } ?: attest() } private fun attest() { isChecking = true - CheckSafetyNetEvent { - handleResponse(it) - }.publish() + CheckSafetyNetEvent(::handleResult).publish() } fun reset() = attest() - private fun handleResponse(response: SafetyNetResult) { + private fun handleResult(result: SafetyNetResult) { isChecking = false - if (response.dismiss) { + if (result.dismiss) { back() return } - response.response?.apply { - val result = ctsProfileMatch && basicIntegrity + result.response?.apply { cachedResult = this - ctsState = ctsProfileMatch - basicIntegrityState = basicIntegrity - evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC" - isSuccess = result - safetyNetTitle = - if (result) R.string.safetynet_attest_success - else R.string.safetynet_attest_failure + if (this === INVALID_RESPONSE) { + isSuccess = false + ctsState = false + basicIntegrityState = false + evalType = "N/A" + safetyNetTitle = R.string.safetynet_res_invalid + } else { + val success = ctsProfileMatch && basicIntegrity + isSuccess = success + ctsState = ctsProfileMatch + basicIntegrityState = basicIntegrity + evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC" + safetyNetTitle = + if (success) R.string.safetynet_attest_success + else R.string.safetynet_attest_failure + } } ?: { isSuccess = false ctsState = false