Handle invalid SafetyNet results

Fix #4253
This commit is contained in:
topjohnwu 2021-04-20 03:39:47 -07:00
parent 1b9d8e068a
commit fb8000b58b
5 changed files with 68 additions and 70 deletions

View File

@ -25,9 +25,7 @@
# Snet
-keepclassmembers class com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper { *; }
-keep,allowobfuscation interface com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback
-keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback {
void onResponse(java.lang.String);
}
-keepclassmembers class * implements com.topjohnwu.magisk.ui.safetynet.SafetyNetHelper$Callback { *; }
# Stub
-keep class com.topjohnwu.magisk.core.App { <init>(java.lang.Object); }

View File

@ -29,7 +29,7 @@ object Const {
const val MAGISK_LOG = "/cache/magisk.log"
// Versions
const val SNET_EXT_VER = 16
const val SNET_EXT_VER = 17
const val SNET_REVISION = "22.0"
const val BOOTCTL_REVISION = "22.0"

View File

@ -31,10 +31,9 @@ import java.io.ByteArrayInputStream
import java.io.File
import java.io.IOException
import java.lang.reflect.InvocationHandler
import java.security.GeneralSecurityException
import java.lang.reflect.Proxy
import java.security.SecureRandom
import java.security.Signature
import java.security.cert.X509Certificate
class CheckSafetyNetEvent(
private val callback: (SafetyNetResult) -> Unit = {}
@ -42,19 +41,17 @@ class CheckSafetyNetEvent(
private val svc get() = ServiceLocator.networkService
private lateinit var apk: File
private lateinit var dex: File
private lateinit var jar: File
private lateinit var nonce: ByteArray
override fun invoke(context: Context) {
apk = File("${context.filesDir.parent}/snet", "snet.jar")
dex = File(apk.parent, "snet.dex")
jar = File("${context.filesDir.parent}/snet", "snet.jar")
scope.launch(Dispatchers.IO) {
attest(context) {
// Download and retry
Shell.sh("rm -rf " + apk.parent).exec()
apk.parentFile?.mkdir()
Shell.sh("rm -rf " + jar.parent).exec()
jar.parentFile?.mkdir()
withContext(Dispatchers.Main) {
showDialog(context)
}
@ -65,25 +62,24 @@ class CheckSafetyNetEvent(
private suspend fun attest(context: Context, onError: suspend (Exception) -> Unit) {
val helper: SafetyNetHelper
try {
val loader = createClassLoader(apk)
val loader = createClassLoader(jar)
// Scan through the dex and find our helper class
var clazz: Class<*>? = null
loop@for (dex in loader.getDexFiles()) {
for (name in dex.entries()) {
if (name.startsWith("x.")) {
val cls = loader.loadClass(name)
if (InvocationHandler::class.java.isAssignableFrom(cls)) {
clazz = cls
break@loop
}
val cls = loader.loadClass(name)
if (InvocationHandler::class.java.isAssignableFrom(cls)) {
clazz = cls
break@loop
}
}
}
clazz ?: throw Exception("Cannot find SafetyNetHelper class")
clazz ?: throw Exception("Cannot find SafetyNetHelper implementation")
helper = clazz.getMethod("get", Class::class.java, Context::class.java, Any::class.java)
.invoke(null, SafetyNetHelper::class.java, context, this) as SafetyNetHelper
helper = Proxy.newProxyInstance(
loader, arrayOf(SafetyNetHelper::class.java),
clazz.newInstance() as InvocationHandler) as SafetyNetHelper
if (helper.version != Const.SNET_EXT_VER)
throw Exception("snet extension version mismatch")
@ -95,7 +91,7 @@ class CheckSafetyNetEvent(
val random = SecureRandom()
nonce = ByteArray(24)
random.nextBytes(nonce)
helper.attest(nonce)
helper.attest(context, nonce, this)
}
// All of these fields are whitelisted
@ -114,7 +110,7 @@ class CheckSafetyNetEvent(
}
}
try {
svc.fetchSafetynet().byteStream().writeTo(apk)
svc.fetchSafetynet().byteStream().writeTo(jar)
attest(context, abort)
} catch (e: IOException) {
abort(e)
@ -147,7 +143,7 @@ class CheckSafetyNetEvent(
Base64.decode(this, Base64.URL_SAFE)
}
private fun String.parseJws(): SafetyNetResponse? {
private fun String.parseJws(): SafetyNetResponse {
val jws = split('.')
val secondDot = lastIndexOf('.')
val rawHeader = String(jws[0].decode())
@ -156,7 +152,8 @@ class CheckSafetyNetEvent(
val signedBytes = substring(0, secondDot).toByteArray()
val moshi = Moshi.Builder().build()
val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader) ?: return null
val header = moshi.adapter(JwsHeader::class.java).fromJson(rawHeader)
?: error("Invalid JWS header")
val alg = when (header.algorithm) {
"RS256" -> "SHA256withRSA"
@ -165,41 +162,30 @@ class CheckSafetyNetEvent(
signature = ASN1Primitive.fromByteArray(signature).getEncoded(ASN1Encoding.DER)
"SHA256withECDSA"
}
else -> return null
else -> error("Unsupported algorithm: ${header.algorithm}")
}
// Verify signature
val certB64 = header.certificates?.first() ?: return null
val certDer = certB64.decode()
val bis = ByteArrayInputStream(certDer)
val cert: X509Certificate
try {
cert = CryptoUtils.readCertificate(bis)
val verifier = Signature.getInstance(alg)
verifier.initVerify(cert.publicKey)
verifier.update(signedBytes)
if (!verifier.verify(signature))
return null
} catch (e: GeneralSecurityException) {
Timber.e(e)
return null
}
val certB64 = header.certificates?.first() ?: error("Cannot find certificate in JWS")
val bis = ByteArrayInputStream(certB64.decode())
val cert = CryptoUtils.readCertificate(bis)
val verifier = Signature.getInstance(alg)
verifier.initVerify(cert.publicKey)
verifier.update(signedBytes)
if (!verifier.verify(signature))
error("Signature mismatch")
// Verify hostname
val hostNameVerifier = JsseDefaultHostnameAuthorizer(setOf())
try {
if (!hostNameVerifier.verify("attest.android.com", cert))
return null
} catch (e: IOException) {
Timber.e(e)
return null
}
val hostnameVerifier = JsseDefaultHostnameAuthorizer(setOf())
if (!hostnameVerifier.verify("attest.android.com", cert))
error("Hostname mismatch")
val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload) ?: return null
val response = moshi.adapter(SafetyNetResponse::class.java).fromJson(payload)
?: error("Invalid SafetyNet response")
// Verify results
if (!response.nonce.decode().contentEquals(nonce))
return null
error("nonce mismatch")
return response
}
@ -207,7 +193,10 @@ class CheckSafetyNetEvent(
override fun onResponse(response: String?) {
if (response != null) {
scope.launch(Dispatchers.Default) {
val res = response.parseJws()
val res = runCatching { response.parseJws() }.getOrElse {
Timber.e(it)
INVALID_RESPONSE
}
withContext(Dispatchers.Main) {
callback(SafetyNetResult(res))
}
@ -231,3 +220,6 @@ data class SafetyNetResponse(
val basicIntegrity: Boolean,
val evaluationType: String = ""
)
// Special instance to indicate invalid SafetyNet response
val INVALID_RESPONSE = SafetyNetResponse("", ctsProfileMatch = false, basicIntegrity = false)

View File

@ -1,10 +1,12 @@
package com.topjohnwu.magisk.ui.safetynet
import android.content.Context
interface SafetyNetHelper {
val version: Int
fun attest(nonce: ByteArray)
fun attest(context: Context, nonce: ByteArray, callback: Callback)
interface Callback {
fun onResponse(response: String?)

View File

@ -6,7 +6,7 @@ import com.topjohnwu.magisk.R
import com.topjohnwu.magisk.arch.BaseViewModel
import com.topjohnwu.magisk.utils.set
data class SafetyNetResult(
class SafetyNetResult(
val response: SafetyNetResponse? = null,
val dismiss: Boolean = false
)
@ -42,37 +42,43 @@ class SafetynetViewModel : BaseViewModel() {
init {
cachedResult?.also {
handleResponse(SafetyNetResult(it))
handleResult(SafetyNetResult(it))
} ?: attest()
}
private fun attest() {
isChecking = true
CheckSafetyNetEvent {
handleResponse(it)
}.publish()
CheckSafetyNetEvent(::handleResult).publish()
}
fun reset() = attest()
private fun handleResponse(response: SafetyNetResult) {
private fun handleResult(result: SafetyNetResult) {
isChecking = false
if (response.dismiss) {
if (result.dismiss) {
back()
return
}
response.response?.apply {
val result = ctsProfileMatch && basicIntegrity
result.response?.apply {
cachedResult = this
ctsState = ctsProfileMatch
basicIntegrityState = basicIntegrity
evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC"
isSuccess = result
safetyNetTitle =
if (result) R.string.safetynet_attest_success
else R.string.safetynet_attest_failure
if (this === INVALID_RESPONSE) {
isSuccess = false
ctsState = false
basicIntegrityState = false
evalType = "N/A"
safetyNetTitle = R.string.safetynet_res_invalid
} else {
val success = ctsProfileMatch && basicIntegrity
isSuccess = success
ctsState = ctsProfileMatch
basicIntegrityState = basicIntegrity
evalType = if (evaluationType.contains("HARDWARE")) "HARDWARE" else "BASIC"
safetyNetTitle =
if (success) R.string.safetynet_attest_success
else R.string.safetynet_attest_failure
}
} ?: {
isSuccess = false
ctsState = false