diff --git a/test/test_InfoExtractor.py b/test/test_InfoExtractor.py index 31e8f8244..54f35ef55 100644 --- a/test/test_InfoExtractor.py +++ b/test/test_InfoExtractor.py @@ -53,6 +53,18 @@ def setUp(self): def test_ie_key(self): self.assertEqual(get_info_extractor(YoutubeIE.ie_key()), YoutubeIE) + def test_get_netrc_login_info(self): + for params in [ + {'usenetrc': True, 'netrc_location': './test/testdata/netrc/netrc'}, + {'netrc_cmd': f'{sys.executable} ./test/testdata/netrc/print_netrc.py'}, + ]: + ie = DummyIE(FakeYDL(params)) + self.assertEqual(ie._get_netrc_login_info(netrc_machine='normal_use'), ('user', 'pass')) + self.assertEqual(ie._get_netrc_login_info(netrc_machine='empty_user'), ('', 'pass')) + self.assertEqual(ie._get_netrc_login_info(netrc_machine='empty_pass'), ('user', '')) + self.assertEqual(ie._get_netrc_login_info(netrc_machine='both_empty'), ('', '')) + self.assertEqual(ie._get_netrc_login_info(netrc_machine='nonexistent'), (None, None)) + def test_html_search_regex(self): html = '

Watch this video

' search = lambda re, *args: self.ie._html_search_regex(re, html, *args) diff --git a/test/test_traversal.py b/test/test_traversal.py index 9179dadda..f1d123bd6 100644 --- a/test/test_traversal.py +++ b/test/test_traversal.py @@ -12,9 +12,10 @@ str_or_none, ) from yt_dlp.utils.traversal import ( - traverse_obj, require, subs_list_to_dict, + traverse_obj, + trim_str, ) _TEST_DATA = { @@ -495,6 +496,20 @@ def test_subs_list_to_dict(self): {'url': 'https://example.com/subs/en2', 'ext': 'ext'}, ]}, '`quality` key should sort subtitle list accordingly' + def test_trim_str(self): + with pytest.raises(TypeError): + trim_str('positional') + + assert callable(trim_str(start='a')) + assert trim_str(start='ab')('abc') == 'c' + assert trim_str(end='bc')('abc') == 'a' + assert trim_str(start='a', end='c')('abc') == 'b' + assert trim_str(start='ab', end='c')('abc') == '' + assert trim_str(start='a', end='bc')('abc') == '' + assert trim_str(start='ab', end='bc')('abc') == '' + assert trim_str(start='abc', end='abc')('abc') == '' + assert trim_str(start='', end='')('abc') == 'abc' + class TestDictGet: def test_dict_get(self): diff --git a/test/test_utils.py b/test/test_utils.py index d4b846f56..04f91547a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,6 +4,7 @@ import os import sys import unittest +import unittest.mock import warnings import datetime as dt @@ -71,6 +72,7 @@ intlist_to_bytes, iri_to_uri, is_html, + join_nonempty, js_to_json, limit_length, locked_file, @@ -343,11 +345,13 @@ def test_remove_start(self): self.assertEqual(remove_start(None, 'A - '), None) self.assertEqual(remove_start('A - B', 'A - '), 'B') self.assertEqual(remove_start('B - A', 'A - '), 'B - A') + self.assertEqual(remove_start('non-empty', ''), 'non-empty') def test_remove_end(self): self.assertEqual(remove_end(None, ' - B'), None) self.assertEqual(remove_end('A - B', ' - B'), 'A') self.assertEqual(remove_end('B - A', ' - B'), 'B - A') + self.assertEqual(remove_end('non-empty', ''), 'non-empty') def test_remove_quotes(self): self.assertEqual(remove_quotes(None), None) @@ -2148,6 +2152,16 @@ def run_shell(args): assert run_shell(args) == expected assert run_shell(shell_quote(args, shell=True)) == expected + def test_partial_application(self): + assert callable(int_or_none(scale=10)), 'missing positional parameter should apply partially' + assert int_or_none(10, scale=0.1) == 100, 'positionally passed argument should call function' + assert int_or_none(v=10) == 10, 'keyword passed positional should call function' + assert int_or_none(scale=0.1)(10) == 100, 'call after partial applicatino should call the function' + + assert callable(join_nonempty(delim=', ')), 'varargs positional should apply partially' + assert callable(join_nonempty()), 'varargs positional should apply partially' + assert join_nonempty(None, delim=', ') == '', 'passed varargs should call the function' + if __name__ == '__main__': unittest.main() diff --git a/test/testdata/netrc/netrc b/test/testdata/netrc/netrc new file mode 100644 index 000000000..bafe92fe6 --- /dev/null +++ b/test/testdata/netrc/netrc @@ -0,0 +1,4 @@ +machine normal_use login user password pass +machine empty_user login "" password pass +machine empty_pass login user password "" +machine both_empty login "" password "" diff --git a/test/testdata/netrc/print_netrc.py b/test/testdata/netrc/print_netrc.py new file mode 100644 index 000000000..5c25814f8 --- /dev/null +++ b/test/testdata/netrc/print_netrc.py @@ -0,0 +1,2 @@ +with open('./test/testdata/netrc/netrc', encoding='utf-8') as fp: + print(fp.read()) diff --git a/yt_dlp/extractor/common.py b/yt_dlp/extractor/common.py index ecece85f5..7e6e6227d 100644 --- a/yt_dlp/extractor/common.py +++ b/yt_dlp/extractor/common.py @@ -1409,6 +1409,13 @@ def _get_netrc_login_info(self, netrc_machine=None): return None, None self.write_debug(f'Using netrc for {netrc_machine} authentication') + + # compat: <=py3.10: netrc cannot parse tokens as empty strings, will return `""` instead + # Ref: https://github.com/yt-dlp/yt-dlp/issues/11413 + # https://github.com/python/cpython/commit/15409c720be0503131713e3d3abc1acd0da07378 + if sys.version_info < (3, 11): + return tuple(x if x != '""' else '' for x in info[::2]) + return info[0], info[2] def _get_login_info(self, username_option='username', password_option='password', netrc_machine=None): diff --git a/yt_dlp/extractor/twitter.py b/yt_dlp/extractor/twitter.py index 5adaf1639..8196ce6c3 100644 --- a/yt_dlp/extractor/twitter.py +++ b/yt_dlp/extractor/twitter.py @@ -150,14 +150,6 @@ def _search_dimensions_in_video_url(a_format, video_url): def is_logged_in(self): return bool(self._get_cookies(self._API_BASE).get('auth_token')) - # XXX: Temporary workaround until twitter.com => x.com migration is completed - def _real_initialize(self): - if self.is_logged_in or not self._get_cookies('https://twitter.com/').get('auth_token'): - return - # User has not yet been migrated to x.com and has passed twitter.com cookies - TwitterBaseIE._API_BASE = 'https://api.twitter.com/1.1/' - TwitterBaseIE._GRAPHQL_API_BASE = 'https://twitter.com/i/api/graphql/' - @functools.cached_property def _selected_api(self): return self._configuration_arg('api', ['graphql'], ie_key='Twitter')[0] diff --git a/yt_dlp/extractor/youtube.py b/yt_dlp/extractor/youtube.py index 5148e8261..99b8bfecc 100644 --- a/yt_dlp/extractor/youtube.py +++ b/yt_dlp/extractor/youtube.py @@ -644,13 +644,14 @@ def _initialize_oauth(self, user, refresh_token): YoutubeBaseInfoExtractor._OAUTH_ACCESS_TOKEN_CACHE[self._OAUTH_PROFILE] = {} if refresh_token: - refresh_token = refresh_token.strip('\'') or None - - # Allow refresh token passed to initialize cache - if refresh_token: + msg = f'{self._OAUTH_DISPLAY_ID}: Using password input as refresh token' + if self.get_param('cachedir') is not False: + msg += ' and caching token to disk; you should supply an empty password next time' + self.to_screen(msg) self.cache.store(self._NETRC_MACHINE, self._oauth_cache_key, refresh_token) + else: + refresh_token = self.cache.load(self._NETRC_MACHINE, self._oauth_cache_key) - refresh_token = refresh_token or self.cache.load(self._NETRC_MACHINE, self._oauth_cache_key) if refresh_token: YoutubeBaseInfoExtractor._OAUTH_ACCESS_TOKEN_CACHE[self._OAUTH_PROFILE]['refresh_token'] = refresh_token try: diff --git a/yt_dlp/utils/_utils.py b/yt_dlp/utils/_utils.py index 8535d2830..e30008e93 100644 --- a/yt_dlp/utils/_utils.py +++ b/yt_dlp/utils/_utils.py @@ -212,6 +212,23 @@ def write_json_file(obj, fn): raise +def partial_application(func): + sig = inspect.signature(func) + required_args = [ + param.name for param in sig.parameters.values() + if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + if param.default is inspect.Parameter.empty + ] + + @functools.wraps(func) + def wrapped(*args, **kwargs): + if set(required_args[len(args):]).difference(kwargs): + return functools.partial(func, *args, **kwargs) + return func(*args, **kwargs) + + return wrapped + + def find_xpath_attr(node, xpath, key, val=None): """ Find the xpath xpath[@key=val] """ assert re.match(r'^[a-zA-Z_-]+$', key) @@ -1192,6 +1209,7 @@ def extract_timezone(date_str, default=None): return timezone, date_str +@partial_application def parse_iso8601(date_str, delimiter='T', timezone=None): """ Return a UNIX timestamp from the given date """ @@ -1269,6 +1287,7 @@ def unified_timestamp(date_str, day_first=True): return calendar.timegm(timetuple) + pm_delta * 3600 - timezone.total_seconds() +@partial_application def determine_ext(url, default_ext='unknown_video'): if url is None or '.' not in url: return default_ext @@ -1944,7 +1963,7 @@ def remove_start(s, start): def remove_end(s, end): - return s[:-len(end)] if s is not None and s.endswith(end) else s + return s[:-len(end)] if s is not None and end and s.endswith(end) else s def remove_quotes(s): @@ -1973,6 +1992,7 @@ def base_url(url): return re.match(r'https?://[^?#]+/', url).group() +@partial_application def urljoin(base, path): if isinstance(path, bytes): path = path.decode() @@ -1988,21 +2008,6 @@ def urljoin(base, path): return urllib.parse.urljoin(base, path) -def partial_application(func): - sig = inspect.signature(func) - - @functools.wraps(func) - def wrapped(*args, **kwargs): - try: - sig.bind(*args, **kwargs) - except TypeError: - return functools.partial(func, *args, **kwargs) - else: - return func(*args, **kwargs) - - return wrapped - - @partial_application def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None): if get_attr and v is not None: @@ -2583,6 +2588,7 @@ def urlencode_postdata(*args, **kargs): return urllib.parse.urlencode(*args, **kargs).encode('ascii') +@partial_application def update_url(url, *, query_update=None, **kwargs): """Replace URL components specified by kwargs @param url str or parse url tuple @@ -2603,6 +2609,7 @@ def update_url(url, *, query_update=None, **kwargs): return urllib.parse.urlunparse(url._replace(**kwargs)) +@partial_application def update_url_query(url, query): return update_url(url, query_update=query) @@ -2924,6 +2931,7 @@ def error_to_str(err): return f'{type(err).__name__}: {err}' +@partial_application def mimetype2ext(mt, default=NO_DEFAULT): if not isinstance(mt, str): if default is not NO_DEFAULT: @@ -4664,6 +4672,7 @@ def to_high_limit_path(path): return path +@partial_application def format_field(obj, field=None, template='%s', ignore=NO_DEFAULT, default='', func=IDENTITY): val = traversal.traverse_obj(obj, *variadic(field)) if not val if ignore is NO_DEFAULT else val in variadic(ignore): @@ -4828,6 +4837,7 @@ def number_of_digits(number): return len('%d' % number) +@partial_application def join_nonempty(*values, delim='-', from_dict=None): if from_dict is not None: values = (traversal.traverse_obj(from_dict, variadic(v)) for v in values) @@ -5278,6 +5288,7 @@ def report_retry(e, count, retries, *, sleep_func, info, warn, error=None, suffi time.sleep(delay) +@partial_application def make_archive_id(ie, video_id): ie_key = ie if isinstance(ie, str) else ie.ie_key() return f'{ie_key.lower()} {video_id}' diff --git a/yt_dlp/utils/traversal.py b/yt_dlp/utils/traversal.py index 0eef817ea..dd9b4690b 100644 --- a/yt_dlp/utils/traversal.py +++ b/yt_dlp/utils/traversal.py @@ -435,6 +435,20 @@ def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False): return functools.partial(func, cls) +def trim_str(*, start=None, end=None): + def trim(s): + if s is None: + return None + start_idx = 0 + if start and s.startswith(start): + start_idx = len(start) + if end and s.endswith(end): + return s[start_idx:-len(end)] + return s[start_idx:] + + return trim + + def get_first(obj, *paths, **kwargs): return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)