Cleanup code for su request

This commit is contained in:
topjohnwu 2020-06-17 03:47:12 -07:00
parent 56602cb9a3
commit e7f1c03151
4 changed files with 76 additions and 109 deletions

View File

@ -20,8 +20,8 @@ import com.topjohnwu.magisk.extensions.get
import com.topjohnwu.magisk.extensions.startActivity
import com.topjohnwu.magisk.extensions.startActivityWithRoot
import com.topjohnwu.magisk.extensions.subscribeK
import com.topjohnwu.magisk.ui.surequest.SuRequestActivity
import com.topjohnwu.magisk.model.entity.toLog
import com.topjohnwu.magisk.ui.surequest.SuRequestActivity
import com.topjohnwu.superuser.Shell
import timber.log.Timber
@ -51,20 +51,8 @@ object SuCallbackHandler : ProviderCallHandler {
}
when (action) {
REQUEST -> {
val intent = context.intent<SuRequestActivity>()
.setAction(action)
.putExtras(data)
.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
.addFlags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
if (Build.VERSION.SDK_INT >= 29) {
// Android Q does not allow starting activity from background
intent.startActivityWithRoot()
} else {
intent.startActivity(context)
}
}
LOG -> handleLogs(context, data)
REQUEST -> handleRequest(context, data)
LOG -> handleLogging(context, data)
NOTIFY -> handleNotify(context, data)
TEST -> {
val mode = data.getInt("mode", 2)
@ -78,13 +66,26 @@ object SuCallbackHandler : ProviderCallHandler {
private fun Any?.toInt(): Int? {
return when (this) {
is Int -> this
is Long -> this.toInt()
is Number -> this.toInt()
else -> null
}
}
private fun handleLogs(context: Context, data: Bundle) {
private fun handleRequest(context: Context, data: Bundle) {
val intent = context.intent<SuRequestActivity>()
.setAction(REQUEST)
.putExtras(data)
.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
.addFlags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
if (Build.VERSION.SDK_INT >= 29) {
// Android Q does not allow starting activity from background
intent.startActivityWithRoot()
} else {
intent.startActivity(context)
}
}
private fun handleLogging(context: Context, data: Bundle) {
val fromUid = data["from.uid"].toInt() ?: return
if (fromUid == Process.myUid())
return

View File

@ -1,57 +0,0 @@
package com.topjohnwu.magisk.core.su
import android.net.LocalSocket
import android.net.LocalSocketAddress
import androidx.collection.ArrayMap
import timber.log.Timber
import java.io.*
abstract class SuConnector @Throws(IOException::class)
protected constructor(name: String) {
private val socket: LocalSocket = LocalSocket()
protected var out: DataOutputStream
protected var input: DataInputStream
init {
socket.connect(LocalSocketAddress(name, LocalSocketAddress.Namespace.ABSTRACT))
out = DataOutputStream(BufferedOutputStream(socket.outputStream))
input = DataInputStream(BufferedInputStream(socket.inputStream))
}
private fun readString(): String {
val len = input.readInt()
val buf = ByteArray(len)
input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
@Throws(IOException::class)
fun readRequest(): Map<String, String> {
val ret = ArrayMap<String, String>()
while (true) {
val name = readString()
if (name == "eof")
break
ret[name] = readString()
}
return ret
}
fun response() {
runCatching {
onResponse()
out.flush()
}.onFailure { Timber.e(it) }
runCatching {
input.close()
out.close()
socket.close()
}
}
@Throws(IOException::class)
protected abstract fun onResponse()
}

View File

@ -2,7 +2,10 @@ package com.topjohnwu.magisk.core.su
import android.content.Intent
import android.content.pm.PackageManager
import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.os.CountDownTimer
import androidx.collection.ArrayMap
import com.topjohnwu.magisk.BuildConfig
import com.topjohnwu.magisk.core.Config
import com.topjohnwu.magisk.core.Const
@ -11,43 +14,36 @@ import com.topjohnwu.magisk.core.model.MagiskPolicy
import com.topjohnwu.magisk.core.model.toPolicy
import com.topjohnwu.magisk.extensions.now
import timber.log.Timber
import java.io.*
import java.util.concurrent.TimeUnit
abstract class SuRequestHandler(
private val packageManager: PackageManager,
private val policyDB: PolicyDao
) {
protected var timer: CountDownTimer = object : CountDownTimer(
TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) {
override fun onFinish() {
respond(MagiskPolicy.DENY, 0)
}
override fun onTick(remains: Long) {}
}
private val socket: LocalSocket = LocalSocket()
private lateinit var out: DataOutputStream
private lateinit var input: DataInputStream
protected var timer: CountDownTimer = DefaultCountDown()
set(value) {
field.cancel()
field = value
field.start()
}
protected lateinit var policy: MagiskPolicy
private val cleanupTasks = mutableListOf<() -> Unit>()
private lateinit var connector: SuConnector
private set
abstract fun onStart()
abstract fun onRespond()
fun start(intent: Intent): Boolean {
val socketName = intent.getStringExtra("socket") ?: return false
try {
connector = object : SuConnector(socketName) {
override fun onResponse() {
out.writeInt(policy.policy)
}
}
val map = connector.readRequest()
socket.connect(LocalSocketAddress(socketName, LocalSocketAddress.Namespace.ABSTRACT))
out = DataOutputStream(BufferedOutputStream(socket.outputStream))
input = DataInputStream(BufferedInputStream(socket.inputStream))
val map = readRequest()
val uid = map["uid"]?.toIntOrNull() ?: return false
policy = uid.toPolicy(packageManager)
} catch (e: Exception) {
@ -71,20 +67,10 @@ abstract class SuRequestHandler(
}
timer.start()
cleanupTasks.add {
timer.cancel()
}
onStart()
return true
}
private fun respond() {
connector.response()
cleanupTasks.forEach { it() }
onRespond()
}
fun respond(action: Int, time: Int) {
val until = if (time > 0)
TimeUnit.MILLISECONDS.toSeconds(now) + TimeUnit.MINUTES.toSeconds(time.toLong())
@ -98,6 +84,45 @@ abstract class SuRequestHandler(
if (until >= 0)
policyDB.update(policy).blockingAwait()
respond()
try {
out.writeInt(policy.policy)
out.flush()
} catch (e: IOException) {
Timber.e(e)
} finally {
runCatching {
input.close()
out.close()
socket.close()
}
}
timer.cancel()
}
@Throws(IOException::class)
private fun readRequest(): Map<String, String> {
fun readString(): String {
val len = input.readInt()
val buf = ByteArray(len)
input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
val ret = ArrayMap<String, String>()
while (true) {
val name = readString()
if (name == "eof")
break
ret[name] = readString()
}
return ret
}
private inner class DefaultCountDown
: CountDownTimer(TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) {
override fun onFinish() {
respond(MagiskPolicy.DENY, 0)
}
override fun onTick(remains: Long) {}
}
}

View File

@ -85,6 +85,9 @@ class SuRequestViewModel(
val pos = selectedItemPosition.value
timeoutPrefs.edit().putInt(policy.packageName, pos).apply()
respond(action, Config.Value.TIMEOUT_LIST[pos])
// Kill activity after response
DieEvent().publish()
}
fun cancelTimer() {
@ -118,11 +121,6 @@ class SuRequestViewModel(
}
}
}
override fun onRespond() {
// Kill activity after response
DieEvent().publish()
}
}
}