0
0
Fork 0
mirror of https://github.com/yt-dlp/yt-dlp.git synced 2024-11-22 02:15:12 +00:00

[utils] traverse_obj: Convenience improvements (#9577)

Add support for:
- `http.cookies.Morsel`
- Multi type filters (`{type, type}`)

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2024-04-01 02:12:03 +02:00 committed by GitHub
parent c305a25c1b
commit 32abfb00bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 51 additions and 10 deletions

View file

@ -1,3 +1,4 @@
import http.cookies
import re
import xml.etree.ElementTree
@ -94,6 +95,8 @@ def test_traversal_set(self):
'Function in set should be a transformation'
assert traverse_obj(_TEST_DATA, (..., {str})) == ['str'], \
'Type in set should be a type filter'
assert traverse_obj(_TEST_DATA, (..., {str, int})) == [100, 'str'], \
'Multiple types in set should be a type filter'
assert traverse_obj(_TEST_DATA, {dict}) == _TEST_DATA, \
'A single set should be wrapped into a path'
assert traverse_obj(_TEST_DATA, (..., {str.upper})) == ['STR'], \
@ -103,7 +106,7 @@ def test_traversal_set(self):
'Function in set should be a transformation'
assert traverse_obj(_TEST_DATA, ('fail', {lambda _: 'const'})) == 'const', \
'Function in set should always be called'
# Sets with length != 1 should raise in debug
# Sets with length < 1 or > 1 not including only types should raise
with pytest.raises(Exception):
traverse_obj(_TEST_DATA, set())
with pytest.raises(Exception):
@ -409,3 +412,31 @@ def test_traversal_unbranching(self):
'`all` should allow further branching'
assert traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, ..., 'index']) == [0, 1], \
'`any` should allow further branching'
def test_traversal_morsel(self):
values = {
'expires': 'a',
'path': 'b',
'comment': 'c',
'domain': 'd',
'max-age': 'e',
'secure': 'f',
'httponly': 'g',
'version': 'h',
'samesite': 'i',
}
morsel = http.cookies.Morsel()
morsel.set('item_key', 'item_value', 'coded_value')
morsel.update(values)
values['key'] = 'item_key'
values['value'] = 'item_value'
for key, value in values.items():
assert traverse_obj(morsel, key) == value, \
'Morsel should provide access to all values'
assert traverse_obj(morsel, ...) == list(values.values()), \
'`...` should yield all values'
assert traverse_obj(morsel, lambda k, v: True) == list(values.values()), \
'function key should yield all values'
assert traverse_obj(morsel, [(None,), any]) == morsel, \
'Morsel should not be implicitly changed to dict on usage'

View file

@ -1,5 +1,6 @@
import collections.abc
import contextlib
import http.cookies
import inspect
import itertools
import re
@ -28,7 +29,8 @@ def traverse_obj(
Each of the provided `paths` is tested and the first producing a valid result will be returned.
The next path will also be tested if the path branched but no results could be found.
Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
Supported values for traversal are `Mapping`, `Iterable`, `re.Match`,
`xml.etree.ElementTree` (xpath) and `http.cookies.Morsel`.
Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
@ -36,8 +38,8 @@ def traverse_obj(
The keys in the path can be one of:
- `None`: Return the current object.
- `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{func}`. If a `type`, returns only values
of this type. If a function, returns `func(obj)`.
like `{type}`/`{type, type, ...}/`{func}`. If a `type`, return only
values of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values.
@ -48,8 +50,10 @@ def traverse_obj(
For `Iterable`s, `key` is the index of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict.
- `dict`: Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
- `any`-builtin: Take the first matching object and return it, resetting branching.
- `all`-builtin: Take all matching objects and return them as a list, resetting branching.
`tuple`, `list`, and `dict` all support nested paths and branches.
@ -102,10 +106,10 @@ def apply_key(key, obj, is_last):
result = obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
if isinstance(obj, item):
if len(key) > 1 or isinstance(item, type):
assert all(isinstance(item, type) for item in key)
if isinstance(obj, tuple(key)):
result = obj
else:
result = try_call(item, args=(obj,))
@ -117,6 +121,8 @@ def apply_key(key, obj, is_last):
elif key is ...:
branching = True
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, collections.abc.Mapping):
result = obj.values()
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
@ -131,6 +137,8 @@ def apply_key(key, obj, is_last):
elif callable(key):
branching = True
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
if isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items()
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
@ -157,6 +165,8 @@ def apply_key(key, obj, is_last):
} or None
elif isinstance(obj, collections.abc.Mapping):
if isinstance(obj, http.cookies.Morsel):
obj = dict(obj, key=obj.key, value=obj.value)
result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else
next((v for k, v in obj.items() if casefold(k) == key), None))
@ -179,7 +189,7 @@ def apply_key(key, obj, is_last):
elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
xpath, _, special = key.rpartition('/')
if not special.startswith('@') and special != 'text()':
if not special.startswith('@') and not special.endswith('()'):
xpath = key
special = None
@ -198,7 +208,7 @@ def apply_specials(element):
return try_call(element.attrib.get, args=(special[1:],))
if special == 'text()':
return element.text
assert False, f'apply_specials is missing case for {special!r}'
raise SyntaxError(f'apply_specials is missing case for {special!r}')
if xpath:
result = list(map(apply_specials, obj.iterfind(xpath)))