diff --git a/test/test_utils.py b/test/test_utils.py index 3d5a6ea6b..ffe1b729f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -105,6 +105,7 @@ from yt_dlp.utils import ( sanitized_Request, shell_quote, smuggle_url, + str_or_none, str_to_int, strip_jsonp, strip_or_none, @@ -2015,6 +2016,29 @@ Line 1 msg='function as query key should perform a filter based on (key, value)') self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, msg='exceptions in the query function should be catched') + if __debug__: + with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): + traverse_obj(_TEST_DATA, lambda a: ...) + with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'): + traverse_obj(_TEST_DATA, lambda a, b, c: ...) + + # Test set as key (transformation/type, like `expected_type`) + self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'], + msg='Function in set should be a transformation') + self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'], + msg='Type in set should be a type filter') + self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA, + msg='A single set should be wrapped into a path') + self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'], + msg='Transformation function should not raise') + self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})), + [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], + msg='Function in set should be a transformation') + if __debug__: + with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): + traverse_obj(_TEST_DATA, set()) + with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): + traverse_obj(_TEST_DATA, {str.upper, str}) # Test alternative paths self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', @@ -2106,6 +2130,20 @@ Line 1 msg='wrap expected_type fuction in try_call') self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], msg='eliminate items that expected_type fails on') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100}, + msg='type as expected_type should filter dict values') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'}, + msg='function as expected_type should transform dict values') + self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1, + msg='expected_type should not filter non final dict values') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}}, + msg='expected_type should transform deep dict values') + self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}], + msg='expected_type should transform branched dict values') + self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4], + msg='expected_type regression for type matching in tuple branching') + self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [], + msg='expected_type regression for type matching in dict result') # Test get_all behavior _GET_ALL_DATA = {'key': [0, 1, 2]} @@ -2189,6 +2227,8 @@ Line 1 msg='failing str key on a `re.Match` should return `default`') self.assertEqual(traverse_obj(mobj, 8), None, msg='failing int key on a `re.Match` should return `default`') + self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], + msg='function on a `re.Match` should give group name as well') if __name__ == '__main__': diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 7d51fe472..55e1c4415 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5424,6 +5424,9 @@ def traverse_obj( The keys in the path can be one of: - `None`: Return the current object. + - `set`: Requires the only item in the set to be a type or function, + like `{type}`/`{func}`. If a `type`, returns only values + of this type. If a function, returns `func(obj)`. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. @@ -5432,6 +5435,8 @@ def traverse_obj( - `function`: Branch out and return values filtered by the function. Read as: `[value for key, value in obj if function(key, value)]`. For `Sequence`s, `key` is the index of the value. + For `re.Match`es, `key` is the group number (0 = full match) + as well as additionally any group names, if given. - `dict` Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. @@ -5441,6 +5446,8 @@ def traverse_obj( @param default Value to return if the paths do not match. @param expected_type If a `type`, only accept final values of this type. If any other callable, try to call the function on each result. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, recursively. This does respect branching paths. @param get_all If `False`, return the first matching result, otherwise all matching ones. @param casesense If `False`, consider string dictionary keys as case insensitive. @@ -5466,16 +5473,25 @@ def traverse_obj( else: type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) - def apply_key(key, obj): + def apply_key(key, test_type, obj): if obj is None: return elif key is None: yield obj + elif isinstance(key, set): + assert len(key) == 1, 'Set should only be used to wrap a single item' + item = next(iter(key)) + if isinstance(item, type): + if isinstance(obj, item): + yield obj + else: + yield try_call(item, args=(obj,)) + elif isinstance(key, (list, tuple)): for branch in key: - _, result = apply_path(obj, branch) + _, result = apply_path(obj, branch, test_type) yield from result elif key is ...: @@ -5494,7 +5510,9 @@ def traverse_obj( elif isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() elif isinstance(obj, re.Match): - iter_obj = enumerate((obj.group(), *obj.groups())) + iter_obj = itertools.chain( + enumerate((obj.group(), *obj.groups())), + obj.groupdict().items()) elif traverse_string: iter_obj = enumerate(str(obj)) else: @@ -5502,7 +5520,7 @@ def traverse_obj( yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) elif isinstance(key, dict): - iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) + iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items()) yield {k: v if v is not None else default for k, v in iter_obj if v is not None or default is not NO_DEFAULT} @@ -5537,11 +5555,24 @@ def traverse_obj( with contextlib.suppress(IndexError): yield obj[key] - def apply_path(start_obj, path): + def lazy_last(iterable): + iterator = iter(iterable) + prev = next(iterator, NO_DEFAULT) + if prev is NO_DEFAULT: + return + + for item in iterator: + yield False, prev + prev = item + + yield True, prev + + def apply_path(start_obj, path, test_type=False): objs = (start_obj,) has_branched = False - for key in variadic(path): + key = None + for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): if is_user_input and key == ':': key = ... @@ -5551,14 +5582,21 @@ def traverse_obj( if key is ... or isinstance(key, (list, tuple)) or callable(key): has_branched = True - key_func = functools.partial(apply_key, key) + if __debug__ and callable(key): + # Verify function signature + inspect.signature(key).bind(None, None) + + key_func = functools.partial(apply_key, key, last) objs = itertools.chain.from_iterable(map(key_func, objs)) + if test_type and not isinstance(key, (dict, list, tuple)): + objs = map(type_test, objs) + return has_branched, objs - def _traverse_obj(obj, path, use_list=True): - has_branched, results = apply_path(obj, path) - results = LazyList(x for x in map(type_test, results) if x is not None) + def _traverse_obj(obj, path, use_list=True, test_type=True): + has_branched, results = apply_path(obj, path, test_type) + results = LazyList(x for x in results if x is not None) if get_all and has_branched: return results.exhaust() if results or use_list else None