Use smart pointers

This commit is contained in:
topjohnwu 2019-11-23 04:57:52 -05:00
parent 5bee1c56a9
commit 01253f050a
8 changed files with 87 additions and 81 deletions

View File

@ -22,15 +22,15 @@ uint32_t dyn_img_hdr::j32 = 0;
uint64_t dyn_img_hdr::j64 = 0; uint64_t dyn_img_hdr::j64 = 0;
static void decompress(format_t type, int fd, const void *in, size_t size) { static void decompress(format_t type, int fd, const void *in, size_t size) {
unique_ptr<stream> ptr(get_decoder(type, open_stream<fd_stream>(fd))); auto ptr = get_decoder(type, make_stream<fd_stream>(fd));
ptr->write(in, size); ptr->write(in, size);
} }
static int64_t compress(format_t type, int fd, const void *in, size_t size) { static off_t compress(format_t type, int fd, const void *in, size_t size) {
auto prev = lseek(fd, 0, SEEK_CUR); auto prev = lseek(fd, 0, SEEK_CUR);
{ {
unique_ptr<stream> ptr(get_encoder(type, open_stream<fd_stream>(fd))); auto strm = get_encoder(type, make_stream<fd_stream>(fd));
ptr->write(in, size); strm->write(in, size);
} }
auto now = lseek(fd, 0, SEEK_CUR); auto now = lseek(fd, 0, SEEK_CUR);
return now - prev; return now - prev;

View File

@ -29,11 +29,8 @@ constexpr size_t LZ4_COMPRESSED = LZ4_COMPRESSBOUND(LZ4_UNCOMPRESSED);
class cpr_stream : public filter_stream { class cpr_stream : public filter_stream {
public: public:
explicit cpr_stream(FILE *fp) : filter_stream(fp) {} using filter_stream::filter_stream;
using stream::read;
int read(void *buf, size_t len) final {
return stream::read(buf, len);
}
}; };
class gz_strm : public cpr_stream { class gz_strm : public cpr_stream {
@ -60,7 +57,7 @@ protected:
ENCODE ENCODE
} mode; } mode;
gz_strm(mode_t mode, FILE *fp) : cpr_stream(fp), mode(mode) { gz_strm(mode_t mode, sFILE &&fp) : cpr_stream(std::move(fp)), mode(mode) {
switch(mode) { switch(mode) {
case DECODE: case DECODE:
inflateInit2(&strm, 15 | 16); inflateInit2(&strm, 15 | 16);
@ -102,12 +99,12 @@ private:
class gz_decoder : public gz_strm { class gz_decoder : public gz_strm {
public: public:
explicit gz_decoder(FILE *fp) : gz_strm(DECODE, fp) {}; explicit gz_decoder(sFILE &&fp) : gz_strm(DECODE, std::move(fp)) {};
}; };
class gz_encoder : public gz_strm { class gz_encoder : public gz_strm {
public: public:
explicit gz_encoder(FILE *fp) : gz_strm(ENCODE, fp) {}; explicit gz_encoder(sFILE &&fp) : gz_strm(ENCODE, std::move(fp)) {};
}; };
class bz_strm : public cpr_stream { class bz_strm : public cpr_stream {
@ -134,7 +131,7 @@ protected:
ENCODE ENCODE
} mode; } mode;
bz_strm(mode_t mode, FILE *fp) : cpr_stream(fp), mode(mode) { bz_strm(mode_t mode, sFILE &&fp) : cpr_stream(std::move(fp)), mode(mode) {
switch(mode) { switch(mode) {
case DECODE: case DECODE:
BZ2_bzDecompressInit(&strm, 0, 0); BZ2_bzDecompressInit(&strm, 0, 0);
@ -176,12 +173,12 @@ private:
class bz_decoder : public bz_strm { class bz_decoder : public bz_strm {
public: public:
explicit bz_decoder(FILE *fp) : bz_strm(DECODE, fp) {}; explicit bz_decoder(sFILE &&fp) : bz_strm(DECODE, std::move(fp)) {};
}; };
class bz_encoder : public bz_strm { class bz_encoder : public bz_strm {
public: public:
explicit bz_encoder(FILE *fp) : bz_strm(ENCODE, fp) {}; explicit bz_encoder(sFILE &&fp) : bz_strm(ENCODE, std::move(fp)) {};
}; };
class lzma_strm : public cpr_stream { class lzma_strm : public cpr_stream {
@ -202,7 +199,8 @@ protected:
ENCODE_LZMA ENCODE_LZMA
} mode; } mode;
lzma_strm(mode_t mode, FILE *fp) : cpr_stream(fp), mode(mode), strm(LZMA_STREAM_INIT) { lzma_strm(mode_t mode, sFILE &&fp)
: cpr_stream(std::move(fp)), mode(mode), strm(LZMA_STREAM_INIT) {
lzma_options_lzma opt; lzma_options_lzma opt;
// Initialize preset // Initialize preset
@ -249,22 +247,22 @@ private:
class lzma_decoder : public lzma_strm { class lzma_decoder : public lzma_strm {
public: public:
explicit lzma_decoder(FILE *fp) : lzma_strm(DECODE, fp) {} explicit lzma_decoder(sFILE &&fp) : lzma_strm(DECODE, std::move(fp)) {}
}; };
class xz_encoder : public lzma_strm { class xz_encoder : public lzma_strm {
public: public:
explicit xz_encoder(FILE *fp) : lzma_strm(ENCODE_XZ, fp) {} explicit xz_encoder(sFILE &&fp) : lzma_strm(ENCODE_XZ, std::move(fp)) {}
}; };
class lzma_encoder : public lzma_strm { class lzma_encoder : public lzma_strm {
public: public:
explicit lzma_encoder(FILE *fp) : lzma_strm(ENCODE_LZMA, fp) {} explicit lzma_encoder(sFILE &&fp) : lzma_strm(ENCODE_LZMA, std::move(fp)) {}
}; };
class LZ4F_decoder : public cpr_stream { class LZ4F_decoder : public cpr_stream {
public: public:
explicit LZ4F_decoder(FILE *fp) : cpr_stream(fp), outbuf(nullptr) { explicit LZ4F_decoder(sFILE &&fp) : cpr_stream(std::move(fp)), outbuf(nullptr) {
LZ4F_createDecompressionContext(&ctx, LZ4F_VERSION); LZ4F_createDecompressionContext(&ctx, LZ4F_VERSION);
} }
@ -319,7 +317,8 @@ private:
class LZ4F_encoder : public cpr_stream { class LZ4F_encoder : public cpr_stream {
public: public:
explicit LZ4F_encoder(FILE *fp) : cpr_stream(fp), outbuf(nullptr), outCapacity(0) { explicit LZ4F_encoder(sFILE &&fp)
: cpr_stream(std::move(fp)), outbuf(nullptr), outCapacity(0) {
LZ4F_createCompressionContext(&ctx, LZ4F_VERSION); LZ4F_createCompressionContext(&ctx, LZ4F_VERSION);
} }
@ -379,9 +378,9 @@ private:
class LZ4_decoder : public cpr_stream { class LZ4_decoder : public cpr_stream {
public: public:
explicit LZ4_decoder(FILE *fp) explicit LZ4_decoder(sFILE &&fp)
: cpr_stream(fp), out_buf(new char[LZ4_UNCOMPRESSED]), buffer(new char[LZ4_COMPRESSED]), : cpr_stream(std::move(fp)), out_buf(new char[LZ4_UNCOMPRESSED]),
init(false), block_sz(0), buf_off(0) {} buffer(new char[LZ4_COMPRESSED]), init(false), block_sz(0), buf_off(0) {}
~LZ4_decoder() override { ~LZ4_decoder() override {
delete[] out_buf; delete[] out_buf;
@ -440,8 +439,8 @@ private:
class LZ4_encoder : public cpr_stream { class LZ4_encoder : public cpr_stream {
public: public:
explicit LZ4_encoder(FILE *fp) explicit LZ4_encoder(sFILE &&fp)
: cpr_stream(fp), outbuf(new char[LZ4_COMPRESSED]), buf(new char[LZ4_UNCOMPRESSED]), : cpr_stream(std::move(fp)), outbuf(new char[LZ4_COMPRESSED]), buf(new char[LZ4_UNCOMPRESSED]),
init(false), buf_off(0), in_total(0) {} init(false), buf_off(0), in_total(0) {}
int write(const void *in, size_t size) override { int write(const void *in, size_t size) override {
@ -501,38 +500,38 @@ private:
unsigned in_total; unsigned in_total;
}; };
filter_stream *get_encoder(format_t type, FILE *fp) { stream_ptr get_encoder(format_t type, sFILE &&fp) {
switch (type) { switch (type) {
case XZ: case XZ:
return new xz_encoder(fp); return make_unique<xz_encoder>(std::move(fp));
case LZMA: case LZMA:
return new lzma_encoder(fp); return make_unique<lzma_encoder>(std::move(fp));
case BZIP2: case BZIP2:
return new bz_encoder(fp); return make_unique<bz_encoder>(std::move(fp));
case LZ4: case LZ4:
return new LZ4F_encoder(fp); return make_unique<LZ4F_encoder>(std::move(fp));
case LZ4_LEGACY: case LZ4_LEGACY:
return new LZ4_encoder(fp); return make_unique<LZ4_encoder>(std::move(fp));
case GZIP: case GZIP:
default: default:
return new gz_encoder(fp); return make_unique<gz_encoder>(std::move(fp));
} }
} }
filter_stream *get_decoder(format_t type, FILE *fp) { stream_ptr get_decoder(format_t type, sFILE &&fp) {
switch (type) { switch (type) {
case XZ: case XZ:
case LZMA: case LZMA:
return new lzma_decoder(fp); return make_unique<lzma_decoder>(std::move(fp));
case BZIP2: case BZIP2:
return new bz_decoder(fp); return make_unique<bz_decoder>(std::move(fp));
case LZ4: case LZ4:
return new LZ4F_decoder(fp); return make_unique<LZ4F_decoder>(std::move(fp));
case LZ4_LEGACY: case LZ4_LEGACY:
return new LZ4_decoder(fp); return make_unique<LZ4_decoder>(std::move(fp));
case GZIP: case GZIP:
default: default:
return new gz_decoder(fp); return make_unique<gz_decoder>(std::move(fp));
} }
} }
@ -541,7 +540,7 @@ void decompress(char *infile, const char *outfile) {
bool rm_in = false; bool rm_in = false;
FILE *in_fp = in_std ? stdin : xfopen(infile, "re"); FILE *in_fp = in_std ? stdin : xfopen(infile, "re");
unique_ptr<stream> strm; stream_ptr strm;
char buf[4096]; char buf[4096];
size_t len; size_t len;
@ -574,7 +573,7 @@ void decompress(char *infile, const char *outfile) {
} }
FILE *out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we"); FILE *out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we");
strm.reset(get_decoder(type, out_fp)); strm = get_decoder(type, make_sFILE(out_fp));
if (ext) *ext = '.'; if (ext) *ext = '.';
} }
if (strm->write(buf, len) < 0) if (strm->write(buf, len) < 0)
@ -615,7 +614,7 @@ void compress(const char *method, const char *infile, const char *outfile) {
out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we"); out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we");
} }
unique_ptr<stream> strm(get_encoder(it->second, out_fp)); auto strm = get_encoder(it->second, make_sFILE(out_fp));
char buf[4096]; char buf[4096];
size_t len; size_t len;

View File

@ -4,7 +4,10 @@
#include "format.h" #include "format.h"
filter_stream *get_encoder(format_t type, FILE *fp = nullptr); stream_ptr get_encoder(format_t type, sFILE &&fp);
filter_stream *get_decoder(format_t type, FILE *fp = nullptr);
stream_ptr get_decoder(format_t type, sFILE &&fp);
void compress(const char *method, const char *infile, const char *outfile); void compress(const char *method, const char *infile, const char *outfile);
void decompress(char *infile, const char *outfile); void decompress(char *infile, const char *outfile);

View File

@ -244,8 +244,8 @@ void magisk_cpio::compress() {
uint8_t *data; uint8_t *data;
size_t len; size_t len;
FILE *fp = open_stream(get_encoder(XZ, open_stream<byte_stream>(data, len))); auto strm = make_stream(get_encoder(XZ, make_stream<byte_stream>(data, len)));
dump(fp); dump(strm.release());
entries.clear(); entries.clear();
entries.insert(std::move(init)); entries.insert(std::move(init));
@ -263,9 +263,10 @@ void magisk_cpio::decompress() {
char *data; char *data;
size_t len; size_t len;
auto strm = get_decoder(XZ, open_stream<byte_stream>(data, len)); {
auto strm = get_decoder(XZ, make_stream<byte_stream>(data, len));
strm->write(it->second->data, it->second->filesize); strm->write(it->second->data, it->second->filesize);
delete strm; }
entries.erase(it); entries.erase(it);
load_cpio(data, len); load_cpio(data, len);

View File

@ -174,19 +174,20 @@ int compile_split_cil() {
} }
int dump_policydb(const char *file) { int dump_policydb(const char *file) {
struct policy_file pf;
policy_file_init(&pf);
uint8_t *data; uint8_t *data;
size_t len; size_t len;
{
auto fp = make_stream<byte_stream>(data, len);
struct policy_file pf;
policy_file_init(&pf);
pf.type = PF_USE_STDIO; pf.type = PF_USE_STDIO;
pf.fp = open_stream<byte_stream>(data, len); pf.fp = fp.get();
if (policydb_write(magisk_policydb, &pf)) { if (policydb_write(magisk_policydb, &pf)) {
LOGE("Fail to create policy image\n"); LOGE("Fail to create policy image\n");
return 1; return 1;
} }
fclose(pf.fp); }
int fd = xopen(file, O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC, 0644); int fd = xopen(file, O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC, 0644);
if (fd < 0) if (fd < 0)

View File

@ -6,6 +6,11 @@
#define do_align(p, a) (((p) + (a) - 1) / (a) * (a)) #define do_align(p, a) (((p) + (a) - 1) / (a) * (a))
#define align_off(p, a) (do_align(p, a) - (p)) #define align_off(p, a) (do_align(p, a) - (p))
using sFILE = std::unique_ptr<FILE, decltype(&fclose)>;
static inline sFILE make_sFILE(FILE *fp = nullptr) {
return sFILE(fp, fclose);
}
struct file_attr { struct file_attr {
struct stat st; struct stat st;
char con[128]; char con[128];

View File

@ -1,16 +1,19 @@
#pragma once #pragma once
#include <unistd.h>
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include "../files.h"
class stream; class stream;
FILE *open_stream(stream *strm); using stream_ptr = std::unique_ptr<stream>;
sFILE make_stream(stream_ptr &&strm);
template <class T, class... Args> template <class T, class... Args>
FILE *open_stream(Args &&... args) { sFILE make_stream(Args &&... args) {
return open_stream(new T(args...)); return make_stream(stream_ptr(new T(std::forward<Args>(args)...)));
} }
class stream { class stream {
@ -24,20 +27,20 @@ public:
// Delegates all operations to the base FILE pointer // Delegates all operations to the base FILE pointer
class filter_stream : public stream { class filter_stream : public stream {
public: public:
filter_stream(FILE *fp) : fp(fp) {} filter_stream() = default;
~filter_stream() override; filter_stream(sFILE &&fp) : fp(std::move(fp)) {}
int read(void *buf, size_t len) override; int read(void *buf, size_t len) override;
int write(const void *buf, size_t len) override; int write(const void *buf, size_t len) override;
void set_base(FILE *f); void set_base(sFILE &&f);
template <class T, class... Args > template <class T, class... Args >
void set_base(Args&&... args) { void set_base(Args&&... args) {
set_base(open_stream<T>(args...)); set_base(make_stream<T>(std::forward<Args>(args)...));
} }
protected: protected:
FILE *fp; sFILE fp;
}; };
// Handy interface for classes that need custom seek logic // Handy interface for classes that need custom seek logic
@ -65,7 +68,7 @@ private:
size_t _cap = 0; size_t _cap = 0;
void resize(size_t new_pos, bool zero = false); void resize(size_t new_pos, bool zero = false);
size_t end_pos() override { return _len; } size_t end_pos() final { return _len; }
}; };
// File stream but does not close the file descriptor at any time // File stream but does not close the file descriptor at any time

View File

@ -23,10 +23,9 @@ static int strm_close(void *v) {
return 0; return 0;
} }
FILE *open_stream(stream *strm) { sFILE make_stream(stream_ptr &&strm) {
FILE *fp = funopen(strm, strm_read, strm_write, strm_seek, strm_close); sFILE fp(funopen(strm.release(), strm_read, strm_write, strm_seek, strm_close), fclose);
// Disable buffering setbuf(fp.get(), nullptr);
setbuf(fp, nullptr);
return fp; return fp;
} }
@ -45,21 +44,16 @@ off_t stream::seek(off_t off, int whence) {
return -1; return -1;
} }
filter_stream::~filter_stream() {
if (fp) fclose(fp);
}
int filter_stream::read(void *buf, size_t len) { int filter_stream::read(void *buf, size_t len) {
return fread(buf, 1, len, fp); return fread(buf, 1, len, fp.get());
} }
int filter_stream::write(const void *buf, size_t len) { int filter_stream::write(const void *buf, size_t len) {
return fwrite(buf, 1, len, fp); return fwrite(buf, 1, len, fp.get());
} }
void filter_stream::set_base(FILE *f) { void filter_stream::set_base(sFILE &&f) {
if (fp) fclose(fp); fp = std::move(f);
fp = f;
} }
off_t seekable_stream::seek_pos(off_t off, int whence) { off_t seekable_stream::seek_pos(off_t off, int whence) {