From ab029d7e9200a273d7204be68c0735b16971ff44 Mon Sep 17 00:00:00 2001 From: Simon Sawicki <37424085+Grub4K@users.noreply.github.com> Date: Sun, 25 Sep 2022 23:03:19 +0200 Subject: [PATCH] [utils] `traverse_obj`: Rewrite, document and add tests (#5024) Authored by: Grub4K --- test/test_utils.py | 187 ++++++++++++++++++++++++++++++++++ yt_dlp/utils.py | 247 ++++++++++++++++++++++++++------------------- 2 files changed, 332 insertions(+), 102 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 96477c53f..69313564a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -109,6 +109,7 @@ from yt_dlp.utils import ( strip_or_none, subtitles_filename, timeconvert, + traverse_obj, unescapeHTML, unified_strdate, unified_timestamp, @@ -1874,6 +1875,192 @@ Line 1 self.assertEqual(get_compatible_ext( vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv') + def test_traverse_obj(self): + _TEST_DATA = { + 100: 100, + 1.2: 1.2, + 'str': 'str', + 'None': None, + '...': ..., + 'urls': [ + {'index': 0, 'url': 'https://www.example.com/0'}, + {'index': 1, 'url': 'https://www.example.com/1'}, + ], + 'data': ( + {'index': 2}, + {'index': 3}, + ), + } + + # Test base functionality + self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str', + msg='allow tuple path') + self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str', + msg='allow list path') + self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str', + msg='allow iterable path') + self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str', + msg='single items should be treated as a path') + self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA) + self.assertEqual(traverse_obj(_TEST_DATA, 100), 100) + self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2) + + # Test Ellipsis behavior + self.assertCountEqual(traverse_obj(_TEST_DATA, ...), + (item for item in _TEST_DATA.values() if item is not None), + msg='`...` should give all values except `None`') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(), + msg='`...` selection for dicts should select all values') + self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='nested `...` queries should work') + self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4), + msg='`...` query result should be flattened') + + # Test function as key + self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)), + [_TEST_DATA['urls']], + msg='function as query key should perform a filter based on (key, value)') + self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, + msg='exceptions in the query function should be catched') + + # Test alternative paths + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', + msg='multiple `path_list` should be treated as alternative paths') + self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str', + msg='alternatives should exit early') + self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None, + msg='alternatives should return `default` if exhausted') + + # Test branch and path nesting + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'], + msg='tuple as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'], + msg='list as key should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'], + msg='double nesting in path should be treated as paths') + self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1], + msg='do not fail early on branching') + self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='tripple nesting in path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))), + ['https://www.example.com/0', 'https://www.example.com/1'], + msg='ellipsis as branch path start gets flattened') + + # Test dictionary as key + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2}, + msg='dict key should result in a dict with the same keys') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}), + {0: 'https://www.example.com/0'}, + msg='dict key should allow paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}), + {0: ['https://www.example.com/0']}, + msg='tuple in dict path should be treated as branches') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}), + {0: ['https://www.example.com/0']}, + msg='double nesting in dict path should be treated as paths') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}), + {0: ['https://www.example.com/1', 'https://www.example.com/0']}, + msg='tripple nesting in dict path should be treated as branches') + self.assertEqual(traverse_obj({}, {0: 1}, default=...), {0: ...}, + msg='do not remove `None` values when dict key') + + # Testing default parameter behavior + _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None, + msg='default value should be `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ..., + msg='chained fails should result in default') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0, + msg='should not short cirquit on `None`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1, + msg='invalid dict key should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1, + msg='`None` is a deliberate sentinel and should become `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None, + msg='`IndexError` should result in `default`') + self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1, + msg='if branched but not successfull return `default`, not `[]`') + + # Testing expected_type behavior + _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str', + msg='accept matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, + msg='reject non matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0', + msg='transform type using type function') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', + expected_type=lambda _: 1 / 0), None, + msg='wrap expected_type fuction in try_call') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], + msg='eliminate items that expected_type fails on') + + # Test get_all behavior + _GET_ALL_DATA = {'key': [0, 1, 2]} + self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0, + msg='if not `get_all`, return only first matching value') + self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2], + msg='do not overflatten if not `get_all`') + + # Test casesense behavior + _CASESENSE_DATA = { + 'KeY': 'value0', + 0: { + 'KeY': 'value1', + 0: {'KeY': 'value2'}, + }, + } + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None, + msg='dict keys should be case sensitive unless `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY', + casesense=False), 'value0', + msg='allow non matching key case if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)), + casesense=False), ['value1'], + msg='allow non matching key case in branch if `casesense`') + self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)), + casesense=False), ['value2'], + msg='allow non matching key case in branch path if `casesense`') + + # Test traverse_string behavior + _TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2} + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None, + msg='do not traverse into string if not `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0), + traverse_string=True), 's', + msg='traverse into string if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1), + traverse_string=True), '.', + msg='traverse into converted data if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...), + traverse_string=True), list('str'), + msg='`...` branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), + traverse_string=True), ['s', 'r'], + msg='branching into string should result in list') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), + traverse_string=True), list('str'), + msg='function branching into string should result in list') + + # Test is_user_input behavior + _IS_USER_INPUT_DATA = {'range8': list(range(8))} + self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), + is_user_input=True), 3, + msg='allow for string indexing if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), + is_user_input=True), tuple(range(8))[3:], + msg='allow for string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), + is_user_input=True), tuple(range(8))[:4:2], + msg='allow step in string slice if `is_user_input`') + self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), + is_user_input=True), range(8), + msg='`:` should be treated as `...` if `is_user_input`') + with self.assertRaises(TypeError, msg='too many params should result in error'): + traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True) + if __name__ == '__main__': unittest.main() diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index f6ab9905d..bc100c9c3 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5,6 +5,7 @@ import binascii import calendar import codecs import collections +import collections.abc import contextlib import datetime import email.header @@ -3189,7 +3190,7 @@ def try_call(*funcs, expected_type=None, args=[], kwargs={}): for f in funcs: try: val = f(*args, **kwargs) - except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError): + except (AttributeError, KeyError, TypeError, IndexError, ValueError, ZeroDivisionError): pass else: if expected_type is None or isinstance(val, expected_type): @@ -5285,107 +5286,149 @@ def load_plugins(name, suffix, namespace): def traverse_obj( - obj, *path_list, default=None, expected_type=None, get_all=True, + obj, *paths, default=None, expected_type=None, get_all=True, casesense=True, is_user_input=False, traverse_string=False): - ''' Traverse nested list/dict/tuple - @param path_list A list of paths which are checked one by one. - Each path is a list of keys where each key is a: - - None: Do nothing - - string: A dictionary key / regex group - - int: An index into a list - - tuple: A list of keys all of which will be traversed - - Ellipsis: Fetch all values in the object - - Function: Takes the key and value as arguments - and returns whether the key matches or not - @param default Default value to return - @param expected_type Only accept final value of this type (Can also be any callable) - @param get_all Return all the values obtained from a path or only the first one - @param casesense Whether to consider dictionary keys as case sensitive + """ + Safely traverse nested `dict`s and `Sequence`s - The following are only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API + >>> obj = [{}, {"key": "value"}] + >>> traverse_obj(obj, (1, "key")) + "value" - @param path_list In addition to the above, - - dict: Given {k:v, ...}; return {k: traverse_obj(obj, v), ...} - @param is_user_input Whether the keys are generated from user input. If True, - strings are converted to int/slice if necessary - @param traverse_string Whether to traverse inside strings. If True, any - non-compatible object will also be converted into a string - ''' # TODO: Write tests - if not casesense: - _lower = lambda k: (k.lower() if isinstance(k, str) else k) - path_list = (map(_lower, variadic(path)) for path in path_list) + Each of the provided `paths` is tested and the first producing a valid result will be returned. + A value of None is treated as the absence of a value. - def _traverse_obj(obj, path, _current_depth=0): - nonlocal depth - path = tuple(variadic(path)) - for i, key in enumerate(path): - if None in (key, obj): - return obj - if isinstance(key, (list, tuple)): - obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key] - key = ... + The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. - if key is ...: - obj = (obj.values() if isinstance(obj, dict) - else obj if isinstance(obj, (list, tuple, LazyList)) - else str(obj) if traverse_string else []) - _current_depth += 1 - depth = max(depth, _current_depth) - return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj] - elif isinstance(key, dict): - obj = filter_dict({k: _traverse_obj(obj, v, _current_depth) for k, v in key.items()}) - elif callable(key): - if isinstance(obj, (list, tuple, LazyList)): - obj = enumerate(obj) - elif isinstance(obj, dict): - obj = obj.items() - else: - if not traverse_string: - return None - obj = str(obj) - _current_depth += 1 - depth = max(depth, _current_depth) - return [_traverse_obj(v, path[i + 1:], _current_depth) for k, v in obj if try_call(key, args=(k, v))] - elif isinstance(obj, dict) and not (is_user_input and key == ':'): - obj = (obj.get(key) if casesense or (key in obj) - else next((v for k, v in obj.items() if _lower(k) == key), None)) - else: - if is_user_input: - key = (int_or_none(key) if ':' not in key - else slice(*map(int_or_none, key.split(':')))) - if key == slice(None): - return _traverse_obj(obj, (..., *path[i + 1:]), _current_depth) - if not isinstance(key, (int, slice)): - return None - if not isinstance(obj, (list, tuple, LazyList)): - if not traverse_string: - return None - obj = str(obj) - try: - obj = obj[key] - except IndexError: - return None - return obj + The keys in the path can be one of: + - `None`: Return the current object. + - `str`/`int`: Return `obj[key]`. + - `slice`: Branch out and return all values in `obj[key]`. + - `Ellipsis`: Branch out and return a list of all values. + - `tuple`/`list`: Branch out and return a list of all matching values. + Read as: `[traverse_obj(obj, branch) for branch in branches]`. + - `function`: Branch out and return values filtered by the function. + Read as: `[value for key, value in obj if function(key, value)]`. + For `Sequence`s, `key` is the index of the value. + - `dict` Transform the current object and return a matching dict. + Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + + `tuple`, `list`, and `dict` all support nested paths and branches + + @params paths Paths which to traverse by. + @param default Value to return if the paths do not match. + @param expected_type If a `type`, only accept final values of this type. + If any other callable, try to call the function on each result. + @param get_all If `False`, return the first matching result, otherwise all matching ones. + @param casesense If `False`, consider string dictionary keys as case insensitive. + + The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + + @param is_user_input Whether the keys are generated from user input. + If `True` strings get converted to `int`/`slice` if needed. + @param traverse_string Whether to traverse into objects as strings. + If `True`, any non-compatible object will first be + converted into a string and then traversed into. + + + @returns The result of the object traversal. + If successful, `get_all=True`, and the path branches at least once, + then a list of results is returned instead. + """ + is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes)) + casefold = lambda k: k.casefold() if isinstance(k, str) else k if isinstance(expected_type, type): type_test = lambda val: val if isinstance(val, expected_type) else None else: - type_test = expected_type or IDENTITY + type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) - for path in path_list: - depth = 0 - val = _traverse_obj(obj, path) - if val is not None: - if depth: - for _ in range(depth - 1): - val = itertools.chain.from_iterable(v for v in val if v is not None) - val = [v for v in map(type_test, val) if v is not None] - if val: - return val if get_all else val[0] + def apply_key(key, obj): + if obj is None: + return + + elif key is None: + yield obj + + elif isinstance(key, (list, tuple)): + for branch in key: + _, result = apply_path(obj, branch) + yield from result + + elif key is ...: + if isinstance(obj, collections.abc.Mapping): + yield from obj.values() + elif is_sequence(obj): + yield from obj + elif traverse_string: + yield from str(obj) + + elif callable(key): + if is_sequence(obj): + iter_obj = enumerate(obj) + elif isinstance(obj, collections.abc.Mapping): + iter_obj = obj.items() + elif traverse_string: + iter_obj = enumerate(str(obj)) else: - val = type_test(val) - if val is not None: - return val + return + yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) + + elif isinstance(key, dict): + iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) + yield {k: v if v is not None else default for k, v in iter_obj + if v is not None or default is not None} + + elif isinstance(obj, dict): + yield (obj.get(key) if casesense or (key in obj) + else next((v for k, v in obj.items() if casefold(k) == key), None)) + + else: + if is_user_input: + key = (int_or_none(key) if ':' not in key + else slice(*map(int_or_none, key.split(':')))) + + if not isinstance(key, (int, slice)): + return + + if not is_sequence(obj): + if not traverse_string: + return + obj = str(obj) + + with contextlib.suppress(IndexError): + yield obj[key] + + def apply_path(start_obj, path): + objs = (start_obj,) + has_branched = False + + for key in variadic(path): + if is_user_input and key == ':': + key = ... + + if not casesense and isinstance(key, str): + key = key.casefold() + + if key is ... or isinstance(key, (list, tuple)) or callable(key): + has_branched = True + + key_func = functools.partial(apply_key, key) + objs = itertools.chain.from_iterable(map(key_func, objs)) + + return has_branched, objs + + def _traverse_obj(obj, path): + has_branched, results = apply_path(obj, path) + results = LazyList(x for x in map(type_test, results) if x is not None) + if results: + return results.exhaust() if get_all and has_branched else results[0] + + for path in paths: + result = _traverse_obj(obj, path) + if result is not None: + return result + return default @@ -5437,7 +5480,7 @@ def jwt_decode_hs256(jwt): WINDOWS_VT_MODE = False if compat_os_name == 'nt' else None -@functools.cache +@ functools.cache def supports_terminal_sequences(stream): if compat_os_name == 'nt': if not WINDOWS_VT_MODE: @@ -5587,7 +5630,7 @@ class Config: *(f'\n{c}'.replace('\n', '\n| ')[1:] for c in self.configs), delim='\n') - @staticmethod + @ staticmethod def read_file(filename, default=[]): try: optionf = open(filename, 'rb') @@ -5608,7 +5651,7 @@ class Config: optionf.close() return res - @staticmethod + @ staticmethod def hide_login_info(opts): PRIVATE_OPTS = {'-p', '--password', '-u', '--username', '--video-password', '--ap-password', '--ap-username'} eqre = re.compile('^(?P' + ('|'.join(re.escape(po) for po in PRIVATE_OPTS)) + ')=.+$') @@ -5632,7 +5675,7 @@ class Config: if config.init(*args): self.configs.append(config) - @property + @ property def all_args(self): for config in reversed(self.configs): yield from config.all_args @@ -5679,7 +5722,7 @@ class WebSocketsWrapper(): # taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications # for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class - @staticmethod + @ staticmethod def run_with_loop(main, loop): if not asyncio.iscoroutine(main): raise ValueError(f'a coroutine was expected, got {main!r}') @@ -5691,7 +5734,7 @@ class WebSocketsWrapper(): if hasattr(loop, 'shutdown_default_executor'): loop.run_until_complete(loop.shutdown_default_executor()) - @staticmethod + @ staticmethod def _cancel_all_tasks(loop): to_cancel = asyncio.all_tasks(loop) @@ -5725,7 +5768,7 @@ def cached_method(f): """Cache a method""" signature = inspect.signature(f) - @functools.wraps(f) + @ functools.wraps(f) def wrapper(self, *args, **kwargs): bound_args = signature.bind(self, *args, **kwargs) bound_args.apply_defaults() @@ -5757,7 +5800,7 @@ class Namespace(types.SimpleNamespace): def __iter__(self): return iter(self.__dict__.values()) - @property + @ property def items_(self): return self.__dict__.items() @@ -5796,13 +5839,13 @@ class RetryManager: def _should_retry(self): return self._error is not NO_DEFAULT and self.attempt <= self.retries - @property + @ property def error(self): if self._error is NO_DEFAULT: return None return self._error - @error.setter + @ error.setter def error(self, value): self._error = value @@ -5814,7 +5857,7 @@ class RetryManager: if self.error: self.error_callback(self.error, self.attempt, self.retries) - @staticmethod + @ staticmethod def report_retry(e, count, retries, *, sleep_func, info, warn, error=None, suffix=None): """Utility function for reporting retries""" if count > retries: