diff --git a/native/jni/magiskboot/bootimg.cpp b/native/jni/magiskboot/bootimg.cpp index 8e5cfe605..9978a7dc3 100644 --- a/native/jni/magiskboot/bootimg.cpp +++ b/native/jni/magiskboot/bootimg.cpp @@ -11,6 +11,7 @@ #include "bootimg.h" #include "magiskboot.h" +#include "compress.h" static void dump(void *buf, size_t size, const char *filename) { if (size == 0) @@ -163,11 +164,8 @@ int boot_img::parse_image(const char * image) { r_fmt = check_fmt(ramdisk, hdr->ramdisk_size); } - char fmt[16]; - get_fmt_name(k_fmt, fmt); - fprintf(stderr, "KERNEL_FMT [%s]\n", fmt); - get_fmt_name(r_fmt, fmt); - fprintf(stderr, "RAMDISK_FMT [%s]\n", fmt); + fprintf(stderr, "KERNEL_FMT [%s]\n", fmt2name[k_fmt]); + fprintf(stderr, "RAMDISK_FMT [%s]\n", fmt2name[r_fmt]); return flags & CHROMEOS_FLAG ? CHROMEOS_RET : 0; default: diff --git a/native/jni/magiskboot/compress.cpp b/native/jni/magiskboot/compress.cpp index f39aecd96..0f072efce 100644 --- a/native/jni/magiskboot/compress.cpp +++ b/native/jni/magiskboot/compress.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -10,142 +11,121 @@ #include "magiskboot.h" #include "compress.h" +using namespace std; + int64_t decompress(format_t type, int fd, const void *from, size_t size) { - auto cmp = get_decoder(type); - int64_t ret = cmp->one_step(fd, from, size); - delete cmp; - return ret; + unique_ptr cmp(get_decoder(type)); + return cmp->one_step(fd, from, size); } int64_t compress(format_t type, int fd, const void *from, size_t size) { - auto cmp = get_encoder(type); - int64_t ret = cmp->one_step(fd, from, size); - delete cmp; - return ret; + unique_ptr cmp(get_encoder(type)); + return cmp->one_step(fd, from, size); } -void decompress(char *from, const char *to) { - int strip = 1; - void *file; - size_t size = 0; - if (strcmp(from, "-") == 0) - stream_full_read(STDIN_FILENO, &file, &size); - else - mmap_ro(from, &file, &size); - format_t type = check_fmt(file, size); - char *ext; - ext = strrchr(from, '.'); - if (to == nullptr) - to = from; - if (ext != nullptr) { - // Strip out a matched file extension - switch (type) { - case GZIP: - if (strcmp(ext, ".gz") != 0) - strip = 0; - break; - case XZ: - if (strcmp(ext, ".xz") != 0) - strip = 0; - break; - case LZMA: - if (strcmp(ext, ".lzma") != 0) - strip = 0; - break; - case BZIP2: - if (strcmp(ext, ".bz2") != 0) - strip = 0; - break; - case LZ4_LEGACY: - case LZ4: - if (strcmp(ext, ".lz4") != 0) - strip = 0; - break; - default: - LOGE("Provided file \'%s\' is not a supported archive format\n", from); +static bool read_file(FILE *fp, const function &fn) { + char buf[4096]; + size_t len; + while ((len = fread(buf, 1, sizeof(buf), fp))) + fn(buf, len); + return true; +} + +void decompress(char *infile, const char *outfile) { + bool in_std = strcmp(infile, "-") == 0; + bool rm_in = false; + + FILE *in_file = in_std ? stdin : xfopen(infile, "re"); + int out_fd = -1; + unique_ptr cmp; + + read_file(in_file, [&](void *buf, size_t len) -> void { + if (out_fd < 0) { + format_t type = check_fmt(buf, len); + if (!COMPRESSED(type)) + LOGE("Input file is not a compressed type!\n"); + + cmp = std::move(unique_ptr(get_decoder(type))); + fprintf(stderr, "Detected format: [%s]\n", fmt2name[type]); + + /* If user does not provide outfile, infile has to be either + * .[ext], or '-'. Outfile will be either or '-'. + * If the input does not have proper format, abort */ + + char *ext = nullptr; + if (outfile == nullptr) { + outfile = infile; + if (!in_std) { + ext = strrchr(infile, '.'); + if (ext == nullptr || strcmp(ext, fmt2ext[type]) != 0) + LOGE("Input file is not a supported type!\n"); + + // Strip out extension and remove input + *ext = '\0'; + rm_in = true; + fprintf(stderr, "Decompressing to [%s]\n", outfile); + } + } + + out_fd = strcmp(outfile, "-") == 0 ? STDOUT_FILENO : creat(outfile, 0644); + cmp->set_outfd(out_fd); + if (ext) *ext = '.'; } - if (strip) - *ext = '\0'; - } + if (!cmp->update(buf, len)) + LOGE("Decompression error!\n"); + }); - int fd; + cmp->finalize(); + fclose(in_file); + close(out_fd); - if (strcmp(to, "-") == 0) { - fd = STDOUT_FILENO; - } else { - fd = creat(to, 0644); - fprintf(stderr, "Decompressing to [%s]\n", to); - } - - decompress(type, fd, file, size); - close(fd); - if (to == from && ext != nullptr) { - *ext = '.'; - unlink(from); - } - if (strcmp(from, "-") == 0) - free(file); - else - munmap(file, size); + if (rm_in) + unlink(infile); } -void compress(const char *method, const char *from, const char *to) { - format_t type; - const char *ext; - char dest[PATH_MAX]; - if (strcmp(method, "gzip") == 0) { - type = GZIP; - ext = "gz"; - } else if (strcmp(method, "xz") == 0) { - type = XZ; - ext = "xz"; - } else if (strcmp(method, "lzma") == 0) { - type = LZMA; - ext = "lzma"; - } else if (strcmp(method, "lz4") == 0) { - type = LZ4; - ext = "lz4"; - } else if (strcmp(method, "lz4_legacy") == 0) { - type = LZ4_LEGACY; - ext = "lz4"; - } else if (strcmp(method, "bzip2") == 0) { - type = BZIP2; - ext = "bz2"; +void compress(const char *method, const char *infile, const char *outfile) { + auto it = name2fmt.find(method); + if (it == name2fmt.end()) + LOGE("Unsupported compression method: [%s]\n", method); + + unique_ptr cmp(get_encoder(it->second)); + + bool in_std = strcmp(infile, "-") == 0; + bool rm_in = false; + + FILE *in_file = in_std ? stdin : xfopen(infile, "re"); + int out_fd; + + if (outfile == nullptr) { + if (in_std) { + out_fd = STDOUT_FILENO; + } else { + /* If user does not provide outfile and infile is not + * STDIN, output to .[ext] */ + char *tmp = new char[strlen(infile) + 5]; + sprintf(tmp, "%s%s", infile, fmt2ext[it->second]); + out_fd = creat(tmp, 0644); + fprintf(stderr, "Compressing to [%s]\n", tmp); + delete[] tmp; + rm_in = true; + } } else { - fprintf(stderr, "Only support following methods: "); - for (int i = 0; SUP_LIST[i]; ++i) - fprintf(stderr, "%s ", SUP_LIST[i]); - fprintf(stderr, "\n"); - exit(1); + out_fd = strcmp(infile, "-") == 0 ? STDOUT_FILENO : creat(infile, 0644); } - void *file; - size_t size; - if (strcmp(from, "-") == 0) - stream_full_read(STDIN_FILENO, &file, &size); - else - mmap_ro(from, &file, &size); - if (to == nullptr) { - if (strcmp(from, "-") == 0) - strcpy(dest, "-"); - else - snprintf(dest, sizeof(dest), "%s.%s", from, ext); - } else - strcpy(dest, to); - int fd; - if (strcmp(dest, "-") == 0) { - fd = STDOUT_FILENO; - } else { - fd = creat(dest, 0644); - fprintf(stderr, "Compressing to [%s]\n", dest); - } - compress(type, fd, file, size); - close(fd); - if (strcmp(from, "-") == 0) - free(file); - else - munmap(file, size); - if (to == nullptr) - unlink(from); + + cmp->set_outfd(out_fd); + + read_file(in_file, [&](void *buf, size_t len) -> void { + if (!cmp->update(buf, len)) + LOGE("Compression error!\n"); + }); + + cmp->finalize(); + fclose(in_file); + close(out_fd); + + if (rm_in) + unlink(infile); } diff --git a/native/jni/magiskboot/compress.h b/native/jni/magiskboot/compress.h index 4bb31c63c..7529915c1 100644 --- a/native/jni/magiskboot/compress.h +++ b/native/jni/magiskboot/compress.h @@ -185,3 +185,7 @@ private: Compression *get_encoder(format_t type); Compression *get_decoder(format_t type); +int64_t compress(format_t type, int fd, const void *from, size_t size); +int64_t decompress(format_t type, int fd, const void *from, size_t size); +void compress(const char *method, const char *infile, const char *outfile); +void decompress(char *infile, const char *outfile); diff --git a/native/jni/magiskboot/format.cpp b/native/jni/magiskboot/format.cpp index 5a4815ff0..23c441d64 100644 --- a/native/jni/magiskboot/format.cpp +++ b/native/jni/magiskboot/format.cpp @@ -2,6 +2,24 @@ #include "format.h" +std::map name2fmt; +Fmt2Name fmt2name; +Fmt2Ext fmt2ext; + +class FormatInit { +public: + FormatInit() { + name2fmt["gzip"] = GZIP; + name2fmt["xz"] = XZ; + name2fmt["lzma"] = LZMA; + name2fmt["bzip2"] = BZIP2; + name2fmt["lz4"] = LZ4; + name2fmt["lz4_legacy"] = LZ4_LEGACY; + } +}; + +static FormatInit init; + #define MATCH(s) (len >= (sizeof(s) - 1) && memcmp(buf, s, sizeof(s) - 1) == 0) format_t check_fmt(const void *buf, size_t len) { @@ -41,44 +59,51 @@ format_t check_fmt(const void *buf, size_t len) { } } -void get_fmt_name(format_t fmt, char *name) { - const char *s; +const char *Fmt2Name::operator[](format_t fmt) { switch (fmt) { case CHROMEOS: - s = "chromeos"; - break; + return "chromeos"; case AOSP: - s = "aosp"; - break; + return "aosp"; case GZIP: - s = "gzip"; - break; + return "gzip"; case LZOP: - s = "lzop"; - break; + return "lzop"; case XZ: - s = "xz"; - break; + return "xz"; case LZMA: - s = "lzma"; - break; + return "lzma"; case BZIP2: - s = "bzip2"; - break; + return "bzip2"; case LZ4: - s = "lz4"; - break; + return "lz4"; case LZ4_LEGACY: - s = "lz4_legacy"; - break; + return "lz4_legacy"; case MTK: - s = "mtk"; - break; + return "mtk"; case DTB: - s = "dtb"; - break; + return "dtb"; default: - s = "raw"; + return "raw"; + } +} + +const char *Fmt2Ext::operator[](format_t fmt) { + switch (fmt) { + case GZIP: + return ".gz"; + case LZOP: + return ".lzop"; + case XZ: + return ".xz"; + case LZMA: + return ".lzma"; + case BZIP2: + return ".bz2"; + case LZ4: + case LZ4_LEGACY: + return ".lz4"; + default: + return ""; } - strcpy(name, s); } diff --git a/native/jni/magiskboot/format.h b/native/jni/magiskboot/format.h index c0e22ee1e..6c51cf2ad 100644 --- a/native/jni/magiskboot/format.h +++ b/native/jni/magiskboot/format.h @@ -1,26 +1,32 @@ #ifndef _FORMAT_H_ #define _FORMAT_H_ +#include +#include + typedef enum { UNKNOWN, +/* Boot formats */ CHROMEOS, AOSP, ELF32, ELF64, + DHTB, + BLOB, +/* Compression formats */ GZIP, - LZOP, XZ, LZMA, BZIP2, LZ4, LZ4_LEGACY, +/* Misc */ + LZOP, MTK, DTB, - DHTB, - BLOB } format_t; -#define COMPRESSED(fmt) (fmt >= GZIP && fmt <= LZ4_LEGACY) +#define COMPRESSED(fmt) ((fmt) >= GZIP && (fmt) <= LZ4_LEGACY) #define BOOT_MAGIC "ANDROID!" #define CHROMEOS_MAGIC "CHROMEOS" @@ -47,7 +53,20 @@ typedef enum { #define SUP_LIST ((const char *[]) { "gzip", "xz", "lzma", "bzip2", "lz4", "lz4_legacy", NULL }) #define SUP_EXT_LIST ((const char *[]) { "gz", "xz", "lzma", "bz2", "lz4", "lz4", NULL }) +class Fmt2Name { +public: + const char *operator[](format_t fmt); +}; + +class Fmt2Ext { +public: + const char *operator[](format_t fmt); +}; + format_t check_fmt(const void *buf, size_t len); -void get_fmt_name(format_t fmt, char *name); + +extern std::map name2fmt; +extern Fmt2Name fmt2name; +extern Fmt2Ext fmt2ext; #endif diff --git a/native/jni/magiskboot/magiskboot.h b/native/jni/magiskboot/magiskboot.h index 3e3283d46..46c2be8f1 100644 --- a/native/jni/magiskboot/magiskboot.h +++ b/native/jni/magiskboot/magiskboot.h @@ -18,14 +18,8 @@ int unpack(const char *image); void repack(const char* orig_image, const char* out_image); void hexpatch(const char *image, const char *from, const char *to); int cpio_commands(int argc, char *argv[]); -void compress(const char *method, const char *from, const char *to); -void decompress(char *from, const char *to); int dtb_commands(const char *cmd, int argc, char *argv[]); -// Compressions -int64_t compress(format_t type, int fd, const void *from, size_t size); -int64_t decompress(format_t type, int fd, const void *from, size_t size); - // Pattern int patch_verity(void **buf, uint32_t *size, int patch); void patch_encryption(void **buf, uint32_t *size); diff --git a/native/jni/magiskboot/main.cpp b/native/jni/magiskboot/main.cpp index 45e2d5b54..964964d19 100644 --- a/native/jni/magiskboot/main.cpp +++ b/native/jni/magiskboot/main.cpp @@ -10,10 +10,7 @@ #include #include "magiskboot.h" - -/******************** - Patch Boot Image -*********************/ +#include "compress.h" static void usage(char *arg0) { fprintf(stderr, @@ -137,24 +134,19 @@ int main(int argc, char *argv[]) { } else if (argc > 2 && strcmp(argv[1], "--unpack") == 0) { return unpack(argv[2]); } else if (argc > 2 && strcmp(argv[1], "--repack") == 0) { - repack(argv[2], argc > 3 ? argv[3] : NEW_BOOT); + repack(argv[2], argv[3] ? argv[3] : NEW_BOOT); } else if (argc > 2 && strcmp(argv[1], "--decompress") == 0) { - decompress(argv[2], argc > 3 ? argv[3] : nullptr); + decompress(argv[2], argv[3]); } else if (argc > 2 && strncmp(argv[1], "--compress", 10) == 0) { - const char *method; - method = strchr(argv[1], '='); - if (method == nullptr) method = "gzip"; - else method++; - compress(method, argv[2], argc > 3 ? argv[3] : nullptr); + compress(argv[1][10] == '=' ? &argv[1][11] : "gzip", argv[2], argv[3]); } else if (argc > 4 && strcmp(argv[1], "--hexpatch") == 0) { hexpatch(argv[2], argv[3], argv[4]); } else if (argc > 2 && strcmp(argv[1], "--cpio") == 0) { if (cpio_commands(argc - 2, argv + 2)) usage(argv[0]); } else if (argc > 2 && strncmp(argv[1], "--dtb", 5) == 0) { - char *cmd = argv[1] + 5; - if (*cmd == '\0') usage(argv[0]); - else ++cmd; - if (dtb_commands(cmd, argc - 2, argv + 2)) + if (argv[1][5] != '-') + usage(argv[0]); + if (dtb_commands(&argv[1][6], argc - 2, argv + 2)) usage(argv[0]); } else { usage(argv[0]);