diff --git a/tdutils/td/utils/base64.cpp b/tdutils/td/utils/base64.cpp index e8e92795c..a1267eb7d 100644 --- a/tdutils/td/utils/base64.cpp +++ b/tdutils/td/utils/base64.cpp @@ -16,14 +16,17 @@ namespace td { //TODO: fix copypaste -static const char *const symbols64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -static const char *const url_symbols64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; +template +static const char *get_characters() { + return is_url ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +} template static const unsigned char *get_character_table() { static unsigned char char_to_value[256]; static bool is_inited = [] { - auto characters = is_url ? url_symbols64 : symbols64; + auto characters = get_characters(); std::fill(std::begin(char_to_value), std::end(char_to_value), static_cast(64)); for (unsigned char i = 0; i < 64; i++) { char_to_value[static_cast(characters[i])] = i; @@ -34,34 +37,40 @@ static const unsigned char *get_character_table() { return char_to_value; } -string base64_encode(Slice input) { +template +string base64_encode_impl(Slice input) { + auto characters = get_characters(); string base64; base64.reserve((input.size() + 2) / 3 * 4); for (size_t i = 0; i < input.size();) { size_t left = min(input.size() - i, static_cast(3)); int c = input.ubegin()[i++] << 16; - base64 += symbols64[c >> 18]; + base64 += characters[c >> 18]; if (left != 1) { c |= input.ubegin()[i++] << 8; } - base64 += symbols64[(c >> 12) & 63]; + base64 += characters[(c >> 12) & 63]; if (left == 3) { c |= input.ubegin()[i++]; } if (left != 1) { - base64 += symbols64[(c >> 6) & 63]; - } else { + base64 += characters[(c >> 6) & 63]; + } else if (!is_url) { base64 += '='; } if (left == 3) { - base64 += symbols64[c & 63]; - } else { + base64 += characters[c & 63]; + } else if (!is_url) { base64 += '='; } } return base64; } +string base64_encode(Slice input) { + return base64_encode_impl(input); +} + Result base64_drop_padding(Slice base64) { if ((base64.size() & 3) != 0) { return Status::Error("Wrong string length"); @@ -133,27 +142,7 @@ Result base64_decode_secure(Slice base64) { } string base64url_encode(Slice input) { - string base64; - base64.reserve((input.size() + 2) / 3 * 4); - for (size_t i = 0; i < input.size();) { - size_t left = min(input.size() - i, static_cast(3)); - int c = input.ubegin()[i++] << 16; - base64 += url_symbols64[c >> 18]; - if (left != 1) { - c |= input.ubegin()[i++] << 8; - } - base64 += url_symbols64[(c >> 12) & 63]; - if (left == 3) { - c |= input.ubegin()[i++]; - } - if (left != 1) { - base64 += url_symbols64[(c >> 6) & 63]; - } - if (left == 3) { - base64 += url_symbols64[c & 63]; - } - } - return base64; + return base64_encode_impl(input); } Result base64url_decode(Slice base64) { diff --git a/tdutils/test/misc.cpp b/tdutils/test/misc.cpp index 24953db1d..f06ce0fbd 100644 --- a/tdutils/test/misc.cpp +++ b/tdutils/test/misc.cpp @@ -231,6 +231,9 @@ TEST(Misc, base64) { ASSERT_TRUE(base64_encode(" /'.;.';≤.];,].',[.;/,.;/]/..;!@#!*(%?::;!%\";") == "ICAgICAgLycuOy4nO+KJpC5dOyxdLicsWy47LywuOy9dLy4uOyFAIyEqKCU/" "Ojo7ISUiOw=="); + ASSERT_TRUE(base64url_encode("ab><") == "YWI-PA"); + ASSERT_TRUE(base64url_encode("ab>