mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2025-01-11 21:15:53 +01:00
[utils] traverse_obj
: Fix several behavioral problems
See #6180 for further info Authored by: Grub4K
This commit is contained in:
parent
88426d9446
commit
b1bde57bef
@ -2000,8 +2000,8 @@ def test_traverse_obj(self):
|
|||||||
|
|
||||||
# Test Ellipsis behavior
|
# Test Ellipsis behavior
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
|
||||||
(item for item in _TEST_DATA.values() if item is not None),
|
(item for item in _TEST_DATA.values() if item not in (None, [], {})),
|
||||||
msg='`...` should give all values except `None`')
|
msg='`...` should give all non discarded values')
|
||||||
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
|
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
|
||||||
msg='`...` selection for dicts should select all values')
|
msg='`...` selection for dicts should select all values')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
|
self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
|
||||||
@ -2084,15 +2084,23 @@ def test_traverse_obj(self):
|
|||||||
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
|
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
|
||||||
msg='tripple nesting in dict path should be treated as branches')
|
msg='tripple nesting in dict path should be treated as branches')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
|
||||||
msg='remove `None` values when dict key')
|
msg='remove `None` values when top level dict key fails')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
|
||||||
msg='do not remove `None` values if `default`')
|
msg='use `default` if key fails and `default`')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
|
||||||
msg='do not remove empty values when dict key')
|
msg='remove empty values when dict key')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: ...},
|
||||||
msg='do not remove empty values when dict key and a default')
|
msg='use `default` when dict key and `default`')
|
||||||
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []},
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
|
||||||
msg='if branch in dict key not successful, return `[]`')
|
msg='remove empty values when nested dict key fails')
|
||||||
|
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
|
||||||
|
msg='default to dict if pruned')
|
||||||
|
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {},
|
||||||
|
msg='default to dict if pruned and default is given')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}},
|
||||||
|
msg='use nested `default` when nested dict key fails and `default`')
|
||||||
|
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {},
|
||||||
|
msg='remove key if branch in dict key not successful')
|
||||||
|
|
||||||
# Testing default parameter behavior
|
# Testing default parameter behavior
|
||||||
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
|
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
|
||||||
@ -2183,14 +2191,17 @@ def test_traverse_obj(self):
|
|||||||
traverse_string=True), '.',
|
traverse_string=True), '.',
|
||||||
msg='traverse into converted data if `traverse_string`')
|
msg='traverse into converted data if `traverse_string`')
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
|
||||||
traverse_string=True), list('str'),
|
traverse_string=True), 'str',
|
||||||
msg='`...` branching into string should result in list')
|
msg='`...` should result in string (same value) if `traverse_string`')
|
||||||
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
|
||||||
|
traverse_string=True), 'sr',
|
||||||
|
msg='`slice` should result in string if `traverse_string`')
|
||||||
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
|
||||||
|
traverse_string=True), 'str',
|
||||||
|
msg='function should result in string if `traverse_string`')
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
|
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
|
||||||
traverse_string=True), ['s', 'r'],
|
traverse_string=True), ['s', 'r'],
|
||||||
msg='branching into string should result in list')
|
msg='branching should result in list if `traverse_string`')
|
||||||
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
|
|
||||||
traverse_string=True), list('str'),
|
|
||||||
msg='function branching into string should result in list')
|
|
||||||
|
|
||||||
# Test is_user_input behavior
|
# Test is_user_input behavior
|
||||||
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
|
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
|
||||||
|
141
yt_dlp/utils.py
141
yt_dlp/utils.py
@ -5420,7 +5420,7 @@ def traverse_obj(
|
|||||||
Each of the provided `paths` is tested and the first producing a valid result will be returned.
|
Each of the provided `paths` is tested and the first producing a valid result will be returned.
|
||||||
The next path will also be tested if the path branched but no results could be found.
|
The next path will also be tested if the path branched but no results could be found.
|
||||||
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
|
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
|
||||||
A value of None is treated as the absence of a value.
|
Unhelpful values (`[]`, `{}`, `None`) are treated as the absence of a value and discarded.
|
||||||
|
|
||||||
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
|
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
|
||||||
|
|
||||||
@ -5446,6 +5446,8 @@ def traverse_obj(
|
|||||||
|
|
||||||
@params paths Paths which to traverse by.
|
@params paths Paths which to traverse by.
|
||||||
@param default Value to return if the paths do not match.
|
@param default Value to return if the paths do not match.
|
||||||
|
If the last key in the path is a `dict`, it will apply to each value inside
|
||||||
|
the dict instead, depth first. Try to avoid if using nested `dict` keys.
|
||||||
@param expected_type If a `type`, only accept final values of this type.
|
@param expected_type If a `type`, only accept final values of this type.
|
||||||
If any other callable, try to call the function on each result.
|
If any other callable, try to call the function on each result.
|
||||||
If the last key in the path is a `dict`, it will apply to each value inside
|
If the last key in the path is a `dict`, it will apply to each value inside
|
||||||
@ -5460,12 +5462,15 @@ def traverse_obj(
|
|||||||
@param traverse_string Whether to traverse into objects as strings.
|
@param traverse_string Whether to traverse into objects as strings.
|
||||||
If `True`, any non-compatible object will first be
|
If `True`, any non-compatible object will first be
|
||||||
converted into a string and then traversed into.
|
converted into a string and then traversed into.
|
||||||
|
The return value of that path will be a string instead,
|
||||||
|
not respecting any further branching.
|
||||||
|
|
||||||
|
|
||||||
@returns The result of the object traversal.
|
@returns The result of the object traversal.
|
||||||
If successful, `get_all=True`, and the path branches at least once,
|
If successful, `get_all=True`, and the path branches at least once,
|
||||||
then a list of results is returned instead.
|
then a list of results is returned instead.
|
||||||
A list is always returned if the last path branches and no `default` is given.
|
If no `default` is given and the last path branches, a `list` of results
|
||||||
|
is always returned. If a path ends on a `dict` that result will always be a `dict`.
|
||||||
"""
|
"""
|
||||||
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
|
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
|
||||||
casefold = lambda k: k.casefold() if isinstance(k, str) else k
|
casefold = lambda k: k.casefold() if isinstance(k, str) else k
|
||||||
@ -5475,87 +5480,94 @@ def traverse_obj(
|
|||||||
else:
|
else:
|
||||||
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
|
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
|
||||||
|
|
||||||
def apply_key(key, test_type, obj):
|
def apply_key(key, obj, is_last):
|
||||||
|
branching = False
|
||||||
|
result = None
|
||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return
|
pass
|
||||||
|
|
||||||
elif key is None:
|
elif key is None:
|
||||||
yield obj
|
result = obj
|
||||||
|
|
||||||
elif isinstance(key, set):
|
elif isinstance(key, set):
|
||||||
assert len(key) == 1, 'Set should only be used to wrap a single item'
|
assert len(key) == 1, 'Set should only be used to wrap a single item'
|
||||||
item = next(iter(key))
|
item = next(iter(key))
|
||||||
if isinstance(item, type):
|
if isinstance(item, type):
|
||||||
if isinstance(obj, item):
|
if isinstance(obj, item):
|
||||||
yield obj
|
result = obj
|
||||||
else:
|
else:
|
||||||
yield try_call(item, args=(obj,))
|
result = try_call(item, args=(obj,))
|
||||||
|
|
||||||
elif isinstance(key, (list, tuple)):
|
elif isinstance(key, (list, tuple)):
|
||||||
for branch in key:
|
branching = True
|
||||||
_, result = apply_path(obj, branch, test_type)
|
result = itertools.chain.from_iterable(
|
||||||
yield from result
|
apply_path(obj, branch, is_last)[0] for branch in key)
|
||||||
|
|
||||||
elif key is ...:
|
elif key is ...:
|
||||||
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
yield from obj.values()
|
result = obj.values()
|
||||||
elif is_sequence(obj):
|
elif is_sequence(obj):
|
||||||
yield from obj
|
result = obj
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
yield from obj.groups()
|
result = obj.groups()
|
||||||
elif traverse_string:
|
elif traverse_string:
|
||||||
yield from str(obj)
|
branching = False
|
||||||
|
result = str(obj)
|
||||||
|
else:
|
||||||
|
result = ()
|
||||||
|
|
||||||
elif callable(key):
|
elif callable(key):
|
||||||
if is_sequence(obj):
|
branching = True
|
||||||
iter_obj = enumerate(obj)
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
elif isinstance(obj, collections.abc.Mapping):
|
|
||||||
iter_obj = obj.items()
|
iter_obj = obj.items()
|
||||||
|
elif is_sequence(obj):
|
||||||
|
iter_obj = enumerate(obj)
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
iter_obj = itertools.chain(
|
iter_obj = itertools.chain(
|
||||||
enumerate((obj.group(), *obj.groups())),
|
enumerate((obj.group(), *obj.groups())),
|
||||||
obj.groupdict().items())
|
obj.groupdict().items())
|
||||||
elif traverse_string:
|
elif traverse_string:
|
||||||
|
branching = False
|
||||||
iter_obj = enumerate(str(obj))
|
iter_obj = enumerate(str(obj))
|
||||||
else:
|
else:
|
||||||
return
|
iter_obj = ()
|
||||||
yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
|
|
||||||
|
result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
|
||||||
|
if not branching: # string traversal
|
||||||
|
result = ''.join(result)
|
||||||
|
|
||||||
elif isinstance(key, dict):
|
elif isinstance(key, dict):
|
||||||
iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items())
|
iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
|
||||||
yield {k: v if v is not None else default for k, v in iter_obj
|
result = {
|
||||||
if v is not None or default is not NO_DEFAULT}
|
k: v if v is not None else default for k, v in iter_obj
|
||||||
|
if v is not None or default is not NO_DEFAULT
|
||||||
|
} or None
|
||||||
|
|
||||||
elif isinstance(obj, collections.abc.Mapping):
|
elif isinstance(obj, collections.abc.Mapping):
|
||||||
yield (obj.get(key) if casesense or (key in obj)
|
result = (obj.get(key) if casesense or (key in obj) else
|
||||||
else next((v for k, v in obj.items() if casefold(k) == key), None))
|
next((v for k, v in obj.items() if casefold(k) == key), None))
|
||||||
|
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
if isinstance(key, int) or casesense:
|
if isinstance(key, int) or casesense:
|
||||||
with contextlib.suppress(IndexError):
|
with contextlib.suppress(IndexError):
|
||||||
yield obj.group(key)
|
result = obj.group(key)
|
||||||
return
|
|
||||||
|
|
||||||
if not isinstance(key, str):
|
elif isinstance(key, str):
|
||||||
return
|
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
||||||
|
|
||||||
yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if is_user_input:
|
|
||||||
key = (int_or_none(key) if ':' not in key
|
|
||||||
else slice(*map(int_or_none, key.split(':'))))
|
|
||||||
|
|
||||||
if not isinstance(key, (int, slice)):
|
|
||||||
return
|
|
||||||
|
|
||||||
|
elif isinstance(key, (int, slice)):
|
||||||
if not is_sequence(obj):
|
if not is_sequence(obj):
|
||||||
if not traverse_string:
|
if traverse_string:
|
||||||
return
|
with contextlib.suppress(IndexError):
|
||||||
obj = str(obj)
|
result = str(obj)[key]
|
||||||
|
else:
|
||||||
|
branching = isinstance(key, slice)
|
||||||
|
with contextlib.suppress(IndexError):
|
||||||
|
result = obj[key]
|
||||||
|
|
||||||
with contextlib.suppress(IndexError):
|
return branching, result if branching else (result,)
|
||||||
yield obj[key]
|
|
||||||
|
|
||||||
def lazy_last(iterable):
|
def lazy_last(iterable):
|
||||||
iterator = iter(iterable)
|
iterator = iter(iterable)
|
||||||
@ -5569,45 +5581,54 @@ def lazy_last(iterable):
|
|||||||
|
|
||||||
yield True, prev
|
yield True, prev
|
||||||
|
|
||||||
def apply_path(start_obj, path, test_type=False):
|
def apply_path(start_obj, path, test_type):
|
||||||
objs = (start_obj,)
|
objs = (start_obj,)
|
||||||
has_branched = False
|
has_branched = False
|
||||||
|
|
||||||
key = None
|
key = None
|
||||||
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
|
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
|
||||||
if is_user_input and key == ':':
|
if is_user_input and isinstance(key, str):
|
||||||
key = ...
|
if key == ':':
|
||||||
|
key = ...
|
||||||
|
elif ':' in key:
|
||||||
|
key = slice(*map(int_or_none, key.split(':')))
|
||||||
|
elif int_or_none(key) is not None:
|
||||||
|
key = int(key)
|
||||||
|
|
||||||
if not casesense and isinstance(key, str):
|
if not casesense and isinstance(key, str):
|
||||||
key = key.casefold()
|
key = key.casefold()
|
||||||
|
|
||||||
if key is ... or isinstance(key, (list, tuple)) or callable(key):
|
|
||||||
has_branched = True
|
|
||||||
|
|
||||||
if __debug__ and callable(key):
|
if __debug__ and callable(key):
|
||||||
# Verify function signature
|
# Verify function signature
|
||||||
inspect.signature(key).bind(None, None)
|
inspect.signature(key).bind(None, None)
|
||||||
|
|
||||||
key_func = functools.partial(apply_key, key, last)
|
new_objs = []
|
||||||
objs = itertools.chain.from_iterable(map(key_func, objs))
|
for obj in objs:
|
||||||
|
branching, results = apply_key(key, obj, last)
|
||||||
|
has_branched |= branching
|
||||||
|
new_objs.append(results)
|
||||||
|
|
||||||
|
objs = itertools.chain.from_iterable(new_objs)
|
||||||
|
|
||||||
if test_type and not isinstance(key, (dict, list, tuple)):
|
if test_type and not isinstance(key, (dict, list, tuple)):
|
||||||
objs = map(type_test, objs)
|
objs = map(type_test, objs)
|
||||||
|
|
||||||
return has_branched, objs
|
return objs, has_branched, isinstance(key, dict)
|
||||||
|
|
||||||
def _traverse_obj(obj, path, use_list=True, test_type=True):
|
|
||||||
has_branched, results = apply_path(obj, path, test_type)
|
|
||||||
results = LazyList(x for x in results if x is not None)
|
|
||||||
|
|
||||||
|
def _traverse_obj(obj, path, allow_empty, test_type):
|
||||||
|
results, has_branched, is_dict = apply_path(obj, path, test_type)
|
||||||
|
results = LazyList(item for item in results if item not in (None, [], {}))
|
||||||
if get_all and has_branched:
|
if get_all and has_branched:
|
||||||
return results.exhaust() if results or use_list else None
|
if results:
|
||||||
|
return results.exhaust()
|
||||||
|
if allow_empty:
|
||||||
|
return [] if default is NO_DEFAULT else default
|
||||||
|
return None
|
||||||
|
|
||||||
return results[0] if results else None
|
return results[0] if results else {} if allow_empty and is_dict else None
|
||||||
|
|
||||||
for index, path in enumerate(paths, 1):
|
for index, path in enumerate(paths, 1):
|
||||||
use_list = default is NO_DEFAULT and index == len(paths)
|
result = _traverse_obj(obj, path, index == len(paths), True)
|
||||||
result = _traverse_obj(obj, path, use_list)
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user