Minor code refactoring

This commit is contained in:
topjohnwu 2021-01-12 00:07:48 -08:00
parent 79dfdb29e7
commit d2acd59ea8
12 changed files with 100 additions and 135 deletions

View File

@ -369,8 +369,8 @@ bool validate_manager(string &pkg, int userid, struct stat *st) {
void exec_sql(int client) {
run_finally f([=]{ close(client); });
char *sql = read_string(client);
char *err = db_exec(sql, [&](db_row &row) -> bool {
string sql = read_string(client);
char *err = db_exec(sql.data(), [client](db_row &row) -> bool {
string out;
bool first = true;
for (auto it : row) {
@ -380,11 +380,9 @@ void exec_sql(int client) {
out += '=';
out += it.second;
}
write_int(client, out.length());
xwrite(client, out.data(), out.length());
write_string(client, out);
return true;
});
free(sql);
write_int(client, 0);
db_err_cmd(err, return; );
}

View File

@ -56,9 +56,8 @@ int magisk_main(int argc, char *argv[]) {
} else if (argv[1] == "-v"sv) {
int fd = connect_daemon();
write_int(fd, CHECK_VERSION);
char *v = read_string(fd);
printf("%s\n", v);
free(v);
string v = read_string(fd);
printf("%s\n", v.data());
return 0;
} else if (argv[1] == "-V"sv) {
int fd = connect_daemon();
@ -101,13 +100,12 @@ int magisk_main(int argc, char *argv[]) {
int fd = connect_daemon();
write_int(fd, SQLITE_CMD);
write_string(fd, argv[2]);
string res;
for (;;) {
char *res = read_string(fd);
if (res[0] == '\0') {
read_string(fd, res);
if (res.empty())
return 0;
}
printf("%s\n", res);
free(res);
printf("%s\n", res.data());
}
} else if (argv[1] == "--remove-modules"sv) {
int fd = connect_daemon();
@ -116,8 +114,8 @@ int magisk_main(int argc, char *argv[]) {
} else if (argv[1] == "--path"sv) {
int fd = connect_daemon();
write_int(fd, GET_PATH);
char *path = read_string(fd);
printf("%s\n", path);
string path = read_string(fd);
printf("%s\n", path.data());
return 0;
} else if (argc >= 3 && argv[1] == "--install-module"sv) {
install_module(argv[2]);

View File

@ -4,6 +4,8 @@
#include <socket.hpp>
#include <utils.hpp>
using namespace std;
static size_t socket_len(sockaddr_un *sun) {
if (sun->sun_path[0])
return sizeof(sa_family_t) + strlen(sun->sun_path) + 1;
@ -160,43 +162,21 @@ void write_int_be(int fd, int val) {
xwrite(fd, &nl, sizeof(nl));
}
static char *rd_str(int fd, int len) {
char *val = (char *) xmalloc(sizeof(char) * (len + 1));
xxread(fd, val, len);
val[len] = '\0';
return val;
}
char* read_string(int fd) {
void read_string(int fd, std::string &str) {
int len = read_int(fd);
return rd_str(fd, len);
str.clear();
str.resize(len);
xxread(fd, str.data(), len);
}
char* read_string_be(int fd) {
int len = read_int_be(fd);
return rd_str(fd, len);
string read_string(int fd) {
string str;
read_string(fd, str);
return str;
}
void write_string(int fd, const char *val) {
void write_string(int fd, string_view str) {
if (fd < 0) return;
int len = strlen(val);
write_int(fd, len);
xwrite(fd, val, len);
}
void write_string_be(int fd, const char *val) {
int len = strlen(val);
write_int_be(fd, len);
xwrite(fd, val, len);
}
void write_key_value(int fd, const char *key, const char *val) {
write_string_be(fd, key);
write_string_be(fd, val);
}
void write_key_token(int fd, const char *key, int tok) {
char val[16];
sprintf(val, "%d", tok);
write_key_value(fd, key, val);
write_int(fd, str.size());
xwrite(fd, str.data(), str.size());
}

View File

@ -2,6 +2,7 @@
#include <sys/un.h>
#include <sys/socket.h>
#include <string_view>
socklen_t setup_sockaddr(sockaddr_un *sun, const char *name);
int socket_accept(int sockfd, int timeout);
@ -12,9 +13,6 @@ int read_int(int fd);
int read_int_be(int fd);
void write_int(int fd, int val);
void write_int_be(int fd, int val);
char *read_string(int fd);
char *read_string_be(int fd);
void write_string(int fd, const char *val);
void write_string_be(int fd, const char *val);
void write_key_value(int fd, const char *key, const char *val);
void write_key_token(int fd, const char *key, int tok);
std::string read_string(int fd);
void read_string(int fd, std::string &str);
void write_string(int fd, std::string_view str);

View File

@ -196,9 +196,8 @@ void SARInit::first_stage_prep() {
int client = xaccept4(sockfd, nullptr, nullptr, SOCK_CLOEXEC);
// Write backup files
char *tmp_dir = read_string(client);
chdir(tmp_dir);
free(tmp_dir);
string tmp_dir = read_string(client);
chdir(tmp_dir.data());
int cfg = xopen(INTLROOT "/config", O_WRONLY | O_CREAT, 0);
xwrite(cfg, config.buf, config.sz);
close(cfg);

View File

@ -160,11 +160,9 @@ static int add_list(const char *pkg, const char *proc) {
}
int add_list(int client) {
char *pkg = read_string(client);
char *proc = read_string(client);
int ret = add_list(pkg, proc);
free(pkg);
free(proc);
string pkg = read_string(client);
string proc = read_string(client);
int ret = add_list(pkg.data(), proc.data());
if (ret == DAEMON_SUCCESS)
update_uid_map();
return ret;
@ -200,11 +198,9 @@ static int rm_list(const char *pkg, const char *proc) {
}
int rm_list(int client) {
char *pkg = read_string(client);
char *proc = read_string(client);
int ret = rm_list(pkg, proc);
free(pkg);
free(proc);
string pkg = read_string(client);
string proc = read_string(client);
int ret = rm_list(pkg.data(), proc.data());
if (ret == DAEMON_SUCCESS)
update_uid_map();
return ret;
@ -246,16 +242,18 @@ static bool init_list() {
if (MAGISKTMP != "/sbin")
add_hide_set(GMS_PKG, GMS_PKG);
update_uid_map();
return true;
}
void ls_list(int client) {
FILE *out = fdopen(recv_fd(client), "a");
for (auto &hide : hide_set)
fprintf(out, "%s|%s\n", hide.first.data(), hide.second.data());
fclose(out);
write_int(client, DAEMON_SUCCESS);
for (auto &hide : hide_set) {
write_int(client, hide.first.size() + hide.second.size() + 1);
xwrite(client, hide.first.data(), hide.first.size());
xwrite(client, "|", 1);
xwrite(client, hide.second.data(), hide.second.size());
}
write_int(client, 0);
close(client);
}
@ -268,7 +266,7 @@ static void update_hide_config() {
}
int launch_magiskhide(bool late_props) {
mutex_guard g(hide_state_lock);
mutex_guard lock(hide_state_lock);
if (SDK_INT < 19)
return DAEMON_ERROR;
@ -300,6 +298,11 @@ int launch_magiskhide(bool late_props) {
hide_state = true;
update_hide_config();
// Unlock here or else we'll be stuck in deadlock
lock.unlock();
update_uid_map();
return DAEMON_SUCCESS;
}

View File

@ -5,7 +5,7 @@
#include "magiskhide.hpp"
using namespace std::literals;
using namespace std;
[[noreturn]] static void usage(char *arg0) {
fprintf(stderr,
@ -112,8 +112,6 @@ int magiskhide_main(int argc, char *argv[]) {
write_string(fd, argv[2]);
write_string(fd, argv[3] ? argv[3] : "");
}
if (req == LS_HIDELIST)
send_fd(fd, STDOUT_FILENO);
// Get response
int code = read_int(fd);
@ -122,30 +120,41 @@ int magiskhide_main(int argc, char *argv[]) {
break;
case HIDE_NOT_ENABLED:
fprintf(stderr, "MagiskHide is not enabled\n");
break;
goto return_code;
case HIDE_IS_ENABLED:
fprintf(stderr, "MagiskHide is enabled\n");
break;
goto return_code;
case HIDE_ITEM_EXIST:
fprintf(stderr, "Target already exists in hide list\n");
break;
goto return_code;
case HIDE_ITEM_NOT_EXIST:
fprintf(stderr, "Target does not exist in hide list\n");
break;
goto return_code;
case HIDE_NO_NS:
fprintf(stderr, "Your kernel doesn't support mount namespace\n");
break;
goto return_code;
case HIDE_INVALID_PKG:
fprintf(stderr, "Invalid package / process name\n");
break;
goto return_code;
case ROOT_REQUIRED:
fprintf(stderr, "Root is required for this operation\n");
break;
goto return_code;
case DAEMON_ERROR:
default:
fprintf(stderr, "Daemon error\n");
return DAEMON_ERROR;
}
if (req == LS_HIDELIST) {
string res;
for (;;) {
read_string(fd, res);
if (res.empty())
break;
printf("%s\n", res.data());
}
}
return_code:
return req == HIDE_STATUS ? (code == HIDE_IS_ENABLED ? 0 : 1) : code != DAEMON_SUCCESS;
}

View File

@ -30,7 +30,7 @@ enum {
? info->uid / 100000 : 0)
#define get_cmd(to) \
(to.command[0] ? to.command : to.shell[0] ? to.shell : DEFAULT_SHELL)
(to.command.empty() ? (to.shell.empty() ? DEFAULT_SHELL : to.shell.data()) : to.command.data())
class Extra {
const char *key;
@ -42,7 +42,7 @@ class Extra {
union {
int int_val;
bool bool_val;
const char * str_val;
const char *str_val;
};
char buf[32];
public:

View File

@ -43,18 +43,6 @@ static void usage(int status) {
exit(status);
}
static char *concat_commands(int argc, char *argv[]) {
char command[ARG_MAX];
command[0] = '\0';
for (int i = optind - 1; i < argc; ++i) {
if (command[0])
sprintf(command, "%s %s", command, argv[i]);
else
strcpy(command, argv[i]);
}
return strdup(command);
}
static void sighandler(int sig) {
restore_stdin();
@ -114,7 +102,13 @@ int su_client_main(int argc, char *argv[]) {
while ((c = getopt_long(argc, argv, "c:hlmps:Vvuz:M", long_opts, nullptr)) != -1) {
switch (c) {
case 'c':
su_req.command = concat_commands(argc, argv);
for (int i = optind - 1; i < argc; ++i) {
if (!su_req.command.empty())
su_req.command += ' ';
su_req.command += '\'';
su_req.command += argv[i];
su_req.command += '\'';
}
optind = argc;
break;
case 'h':

View File

@ -45,19 +45,9 @@ struct su_req_base {
} __attribute__((packed));
struct su_request : public su_req_base {
const char *shell = DEFAULT_SHELL;
const char *command = "";
su_request(bool dyn = false) : dyn(dyn) {}
~su_request() {
if (dyn) {
free(const_cast<char*>(shell));
free(const_cast<char*>(command));
}
}
private:
bool dyn;
} __attribute__((packed));
std::string shell = DEFAULT_SHELL;
std::string command;
};
struct su_context {
std::shared_ptr<su_info> info;

View File

@ -163,14 +163,14 @@ void su_daemon_handler(int client, struct ucred *credential) {
su_context ctx = {
.info = get_su_info(credential->uid),
.req = su_request(true),
.req = su_request(),
.pid = credential->pid
};
// Read su_request
xxread(client, &ctx.req, sizeof(su_req_base));
ctx.req.shell = read_string(client);
ctx.req.command = read_string(client);
read_string(client, ctx.req.shell);
read_string(client, ctx.req.command);
if (ctx.info->access.log)
app_log(ctx);
@ -223,18 +223,18 @@ void su_daemon_handler(int client, struct ucred *credential) {
xsetsid();
// Get pts_slave
char *pts_slave = read_string(client);
string pts_slave = read_string(client);
// The FDs for each of the streams
int infd = recv_fd(client);
int outfd = recv_fd(client);
int errfd = recv_fd(client);
if (pts_slave[0]) {
LOGD("su: pts_slave=[%s]\n", pts_slave);
if (!pts_slave.empty()) {
LOGD("su: pts_slave=[%s]\n", pts_slave.data());
// Check pts_slave file is owned by daemon_from_uid
struct stat st;
xstat(pts_slave, &st);
xstat(pts_slave.data(), &st);
// If caller is not root, ensure the owner of pts_slave is the caller
if(st.st_uid != ctx.info->uid && ctx.info->uid != 0)
@ -243,7 +243,7 @@ void su_daemon_handler(int client, struct ucred *credential) {
// Opening the TTY has to occur after the
// fork() and setsid() so that it becomes
// our controlling TTY and not the daemon's
int ptsfd = xopen(pts_slave, O_RDWR);
int ptsfd = xopen(pts_slave.data(), O_RDWR);
if (infd < 0)
infd = ptsfd;
@ -253,8 +253,6 @@ void su_daemon_handler(int client, struct ucred *credential) {
errfd = ptsfd;
}
free(pts_slave);
// Swap out stdin, stdout, stderr
xdup2(infd, STDIN_FILENO);
xdup2(outfd, STDOUT_FILENO);
@ -290,13 +288,13 @@ void su_daemon_handler(int client, struct ucred *credential) {
break;
}
const char *argv[] = { nullptr, nullptr, nullptr, nullptr };
const char *argv[4] = { nullptr };
argv[0] = ctx.req.login ? "-" : ctx.req.shell;
argv[0] = ctx.req.login ? "-" : ctx.req.shell.data();
if (ctx.req.command[0]) {
if (!ctx.req.command.empty()) {
argv[1] = "-c";
argv[2] = ctx.req.command;
argv[2] = ctx.req.command.data();
}
// Setup environment
@ -321,7 +319,7 @@ void su_daemon_handler(int client, struct ucred *credential) {
setenv("HOME", pw->pw_dir, 1);
setenv("USER", pw->pw_name, 1);
setenv("LOGNAME", pw->pw_name, 1);
setenv("SHELL", ctx.req.shell, 1);
setenv("SHELL", ctx.req.shell.data(), 1);
}
}
@ -330,8 +328,8 @@ void su_daemon_handler(int client, struct ucred *credential) {
sigemptyset(&block_set);
sigprocmask(SIG_SETMASK, &block_set, nullptr);
set_identity(ctx.req.uid);
execvp(ctx.req.shell, (char **) argv);
fprintf(stderr, "Cannot execute %s: %s\n", ctx.req.shell, strerror(errno));
execvp(ctx.req.shell.data(), (char **) argv);
fprintf(stderr, "Cannot execute %s: %s\n", ctx.req.shell.data(), strerror(errno));
PLOGE("exec");
exit(EXIT_FAILURE);
}

View File

@ -13,15 +13,13 @@ public:
explicit mutex_guard(pthread_mutex_t &m): mutex(&m) {
pthread_mutex_lock(mutex);
}
explicit mutex_guard(pthread_mutex_t *m): mutex(m) {
pthread_mutex_lock(mutex);
}
~mutex_guard() {
void unlock() {
pthread_mutex_unlock(mutex);
mutex = nullptr;
}
~mutex_guard() {
if (mutex) pthread_mutex_unlock(mutex);
}
private:
pthread_mutex_t *mutex;
};