0
0
Fork 0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2025-01-05 06:21:01 +00:00

[utils] traverse_obj: Fix several behavioral problems

See #6180 for further info

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2023-02-08 04:11:08 +01:00 committed by GitHub
parent 88426d9446
commit b1bde57bef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 76 deletions

View file

@ -2000,8 +2000,8 @@ def test_traverse_obj(self):
# Test Ellipsis behavior # Test Ellipsis behavior
self.assertCountEqual(traverse_obj(_TEST_DATA, ...), self.assertCountEqual(traverse_obj(_TEST_DATA, ...),
(item for item in _TEST_DATA.values() if item is not None), (item for item in _TEST_DATA.values() if item not in (None, [], {})),
msg='`...` should give all values except `None`') msg='`...` should give all non discarded values')
self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(), self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(),
msg='`...` selection for dicts should select all values') msg='`...` selection for dicts should select all values')
self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')), self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')),
@ -2084,15 +2084,23 @@ def test_traverse_obj(self):
{0: ['https://www.example.com/1', 'https://www.example.com/0']}, {0: ['https://www.example.com/1', 'https://www.example.com/0']},
msg='tripple nesting in dict path should be treated as branches') msg='tripple nesting in dict path should be treated as branches')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
msg='remove `None` values when dict key') msg='remove `None` values when top level dict key fails')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...},
msg='do not remove `None` values if `default`') msg='use `default` if key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
msg='do not remove empty values when dict key') msg='remove empty values when dict key')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}}, self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: ...},
msg='do not remove empty values when dict key and a default') msg='use `default` when dict key and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []}, self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
msg='if branch in dict key not successful, return `[]`') msg='remove empty values when nested dict key fails')
self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
msg='default to dict if pruned')
self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {},
msg='default to dict if pruned and default is given')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}},
msg='use nested `default` when nested dict key fails and `default`')
self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {},
msg='remove key if branch in dict key not successful')
# Testing default parameter behavior # Testing default parameter behavior
_DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@ -2183,14 +2191,17 @@ def test_traverse_obj(self):
traverse_string=True), '.', traverse_string=True), '.',
msg='traverse into converted data if `traverse_string`') msg='traverse into converted data if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...),
traverse_string=True), list('str'), traverse_string=True), 'str',
msg='`...` branching into string should result in list') msg='`...` should result in string (same value) if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
traverse_string=True), 'sr',
msg='`slice` should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
traverse_string=True), 'str',
msg='function should result in string if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
traverse_string=True), ['s', 'r'], traverse_string=True), ['s', 'r'],
msg='branching into string should result in list') msg='branching should result in list if `traverse_string`')
self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
traverse_string=True), list('str'),
msg='function branching into string should result in list')
# Test is_user_input behavior # Test is_user_input behavior
_IS_USER_INPUT_DATA = {'range8': list(range(8))} _IS_USER_INPUT_DATA = {'range8': list(range(8))}

View file

@ -5420,7 +5420,7 @@ def traverse_obj(
Each of the provided `paths` is tested and the first producing a valid result will be returned. Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found. The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
A value of None is treated as the absence of a value. Unhelpful values (`[]`, `{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@ -5446,6 +5446,8 @@ def traverse_obj(
@params paths Paths which to traverse by. @params paths Paths which to traverse by.
@param default Value to return if the paths do not match. @param default Value to return if the paths do not match.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, depth first. Try to avoid if using nested `dict` keys.
@param expected_type If a `type`, only accept final values of this type. @param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result. If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside If the last key in the path is a `dict`, it will apply to each value inside
@ -5460,12 +5462,15 @@ def traverse_obj(
@param traverse_string Whether to traverse into objects as strings. @param traverse_string Whether to traverse into objects as strings.
If `True`, any non-compatible object will first be If `True`, any non-compatible object will first be
converted into a string and then traversed into. converted into a string and then traversed into.
The return value of that path will be a string instead,
not respecting any further branching.
@returns The result of the object traversal. @returns The result of the object traversal.
If successful, `get_all=True`, and the path branches at least once, If successful, `get_all=True`, and the path branches at least once,
then a list of results is returned instead. then a list of results is returned instead.
A list is always returned if the last path branches and no `default` is given. If no `default` is given and the last path branches, a `list` of results
is always returned. If a path ends on a `dict` that result will always be a `dict`.
""" """
is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes)) is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes))
casefold = lambda k: k.casefold() if isinstance(k, str) else k casefold = lambda k: k.casefold() if isinstance(k, str) else k
@ -5475,87 +5480,94 @@ def traverse_obj(
else: else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def apply_key(key, test_type, obj): def apply_key(key, obj, is_last):
branching = False
result = None
if obj is None: if obj is None:
return pass
elif key is None: elif key is None:
yield obj result = obj
elif isinstance(key, set): elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item' assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key)) item = next(iter(key))
if isinstance(item, type): if isinstance(item, type):
if isinstance(obj, item): if isinstance(obj, item):
yield obj result = obj
else: else:
yield try_call(item, args=(obj,)) result = try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)): elif isinstance(key, (list, tuple)):
for branch in key: branching = True
_, result = apply_path(obj, branch, test_type) result = itertools.chain.from_iterable(
yield from result apply_path(obj, branch, is_last)[0] for branch in key)
elif key is ...: elif key is ...:
branching = True
if isinstance(obj, collections.abc.Mapping): if isinstance(obj, collections.abc.Mapping):
yield from obj.values() result = obj.values()
elif is_sequence(obj): elif is_sequence(obj):
yield from obj result = obj
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
yield from obj.groups() result = obj.groups()
elif traverse_string: elif traverse_string:
yield from str(obj) branching = False
result = str(obj)
else:
result = ()
elif callable(key): elif callable(key):
if is_sequence(obj): branching = True
iter_obj = enumerate(obj) if isinstance(obj, collections.abc.Mapping):
elif isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif is_sequence(obj):
iter_obj = enumerate(obj)
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
iter_obj = itertools.chain( iter_obj = itertools.chain(
enumerate((obj.group(), *obj.groups())), enumerate((obj.group(), *obj.groups())),
obj.groupdict().items()) obj.groupdict().items())
elif traverse_string: elif traverse_string:
branching = False
iter_obj = enumerate(str(obj)) iter_obj = enumerate(str(obj))
else: else:
return iter_obj = ()
yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
if not branching: # string traversal
result = ''.join(result)
elif isinstance(key, dict): elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items()) iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
yield {k: v if v is not None else default for k, v in iter_obj result = {
if v is not None or default is not NO_DEFAULT} k: v if v is not None else default for k, v in iter_obj
if v is not None or default is not NO_DEFAULT
} or None
elif isinstance(obj, collections.abc.Mapping): elif isinstance(obj, collections.abc.Mapping):
yield (obj.get(key) if casesense or (key in obj) result = (obj.get(key) if casesense or (key in obj) else
else next((v for k, v in obj.items() if casefold(k) == key), None)) next((v for k, v in obj.items() if casefold(k) == key), None))
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
if isinstance(key, int) or casesense: if isinstance(key, int) or casesense:
with contextlib.suppress(IndexError): with contextlib.suppress(IndexError):
yield obj.group(key) result = obj.group(key)
return
if not isinstance(key, str): elif isinstance(key, str):
return result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
else:
if is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if not isinstance(key, (int, slice)):
return
elif isinstance(key, (int, slice)):
if not is_sequence(obj): if not is_sequence(obj):
if not traverse_string: if traverse_string:
return
obj = str(obj)
with contextlib.suppress(IndexError): with contextlib.suppress(IndexError):
yield obj[key] result = str(obj)[key]
else:
branching = isinstance(key, slice)
with contextlib.suppress(IndexError):
result = obj[key]
return branching, result if branching else (result,)
def lazy_last(iterable): def lazy_last(iterable):
iterator = iter(iterable) iterator = iter(iterable)
@ -5569,45 +5581,54 @@ def lazy_last(iterable):
yield True, prev yield True, prev
def apply_path(start_obj, path, test_type=False): def apply_path(start_obj, path, test_type):
objs = (start_obj,) objs = (start_obj,)
has_branched = False has_branched = False
key = None key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if is_user_input and key == ':': if is_user_input and isinstance(key, str):
if key == ':':
key = ... key = ...
elif ':' in key:
key = slice(*map(int_or_none, key.split(':')))
elif int_or_none(key) is not None:
key = int(key)
if not casesense and isinstance(key, str): if not casesense and isinstance(key, str):
key = key.casefold() key = key.casefold()
if key is ... or isinstance(key, (list, tuple)) or callable(key):
has_branched = True
if __debug__ and callable(key): if __debug__ and callable(key):
# Verify function signature # Verify function signature
inspect.signature(key).bind(None, None) inspect.signature(key).bind(None, None)
key_func = functools.partial(apply_key, key, last) new_objs = []
objs = itertools.chain.from_iterable(map(key_func, objs)) for obj in objs:
branching, results = apply_key(key, obj, last)
has_branched |= branching
new_objs.append(results)
objs = itertools.chain.from_iterable(new_objs)
if test_type and not isinstance(key, (dict, list, tuple)): if test_type and not isinstance(key, (dict, list, tuple)):
objs = map(type_test, objs) objs = map(type_test, objs)
return has_branched, objs return objs, has_branched, isinstance(key, dict)
def _traverse_obj(obj, path, use_list=True, test_type=True):
has_branched, results = apply_path(obj, path, test_type)
results = LazyList(x for x in results if x is not None)
def _traverse_obj(obj, path, allow_empty, test_type):
results, has_branched, is_dict = apply_path(obj, path, test_type)
results = LazyList(item for item in results if item not in (None, [], {}))
if get_all and has_branched: if get_all and has_branched:
return results.exhaust() if results or use_list else None if results:
return results.exhaust()
if allow_empty:
return [] if default is NO_DEFAULT else default
return None
return results[0] if results else None return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1): for index, path in enumerate(paths, 1):
use_list = default is NO_DEFAULT and index == len(paths) result = _traverse_obj(obj, path, index == len(paths), True)
result = _traverse_obj(obj, path, use_list)
if result is not None: if result is not None:
return result return result