diff --git a/test/test_networking.py b/test/test_networking.py index 8cadd86f5..10534242a 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -13,6 +13,7 @@ import http.client import http.cookiejar import http.server import io +import logging import pathlib import random import ssl @@ -752,6 +753,25 @@ class TestClientCertificate: }) +class TestRequestHandlerMisc: + """Misc generic tests for request handlers, not related to request or validation testing""" + @pytest.mark.parametrize('handler,logger_name', [ + ('Requests', 'urllib3'), + ('Websockets', 'websockets.client'), + ('Websockets', 'websockets.server') + ], indirect=['handler']) + def test_remove_logging_handler(self, handler, logger_name): + # Ensure any logging handlers, which may contain a YoutubeDL instance, + # are removed when we close the request handler + # See: https://github.com/yt-dlp/yt-dlp/issues/8922 + logging_handlers = logging.getLogger(logger_name).handlers + before_count = len(logging_handlers) + rh = handler() + assert len(logging_handlers) == before_count + 1 + rh.close() + assert len(logging_handlers) == before_count + + class TestUrllibRequestHandler(TestRequestHandlerBase): @pytest.mark.parametrize('handler', ['Urllib'], indirect=True) def test_file_urls(self, handler): @@ -827,6 +847,7 @@ class TestUrllibRequestHandler(TestRequestHandlerBase): assert not isinstance(exc_info.value, TransportError) +@pytest.mark.parametrize('handler', ['Requests'], indirect=True) class TestRequestsRequestHandler(TestRequestHandlerBase): @pytest.mark.parametrize('raised,expected', [ (lambda: requests.exceptions.ConnectTimeout(), TransportError), @@ -843,7 +864,6 @@ class TestRequestsRequestHandler(TestRequestHandlerBase): (lambda: requests.exceptions.RequestException(), RequestError) # (lambda: requests.exceptions.TooManyRedirects(), HTTPError) - Needs a response object ]) - @pytest.mark.parametrize('handler', ['Requests'], indirect=True) def test_request_error_mapping(self, handler, monkeypatch, raised, expected): with handler() as rh: def mock_get_instance(*args, **kwargs): @@ -877,7 +897,6 @@ class TestRequestsRequestHandler(TestRequestHandlerBase): '3 bytes read, 5 more expected' ), ]) - @pytest.mark.parametrize('handler', ['Requests'], indirect=True) def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match): from requests.models import Response as RequestsResponse from urllib3.response import HTTPResponse as Urllib3Response @@ -896,6 +915,21 @@ class TestRequestsRequestHandler(TestRequestHandlerBase): assert exc_info.type is expected + def test_close(self, handler, monkeypatch): + rh = handler() + session = rh._get_instance(cookiejar=rh.cookiejar) + called = False + original_close = session.close + + def mock_close(*args, **kwargs): + nonlocal called + called = True + return original_close(*args, **kwargs) + + monkeypatch.setattr(session, 'close', mock_close) + rh.close() + assert called + def run_validation(handler, error, req, **handler_kwargs): with handler(**handler_kwargs) as rh: @@ -1205,6 +1239,19 @@ class TestRequestDirector: assert director.send(Request('http://')).read() == b'' assert director.send(Request('http://', headers={'prefer': '1'})).read() == b'supported' + def test_close(self, monkeypatch): + director = RequestDirector(logger=FakeLogger()) + director.add_handler(FakeRH(logger=FakeLogger())) + called = False + + def mock_close(*args, **kwargs): + nonlocal called + called = True + + monkeypatch.setattr(director.handlers[FakeRH.RH_KEY], 'close', mock_close) + director.close() + assert called + # XXX: do we want to move this to test_YoutubeDL.py? class TestYoutubeDLNetworking: diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 00e4bdb49..7b19029bf 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -258,10 +258,10 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): # Forward urllib3 debug messages to our logger logger = logging.getLogger('urllib3') - handler = Urllib3LoggingHandler(logger=self._logger) - handler.setFormatter(logging.Formatter('requests: %(message)s')) - handler.addFilter(Urllib3LoggingFilter()) - logger.addHandler(handler) + self.__logging_handler = Urllib3LoggingHandler(logger=self._logger) + self.__logging_handler.setFormatter(logging.Formatter('requests: %(message)s')) + self.__logging_handler.addFilter(Urllib3LoggingFilter()) + logger.addHandler(self.__logging_handler) # TODO: Use a logger filter to suppress pool reuse warning instead logger.setLevel(logging.ERROR) @@ -276,6 +276,9 @@ class RequestsRH(RequestHandler, InstanceStoreMixin): def close(self): self._clear_instances() + # Remove the logging handler that contains a reference to our logger + # See: https://github.com/yt-dlp/yt-dlp/issues/8922 + logging.getLogger('urllib3').removeHandler(self.__logging_handler) def _check_extensions(self, extensions): super()._check_extensions(extensions) diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index ed64080d6..159793204 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -90,10 +90,12 @@ class WebsocketsRH(WebSocketRequestHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.__logging_handlers = {} for name in ('websockets.client', 'websockets.server'): logger = logging.getLogger(name) handler = logging.StreamHandler(stream=sys.stdout) handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) + self.__logging_handlers[name] = handler logger.addHandler(handler) if self.verbose: logger.setLevel(logging.DEBUG) @@ -103,6 +105,12 @@ class WebsocketsRH(WebSocketRequestHandler): extensions.pop('timeout', None) extensions.pop('cookiejar', None) + def close(self): + # Remove the logging handler that contains a reference to our logger + # See: https://github.com/yt-dlp/yt-dlp/issues/8922 + for name, handler in self.__logging_handlers.items(): + logging.getLogger(name).removeHandler(handler) + def _send(self, request): timeout = float(request.extensions.get('timeout') or self.timeout) headers = self._merge_headers(request.headers)