diff --git a/test/test_websockets.py b/test/test_websockets.py index 13b3a1e76..b294b0932 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -32,8 +32,6 @@ ) from yt_dlp.utils.networking import HTTPHeaderDict -from test.conftest import validate_and_send - TEST_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -66,7 +64,9 @@ def process_request(self, request): def create_websocket_server(**ws_kwargs): import websockets.sync.server - wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs) + wsd = websockets.sync.server.serve( + websocket_handler, '127.0.0.1', 0, + process_request=process_request, open_timeout=2, **ws_kwargs) ws_port = wsd.socket.getsockname()[1] ws_server_thread = threading.Thread(target=wsd.serve_forever) ws_server_thread.daemon = True @@ -100,6 +100,19 @@ def create_mtls_wss_websocket_server(): return create_websocket_server(ssl_context=sslctx) +def ws_validate_and_send(rh, req): + rh.validate(req) + max_tries = 3 + for i in range(max_tries): + try: + return rh.send(req) + except TransportError as e: + if i < (max_tries - 1) and 'connection closed during handshake' in str(e): + # websockets server sometimes hangs on new connections + continue + raise + + @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers') class TestWebsSocketRequestHandlerConformance: @classmethod @@ -119,7 +132,7 @@ def setup_class(cls): @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) def test_basic_websockets(self, handler): with handler() as rh: - ws = validate_and_send(rh, Request(self.ws_base_url)) + ws = ws_validate_and_send(rh, Request(self.ws_base_url)) assert 'upgrade' in ws.headers assert ws.status == 101 ws.send('foo') @@ -131,7 +144,7 @@ def test_basic_websockets(self, handler): @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) def test_send_types(self, handler, msg, opcode): with handler() as rh: - ws = validate_and_send(rh, Request(self.ws_base_url)) + ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send(msg) assert int(ws.recv()) == opcode ws.close() @@ -140,10 +153,10 @@ def test_send_types(self, handler, msg, opcode): def test_verify_cert(self, handler): with handler() as rh: with pytest.raises(CertificateVerifyError): - validate_and_send(rh, Request(self.wss_base_url)) + ws_validate_and_send(rh, Request(self.wss_base_url)) with handler(verify=False) as rh: - ws = validate_and_send(rh, Request(self.wss_base_url)) + ws = ws_validate_and_send(rh, Request(self.wss_base_url)) assert ws.status == 101 ws.close() @@ -151,7 +164,7 @@ def test_verify_cert(self, handler): def test_ssl_error(self, handler): with handler(verify=False) as rh: with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info: - validate_and_send(rh, Request(self.bad_wss_host)) + ws_validate_and_send(rh, Request(self.bad_wss_host)) assert not issubclass(exc_info.type, CertificateVerifyError) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @@ -163,7 +176,7 @@ def test_ssl_error(self, handler): ]) def test_percent_encode(self, handler, path, expected): with handler() as rh: - ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) + ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}')) ws.send('path') assert ws.recv() == expected assert ws.status == 101 @@ -174,7 +187,7 @@ def test_remove_dot_segments(self, handler): with handler() as rh: # This isn't a comprehensive test, # but it should be enough to check whether the handler is removing dot segments - ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) + ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test')) assert ws.status == 101 ws.send('path') assert ws.recv() == '/test' @@ -187,7 +200,7 @@ def test_remove_dot_segments(self, handler): def test_raise_http_error(self, handler, status): with handler() as rh: with pytest.raises(HTTPError) as exc_info: - validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) + ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}')) assert exc_info.value.status == status @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) @@ -198,7 +211,7 @@ def test_raise_http_error(self, handler, status): def test_timeout(self, handler, params, extensions): with handler(**params) as rh: with pytest.raises(TransportError): - validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) + ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions)) @pytest.mark.parametrize('handler', ['Websockets'], indirect=True) def test_cookies(self, handler): @@ -210,18 +223,18 @@ def test_cookies(self, handler): comment_url=None, rest={})) with handler(cookiejar=cookiejar) as rh: - ws = validate_and_send(rh, Request(self.ws_base_url)) + ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' ws.close() with handler() as rh: - ws = validate_and_send(rh, Request(self.ws_base_url)) + ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') assert 'cookie' not in json.loads(ws.recv()) ws.close() - ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) + ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar})) ws.send('headers') assert json.loads(ws.recv())['cookie'] == 'test=ytdlp' ws.close() @@ -231,7 +244,7 @@ def test_source_address(self, handler): source_address = f'127.0.0.{random.randint(5, 255)}' verify_address_availability(source_address) with handler(source_address=source_address) as rh: - ws = validate_and_send(rh, Request(self.ws_base_url)) + ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('source_address') assert source_address == ws.recv() ws.close() @@ -240,7 +253,7 @@ def test_source_address(self, handler): def test_response_url(self, handler): with handler() as rh: url = f'{self.ws_base_url}/something' - ws = validate_and_send(rh, Request(url)) + ws = ws_validate_and_send(rh, Request(url)) assert ws.url == url ws.close() @@ -248,14 +261,14 @@ def test_response_url(self, handler): def test_request_headers(self, handler): with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh: # Global Headers - ws = validate_and_send(rh, Request(self.ws_base_url)) + ws = ws_validate_and_send(rh, Request(self.ws_base_url)) ws.send('headers') headers = HTTPHeaderDict(json.loads(ws.recv())) assert headers['test1'] == 'test' ws.close() # Per request headers, merged with global - ws = validate_and_send(rh, Request( + ws = ws_validate_and_send(rh, Request( self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'})) ws.send('headers') headers = HTTPHeaderDict(json.loads(ws.recv())) @@ -288,7 +301,7 @@ def test_mtls(self, handler, client_cert): verify=False, client_cert=client_cert ) as rh: - validate_and_send(rh, Request(self.mtls_wss_base_url)).close() + ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close() def create_fake_ws_connection(raised):