mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-11-27 15:16:52 +01:00
[utils] traverse_obj
: Always return list when branching (#5170)
Fixes #5162 Authored by: Grub4K
This commit is contained in:
parent
3b55aaac59
commit
f99bbfc983
@ -1890,6 +1890,7 @@ def test_traverse_obj(self):
|
|||||||
{'index': 2},
|
{'index': 2},
|
||||||
{'index': 3},
|
{'index': 3},
|
||||||
),
|
),
|
||||||
|
'dict': {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Test base functionality
|
# Test base functionality
|
||||||
@ -1926,11 +1927,15 @@ def test_traverse_obj(self):
|
|||||||
|
|
||||||
# Test alternative paths
|
# Test alternative paths
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
|
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
|
||||||
msg='multiple `path_list` should be treated as alternative paths')
|
msg='multiple `paths` should be treated as alternative paths')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
|
self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
|
||||||
msg='alternatives should exit early')
|
msg='alternatives should exit early')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
|
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
|
||||||
msg='alternatives should return `default` if exhausted')
|
msg='alternatives should return `default` if exhausted')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, (..., 'fail'), 100), 100,
|
||||||
|
msg='alternatives should track their own branching return')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, ('dict', ...), ('data', ...)), list(_TEST_DATA['data']),
|
||||||
|
msg='alternatives on empty objects should search further')
|
||||||
|
|
||||||
# Test branch and path nesting
|
# Test branch and path nesting
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
|
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
|
||||||
@ -1963,8 +1968,16 @@ def test_traverse_obj(self):
|
|||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
|
||||||
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
|
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
|
||||||
msg='tripple nesting in dict path should be treated as branches')
|
msg='tripple nesting in dict path should be treated as branches')
|
||||||
self.assertEqual(traverse_obj({}, {0: 1}, default=...), {0: ...},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
|
||||||
msg='do not remove `None` values when dict key')
|
msg='remove `None` values when dict key')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
|
||||||
|
msg='do not remove `None` values if `default`')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
|
||||||
|
msg='do not remove empty values when dict key')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}},
|
||||||
|
msg='do not remove empty values when dict key and a default')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []},
|
||||||
|
msg='if branch in dict key not successful, return `[]`')
|
||||||
|
|
||||||
# Testing default parameter behavior
|
# Testing default parameter behavior
|
||||||
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
|
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
|
||||||
@ -1981,7 +1994,13 @@ def test_traverse_obj(self):
|
|||||||
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
|
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
|
||||||
msg='`IndexError` should result in `default`')
|
msg='`IndexError` should result in `default`')
|
||||||
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
|
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
|
||||||
msg='if branched but not successfull return `default`, not `[]`')
|
msg='if branched but not successful return `default` if defined, not `[]`')
|
||||||
|
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=None), None,
|
||||||
|
msg='if branched but not successful return `default` even if `default` is `None`')
|
||||||
|
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail')), [],
|
||||||
|
msg='if branched but not successful return `[]`, not `default`')
|
||||||
|
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [],
|
||||||
|
msg='if branched but object is empty return `[]`, not `default`')
|
||||||
|
|
||||||
# Testing expected_type behavior
|
# Testing expected_type behavior
|
||||||
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
|
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
|
||||||
|
@ -5294,7 +5294,7 @@ def load_plugins(name, suffix, namespace):
|
|||||||
|
|
||||||
|
|
||||||
def traverse_obj(
|
def traverse_obj(
|
||||||
obj, *paths, default=None, expected_type=None, get_all=True,
|
obj, *paths, default=NO_DEFAULT, expected_type=None, get_all=True,
|
||||||
casesense=True, is_user_input=False, traverse_string=False):
|
casesense=True, is_user_input=False, traverse_string=False):
|
||||||
"""
|
"""
|
||||||
Safely traverse nested `dict`s and `Sequence`s
|
Safely traverse nested `dict`s and `Sequence`s
|
||||||
@ -5304,6 +5304,7 @@ def traverse_obj(
|
|||||||
"value"
|
"value"
|
||||||
|
|
||||||
Each of the provided `paths` is tested and the first producing a valid result will be returned.
|
Each of the provided `paths` is tested and the first producing a valid result will be returned.
|
||||||
|
The next path will also be tested if the path branched but no results could be found.
|
||||||
A value of None is treated as the absence of a value.
|
A value of None is treated as the absence of a value.
|
||||||
|
|
||||||
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
|
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
|
||||||
@ -5342,6 +5343,7 @@ def traverse_obj(
|
|||||||
@returns The result of the object traversal.
|
@returns The result of the object traversal.
|
||||||
If successful, `get_all=True`, and the path branches at least once,
|
If successful, `get_all=True`, and the path branches at least once,
|
||||||
then a list of results is returned instead.
|
then a list of results is returned instead.
|
||||||
|
A list is always returned if the last path branches and no `default` is given.
|
||||||
"""
|
"""
|
||||||
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
|
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
|
||||||
casefold = lambda k: k.casefold() if isinstance(k, str) else k
|
casefold = lambda k: k.casefold() if isinstance(k, str) else k
|
||||||
@ -5385,7 +5387,7 @@ def apply_key(key, obj):
|
|||||||
elif isinstance(key, dict):
|
elif isinstance(key, dict):
|
||||||
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
|
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
|
||||||
yield {k: v if v is not None else default for k, v in iter_obj
|
yield {k: v if v is not None else default for k, v in iter_obj
|
||||||
if v is not None or default is not None}
|
if v is not None or default is not NO_DEFAULT}
|
||||||
|
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
yield (obj.get(key) if casesense or (key in obj)
|
yield (obj.get(key) if casesense or (key in obj)
|
||||||
@ -5426,18 +5428,22 @@ def apply_path(start_obj, path):
|
|||||||
|
|
||||||
return has_branched, objs
|
return has_branched, objs
|
||||||
|
|
||||||
def _traverse_obj(obj, path):
|
def _traverse_obj(obj, path, use_list=True):
|
||||||
has_branched, results = apply_path(obj, path)
|
has_branched, results = apply_path(obj, path)
|
||||||
results = LazyList(x for x in map(type_test, results) if x is not None)
|
results = LazyList(x for x in map(type_test, results) if x is not None)
|
||||||
if results:
|
|
||||||
return results.exhaust() if get_all and has_branched else results[0]
|
|
||||||
|
|
||||||
for path in paths:
|
if get_all and has_branched:
|
||||||
result = _traverse_obj(obj, path)
|
return results.exhaust() if results or use_list else None
|
||||||
|
|
||||||
|
return results[0] if results else None
|
||||||
|
|
||||||
|
for index, path in enumerate(paths, 1):
|
||||||
|
use_list = default is NO_DEFAULT and index == len(paths)
|
||||||
|
result = _traverse_obj(obj, path, use_list)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return default
|
return None if default is NO_DEFAULT else default
|
||||||
|
|
||||||
|
|
||||||
def traverse_dict(dictn, keys, casesense=True):
|
def traverse_dict(dictn, keys, casesense=True):
|
||||||
|
Loading…
Reference in New Issue
Block a user