1
1
mirror of https://github.com/ytdl-org/youtube-dl synced 2024-11-27 06:36:49 +01:00

[utils] Align traverse_obj() with yt-dlp

Thanks Grub4k for these:
* traverse `Iterable`s, from https://github.com/yt-dlp/yt-dlp/pull/6902, etc
* traverse `set` key for transformations/filters, `re.Match` group names, from
  776995bc10, etc
* traverse `re.Match`es, from https://github.com/yt-dlp/yt-dlp/pull/5174
* always return list when branching, from https://github.com/yt-dlp/yt-dlp/pull/5170
This commit is contained in:
dirkf 2023-05-03 12:40:09 +01:00
parent 47214e46d8
commit 825a40744b
2 changed files with 23 additions and 23 deletions

View File

@ -20,7 +20,7 @@ import xml.etree.ElementTree
from youtube_dl.utils import ( from youtube_dl.utils import (
age_restricted, age_restricted,
args_to_str, args_to_str,
encode_base_n, base_url,
caesar, caesar,
clean_html, clean_html,
clean_podcast_url, clean_podcast_url,
@ -29,10 +29,12 @@ from youtube_dl.utils import (
detect_exe_version, detect_exe_version,
determine_ext, determine_ext,
dict_get, dict_get,
encode_base_n,
encode_compat_str, encode_compat_str,
encodeFilename, encodeFilename,
escape_rfc3986, escape_rfc3986,
escape_url, escape_url,
expand_path,
extract_attributes, extract_attributes,
ExtractorError, ExtractorError,
find_xpath_attr, find_xpath_attr,
@ -51,6 +53,7 @@ from youtube_dl.utils import (
js_to_json, js_to_json,
LazyList, LazyList,
limit_length, limit_length,
lowercase_escape,
merge_dicts, merge_dicts,
mimetype2ext, mimetype2ext,
month_by_name, month_by_name,
@ -66,17 +69,16 @@ from youtube_dl.utils import (
parse_resolution, parse_resolution,
parse_bitrate, parse_bitrate,
pkcs1pad, pkcs1pad,
read_batch_urls,
sanitize_filename,
sanitize_path,
sanitize_url,
expand_path,
prepend_extension, prepend_extension,
replace_extension, read_batch_urls,
remove_start, remove_start,
remove_end, remove_end,
remove_quotes, remove_quotes,
replace_extension,
rot47, rot47,
sanitize_filename,
sanitize_path,
sanitize_url,
shell_quote, shell_quote,
smuggle_url, smuggle_url,
str_or_none, str_or_none,
@ -93,10 +95,8 @@ from youtube_dl.utils import (
unified_timestamp, unified_timestamp,
unsmuggle_url, unsmuggle_url,
uppercase_escape, uppercase_escape,
lowercase_escape,
url_basename, url_basename,
url_or_none, url_or_none,
base_url,
urljoin, urljoin,
urlencode_postdata, urlencode_postdata,
urshift, urshift,
@ -1586,6 +1586,11 @@ Line 1
'dict': {}, 'dict': {},
} }
# define a pukka Iterable
def iter_range(stop):
for from_ in range(stop):
yield from_
# Test base functionality # Test base functionality
self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
msg='allow tuple path') msg='allow tuple path')
@ -1602,13 +1607,13 @@ Line 1
# Test Ellipsis behavior # Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis), self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
(item for item in _TEST_DATA.values() if item not in (None, {})), (item for item in _TEST_DATA.values() if item not in (None, {})),
msg='`...` should give all non discarded values') msg='`...` should give all non-discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(), self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values') msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')), self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
['https://www.example.com/0', 'https://www.example.com/1'], ['https://www.example.com/0', 'https://www.example.com/1'],
msg='nested `...` queries should work') msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4), self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), iter_range(4),
msg='`...` query result should be flattened') msg='`...` query result should be flattened')
self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)), self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
msg='`...` should accept iterables') msg='`...` should accept iterables')
@ -1618,7 +1623,7 @@ Line 1
[_TEST_DATA['urls']], [_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)') msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)), self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), set(('str',)),
msg='exceptions in the query function should be catched') msg='exceptions in the query function should be caught')
self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2], self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
msg='function key should accept iterables') msg='function key should accept iterables')
if __debug__: if __debug__:
@ -1706,7 +1711,7 @@ Line 1
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
msg='remove empty values when dict key') msg='remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
msg='use `default` when dict key and `default`') msg='use `default` when dict key and a default')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {}, self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
msg='remove empty values when nested dict key fails') msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
@ -1768,7 +1773,7 @@ Line 1
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
'str', msg='accept matching `expected_type` type') 'str', msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
None, msg='reject non matching `expected_type` type') None, msg='reject non-matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
'0', msg='transform type using type function') '0', msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
@ -1780,7 +1785,7 @@ Line 1
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
{0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values') {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int), self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, set((int_or_none,))), expected_type=int),
1, msg='expected_type should not filter non final dict values') 1, msg='expected_type should not filter non-final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
{0: {0: 100}}, msg='expected_type should transform deep dict values') {0: {0: 100}}, msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)), self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
@ -1838,7 +1843,7 @@ Line 1
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
_traverse_string=True), 'sr', _traverse_string=True), 'sr',
msg='`slice` should result in string if `traverse_string`') msg='`slice` should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == 's'),
_traverse_string=True), 'str', _traverse_string=True), 'str',
msg='function should result in string if `traverse_string`') msg='function should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),

View File

@ -4268,13 +4268,8 @@ def variadic(x, allowed_types=NO_DEFAULT):
def dict_get(d, key_or_keys, default=None, skip_false_values=True): def dict_get(d, key_or_keys, default=None, skip_false_values=True):
if isinstance(key_or_keys, (list, tuple)): exp = (lambda x: x or None) if skip_false_values else IDENTITY
for key in key_or_keys: return traverse_obj(d, *variadic(key_or_keys), expected_type=exp, default=default)
if key not in d or d[key] is None or skip_false_values and not d[key]:
continue
return d[key]
return default
return d.get(key_or_keys, default)
def try_call(*funcs, **kwargs): def try_call(*funcs, **kwargs):