Minor code refactoring

This commit is contained in:
topjohnwu 2021-01-12 00:07:48 -08:00
parent 79dfdb29e7
commit d2acd59ea8
12 changed files with 100 additions and 135 deletions

View File

@ -369,8 +369,8 @@ bool validate_manager(string &pkg, int userid, struct stat *st) {
void exec_sql(int client) { void exec_sql(int client) {
run_finally f([=]{ close(client); }); run_finally f([=]{ close(client); });
char *sql = read_string(client); string sql = read_string(client);
char *err = db_exec(sql, [&](db_row &row) -> bool { char *err = db_exec(sql.data(), [client](db_row &row) -> bool {
string out; string out;
bool first = true; bool first = true;
for (auto it : row) { for (auto it : row) {
@ -380,11 +380,9 @@ void exec_sql(int client) {
out += '='; out += '=';
out += it.second; out += it.second;
} }
write_int(client, out.length()); write_string(client, out);
xwrite(client, out.data(), out.length());
return true; return true;
}); });
free(sql);
write_int(client, 0); write_int(client, 0);
db_err_cmd(err, return; ); db_err_cmd(err, return; );
} }

View File

@ -56,9 +56,8 @@ int magisk_main(int argc, char *argv[]) {
} else if (argv[1] == "-v"sv) { } else if (argv[1] == "-v"sv) {
int fd = connect_daemon(); int fd = connect_daemon();
write_int(fd, CHECK_VERSION); write_int(fd, CHECK_VERSION);
char *v = read_string(fd); string v = read_string(fd);
printf("%s\n", v); printf("%s\n", v.data());
free(v);
return 0; return 0;
} else if (argv[1] == "-V"sv) { } else if (argv[1] == "-V"sv) {
int fd = connect_daemon(); int fd = connect_daemon();
@ -101,13 +100,12 @@ int magisk_main(int argc, char *argv[]) {
int fd = connect_daemon(); int fd = connect_daemon();
write_int(fd, SQLITE_CMD); write_int(fd, SQLITE_CMD);
write_string(fd, argv[2]); write_string(fd, argv[2]);
string res;
for (;;) { for (;;) {
char *res = read_string(fd); read_string(fd, res);
if (res[0] == '\0') { if (res.empty())
return 0; return 0;
} printf("%s\n", res.data());
printf("%s\n", res);
free(res);
} }
} else if (argv[1] == "--remove-modules"sv) { } else if (argv[1] == "--remove-modules"sv) {
int fd = connect_daemon(); int fd = connect_daemon();
@ -116,8 +114,8 @@ int magisk_main(int argc, char *argv[]) {
} else if (argv[1] == "--path"sv) { } else if (argv[1] == "--path"sv) {
int fd = connect_daemon(); int fd = connect_daemon();
write_int(fd, GET_PATH); write_int(fd, GET_PATH);
char *path = read_string(fd); string path = read_string(fd);
printf("%s\n", path); printf("%s\n", path.data());
return 0; return 0;
} else if (argc >= 3 && argv[1] == "--install-module"sv) { } else if (argc >= 3 && argv[1] == "--install-module"sv) {
install_module(argv[2]); install_module(argv[2]);

View File

@ -4,6 +4,8 @@
#include <socket.hpp> #include <socket.hpp>
#include <utils.hpp> #include <utils.hpp>
using namespace std;
static size_t socket_len(sockaddr_un *sun) { static size_t socket_len(sockaddr_un *sun) {
if (sun->sun_path[0]) if (sun->sun_path[0])
return sizeof(sa_family_t) + strlen(sun->sun_path) + 1; return sizeof(sa_family_t) + strlen(sun->sun_path) + 1;
@ -160,43 +162,21 @@ void write_int_be(int fd, int val) {
xwrite(fd, &nl, sizeof(nl)); xwrite(fd, &nl, sizeof(nl));
} }
static char *rd_str(int fd, int len) { void read_string(int fd, std::string &str) {
char *val = (char *) xmalloc(sizeof(char) * (len + 1));
xxread(fd, val, len);
val[len] = '\0';
return val;
}
char* read_string(int fd) {
int len = read_int(fd); int len = read_int(fd);
return rd_str(fd, len); str.clear();
str.resize(len);
xxread(fd, str.data(), len);
} }
char* read_string_be(int fd) { string read_string(int fd) {
int len = read_int_be(fd); string str;
return rd_str(fd, len); read_string(fd, str);
return str;
} }
void write_string(int fd, const char *val) { void write_string(int fd, string_view str) {
if (fd < 0) return; if (fd < 0) return;
int len = strlen(val); write_int(fd, str.size());
write_int(fd, len); xwrite(fd, str.data(), str.size());
xwrite(fd, val, len);
}
void write_string_be(int fd, const char *val) {
int len = strlen(val);
write_int_be(fd, len);
xwrite(fd, val, len);
}
void write_key_value(int fd, const char *key, const char *val) {
write_string_be(fd, key);
write_string_be(fd, val);
}
void write_key_token(int fd, const char *key, int tok) {
char val[16];
sprintf(val, "%d", tok);
write_key_value(fd, key, val);
} }

View File

@ -2,6 +2,7 @@
#include <sys/un.h> #include <sys/un.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <string_view>
socklen_t setup_sockaddr(sockaddr_un *sun, const char *name); socklen_t setup_sockaddr(sockaddr_un *sun, const char *name);
int socket_accept(int sockfd, int timeout); int socket_accept(int sockfd, int timeout);
@ -12,9 +13,6 @@ int read_int(int fd);
int read_int_be(int fd); int read_int_be(int fd);
void write_int(int fd, int val); void write_int(int fd, int val);
void write_int_be(int fd, int val); void write_int_be(int fd, int val);
char *read_string(int fd); std::string read_string(int fd);
char *read_string_be(int fd); void read_string(int fd, std::string &str);
void write_string(int fd, const char *val); void write_string(int fd, std::string_view str);
void write_string_be(int fd, const char *val);
void write_key_value(int fd, const char *key, const char *val);
void write_key_token(int fd, const char *key, int tok);

View File

@ -196,9 +196,8 @@ void SARInit::first_stage_prep() {
int client = xaccept4(sockfd, nullptr, nullptr, SOCK_CLOEXEC); int client = xaccept4(sockfd, nullptr, nullptr, SOCK_CLOEXEC);
// Write backup files // Write backup files
char *tmp_dir = read_string(client); string tmp_dir = read_string(client);
chdir(tmp_dir); chdir(tmp_dir.data());
free(tmp_dir);
int cfg = xopen(INTLROOT "/config", O_WRONLY | O_CREAT, 0); int cfg = xopen(INTLROOT "/config", O_WRONLY | O_CREAT, 0);
xwrite(cfg, config.buf, config.sz); xwrite(cfg, config.buf, config.sz);
close(cfg); close(cfg);

View File

@ -160,11 +160,9 @@ static int add_list(const char *pkg, const char *proc) {
} }
int add_list(int client) { int add_list(int client) {
char *pkg = read_string(client); string pkg = read_string(client);
char *proc = read_string(client); string proc = read_string(client);
int ret = add_list(pkg, proc); int ret = add_list(pkg.data(), proc.data());
free(pkg);
free(proc);
if (ret == DAEMON_SUCCESS) if (ret == DAEMON_SUCCESS)
update_uid_map(); update_uid_map();
return ret; return ret;
@ -200,11 +198,9 @@ static int rm_list(const char *pkg, const char *proc) {
} }
int rm_list(int client) { int rm_list(int client) {
char *pkg = read_string(client); string pkg = read_string(client);
char *proc = read_string(client); string proc = read_string(client);
int ret = rm_list(pkg, proc); int ret = rm_list(pkg.data(), proc.data());
free(pkg);
free(proc);
if (ret == DAEMON_SUCCESS) if (ret == DAEMON_SUCCESS)
update_uid_map(); update_uid_map();
return ret; return ret;
@ -246,16 +242,18 @@ static bool init_list() {
if (MAGISKTMP != "/sbin") if (MAGISKTMP != "/sbin")
add_hide_set(GMS_PKG, GMS_PKG); add_hide_set(GMS_PKG, GMS_PKG);
update_uid_map();
return true; return true;
} }
void ls_list(int client) { void ls_list(int client) {
FILE *out = fdopen(recv_fd(client), "a");
for (auto &hide : hide_set)
fprintf(out, "%s|%s\n", hide.first.data(), hide.second.data());
fclose(out);
write_int(client, DAEMON_SUCCESS); write_int(client, DAEMON_SUCCESS);
for (auto &hide : hide_set) {
write_int(client, hide.first.size() + hide.second.size() + 1);
xwrite(client, hide.first.data(), hide.first.size());
xwrite(client, "|", 1);
xwrite(client, hide.second.data(), hide.second.size());
}
write_int(client, 0);
close(client); close(client);
} }
@ -268,7 +266,7 @@ static void update_hide_config() {
} }
int launch_magiskhide(bool late_props) { int launch_magiskhide(bool late_props) {
mutex_guard g(hide_state_lock); mutex_guard lock(hide_state_lock);
if (SDK_INT < 19) if (SDK_INT < 19)
return DAEMON_ERROR; return DAEMON_ERROR;
@ -300,6 +298,11 @@ int launch_magiskhide(bool late_props) {
hide_state = true; hide_state = true;
update_hide_config(); update_hide_config();
// Unlock here or else we'll be stuck in deadlock
lock.unlock();
update_uid_map();
return DAEMON_SUCCESS; return DAEMON_SUCCESS;
} }

View File

@ -5,7 +5,7 @@
#include "magiskhide.hpp" #include "magiskhide.hpp"
using namespace std::literals; using namespace std;
[[noreturn]] static void usage(char *arg0) { [[noreturn]] static void usage(char *arg0) {
fprintf(stderr, fprintf(stderr,
@ -112,8 +112,6 @@ int magiskhide_main(int argc, char *argv[]) {
write_string(fd, argv[2]); write_string(fd, argv[2]);
write_string(fd, argv[3] ? argv[3] : ""); write_string(fd, argv[3] ? argv[3] : "");
} }
if (req == LS_HIDELIST)
send_fd(fd, STDOUT_FILENO);
// Get response // Get response
int code = read_int(fd); int code = read_int(fd);
@ -122,30 +120,41 @@ int magiskhide_main(int argc, char *argv[]) {
break; break;
case HIDE_NOT_ENABLED: case HIDE_NOT_ENABLED:
fprintf(stderr, "MagiskHide is not enabled\n"); fprintf(stderr, "MagiskHide is not enabled\n");
break; goto return_code;
case HIDE_IS_ENABLED: case HIDE_IS_ENABLED:
fprintf(stderr, "MagiskHide is enabled\n"); fprintf(stderr, "MagiskHide is enabled\n");
break; goto return_code;
case HIDE_ITEM_EXIST: case HIDE_ITEM_EXIST:
fprintf(stderr, "Target already exists in hide list\n"); fprintf(stderr, "Target already exists in hide list\n");
break; goto return_code;
case HIDE_ITEM_NOT_EXIST: case HIDE_ITEM_NOT_EXIST:
fprintf(stderr, "Target does not exist in hide list\n"); fprintf(stderr, "Target does not exist in hide list\n");
break; goto return_code;
case HIDE_NO_NS: case HIDE_NO_NS:
fprintf(stderr, "Your kernel doesn't support mount namespace\n"); fprintf(stderr, "Your kernel doesn't support mount namespace\n");
break; goto return_code;
case HIDE_INVALID_PKG: case HIDE_INVALID_PKG:
fprintf(stderr, "Invalid package / process name\n"); fprintf(stderr, "Invalid package / process name\n");
break; goto return_code;
case ROOT_REQUIRED: case ROOT_REQUIRED:
fprintf(stderr, "Root is required for this operation\n"); fprintf(stderr, "Root is required for this operation\n");
break; goto return_code;
case DAEMON_ERROR: case DAEMON_ERROR:
default: default:
fprintf(stderr, "Daemon error\n"); fprintf(stderr, "Daemon error\n");
return DAEMON_ERROR; return DAEMON_ERROR;
} }
if (req == LS_HIDELIST) {
string res;
for (;;) {
read_string(fd, res);
if (res.empty())
break;
printf("%s\n", res.data());
}
}
return_code:
return req == HIDE_STATUS ? (code == HIDE_IS_ENABLED ? 0 : 1) : code != DAEMON_SUCCESS; return req == HIDE_STATUS ? (code == HIDE_IS_ENABLED ? 0 : 1) : code != DAEMON_SUCCESS;
} }

View File

@ -30,7 +30,7 @@ enum {
? info->uid / 100000 : 0) ? info->uid / 100000 : 0)
#define get_cmd(to) \ #define get_cmd(to) \
(to.command[0] ? to.command : to.shell[0] ? to.shell : DEFAULT_SHELL) (to.command.empty() ? (to.shell.empty() ? DEFAULT_SHELL : to.shell.data()) : to.command.data())
class Extra { class Extra {
const char *key; const char *key;
@ -42,7 +42,7 @@ class Extra {
union { union {
int int_val; int int_val;
bool bool_val; bool bool_val;
const char * str_val; const char *str_val;
}; };
char buf[32]; char buf[32];
public: public:

View File

@ -43,18 +43,6 @@ static void usage(int status) {
exit(status); exit(status);
} }
static char *concat_commands(int argc, char *argv[]) {
char command[ARG_MAX];
command[0] = '\0';
for (int i = optind - 1; i < argc; ++i) {
if (command[0])
sprintf(command, "%s %s", command, argv[i]);
else
strcpy(command, argv[i]);
}
return strdup(command);
}
static void sighandler(int sig) { static void sighandler(int sig) {
restore_stdin(); restore_stdin();
@ -114,7 +102,13 @@ int su_client_main(int argc, char *argv[]) {
while ((c = getopt_long(argc, argv, "c:hlmps:Vvuz:M", long_opts, nullptr)) != -1) { while ((c = getopt_long(argc, argv, "c:hlmps:Vvuz:M", long_opts, nullptr)) != -1) {
switch (c) { switch (c) {
case 'c': case 'c':
su_req.command = concat_commands(argc, argv); for (int i = optind - 1; i < argc; ++i) {
if (!su_req.command.empty())
su_req.command += ' ';
su_req.command += '\'';
su_req.command += argv[i];
su_req.command += '\'';
}
optind = argc; optind = argc;
break; break;
case 'h': case 'h':

View File

@ -45,19 +45,9 @@ struct su_req_base {
} __attribute__((packed)); } __attribute__((packed));
struct su_request : public su_req_base { struct su_request : public su_req_base {
const char *shell = DEFAULT_SHELL; std::string shell = DEFAULT_SHELL;
const char *command = ""; std::string command;
su_request(bool dyn = false) : dyn(dyn) {} };
~su_request() {
if (dyn) {
free(const_cast<char*>(shell));
free(const_cast<char*>(command));
}
}
private:
bool dyn;
} __attribute__((packed));
struct su_context { struct su_context {
std::shared_ptr<su_info> info; std::shared_ptr<su_info> info;

View File

@ -163,14 +163,14 @@ void su_daemon_handler(int client, struct ucred *credential) {
su_context ctx = { su_context ctx = {
.info = get_su_info(credential->uid), .info = get_su_info(credential->uid),
.req = su_request(true), .req = su_request(),
.pid = credential->pid .pid = credential->pid
}; };
// Read su_request // Read su_request
xxread(client, &ctx.req, sizeof(su_req_base)); xxread(client, &ctx.req, sizeof(su_req_base));
ctx.req.shell = read_string(client); read_string(client, ctx.req.shell);
ctx.req.command = read_string(client); read_string(client, ctx.req.command);
if (ctx.info->access.log) if (ctx.info->access.log)
app_log(ctx); app_log(ctx);
@ -223,18 +223,18 @@ void su_daemon_handler(int client, struct ucred *credential) {
xsetsid(); xsetsid();
// Get pts_slave // Get pts_slave
char *pts_slave = read_string(client); string pts_slave = read_string(client);
// The FDs for each of the streams // The FDs for each of the streams
int infd = recv_fd(client); int infd = recv_fd(client);
int outfd = recv_fd(client); int outfd = recv_fd(client);
int errfd = recv_fd(client); int errfd = recv_fd(client);
if (pts_slave[0]) { if (!pts_slave.empty()) {
LOGD("su: pts_slave=[%s]\n", pts_slave); LOGD("su: pts_slave=[%s]\n", pts_slave.data());
// Check pts_slave file is owned by daemon_from_uid // Check pts_slave file is owned by daemon_from_uid
struct stat st; struct stat st;
xstat(pts_slave, &st); xstat(pts_slave.data(), &st);
// If caller is not root, ensure the owner of pts_slave is the caller // If caller is not root, ensure the owner of pts_slave is the caller
if(st.st_uid != ctx.info->uid && ctx.info->uid != 0) if(st.st_uid != ctx.info->uid && ctx.info->uid != 0)
@ -243,7 +243,7 @@ void su_daemon_handler(int client, struct ucred *credential) {
// Opening the TTY has to occur after the // Opening the TTY has to occur after the
// fork() and setsid() so that it becomes // fork() and setsid() so that it becomes
// our controlling TTY and not the daemon's // our controlling TTY and not the daemon's
int ptsfd = xopen(pts_slave, O_RDWR); int ptsfd = xopen(pts_slave.data(), O_RDWR);
if (infd < 0) if (infd < 0)
infd = ptsfd; infd = ptsfd;
@ -253,8 +253,6 @@ void su_daemon_handler(int client, struct ucred *credential) {
errfd = ptsfd; errfd = ptsfd;
} }
free(pts_slave);
// Swap out stdin, stdout, stderr // Swap out stdin, stdout, stderr
xdup2(infd, STDIN_FILENO); xdup2(infd, STDIN_FILENO);
xdup2(outfd, STDOUT_FILENO); xdup2(outfd, STDOUT_FILENO);
@ -290,13 +288,13 @@ void su_daemon_handler(int client, struct ucred *credential) {
break; break;
} }
const char *argv[] = { nullptr, nullptr, nullptr, nullptr }; const char *argv[4] = { nullptr };
argv[0] = ctx.req.login ? "-" : ctx.req.shell; argv[0] = ctx.req.login ? "-" : ctx.req.shell.data();
if (ctx.req.command[0]) { if (!ctx.req.command.empty()) {
argv[1] = "-c"; argv[1] = "-c";
argv[2] = ctx.req.command; argv[2] = ctx.req.command.data();
} }
// Setup environment // Setup environment
@ -321,7 +319,7 @@ void su_daemon_handler(int client, struct ucred *credential) {
setenv("HOME", pw->pw_dir, 1); setenv("HOME", pw->pw_dir, 1);
setenv("USER", pw->pw_name, 1); setenv("USER", pw->pw_name, 1);
setenv("LOGNAME", pw->pw_name, 1); setenv("LOGNAME", pw->pw_name, 1);
setenv("SHELL", ctx.req.shell, 1); setenv("SHELL", ctx.req.shell.data(), 1);
} }
} }
@ -330,8 +328,8 @@ void su_daemon_handler(int client, struct ucred *credential) {
sigemptyset(&block_set); sigemptyset(&block_set);
sigprocmask(SIG_SETMASK, &block_set, nullptr); sigprocmask(SIG_SETMASK, &block_set, nullptr);
set_identity(ctx.req.uid); set_identity(ctx.req.uid);
execvp(ctx.req.shell, (char **) argv); execvp(ctx.req.shell.data(), (char **) argv);
fprintf(stderr, "Cannot execute %s: %s\n", ctx.req.shell, strerror(errno)); fprintf(stderr, "Cannot execute %s: %s\n", ctx.req.shell.data(), strerror(errno));
PLOGE("exec"); PLOGE("exec");
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }

View File

@ -13,15 +13,13 @@ public:
explicit mutex_guard(pthread_mutex_t &m): mutex(&m) { explicit mutex_guard(pthread_mutex_t &m): mutex(&m) {
pthread_mutex_lock(mutex); pthread_mutex_lock(mutex);
} }
void unlock() {
explicit mutex_guard(pthread_mutex_t *m): mutex(m) {
pthread_mutex_lock(mutex);
}
~mutex_guard() {
pthread_mutex_unlock(mutex); pthread_mutex_unlock(mutex);
mutex = nullptr;
}
~mutex_guard() {
if (mutex) pthread_mutex_unlock(mutex);
} }
private: private:
pthread_mutex_t *mutex; pthread_mutex_t *mutex;
}; };