Use C++ smart pointer for caching su_info

This commit is contained in:
topjohnwu 2019-07-07 00:31:49 -07:00
parent 4f206fd918
commit ab90901793
3 changed files with 67 additions and 82 deletions

View File

@ -10,6 +10,8 @@
#include "su.h" #include "su.h"
using namespace std;
bool CONNECT_BROADCAST; bool CONNECT_BROADCAST;
#define START_ACTIVITY \ #define START_ACTIVITY \
@ -33,21 +35,21 @@ static inline const char *get_command(const su_request *to) {
return DEFAULT_SHELL; return DEFAULT_SHELL;
} }
static inline void get_user(char *user, su_info *info) { static inline void get_user(char *user, const su_info *info) {
sprintf(user, "%d", sprintf(user, "%d",
info->cfg[SU_MULTIUSER_MODE] == MULTIUSER_MODE_USER info->cfg[SU_MULTIUSER_MODE] == MULTIUSER_MODE_USER
? info->uid / 100000 ? info->uid / 100000
: 0); : 0);
} }
static inline void get_uid(char *uid, su_info *info) { static inline void get_uid(char *uid, const su_info *info) {
sprintf(uid, "%d", sprintf(uid, "%d",
info->cfg[SU_MULTIUSER_MODE] == MULTIUSER_MODE_OWNER_MANAGED info->cfg[SU_MULTIUSER_MODE] == MULTIUSER_MODE_OWNER_MANAGED
? info->uid % 100000 ? info->uid % 100000
: info->uid); : info->uid);
} }
static void exec_am_cmd(const char **args, su_info *info) { static void exec_am_cmd(const char **args, const su_info *info) {
char component[128]; char component[128];
sprintf(component, "%s/%s", info->str[SU_MANAGER].data(), args[3][0] == 'b' ? "a.h" : "a.m"); sprintf(component, "%s/%s", info->str[SU_MANAGER].data(), args[3][0] == 'b' ? "a.h" : "a.m");
char user[8]; char user[8];
@ -76,29 +78,29 @@ static void exec_am_cmd(const char **args, su_info *info) {
"--ei", "to.uid", toUid, \ "--ei", "to.uid", toUid, \
"--ei", "pid", pid, \ "--ei", "pid", pid, \
"--ei", "policy", policy, \ "--ei", "policy", policy, \
"--es", "command", get_command(&ctx->req), \ "--es", "command", get_command(&ctx.req), \
"--ez", "notify", ctx->info->access.notify ? "true" : "false", \ "--ez", "notify", ctx.info->access.notify ? "true" : "false", \
nullptr nullptr
void app_log(su_context *ctx) { void app_log(const su_context &ctx) {
char fromUid[8]; char fromUid[8];
get_uid(fromUid, ctx->info); get_uid(fromUid, ctx.info.get());
char toUid[8]; char toUid[8];
sprintf(toUid, "%d", ctx->req.uid); sprintf(toUid, "%d", ctx.req.uid);
char pid[8]; char pid[8];
sprintf(pid, "%d", ctx->pid); sprintf(pid, "%d", ctx.pid);
char policy[2]; char policy[2];
sprintf(policy, "%d", ctx->info->access.policy); sprintf(policy, "%d", ctx.info->access.policy);
if (CONNECT_BROADCAST) { if (CONNECT_BROADCAST) {
const char *cmd[] = { START_BROADCAST, LOG_BODY }; const char *cmd[] = { START_BROADCAST, LOG_BODY };
exec_am_cmd(cmd, ctx->info); exec_am_cmd(cmd, ctx.info.get());
} else { } else {
const char *cmd[] = { START_ACTIVITY, LOG_BODY }; const char *cmd[] = { START_ACTIVITY, LOG_BODY };
exec_am_cmd(cmd, ctx->info); exec_am_cmd(cmd, ctx.info.get());
} }
} }
@ -108,30 +110,30 @@ void app_log(su_context *ctx) {
"--ei", "policy", policy, \ "--ei", "policy", policy, \
nullptr nullptr
void app_notify(su_context *ctx) { void app_notify(const su_context &ctx) {
char fromUid[8]; char fromUid[8];
get_uid(fromUid, ctx->info); get_uid(fromUid, ctx.info.get());
char policy[2]; char policy[2];
sprintf(policy, "%d", ctx->info->access.policy); sprintf(policy, "%d", ctx.info->access.policy);
if (CONNECT_BROADCAST) { if (CONNECT_BROADCAST) {
const char *cmd[] = { START_BROADCAST, NOTIFY_BODY }; const char *cmd[] = { START_BROADCAST, NOTIFY_BODY };
exec_am_cmd(cmd, ctx->info); exec_am_cmd(cmd, ctx.info.get());
} else { } else {
const char *cmd[] = { START_ACTIVITY, NOTIFY_BODY }; const char *cmd[] = { START_ACTIVITY, NOTIFY_BODY };
exec_am_cmd(cmd, ctx->info); exec_am_cmd(cmd, ctx.info.get());
} }
} }
void app_connect(const char *socket, su_info *info) { void app_connect(const char *socket, const shared_ptr<su_info> &info) {
const char *cmd[] = { const char *cmd[] = {
START_ACTIVITY, "request", START_ACTIVITY, "request",
"--es", "socket", socket, "--es", "socket", socket,
nullptr nullptr
}; };
exec_am_cmd(cmd, info); exec_am_cmd(cmd, info.get());
} }
void broadcast_test() { void broadcast_test() {
@ -144,7 +146,7 @@ void broadcast_test() {
exec_am_cmd(cmd, &info); exec_am_cmd(cmd, &info);
} }
void socket_send_request(int fd, su_info *info) { void socket_send_request(int fd, const shared_ptr<su_info> &info) {
write_key_token(fd, "uid", info->uid); write_key_token(fd, "uid", info->uid);
write_string_be(fd, "eof"); write_string_be(fd, "eof");
} }

View File

@ -1,44 +1,39 @@
/* su.h - Store all general su info #pragma once
*/
#ifndef _SU_H_
#define _SU_H_
#include <limits.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <limits.h>
#include <memory>
#include <db.h> #include <db.h>
#define DEFAULT_SHELL "/system/bin/sh" #define DEFAULT_SHELL "/system/bin/sh"
// Constants for atty // Constants for atty
#define ATTY_IN 1 #define ATTY_IN (1 << 0)
#define ATTY_OUT 2 #define ATTY_OUT (1 << 1)
#define ATTY_ERR 4 #define ATTY_ERR (1 << 2)
class su_info { class su_info {
public: public:
unsigned uid; /* Unique key to find su_info */ /* Unique key */
int count; /* Just a count for debugging purpose */ const unsigned uid;
/* These values should be guarded with internal lock */ /* These should be guarded with internal lock */
db_settings cfg; db_settings cfg;
db_strings str; db_strings str;
su_access access; su_access access;
struct stat mgr_st; struct stat mgr_st;
/* These should be guarded with global cache lock */ /* This should be guarded with global cache lock */
int ref; long timestamp;
time_t timestamp;
su_info(unsigned uid = 0); su_info(unsigned uid = 0);
~su_info(); ~su_info();
void lock(); void lock();
void unlock(); void unlock();
bool isFresh(); bool is_fresh();
void newRef(); void refresh();
void deRef();
private: private:
pthread_mutex_t _lock; /* Internal lock */ pthread_mutex_t _lock; /* Internal lock */
@ -60,16 +55,14 @@ struct su_request : public su_req_base {
} __attribute__((packed)); } __attribute__((packed));
struct su_context { struct su_context {
su_info *info; std::shared_ptr<su_info> info;
su_request req; su_request req;
pid_t pid; pid_t pid;
}; };
// connect.c // connect.c
void app_log(su_context *ctx); void app_log(const su_context &ctx);
void app_notify(su_context *ctx); void app_notify(const su_context &ctx);
void app_connect(const char *socket, su_info *info); void app_connect(const char *socket, const std::shared_ptr<su_info> &info);
void socket_send_request(int fd, su_info *info); void socket_send_request(int fd, const std::shared_ptr<su_info> &info);
#endif

View File

@ -20,14 +20,16 @@
#include "su.h" #include "su.h"
#include "pts.h" #include "pts.h"
using namespace std;
#define LOCK_CACHE() pthread_mutex_lock(&cache_lock) #define LOCK_CACHE() pthread_mutex_lock(&cache_lock)
#define UNLOCK_CACHE() pthread_mutex_unlock(&cache_lock) #define UNLOCK_CACHE() pthread_mutex_unlock(&cache_lock)
static pthread_mutex_t cache_lock = PTHREAD_MUTEX_INITIALIZER; static pthread_mutex_t cache_lock = PTHREAD_MUTEX_INITIALIZER;
static su_info *cache; static shared_ptr<su_info> cached;
su_info::su_info(unsigned uid) : su_info::su_info(unsigned uid) :
uid(uid), count(0), access(DEFAULT_SU_ACCESS), mgr_st({}), ref(0), uid(uid), access(DEFAULT_SU_ACCESS), mgr_st({}),
timestamp(0), _lock(PTHREAD_MUTEX_INITIALIZER) {} timestamp(0), _lock(PTHREAD_MUTEX_INITIALIZER) {}
su_info::~su_info() { su_info::~su_info() {
@ -42,27 +44,20 @@ void su_info::unlock() {
pthread_mutex_unlock(&_lock); pthread_mutex_unlock(&_lock);
} }
bool su_info::isFresh() { bool su_info::is_fresh() {
return time(nullptr) - timestamp < 3; /* 3 seconds */ timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
long current = ts.tv_sec * 1000L + ts.tv_nsec / 1000L;
return current - timestamp < 3000; /* 3 seconds */
} }
void su_info::newRef() { void su_info::refresh() {
timestamp = time(nullptr); timespec ts;
++ref; clock_gettime(CLOCK_MONOTONIC, &ts);
timestamp = ts.tv_sec * 1000L + ts.tv_nsec / 1000L;
} }
void su_info::deRef() { static void database_check(const shared_ptr<su_info> &info) {
LOCK_CACHE();
--ref;
if (ref == 0 && !isFresh()) {
if (cache == this)
cache = nullptr;
delete this;
}
UNLOCK_CACHE();
}
static void database_check(su_info *info) {
int uid = info->uid; int uid = info->uid;
get_db_settings(info->cfg); get_db_settings(info->cfg);
get_db_strings(info->str); get_db_strings(info->str);
@ -91,28 +86,24 @@ static void database_check(su_info *info) {
validate_manager(info->str[SU_MANAGER], uid / 100000, &info->mgr_st); validate_manager(info->str[SU_MANAGER], uid / 100000, &info->mgr_st);
} }
static su_info *get_su_info(unsigned uid) { static shared_ptr<su_info> get_su_info(unsigned uid) {
su_info *info = nullptr; shared_ptr<su_info> info;
// Get from cache or new instance // Get from cache or new instance
LOCK_CACHE(); LOCK_CACHE();
if (cache && cache->uid == uid && cache->isFresh()) { if (!cached || cached->uid != uid || !cached->is_fresh())
info = cache; cached = make_shared<su_info>(uid);
} else { info = cached;
if (cache && cache->ref == 0) info->refresh();
delete cache;
cache = info = new su_info(uid);
}
info->newRef();
UNLOCK_CACHE(); UNLOCK_CACHE();
LOGD("su: request from uid=[%d] (#%d)\n", info->uid, ++info->count); LOGD("su: request from uid=[%d]\n", info->uid);
// Lock before the policy is determined // Lock before the policy is determined
info->lock(); info->lock();
if (info->access.policy == QUERY) { if (info->access.policy == QUERY) {
// Not cached, get data from database // Not cached, get data from database
database_check(info); database_check(info);
// Check su access settings // Check su access settings
@ -195,13 +186,12 @@ static void set_identity(unsigned uid) {
void su_daemon_handler(int client, struct ucred *credential) { void su_daemon_handler(int client, struct ucred *credential) {
LOGD("su: request from pid=[%d], client=[%d]\n", credential->pid, client); LOGD("su: request from pid=[%d], client=[%d]\n", credential->pid, client);
su_info *info = get_su_info(credential->uid); auto info = get_su_info(credential->uid);
// Fail fast // Fail fast
if (info->access.policy == DENY && info->str[SU_MANAGER][0] == '\0') { if (info->access.policy == DENY && info->str[SU_MANAGER][0] == '\0') {
LOGD("su: fast deny\n"); LOGD("su: fast deny\n");
info->deRef();
write_int(client, DENY); write_int(client, DENY);
close(client); close(client);
return; return;
@ -214,7 +204,7 @@ void su_daemon_handler(int client, struct ucred *credential) {
*/ */
int child = xfork(); int child = xfork();
if (child) { if (child) {
info->deRef(); info.reset();
// Wait result // Wait result
LOGD("su: waiting child pid=[%d]\n", child); LOGD("su: waiting child pid=[%d]\n", child);
@ -320,9 +310,9 @@ void su_daemon_handler(int client, struct ucred *credential) {
} }
if (info->access.log) if (info->access.log)
app_log(&ctx); app_log(ctx);
else if (info->access.notify) else if (info->access.notify)
app_notify(&ctx); app_notify(ctx);
if (info->access.policy == ALLOW) { if (info->access.policy == ALLOW) {
const char *argv[] = { nullptr, nullptr, nullptr, nullptr }; const char *argv[] = { nullptr, nullptr, nullptr, nullptr };