Handle invalid SafetyNet results

Fix #4253
This commit is contained in:
topjohnwu 2021-04-20 03:39:47 -07:00
parent 1b9d8e068a
commit fb8000b58b
5 changed files with 68 additions and 70 deletions

View File

@ -25,9 +25,7 @@
# Snet # Snet
-keepclassmembers class com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper { *; } -keepclassmembers class com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper { *; }
-keep,allowobfuscation interface com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback -keep,allowobfuscation interface com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback
-keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { -keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { *; }
void onResponse(java.lang.String);
}
# Stub # Stub
-keep class com.topjohnwu.magisk.core.App { <init>(java.lang.Object); } -keep class com.topjohnwu.magisk.core.App { <init>(java.lang.Object); }

View File

@ -29,7 +29,7 @@ object Const {
const val MAGISK_LOG = "/cache/magisk.log" const val MAGISK_LOG = "/cache/magisk.log"
// Versions // Versions
const val SNET_EXT_VER = 16 const val SNET_EXT_VER = 17
const val SNET_REVISION = "22.0" const val SNET_REVISION = "22.0"
const val BOOTCTL_REVISION = "22.0" const val BOOTCTL_REVISION = "22.0"

View File

@ -31,10 +31,9 @@ import java.io.ByteArrayInputStream
import java.io.File import java.io.File
import java.io.IOException import java.io.IOException
import java.lang.reflect.InvocationHandler import java.lang.reflect.InvocationHandler
import java.security.GeneralSecurityException import java.lang.reflect.Proxy
import java.security.SecureRandom import java.security.SecureRandom
import java.security.Signature import java.security.Signature
import java.security.cert.X509Certificate
class CheckSafetyNetEvent( class CheckSafetyNetEvent(
private val callback: (SafetyNetResult) -> Unit = {} private val callback: (SafetyNetResult) -> Unit = {}
@ -42,19 +41,17 @@ class CheckSafetyNetEvent(
private val svc get() = ServiceLocator.networkService private val svc get() = ServiceLocator.networkService
private lateinit var apk: File private lateinit var jar: File
private lateinit var dex: File
private lateinit var nonce: ByteArray private lateinit var nonce: ByteArray
override fun invoke(context: Context) { override fun invoke(context: Context) {
apk = File("${context.filesDir.parent}/snet", "snet.jar") jar = File("${context.filesDir.parent}/snet", "snet.jar")
dex = File(apk.parent, "snet.dex")
scope.launch(Dispatchers.IO) { scope.launch(Dispatchers.IO) {
attest(context) { attest(context) {
// Download and retry // Download and retry
Shell.sh("rm -rf " + apk.parent).exec() Shell.sh("rm -rf " + jar.parent).exec()
apk.parentFile?.mkdir() jar.parentFile?.mkdir()
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
showDialog(context) showDialog(context)
} }
@ -65,25 +62,24 @@ class CheckSafetyNetEvent(
private suspend fun attest(context: Context, onError: suspend (Exception) -> Unit) { private suspend fun attest(context: Context, onError: suspend (Exception) -> Unit) {
val helper: SafetyNetHelper val helper: SafetyNetHelper
try { try {
val loader = createClassLoader(apk) val loader = createClassLoader(jar)
// Scan through the dex and find our helper class // Scan through the dex and find our helper class
var clazz: Class<*>? = null var clazz: Class<*>? = null
loop@for (dex in loader.getDexFiles()) { loop@for (dex in loader.getDexFiles()) {
for (name in dex.entries()) { for (name in dex.entries()) {
if (name.startsWith("x.")) { val cls = loader.loadClass(name)
val cls = loader.loadClass(name) if (InvocationHandler::class.java.isAssignableFrom(cls)) {
if (InvocationHandler::class.java.isAssignableFrom(cls)) { clazz = cls
clazz = cls break@loop
break@loop
}
} }
} }
} }
clazz ?: throw Exception("Cannot find SafetyNetHelper class") clazz ?: throw Exception("Cannot find SafetyNetHelper implementation")
helper = clazz.getMethod("get", Class::class.java, Context::class.java, Any::class.java) helper = Proxy.newProxyInstance(
.invoke(null, SafetyNetHelper::class.java, context, this) as SafetyNetHelper loader, arrayOf(SafetyNetHelper::class.java),
clazz.newInstance() as InvocationHandler) as SafetyNetHelper
if (helper.version != Const.SNET_EXT_VER) if (helper.version != Const.SNET_EXT_VER)
throw Exception("snet extension version mismatch") throw Exception("snet extension version mismatch")
@ -95,7 +91,7 @@ class CheckSafetyNetEvent(
val random = SecureRandom() val random = SecureRandom()
nonce = ByteArray(24) nonce = ByteArray(24)
random.nextBytes(nonce) random.nextBytes(nonce)
helper.attest(nonce) helper.attest(context, nonce, this)
} }
// All of these fields are whitelisted // All of these fields are whitelisted
@ -114,7 +110,7 @@ class CheckSafetyNetEvent(
} }
} }
try { try {
svc.fetchSafetynet().byteStream().writeTo(apk) svc.fetchSafetynet().byteStream().writeTo(jar)
attest(context, abort) attest(context, abort)
} catch (e: IOException) { } catch (e: IOException) {
abort(e) abort(e)
@ -147,7 +143,7 @@ class CheckSafetyNetEvent(
Base64.decode(this, Base64.URL_SAFE) Base64.decode(this, Base64.URL_SAFE)
} }
private fun String.parseJws(): SafetyNetResponse? { private fun String.parseJws(): SafetyNetResponse {
val jws = split('.') val jws = split('.')
val secondDot = lastIndexOf('.') val secondDot = lastIndexOf('.')
val rawHeader = String(jws[0].decode()) val rawHeader = String(jws[0].decode())
@ -156,7 +152,8 @@ class CheckSafetyNetEvent(
val signedBytes = substring(0, secondDot).toByteArray() val signedBytes = substring(0, secondDot).toByteArray()
val moshi = Moshi.Builder().build() val moshi = Moshi.Builder().build()
val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader) ?: return null val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader)
?: error("Invalid JWS header")
val alg = when (header.algorithm) { val alg = when (header.algorithm) {
"RS256" -> "SHA256withRSA" "RS256" -> "SHA256withRSA"
@ -165,41 +162,30 @@ class CheckSafetyNetEvent(
signature = ASN1Primitive.fromByteArray(signature).getEncoded(ASN1Encoding.DER) signature = ASN1Primitive.fromByteArray(signature).getEncoded(ASN1Encoding.DER)
"SHA256withECDSA" "SHA256withECDSA"
} }
else -> return null else -> error("Unsupported algorithm: ${header.algorithm}")
} }
// Verify signature // Verify signature
val certB64 = header.certificates?.first() ?: return null val certB64 = header.certificates?.first() ?: error("Cannot find certificate in JWS")
val certDer = certB64.decode() val bis = ByteArrayInputStream(certB64.decode())
val bis = ByteArrayInputStream(certDer) val cert = CryptoUtils.readCertificate(bis)
val cert: X509Certificate val verifier = Signature.getInstance(alg)
try { verifier.initVerify(cert.publicKey)
cert = CryptoUtils.readCertificate(bis) verifier.update(signedBytes)
val verifier = Signature.getInstance(alg) if (!verifier.verify(signature))
verifier.initVerify(cert.publicKey) error("Signature mismatch")
verifier.update(signedBytes)
if (!verifier.verify(signature))
return null
} catch (e: GeneralSecurityException) {
Timber.e(e)
return null
}
// Verify hostname // Verify hostname
val hostNameVerifier = JsseDefaultHostnameAuthorizer(setOf()) val hostnameVerifier = JsseDefaultHostnameAuthorizer(setOf())
try { if (!hostnameVerifier.verify("attest.android.com", cert))
if (!hostNameVerifier.verify("attest.android.com", cert)) error("Hostname mismatch")
return null
} catch (e: IOException) {
Timber.e(e)
return null
}
val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload) ?: return null val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload)
?: error("Invalid SafetyNet response")
// Verify results // Verify results
if (!response.nonce.decode().contentEquals(nonce)) if (!response.nonce.decode().contentEquals(nonce))
return null error("nonce mismatch")
return response return response
} }
@ -207,7 +193,10 @@ class CheckSafetyNetEvent(
override fun onResponse(response: String?) { override fun onResponse(response: String?) {
if (response != null) { if (response != null) {
scope.launch(Dispatchers.Default) { scope.launch(Dispatchers.Default) {
val res = response.parseJws() val res = runCatching { response.parseJws() }.getOrElse {
Timber.e(it)
INVALID_RESPONSE
}
withContext(Dispatchers.Main) { withContext(Dispatchers.Main) {
callback(SafetyNetResult(res)) callback(SafetyNetResult(res))
} }
@ -231,3 +220,6 @@ data class SafetyNetResponse(
val basicIntegrity: Boolean, val basicIntegrity: Boolean,
val evaluationType: String = "" val evaluationType: String = ""
) )
// Special instance to indicate invalid SafetyNet response
val INVALID_RESPONSE = SafetyNetResponse("", ctsProfileMatch = false, basicIntegrity = false)

View File

@ -1,10 +1,12 @@
package com.topjohnwu.magisk.ui.safetynet package com.topjohnwu.magisk.ui.safetynet
import android.content.Context
interface SafetyNetHelper { interface SafetyNetHelper {
val version: Int val version: Int
fun attest(nonce: ByteArray) fun attest(context: Context, nonce: ByteArray, callback: Callback)
interface Callback { interface Callback {
fun onResponse(response: String?) fun onResponse(response: String?)

View File

@ -6,7 +6,7 @@ import com.topjohnwu.magisk.R
import com.topjohnwu.magisk.arch.BaseViewModel import com.topjohnwu.magisk.arch.BaseViewModel
import com.topjohnwu.magisk.utils.set import com.topjohnwu.magisk.utils.set
data class SafetyNetResult( class SafetyNetResult(
val response: SafetyNetResponse? = null, val response: SafetyNetResponse? = null,
val dismiss: Boolean = false val dismiss: Boolean = false
) )
@ -42,37 +42,43 @@ class SafetynetViewModel : BaseViewModel() {
init { init {
cachedResult?.also { cachedResult?.also {
handleResponse(SafetyNetResult(it)) handleResult(SafetyNetResult(it))
} ?: attest() } ?: attest()
} }
private fun attest() { private fun attest() {
isChecking = true isChecking = true
CheckSafetyNetEvent { CheckSafetyNetEvent(::handleResult).publish()
handleResponse(it)
}.publish()
} }
fun reset() = attest() fun reset() = attest()
private fun handleResponse(response: SafetyNetResult) { private fun handleResult(result: SafetyNetResult) {
isChecking = false isChecking = false
if (response.dismiss) { if (result.dismiss) {
back() back()
return return
} }
response.response?.apply { result.response?.apply {
val result = ctsProfileMatch && basicIntegrity
cachedResult = this cachedResult = this
ctsState = ctsProfileMatch if (this === INVALID_RESPONSE) {
basicIntegrityState = basicIntegrity isSuccess = false
evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC" ctsState = false
isSuccess = result basicIntegrityState = false
safetyNetTitle = evalType = "N/A"
if (result) R.string.safetynet_attest_success safetyNetTitle = R.string.safetynet_res_invalid
else R.string.safetynet_attest_failure } else {
val success = ctsProfileMatch && basicIntegrity
isSuccess = success
ctsState = ctsProfileMatch
basicIntegrityState = basicIntegrity
evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC"
safetyNetTitle =
if (success) R.string.safetynet_attest_success
else R.string.safetynet_attest_failure
}
} ?: { } ?: {
isSuccess = false isSuccess = false
ctsState = false ctsState = false