diff --git a/td/telegram/InlineQueriesManager.cpp b/td/telegram/InlineQueriesManager.cpp index c07a31ef..a3e9723c 100644 --- a/td/telegram/InlineQueriesManager.cpp +++ b/td/telegram/InlineQueriesManager.cpp @@ -330,7 +330,7 @@ void InlineQueriesManager::answer_inline_query(int64 inline_query_id, bool is_pe if (switch_pm_parameter.size() > 64) { return promise.set_error(Status::Error(400, "Too long switch_pm_parameter specified")); } - if (!is_base64url(switch_pm_parameter)) { + if (!is_base64url_characters(switch_pm_parameter)) { return promise.set_error(Status::Error(400, "Unallowed characters in switch_pm_parameter are used")); } } diff --git a/tdutils/td/utils/base64.cpp b/tdutils/td/utils/base64.cpp index e4ecb5e7..35d91690 100644 --- a/tdutils/td/utils/base64.cpp +++ b/tdutils/td/utils/base64.cpp @@ -216,6 +216,17 @@ Result base64url_decode(Slice base64) { return output; } +template +static const unsigned char *get_character_table() { + if (is_url) { + init_base64url_table(); + return url_char_to_value; + } else { + init_base64_table(); + return char_to_value; + } +} + template static bool is_base64_impl(Slice input) { size_t padding_length = 0; @@ -233,14 +244,7 @@ static bool is_base64_impl(Slice input) { return false; } - unsigned char *table; - if (is_url) { - init_base64url_table(); - table = url_char_to_value; - } else { - init_base64_table(); - table = char_to_value; - } + auto table = get_character_table(); for (auto c : input) { if (table[static_cast(c)] == 64) { return false; @@ -271,6 +275,25 @@ bool is_base64url(Slice input) { return is_base64_impl(input); } +template +static bool is_base64_characters_impl(Slice input) { + auto table = get_character_table(); + for (auto c : input) { + if (table[static_cast(c)] == 64) { + return false; + } + } + return true; +} + +bool is_base64_characters(Slice input) { + return is_base64_characters_impl(input); +} + +bool is_base64url_characters(Slice input) { + return is_base64_characters_impl(input); +} + string base64_filter(Slice input) { string res; res.reserve(input.size()); diff --git a/tdutils/td/utils/base64.h b/tdutils/td/utils/base64.h index 369c02a3..05ef1c7d 100644 --- a/tdutils/td/utils/base64.h +++ b/tdutils/td/utils/base64.h @@ -23,6 +23,9 @@ Result base64url_decode(Slice base64); bool is_base64(Slice input); bool is_base64url(Slice input); +bool is_base64_characters(Slice input); +bool is_base64url_characters(Slice input); + string base64_filter(Slice input); } // namespace td diff --git a/tdutils/test/misc.cpp b/tdutils/test/misc.cpp index acda4e38..24953db1 100644 --- a/tdutils/test/misc.cpp +++ b/tdutils/test/misc.cpp @@ -183,6 +183,30 @@ TEST(Misc, base64) { ASSERT_TRUE(is_base64url("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") == false); ASSERT_TRUE(is_base64url("====") == false); + ASSERT_TRUE(is_base64_characters("dGVzdA==") == false); + ASSERT_TRUE(is_base64_characters("dGVzdB==") == false); + ASSERT_TRUE(is_base64_characters("dGVzdA=") == false); + ASSERT_TRUE(is_base64_characters("dGVzdA") == true); + ASSERT_TRUE(is_base64_characters("dGVz") == true); + ASSERT_TRUE(is_base64_characters("") == true); + ASSERT_TRUE(is_base64_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") == true); + ASSERT_TRUE(is_base64_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=") == false); + ASSERT_TRUE(is_base64_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-/") == false); + ASSERT_TRUE(is_base64_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") == false); + ASSERT_TRUE(is_base64_characters("====") == false); + + ASSERT_TRUE(is_base64url_characters("dGVzdA==") == false); + ASSERT_TRUE(is_base64url_characters("dGVzdB==") == false); + ASSERT_TRUE(is_base64url_characters("dGVzdA=") == false); + ASSERT_TRUE(is_base64url_characters("dGVzdA") == true); + ASSERT_TRUE(is_base64url_characters("dGVz") == true); + ASSERT_TRUE(is_base64url_characters("") == true); + ASSERT_TRUE(is_base64url_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") == true); + ASSERT_TRUE(is_base64url_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=") == false); + ASSERT_TRUE(is_base64url_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-/") == false); + ASSERT_TRUE(is_base64url_characters("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") == false); + ASSERT_TRUE(is_base64url_characters("====") == false); + for (int l = 0; l < 300000; l += l / 20 + l / 1000 * 500 + 1) { for (int t = 0; t < 10; t++) { string s = rand_string(std::numeric_limits::min(), std::numeric_limits::max(), l);