Simplify base64url_decode implementation.

GitOrigin-RevId: 873483e61cc54fad78a09aa8a143070c5e018dfb
This commit is contained in:
levlam 2020-01-09 22:07:23 +03:00
parent 529f925d40
commit 7db7757d2d
2 changed files with 23 additions and 54 deletions

View File

@ -14,7 +14,6 @@
#include <iterator> #include <iterator>
namespace td { namespace td {
//TODO: fix copypaste
template <bool is_url> template <bool is_url>
static const char *get_characters() { static const char *get_characters() {
@ -67,15 +66,8 @@ string base64_encode_impl(Slice input) {
return base64; return base64;
} }
string base64_encode(Slice input) { template <bool is_url>
return base64_encode_impl<false>(input);
}
Result<Slice> base64_drop_padding(Slice base64) { Result<Slice> base64_drop_padding(Slice base64) {
if ((base64.size() & 3) != 0) {
return Status::Error("Wrong string length");
}
size_t padding_length = 0; size_t padding_length = 0;
while (!base64.empty() && base64.back() == '=') { while (!base64.empty() && base64.back() == '=') {
base64.remove_suffix(1); base64.remove_suffix(1);
@ -84,12 +76,18 @@ Result<Slice> base64_drop_padding(Slice base64) {
if (padding_length >= 3) { if (padding_length >= 3) {
return Status::Error("Wrong string padding"); return Status::Error("Wrong string padding");
} }
if ((!is_url || padding_length > 0) && ((base64.size() + padding_length) & 3) != 0) {
return Status::Error("Wrong padding length");
}
if (is_url && (base64.size() & 3) == 1) {
return Status::Error("Wrong string length");
}
return base64; return base64;
} }
template <class F> template <bool is_url, class F>
Status base64_do_decode(Slice base64, F &&append) { Status base64_do_decode(Slice base64, F &&append) {
auto table = get_character_table<false>(); auto table = get_character_table<is_url>();
for (size_t i = 0; i < base64.size();) { for (size_t i = 0; i < base64.size();) {
size_t left = min(base64.size() - i, static_cast<size_t>(4)); size_t left = min(base64.size() - i, static_cast<size_t>(4));
int c = 0; int c = 0;
@ -120,20 +118,20 @@ Status base64_do_decode(Slice base64, F &&append) {
} }
Result<string> base64_decode(Slice base64) { Result<string> base64_decode(Slice base64) {
TRY_RESULT_ASSIGN(base64, base64_drop_padding(base64)); TRY_RESULT_ASSIGN(base64, base64_drop_padding<false>(base64));
string output; string output;
output.reserve(((base64.size() + 3) >> 2) * 3); output.reserve(((base64.size() + 3) >> 2) * 3);
TRY_STATUS(base64_do_decode(base64, [&output](char c) { output += c; })); TRY_STATUS(base64_do_decode<false>(base64, [&output](char c) { output += c; }));
return output; return output;
} }
Result<SecureString> base64_decode_secure(Slice base64) { Result<SecureString> base64_decode_secure(Slice base64) {
TRY_RESULT_ASSIGN(base64, base64_drop_padding(base64)); TRY_RESULT_ASSIGN(base64, base64_drop_padding<false>(base64));
SecureString output(((base64.size() + 3) >> 2) * 3); SecureString output(((base64.size() + 3) >> 2) * 3);
char *ptr = output.as_mutable_slice().begin(); char *ptr = output.as_mutable_slice().begin();
TRY_STATUS(base64_do_decode(base64, [&ptr](char c) { *ptr++ = c; })); TRY_STATUS(base64_do_decode<false>(base64, [&ptr](char c) { *ptr++ = c; }));
size_t size = ptr - output.as_mutable_slice().begin(); size_t size = ptr - output.as_mutable_slice().begin();
if (size == output.size()) { if (size == output.size()) {
return std::move(output); return std::move(output);
@ -141,53 +139,20 @@ Result<SecureString> base64_decode_secure(Slice base64) {
return SecureString(output.as_slice().substr(0, size)); return SecureString(output.as_slice().substr(0, size));
} }
string base64_encode(Slice input) {
return base64_encode_impl<false>(input);
}
string base64url_encode(Slice input) { string base64url_encode(Slice input) {
return base64_encode_impl<true>(input); return base64_encode_impl<true>(input);
} }
Result<string> base64url_decode(Slice base64) { Result<string> base64url_decode(Slice base64) {
size_t padding_length = 0; TRY_RESULT_ASSIGN(base64, base64_drop_padding<true>(base64));
while (!base64.empty() && base64.back() == '=') {
base64.remove_suffix(1);
padding_length++;
}
if (padding_length >= 3 || (padding_length > 0 && ((base64.size() + padding_length) & 3) != 0)) {
return Status::Error("Wrong string padding");
}
if ((base64.size() & 3) == 1) {
return Status::Error("Wrong string length");
}
auto table = get_character_table<true>();
string output; string output;
output.reserve(((base64.size() + 3) >> 2) * 3); output.reserve(((base64.size() + 3) >> 2) * 3);
for (size_t i = 0; i < base64.size();) { TRY_STATUS(base64_do_decode<true>(base64, [&output](char c) { output += c; }));
size_t left = min(base64.size() - i, static_cast<size_t>(4));
int c = 0;
for (size_t t = 0; t < left; t++) {
auto value = table[base64.ubegin()[i++]];
if (value == 64) {
return Status::Error("Wrong character in the string");
}
c |= value << ((3 - t) * 6);
}
output += static_cast<char>(static_cast<unsigned char>(c >> 16)); // implementation-defined
if (left == 2) {
if ((c & ((1 << 16) - 1)) != 0) {
return Status::Error("Wrong padding in the string");
}
} else {
output += static_cast<char>(static_cast<unsigned char>(c >> 8)); // implementation-defined
if (left == 3) {
if ((c & ((1 << 8) - 1)) != 0) {
return Status::Error("Wrong padding in the string");
}
} else {
output += static_cast<char>(static_cast<unsigned char>(c)); // implementation-defined
}
}
}
return output; return output;
} }

View File

@ -163,7 +163,9 @@ TEST(Misc, base64) {
ASSERT_TRUE(is_base64("dGVzdB==") == false); ASSERT_TRUE(is_base64("dGVzdB==") == false);
ASSERT_TRUE(is_base64("dGVzdA=") == false); ASSERT_TRUE(is_base64("dGVzdA=") == false);
ASSERT_TRUE(is_base64("dGVzdA") == false); ASSERT_TRUE(is_base64("dGVzdA") == false);
ASSERT_TRUE(is_base64("dGVzd") == false);
ASSERT_TRUE(is_base64("dGVz") == true); ASSERT_TRUE(is_base64("dGVz") == true);
ASSERT_TRUE(is_base64("dGVz====") == false);
ASSERT_TRUE(is_base64("") == true); ASSERT_TRUE(is_base64("") == true);
ASSERT_TRUE(is_base64("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") == true); ASSERT_TRUE(is_base64("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") == true);
ASSERT_TRUE(is_base64("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=") == false); ASSERT_TRUE(is_base64("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=") == false);
@ -175,7 +177,9 @@ TEST(Misc, base64) {
ASSERT_TRUE(is_base64url("dGVzdB==") == false); ASSERT_TRUE(is_base64url("dGVzdB==") == false);
ASSERT_TRUE(is_base64url("dGVzdA=") == false); ASSERT_TRUE(is_base64url("dGVzdA=") == false);
ASSERT_TRUE(is_base64url("dGVzdA") == true); ASSERT_TRUE(is_base64url("dGVzdA") == true);
ASSERT_TRUE(is_base64url("dGVzd") == false);
ASSERT_TRUE(is_base64url("dGVz") == true); ASSERT_TRUE(is_base64url("dGVz") == true);
ASSERT_TRUE(is_base64url("dGVz====") == false);
ASSERT_TRUE(is_base64url("") == true); ASSERT_TRUE(is_base64url("") == true);
ASSERT_TRUE(is_base64url("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") == true); ASSERT_TRUE(is_base64url("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") == true);
ASSERT_TRUE(is_base64url("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=") == false); ASSERT_TRUE(is_base64url("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_=") == false);