Compare commits

...

3 Commits

Author SHA1 Message Date
bashonly 6c7bfb5c7c
oops
Authored by: bashonly
2024-04-23 18:28:43 -05:00
bashonly 805f40a001
add tests
Authored by: bashonly
2024-04-23 18:22:58 -05:00
bashonly b77ccbc67b
new impl
Authored by: bashonly
2024-04-23 18:04:12 -05:00
4 changed files with 45 additions and 28 deletions

View File

@ -785,6 +785,25 @@ class TestHTTPImpersonateRequestHandler(TestRequestHandlerBase):
assert res.status == 200 assert res.status == 200
assert std_headers['user-agent'].lower() not in res.read().decode().lower() assert std_headers['user-agent'].lower() not in res.read().decode().lower()
def test_response_extensions(self, handler):
with handler() as rh:
for target in rh.supported_targets:
request = Request(
f'http://127.0.0.1:{self.http_port}/gen_200', extensions={'impersonate': target})
res = validate_and_send(rh, request)
assert res.extensions['impersonate'] == rh._get_request_target(request)
def test_http_error_response_extensions(self, handler):
with handler() as rh:
for target in rh.supported_targets:
request = Request(
f'http://127.0.0.1:{self.http_port}/gen_404', extensions={'impersonate': target})
try:
validate_and_send(rh, request)
except HTTPError as e:
res = e.response
assert res.extensions['impersonate'] == rh._get_request_target(request)
class TestRequestHandlerMisc: class TestRequestHandlerMisc:
"""Misc generic tests for request handlers, not related to request or validation testing""" """Misc generic tests for request handlers, not related to request or validation testing"""
@ -1013,7 +1032,7 @@ class TestCurlCFFIRequestHandler(TestRequestHandlerBase):
from yt_dlp.networking._curlcffi import CurlCFFIResponseAdapter from yt_dlp.networking._curlcffi import CurlCFFIResponseAdapter
curl_res = curl_cffi.requests.Response() curl_res = curl_cffi.requests.Response()
res = CurlCFFIResponseAdapter(curl_res, None) res = CurlCFFIResponseAdapter(curl_res)
def mock_read(*args, **kwargs): def mock_read(*args, **kwargs):
try: try:

View File

@ -5,7 +5,13 @@ import math
import urllib.parse import urllib.parse
from ._helper import InstanceStoreMixin, select_proxy from ._helper import InstanceStoreMixin, select_proxy
from .common import Features, Request, register_preference, register_rh from .common import (
Features,
Request,
Response,
register_preference,
register_rh,
)
from .exceptions import ( from .exceptions import (
CertificateVerifyError, CertificateVerifyError,
HTTPError, HTTPError,
@ -14,11 +20,7 @@ from .exceptions import (
SSLError, SSLError,
TransportError, TransportError,
) )
from .impersonate import ( from .impersonate import ImpersonateRequestHandler, ImpersonateTarget
ImpersonateRequestHandler,
ImpersonateResponse,
ImpersonateTarget,
)
from ..dependencies import curl_cffi from ..dependencies import curl_cffi
from ..utils import int_or_none from ..utils import int_or_none
@ -78,16 +80,15 @@ class CurlCFFIResponseReader(io.IOBase):
super().close() super().close()
class CurlCFFIResponseAdapter(ImpersonateResponse): class CurlCFFIResponseAdapter(Response):
fp: CurlCFFIResponseReader fp: CurlCFFIResponseReader
def __init__(self, response: curl_cffi.requests.Response, impersonate: ImpersonateTarget | None): def __init__(self, response: curl_cffi.requests.Response):
super().__init__( super().__init__(
fp=CurlCFFIResponseReader(response), fp=CurlCFFIResponseReader(response),
headers=response.headers, headers=response.headers,
url=response.url, url=response.url,
status=response.status_code, status=response.status_code)
impersonate=impersonate)
def read(self, amt=None): def read(self, amt=None):
try: try:
@ -131,6 +132,15 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
extensions.pop('cookiejar', None) extensions.pop('cookiejar', None)
extensions.pop('timeout', None) extensions.pop('timeout', None)
def send(self, request: Request) -> Response:
try:
response = super().send(request)
except HTTPError as e:
e.response.extensions['impersonate'] = self._get_request_target(request)
raise
response.extensions['impersonate'] = self._get_request_target(request)
return response
def _send(self, request: Request): def _send(self, request: Request):
max_redirects_exceeded = False max_redirects_exceeded = False
session: curl_cffi.requests.Session = self._get_instance( session: curl_cffi.requests.Session = self._get_instance(
@ -177,8 +187,6 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
session.curl.setopt(CurlOpt.LOW_SPEED_LIMIT, 1) # 1 byte per second session.curl.setopt(CurlOpt.LOW_SPEED_LIMIT, 1) # 1 byte per second
session.curl.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(timeout)) session.curl.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(timeout))
impersonate_target = self._get_request_target(request)
try: try:
curl_response = session.request( curl_response = session.request(
method=request.method, method=request.method,
@ -188,7 +196,8 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
verify=self.verify, verify=self.verify,
max_redirects=5, max_redirects=5,
timeout=timeout, timeout=timeout,
impersonate=self._SUPPORTED_IMPERSONATE_TARGET_MAP.get(impersonate_target), impersonate=self._SUPPORTED_IMPERSONATE_TARGET_MAP.get(
self._get_request_target(request)),
interface=self.source_address, interface=self.source_address,
stream=True stream=True
) )
@ -208,7 +217,7 @@ class CurlCFFIRH(ImpersonateRequestHandler, InstanceStoreMixin):
else: else:
raise TransportError(cause=e) from e raise TransportError(cause=e) from e
response = CurlCFFIResponseAdapter(curl_response, impersonate_target) response = CurlCFFIResponseAdapter(curl_response)
if not 200 <= response.status < 300: if not 200 <= response.status < 300:
raise HTTPError(response, redirect_loop=max_redirects_exceeded) raise HTTPError(response, redirect_loop=max_redirects_exceeded)

View File

@ -517,6 +517,7 @@ class Response(io.IOBase):
self.reason = reason or HTTPStatus(status).phrase self.reason = reason or HTTPStatus(status).phrase
except ValueError: except ValueError:
self.reason = None self.reason = None
self.extensions = {}
def readable(self): def readable(self):
return self.fp.readable() return self.fp.readable()

View File

@ -5,7 +5,7 @@ from abc import ABC
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from .common import RequestHandler, Response, register_preference from .common import RequestHandler, register_preference
from .exceptions import UnsupportedRequest from .exceptions import UnsupportedRequest
from ..compat.types import NoneType from ..compat.types import NoneType
from ..utils import classproperty, join_nonempty from ..utils import classproperty, join_nonempty
@ -58,18 +58,6 @@ class ImpersonateTarget:
return cls(**mobj.groupdict()) return cls(**mobj.groupdict())
class ImpersonateResponse(Response):
"""
Wrapper class for a Response to an impersonated Request.
Parameters:
@param impersonate: the ImpersonateTarget used in the originating Request.
"""
def __init__(self, *args, impersonate: ImpersonateTarget | None, **kwargs):
super().__init__(*args, **kwargs)
self.impersonate = impersonate
class ImpersonateRequestHandler(RequestHandler, ABC): class ImpersonateRequestHandler(RequestHandler, ABC):
""" """
Base class for request handlers that support browser impersonation. Base class for request handlers that support browser impersonation.