diff --git a/test/test_utils.py b/test/test_utils.py
index edc712f07..3cdb21d40 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -34,6 +34,9 @@
find_xpath_attr,
fix_xml_ampersands,
get_element_by_class,
+ get_element_by_attribute,
+ get_elements_by_class,
+ get_elements_by_attribute,
InAdvancePagedList,
intlist_to_bytes,
is_html,
@@ -1124,6 +1127,32 @@ def test_get_element_by_class(self):
self.assertEqual(get_element_by_class('foo', html), 'nice')
self.assertEqual(get_element_by_class('no-such-class', html), None)
+ def test_get_element_by_attribute(self):
+ html = '''
+ nice
+ '''
+
+ self.assertEqual(get_element_by_attribute('class', 'foo bar', html), 'nice')
+ self.assertEqual(get_element_by_attribute('class', 'foo', html), None)
+ self.assertEqual(get_element_by_attribute('class', 'no-such-foo', html), None)
+
+ def test_get_elements_by_class(self):
+ html = '''
+ nicealso nice
+ '''
+
+ self.assertEqual(get_elements_by_class('foo', html), ['nice', 'also nice'])
+ self.assertEqual(get_elements_by_class('no-such-class', html), [])
+
+ def test_get_elements_by_attribute(self):
+ html = '''
+ nicealso nice
+ '''
+
+ self.assertEqual(get_elements_by_attribute('class', 'foo bar', html), ['nice', 'also nice'])
+ self.assertEqual(get_elements_by_attribute('class', 'foo', html), [])
+ self.assertEqual(get_elements_by_attribute('class', 'no-such-foo', html), [])
+
if __name__ == '__main__':
unittest.main()
diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py
index 67a847eba..a81fe7d30 100644
--- a/youtube_dl/utils.py
+++ b/youtube_dl/utils.py
@@ -337,17 +337,30 @@ def get_element_by_id(id, html):
def get_element_by_class(class_name, html):
- return get_element_by_attribute(
+ """Return the content of the first tag with the specified class in the passed HTML document"""
+ retval = get_elements_by_class(class_name, html)
+ return retval[0] if retval else None
+
+
+def get_element_by_attribute(attribute, value, html, escape_value=True):
+ retval = get_elements_by_attribute(attribute, value, html, escape_value)
+ return retval[0] if retval else None
+
+
+def get_elements_by_class(class_name, html):
+ """Return the content of all tags with the specified class in the passed HTML document as a list"""
+ return get_elements_by_attribute(
'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
html, escape_value=False)
-def get_element_by_attribute(attribute, value, html, escape_value=True):
+def get_elements_by_attribute(attribute, value, html, escape_value=True):
"""Return the content of the tag with the specified attribute in the passed HTML document"""
value = re.escape(value) if escape_value else value
- m = re.search(r'''(?xs)
+ retlist = []
+ for m in re.finditer(r'''(?xs)
<([a-zA-Z0-9:._-]+)
(?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
\s+%s=['"]?%s['"]?
@@ -355,16 +368,15 @@ def get_element_by_attribute(attribute, value, html, escape_value=True):
\s*>
(?P.*?)
\1>
- ''' % (re.escape(attribute), value), html)
+ ''' % (re.escape(attribute), value), html):
+ res = m.group('content')
- if not m:
- return None
- res = m.group('content')
+ if res.startswith('"') or res.startswith("'"):
+ res = res[1:-1]
- if res.startswith('"') or res.startswith("'"):
- res = res[1:-1]
+ retlist.append(unescapeHTML(res))
- return unescapeHTML(res)
+ return retlist
class HTMLAttributeParser(compat_HTMLParser):