[utils] `traverse_obj`: Rewrite, document and add tests (#5024)

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2022-09-25 23:03:19 +02:00 committed by GitHub
parent 0bd5a039ea
commit ab029d7e92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 332 additions and 102 deletions

View File

@ -109,6 +109,7 @@ from yt_dlp.utils import (
strip_or_none,
subtitles_filename,
timeconvert,
traverse_obj,
unescapeHTML,
unified_strdate,
unified_timestamp,
@ -1874,6 +1875,192 @@ Line 1
self.assertEqual(get_compatible_ext(
vcodecs=['av1'], acodecs=['mp4a'], vexts=['webm'], aexts=['m4a'], preferences=('webm', 'mkv')), 'mkv')
def test_traverse_obj(self):
_TEST_DATA = {
100: 100,
1.2: 1.2,
'str': 'str',
'None': None,
'...': ...,
'urls': [
{'index': 0, 'url': 'https://www.example.com/0'},
{'index': 1, 'url': 'https://www.example.com/1'},
],
'data': (
{'index': 2},
{'index': 3},
),
}
# Test base functionality
self.assertEqual(traverse_obj(_TEST_DATA, ('str',)), 'str',
msg='allow tuple path')
self.assertEqual(traverse_obj(_TEST_DATA, ['str']), 'str',
msg='allow list path')
self.assertEqual(traverse_obj(_TEST_DATA, (value for value in ("str",))), 'str',
msg='allow iterable path')
self.assertEqual(traverse_obj(_TEST_DATA, 'str'), 'str',
msg='single items should be treated as a path')
self.assertEqual(traverse_obj(_TEST_DATA, None), _TEST_DATA)
self.assertEqual(traverse_obj(_TEST_DATA, 100), 100)
self.assertEqual(traverse_obj(_TEST_DATA, 1.2), 1.2)
# Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
(item for item in _TEST_DATA.values() if item is not None),
msg='`...` should give all values except `None`')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
['https://www.example.com/0', 'https://www.example.com/1'],
msg='nested `...` queries should work')
self.assertCountEqual(traverse_obj(_TEST_DATA, (..., ..., 'index')), range(4),
msg='`...` query result should be flattened')
# Test function as key
self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
[_TEST_DATA['urls']],
msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be catched')
# Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
msg='multiple `path_list` should be treated as alternative paths')
self.assertEqual(traverse_obj(_TEST_DATA, 'str', 100), 'str',
msg='alternatives should exit early')
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'fail'), None,
msg='alternatives should return `default` if exhausted')
# Test branch and path nesting
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', (3, 0), 'url')), ['https://www.example.com/0'],
msg='tuple as key should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', [3, 0], 'url')), ['https://www.example.com/0'],
msg='list as key should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ((1, 'fail'), (0, 'url')))), ['https://www.example.com/0'],
msg='double nesting in path should be treated as paths')
self.assertEqual(traverse_obj(['0', [1, 2]], [(0, 1), 0]), [1],
msg='do not fail early on branching')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', ((1, ('fail', 'url')), (0, 'url')))),
['https://www.example.com/0', 'https://www.example.com/1'],
msg='tripple nesting in path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, ('urls', ('fail', (..., 'url')))),
['https://www.example.com/0', 'https://www.example.com/1'],
msg='ellipsis as branch path start gets flattened')
# Test dictionary as key
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}), {0: 100, 1: 1.2},
msg='dict key should result in a dict with the same keys')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', 0, 'url')}),
{0: 'https://www.example.com/0'},
msg='dict key should allow paths')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', (3, 0), 'url')}),
{0: ['https://www.example.com/0']},
msg='tuple in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, 'fail'), (0, 'url')))}),
{0: ['https://www.example.com/0']},
msg='double nesting in dict path should be treated as paths')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('urls', ((1, ('fail', 'url')), (0, 'url')))}),
{0: ['https://www.example.com/1', 'https://www.example.com/0']},
msg='tripple nesting in dict path should be treated as branches')
self.assertEqual(traverse_obj({}, {0: 1}, default=...), {0: ...},
msg='do not remove `None` values when dict key')
# Testing default parameter behavior
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail'), None,
msg='default value should be `None`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', 'fail', default=...), ...,
msg='chained fails should result in default')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', 'int'), 0,
msg='should not short cirquit on `None`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'fail', default=1), 1,
msg='invalid dict key should result in `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, 'None', default=1), 1,
msg='`None` is a deliberate sentinel and should become `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', 10)), None,
msg='`IndexError` should result in `default`')
self.assertEqual(traverse_obj(_DEFAULT_DATA, (..., 'fail'), default=1), 1,
msg='if branched but not successfull return `default`, not `[]`')
# Testing expected_type behavior
_EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str',
msg='accept matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None,
msg='reject non matching `expected_type` type')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0',
msg='transform type using type function')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str',
expected_type=lambda _: 1 / 0), None,
msg='wrap expected_type fuction in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'],
msg='eliminate items that expected_type fails on')
# Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]}
self.assertEqual(traverse_obj(_GET_ALL_DATA, ('key', ...), get_all=False), 0,
msg='if not `get_all`, return only first matching value')
self.assertEqual(traverse_obj(_GET_ALL_DATA, ..., get_all=False), [0, 1, 2],
msg='do not overflatten if not `get_all`')
# Test casesense behavior
_CASESENSE_DATA = {
'KeY': 'value0',
0: {
'KeY': 'value1',
0: {'KeY': 'value2'},
},
}
self.assertEqual(traverse_obj(_CASESENSE_DATA, 'key'), None,
msg='dict keys should be case sensitive unless `casesense`')
self.assertEqual(traverse_obj(_CASESENSE_DATA, 'keY',
casesense=False), 'value0',
msg='allow non matching key case if `casesense`')
self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ('keY',)),
casesense=False), ['value1'],
msg='allow non matching key case in branch if `casesense`')
self.assertEqual(traverse_obj(_CASESENSE_DATA, (0, ((0, 'keY'),)),
casesense=False), ['value2'],
msg='allow non matching key case in branch path if `casesense`')
# Test traverse_string behavior
_TRAVERSE_STRING_DATA = {'str': 'str', 1.2: 1.2}
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0)), None,
msg='do not traverse into string if not `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', 0),
traverse_string=True), 's',
msg='traverse into string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, (1.2, 1),
traverse_string=True), '.',
msg='traverse into converted data if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
traverse_string=True), list('str'),
msg='`...` branching into string should result in list')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
traverse_string=True), ['s', 'r'],
msg='branching into string should result in list')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
traverse_string=True), list('str'),
msg='function branching into string should result in list')
# Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))}
self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'),
is_user_input=True), 3,
msg='allow for string indexing if `is_user_input`')
self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'),
is_user_input=True), tuple(range(8))[3:],
msg='allow for string slice if `is_user_input`')
self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'),
is_user_input=True), tuple(range(8))[:4:2],
msg='allow step in string slice if `is_user_input`')
self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'),
is_user_input=True), range(8),
msg='`:` should be treated as `...` if `is_user_input`')
with self.assertRaises(TypeError, msg='too many params should result in error'):
traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), is_user_input=True)
if __name__ == '__main__':
unittest.main()

View File

@ -5,6 +5,7 @@ import binascii
import calendar
import codecs
import collections
import collections.abc
import contextlib
import datetime
import email.header
@ -3189,7 +3190,7 @@ def try_call(*funcs, expected_type=None, args=[], kwargs={}):
for f in funcs:
try:
val = f(*args, **kwargs)
except (AttributeError, KeyError, TypeError, IndexError, ZeroDivisionError):
except (AttributeError, KeyError, TypeError, IndexError, ValueError, ZeroDivisionError):
pass
else:
if expected_type is None or isinstance(val, expected_type):
@ -5285,107 +5286,149 @@ def load_plugins(name, suffix, namespace):
def traverse_obj(
obj, *path_list, default=None, expected_type=None, get_all=True,
obj, *paths, default=None, expected_type=None, get_all=True,
casesense=True, is_user_input=False, traverse_string=False):
''' Traverse nested list/dict/tuple
@param path_list A list of paths which are checked one by one.
Each path is a list of keys where each key is a:
- None: Do nothing
- string: A dictionary key / regex group
- int: An index into a list
- tuple: A list of keys all of which will be traversed
- Ellipsis: Fetch all values in the object
- Function: Takes the key and value as arguments
and returns whether the key matches or not
@param default Default value to return
@param expected_type Only accept final value of this type (Can also be any callable)
@param get_all Return all the values obtained from a path or only the first one
@param casesense Whether to consider dictionary keys as case sensitive
"""
Safely traverse nested `dict`s and `Sequence`s
The following are only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API
>>> obj = [{}, {"key": "value"}]
>>> traverse_obj(obj, (1, "key"))
"value"
@param path_list In addition to the above,
- dict: Given {k:v, ...}; return {k: traverse_obj(obj, v), ...}
@param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary
@param traverse_string Whether to traverse inside strings. If True, any
non-compatible object will also be converted into a string
''' # TODO: Write tests
if not casesense:
_lower = lambda k: (k.lower() if isinstance(k, str) else k)
path_list = (map(_lower, variadic(path)) for path in path_list)
Each of the provided `paths` is tested and the first producing a valid result will be returned.
A value of None is treated as the absence of a value.
def _traverse_obj(obj, path, _current_depth=0):
nonlocal depth
path = tuple(variadic(path))
for i, key in enumerate(path):
if None in (key, obj):
return obj
if isinstance(key, (list, tuple)):
obj = [_traverse_obj(obj, sub_key, _current_depth) for sub_key in key]
key = ...
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
if key is ...:
obj = (obj.values() if isinstance(obj, dict)
else obj if isinstance(obj, (list, tuple, LazyList))
else str(obj) if traverse_string else [])
_current_depth += 1
depth = max(depth, _current_depth)
return [_traverse_obj(inner_obj, path[i + 1:], _current_depth) for inner_obj in obj]
elif isinstance(key, dict):
obj = filter_dict({k: _traverse_obj(obj, v, _current_depth) for k, v in key.items()})
elif callable(key):
if isinstance(obj, (list, tuple, LazyList)):
obj = enumerate(obj)
elif isinstance(obj, dict):
obj = obj.items()
else:
if not traverse_string:
return None
obj = str(obj)
_current_depth += 1
depth = max(depth, _current_depth)
return [_traverse_obj(v, path[i + 1:], _current_depth) for k, v in obj if try_call(key, args=(k, v))]
elif isinstance(obj, dict) and not (is_user_input and key == ':'):
obj = (obj.get(key) if casesense or (key in obj)
else next((v for k, v in obj.items() if _lower(k) == key), None))
else:
if is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if key == slice(None):
return _traverse_obj(obj, (..., *path[i + 1:]), _current_depth)
if not isinstance(key, (int, slice)):
return None
if not isinstance(obj, (list, tuple, LazyList)):
if not traverse_string:
return None
obj = str(obj)
try:
obj = obj[key]
except IndexError:
return None
return obj
The keys in the path can be one of:
- `None`: Return the current object.
- `str`/`int`: Return `obj[key]`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
- `tuple`/`list`: Branch out and return a list of all matching values.
Read as: `[traverse_obj(obj, branch) for branch in branches]`.
- `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`.
For `Sequence`s, `key` is the index of the value.
- `dict` Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
`tuple`, `list`, and `dict` all support nested paths and branches
@params paths Paths which to traverse by.
@param default Value to return if the paths do not match.
@param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result.
@param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive.
The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API
@param is_user_input Whether the keys are generated from user input.
If `True` strings get converted to `int`/`slice` if needed.
@param traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be
converted into a string and then traversed into.
@returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead.
"""
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: k.casefold() if isinstance(k, str) else k
if isinstance(expected_type, type):
type_test = lambda val: val if isinstance(val, expected_type) else None
else:
type_test = expected_type or IDENTITY
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
for path in path_list:
depth = 0
val = _traverse_obj(obj, path)
if val is not None:
if depth:
for _ in range(depth - 1):
val = itertools.chain.from_iterable(v for v in val if v is not None)
val = [v for v in map(type_test, val) if v is not None]
if val:
return val if get_all else val[0]
def apply_key(key, obj):
if obj is None:
return
elif key is None:
yield obj
elif isinstance(key, (list, tuple)):
for branch in key:
_, result = apply_path(obj, branch)
yield from result
elif key is ...:
if isinstance(obj, collections.abc.Mapping):
yield from obj.values()
elif is_sequence(obj):
yield from obj
elif traverse_string:
yield from str(obj)
elif callable(key):
if is_sequence(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
elif traverse_string:
iter_obj = enumerate(str(obj))
else:
val = type_test(val)
if val is not None:
return val
return
yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
yield {k: v if v is not None else default for k, v in iter_obj
if v is not None or default is not None}
elif isinstance(obj, dict):
yield (obj.get(key) if casesense or (key in obj)
else next((v for k, v in obj.items() if casefold(k) == key), None))
else:
if is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if not isinstance(key, (int, slice)):
return
if not is_sequence(obj):
if not traverse_string:
return
obj = str(obj)
with contextlib.suppress(IndexError):
yield obj[key]
def apply_path(start_obj, path):
objs = (start_obj,)
has_branched = False
for key in variadic(path):
if is_user_input and key == ':':
key = ...
if not casesense and isinstance(key, str):
key = key.casefold()
if key is ... or isinstance(key, (list, tuple)) or callable(key):
has_branched = True
key_func = functools.partial(apply_key, key)
objs = itertools.chain.from_iterable(map(key_func, objs))
return has_branched, objs
def _traverse_obj(obj, path):
has_branched, results = apply_path(obj, path)
results = LazyList(x for x in map(type_test, results) if x is not None)
if results:
return results.exhaust() if get_all and has_branched else results[0]
for path in paths:
result = _traverse_obj(obj, path)
if result is not None:
return result
return default
@ -5437,7 +5480,7 @@ def jwt_decode_hs256(jwt):
WINDOWS_VT_MODE = False if compat_os_name == 'nt' else None
@functools.cache
@ functools.cache
def supports_terminal_sequences(stream):
if compat_os_name == 'nt':
if not WINDOWS_VT_MODE:
@ -5587,7 +5630,7 @@ class Config:
*(f'\n{c}'.replace('\n', '\n| ')[1:] for c in self.configs),
delim='\n')
@staticmethod
@ staticmethod
def read_file(filename, default=[]):
try:
optionf = open(filename, 'rb')
@ -5608,7 +5651,7 @@ class Config:
optionf.close()
return res
@staticmethod
@ staticmethod
def hide_login_info(opts):
PRIVATE_OPTS = {'-p', '--password', '-u', '--username', '--video-password', '--ap-password', '--ap-username'}
eqre = re.compile('^(?P<key>' + ('|'.join(re.escape(po) for po in PRIVATE_OPTS)) + ')=.+$')
@ -5632,7 +5675,7 @@ class Config:
if config.init(*args):
self.configs.append(config)
@property
@ property
def all_args(self):
for config in reversed(self.configs):
yield from config.all_args
@ -5679,7 +5722,7 @@ class WebSocketsWrapper():
# taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
# for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
@staticmethod
@ staticmethod
def run_with_loop(main, loop):
if not asyncio.iscoroutine(main):
raise ValueError(f'a coroutine was expected, got {main!r}')
@ -5691,7 +5734,7 @@ class WebSocketsWrapper():
if hasattr(loop, 'shutdown_default_executor'):
loop.run_until_complete(loop.shutdown_default_executor())
@staticmethod
@ staticmethod
def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
@ -5725,7 +5768,7 @@ def cached_method(f):
"""Cache a method"""
signature = inspect.signature(f)
@functools.wraps(f)
@ functools.wraps(f)
def wrapper(self, *args, **kwargs):
bound_args = signature.bind(self, *args, **kwargs)
bound_args.apply_defaults()
@ -5757,7 +5800,7 @@ class Namespace(types.SimpleNamespace):
def __iter__(self):
return iter(self.__dict__.values())
@property
@ property
def items_(self):
return self.__dict__.items()
@ -5796,13 +5839,13 @@ class RetryManager:
def _should_retry(self):
return self._error is not NO_DEFAULT and self.attempt <= self.retries
@property
@ property
def error(self):
if self._error is NO_DEFAULT:
return None
return self._error
@error.setter
@ error.setter
def error(self, value):
self._error = value
@ -5814,7 +5857,7 @@ class RetryManager:
if self.error:
self.error_callback(self.error, self.attempt, self.retries)
@staticmethod
@ staticmethod
def report_retry(e, count, retries, *, sleep_func, info, warn, error=None, suffix=None):
"""Utility function for reporting retries"""
if count > retries: