mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-11-06 13:07:07 +01:00
[rh:websockets] Migrate websockets to networking framework (#7720)
* Adds a basic WebSocket framework * Introduces new minimum `websockets` version of 12.0 * Deprecates `WebSocketsWrapper` Fixes https://github.com/yt-dlp/yt-dlp/issues/8439 Authored by: coletdjnz
This commit is contained in:
parent
45d82be65f
commit
ccfd70f4c2
@ -6,3 +6,4 @@ brotlicffi; implementation_name!='cpython'
|
|||||||
certifi
|
certifi
|
||||||
requests>=2.31.0,<3
|
requests>=2.31.0,<3
|
||||||
urllib3>=1.26.17,<3
|
urllib3>=1.26.17,<3
|
||||||
|
websockets>=12.0
|
||||||
|
@ -19,3 +19,8 @@ def handler(request):
|
|||||||
pytest.skip(f'{RH_KEY} request handler is not available')
|
pytest.skip(f'{RH_KEY} request handler is not available')
|
||||||
|
|
||||||
return functools.partial(handler, logger=FakeLogger)
|
return functools.partial(handler, logger=FakeLogger)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_send(rh, req):
|
||||||
|
rh.validate(req)
|
||||||
|
return rh.send(req)
|
||||||
|
@ -52,6 +52,8 @@
|
|||||||
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
from yt_dlp.utils._utils import _YDLLogger as FakeLogger
|
||||||
from yt_dlp.utils.networking import HTTPHeaderDict
|
from yt_dlp.utils.networking import HTTPHeaderDict
|
||||||
|
|
||||||
|
from test.conftest import validate_and_send
|
||||||
|
|
||||||
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
@ -275,11 +277,6 @@ def send_header(self, keyword, value):
|
|||||||
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
|
self._headers_buffer.append(f'{keyword}: {value}\r\n'.encode())
|
||||||
|
|
||||||
|
|
||||||
def validate_and_send(rh, req):
|
|
||||||
rh.validate(req)
|
|
||||||
return rh.send(req)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRequestHandlerBase:
|
class TestRequestHandlerBase:
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
@ -872,8 +869,9 @@ def request(self, *args, **kwargs):
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Requests'], indirect=True)
|
||||||
def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
|
def test_response_error_mapping(self, handler, monkeypatch, raised, expected, match):
|
||||||
from urllib3.response import HTTPResponse as Urllib3Response
|
|
||||||
from requests.models import Response as RequestsResponse
|
from requests.models import Response as RequestsResponse
|
||||||
|
from urllib3.response import HTTPResponse as Urllib3Response
|
||||||
|
|
||||||
from yt_dlp.networking._requests import RequestsResponseAdapter
|
from yt_dlp.networking._requests import RequestsResponseAdapter
|
||||||
requests_res = RequestsResponse()
|
requests_res = RequestsResponse()
|
||||||
requests_res.raw = Urllib3Response(body=b'', status=200)
|
requests_res.raw = Urllib3Response(body=b'', status=200)
|
||||||
@ -929,13 +927,17 @@ class HTTPSupportedRH(ValidationRH):
|
|||||||
('http', False, {}),
|
('http', False, {}),
|
||||||
('https', False, {}),
|
('https', False, {}),
|
||||||
]),
|
]),
|
||||||
|
('Websockets', [
|
||||||
|
('ws', False, {}),
|
||||||
|
('wss', False, {}),
|
||||||
|
]),
|
||||||
(NoCheckRH, [('http', False, {})]),
|
(NoCheckRH, [('http', False, {})]),
|
||||||
(ValidationRH, [('http', UnsupportedRequest, {})])
|
(ValidationRH, [('http', UnsupportedRequest, {})])
|
||||||
]
|
]
|
||||||
|
|
||||||
PROXY_SCHEME_TESTS = [
|
PROXY_SCHEME_TESTS = [
|
||||||
# scheme, expected to fail
|
# scheme, expected to fail
|
||||||
('Urllib', [
|
('Urllib', 'http', [
|
||||||
('http', False),
|
('http', False),
|
||||||
('https', UnsupportedRequest),
|
('https', UnsupportedRequest),
|
||||||
('socks4', False),
|
('socks4', False),
|
||||||
@ -944,7 +946,7 @@ class HTTPSupportedRH(ValidationRH):
|
|||||||
('socks5h', False),
|
('socks5h', False),
|
||||||
('socks', UnsupportedRequest),
|
('socks', UnsupportedRequest),
|
||||||
]),
|
]),
|
||||||
('Requests', [
|
('Requests', 'http', [
|
||||||
('http', False),
|
('http', False),
|
||||||
('https', False),
|
('https', False),
|
||||||
('socks4', False),
|
('socks4', False),
|
||||||
@ -952,8 +954,11 @@ class HTTPSupportedRH(ValidationRH):
|
|||||||
('socks5', False),
|
('socks5', False),
|
||||||
('socks5h', False),
|
('socks5h', False),
|
||||||
]),
|
]),
|
||||||
(NoCheckRH, [('http', False)]),
|
(NoCheckRH, 'http', [('http', False)]),
|
||||||
(HTTPSupportedRH, [('http', UnsupportedRequest)]),
|
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
|
||||||
|
('Websockets', 'ws', [('http', UnsupportedRequest)]),
|
||||||
|
(NoCheckRH, 'http', [('http', False)]),
|
||||||
|
(HTTPSupportedRH, 'http', [('http', UnsupportedRequest)]),
|
||||||
]
|
]
|
||||||
|
|
||||||
PROXY_KEY_TESTS = [
|
PROXY_KEY_TESTS = [
|
||||||
@ -972,7 +977,7 @@ class HTTPSupportedRH(ValidationRH):
|
|||||||
]
|
]
|
||||||
|
|
||||||
EXTENSION_TESTS = [
|
EXTENSION_TESTS = [
|
||||||
('Urllib', [
|
('Urllib', 'http', [
|
||||||
({'cookiejar': 'notacookiejar'}, AssertionError),
|
({'cookiejar': 'notacookiejar'}, AssertionError),
|
||||||
({'cookiejar': YoutubeDLCookieJar()}, False),
|
({'cookiejar': YoutubeDLCookieJar()}, False),
|
||||||
({'cookiejar': CookieJar()}, AssertionError),
|
({'cookiejar': CookieJar()}, AssertionError),
|
||||||
@ -980,17 +985,21 @@ class HTTPSupportedRH(ValidationRH):
|
|||||||
({'timeout': 'notatimeout'}, AssertionError),
|
({'timeout': 'notatimeout'}, AssertionError),
|
||||||
({'unsupported': 'value'}, UnsupportedRequest),
|
({'unsupported': 'value'}, UnsupportedRequest),
|
||||||
]),
|
]),
|
||||||
('Requests', [
|
('Requests', 'http', [
|
||||||
({'cookiejar': 'notacookiejar'}, AssertionError),
|
({'cookiejar': 'notacookiejar'}, AssertionError),
|
||||||
({'cookiejar': YoutubeDLCookieJar()}, False),
|
({'cookiejar': YoutubeDLCookieJar()}, False),
|
||||||
({'timeout': 1}, False),
|
({'timeout': 1}, False),
|
||||||
({'timeout': 'notatimeout'}, AssertionError),
|
({'timeout': 'notatimeout'}, AssertionError),
|
||||||
({'unsupported': 'value'}, UnsupportedRequest),
|
({'unsupported': 'value'}, UnsupportedRequest),
|
||||||
]),
|
]),
|
||||||
(NoCheckRH, [
|
(NoCheckRH, 'http', [
|
||||||
({'cookiejar': 'notacookiejar'}, False),
|
({'cookiejar': 'notacookiejar'}, False),
|
||||||
({'somerandom': 'test'}, False), # but any extension is allowed through
|
({'somerandom': 'test'}, False), # but any extension is allowed through
|
||||||
]),
|
]),
|
||||||
|
('Websockets', 'ws', [
|
||||||
|
({'cookiejar': YoutubeDLCookieJar()}, False),
|
||||||
|
({'timeout': 2}, False),
|
||||||
|
]),
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
|
@pytest.mark.parametrize('handler,scheme,fail,handler_kwargs', [
|
||||||
@ -1016,14 +1025,14 @@ def test_proxy_key(self, handler, proxy_key, fail):
|
|||||||
run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
|
run_validation(handler, fail, Request('http://', proxies={proxy_key: 'http://example.com'}))
|
||||||
run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
|
run_validation(handler, fail, Request('http://'), proxies={proxy_key: 'http://example.com'})
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,scheme,fail', [
|
@pytest.mark.parametrize('handler,req_scheme,scheme,fail', [
|
||||||
(handler_tests[0], scheme, fail)
|
(handler_tests[0], handler_tests[1], scheme, fail)
|
||||||
for handler_tests in PROXY_SCHEME_TESTS
|
for handler_tests in PROXY_SCHEME_TESTS
|
||||||
for scheme, fail in handler_tests[1]
|
for scheme, fail in handler_tests[2]
|
||||||
], indirect=['handler'])
|
], indirect=['handler'])
|
||||||
def test_proxy_scheme(self, handler, scheme, fail):
|
def test_proxy_scheme(self, handler, req_scheme, scheme, fail):
|
||||||
run_validation(handler, fail, Request('http://', proxies={'http': f'{scheme}://example.com'}))
|
run_validation(handler, fail, Request(f'{req_scheme}://', proxies={req_scheme: f'{scheme}://example.com'}))
|
||||||
run_validation(handler, fail, Request('http://'), proxies={'http': f'{scheme}://example.com'})
|
run_validation(handler, fail, Request(f'{req_scheme}://'), proxies={req_scheme: f'{scheme}://example.com'})
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True)
|
@pytest.mark.parametrize('handler', ['Urllib', HTTPSupportedRH, 'Requests'], indirect=True)
|
||||||
def test_empty_proxy(self, handler):
|
def test_empty_proxy(self, handler):
|
||||||
@ -1035,14 +1044,14 @@ def test_empty_proxy(self, handler):
|
|||||||
def test_invalid_proxy_url(self, handler, proxy_url):
|
def test_invalid_proxy_url(self, handler, proxy_url):
|
||||||
run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
|
run_validation(handler, UnsupportedRequest, Request('http://', proxies={'http': proxy_url}))
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,extensions,fail', [
|
@pytest.mark.parametrize('handler,scheme,extensions,fail', [
|
||||||
(handler_tests[0], extensions, fail)
|
(handler_tests[0], handler_tests[1], extensions, fail)
|
||||||
for handler_tests in EXTENSION_TESTS
|
for handler_tests in EXTENSION_TESTS
|
||||||
for extensions, fail in handler_tests[1]
|
for extensions, fail in handler_tests[2]
|
||||||
], indirect=['handler'])
|
], indirect=['handler'])
|
||||||
def test_extension(self, handler, extensions, fail):
|
def test_extension(self, handler, scheme, extensions, fail):
|
||||||
run_validation(
|
run_validation(
|
||||||
handler, fail, Request('http://', extensions=extensions))
|
handler, fail, Request(f'{scheme}://', extensions=extensions))
|
||||||
|
|
||||||
def test_invalid_request_type(self):
|
def test_invalid_request_type(self):
|
||||||
rh = self.ValidationRH(logger=FakeLogger())
|
rh = self.ValidationRH(logger=FakeLogger())
|
||||||
@ -1075,6 +1084,22 @@ def __init__(self, *args, **kwargs):
|
|||||||
self._request_director = self.build_request_director([FakeRH])
|
self._request_director = self.build_request_director([FakeRH])
|
||||||
|
|
||||||
|
|
||||||
|
class AllUnsupportedRHYDL(FakeYDL):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
||||||
|
class UnsupportedRH(RequestHandler):
|
||||||
|
def _send(self, request: Request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
_SUPPORTED_FEATURES = ()
|
||||||
|
_SUPPORTED_PROXY_SCHEMES = ()
|
||||||
|
_SUPPORTED_URL_SCHEMES = ()
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._request_director = self.build_request_director([UnsupportedRH])
|
||||||
|
|
||||||
|
|
||||||
class TestRequestDirector:
|
class TestRequestDirector:
|
||||||
|
|
||||||
def test_handler_operations(self):
|
def test_handler_operations(self):
|
||||||
@ -1234,6 +1259,12 @@ def test_file_urls_error(self):
|
|||||||
with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
|
with pytest.raises(RequestError, match=r'file:// URLs are disabled by default'):
|
||||||
ydl.urlopen('file://')
|
ydl.urlopen('file://')
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('scheme', (['ws', 'wss']))
|
||||||
|
def test_websocket_unavailable_error(self, scheme):
|
||||||
|
with AllUnsupportedRHYDL() as ydl:
|
||||||
|
with pytest.raises(RequestError, match=r'This request requires WebSocket support'):
|
||||||
|
ydl.urlopen(f'{scheme}://')
|
||||||
|
|
||||||
def test_legacy_server_connect_error(self):
|
def test_legacy_server_connect_error(self):
|
||||||
with FakeRHYDL() as ydl:
|
with FakeRHYDL() as ydl:
|
||||||
for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):
|
for error in ('UNSAFE_LEGACY_RENEGOTIATION_DISABLED', 'SSLV3_ALERT_HANDSHAKE_FAILURE'):
|
||||||
|
@ -210,6 +210,16 @@ def do_GET(self):
|
|||||||
self.wfile.write(payload.encode())
|
self.wfile.write(payload.encode())
|
||||||
|
|
||||||
|
|
||||||
|
class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
|
||||||
|
def handle(self):
|
||||||
|
import websockets.sync.server
|
||||||
|
protocol = websockets.ServerProtocol()
|
||||||
|
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
|
||||||
|
connection.handshake()
|
||||||
|
connection.send(json.dumps(self.socks_info))
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
|
def socks_server(socks_server_class, request_handler, bind_ip=None, **socks_server_kwargs):
|
||||||
server = server_thread = None
|
server = server_thread = None
|
||||||
@ -252,8 +262,22 @@ def socks_info_request(self, handler, target_domain=None, target_port=None, **re
|
|||||||
return json.loads(handler.send(request).read().decode())
|
return json.loads(handler.send(request).read().decode())
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketSocksTestProxyContext(SocksProxyTestContext):
|
||||||
|
REQUEST_HANDLER_CLASS = SocksWebSocketTestRequestHandler
|
||||||
|
|
||||||
|
def socks_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
|
||||||
|
request = Request(f'ws://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
|
||||||
|
handler.validate(request)
|
||||||
|
ws = handler.send(request)
|
||||||
|
ws.send('socks_info')
|
||||||
|
socks_info = ws.recv()
|
||||||
|
ws.close()
|
||||||
|
return json.loads(socks_info)
|
||||||
|
|
||||||
|
|
||||||
CTX_MAP = {
|
CTX_MAP = {
|
||||||
'http': HTTPSocksTestProxyContext,
|
'http': HTTPSocksTestProxyContext,
|
||||||
|
'ws': WebSocketSocksTestProxyContext,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -263,7 +287,7 @@ def ctx(request):
|
|||||||
|
|
||||||
|
|
||||||
class TestSocks4Proxy:
|
class TestSocks4Proxy:
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks4_no_auth(self, handler, ctx):
|
def test_socks4_no_auth(self, handler, ctx):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
@ -271,7 +295,7 @@ def test_socks4_no_auth(self, handler, ctx):
|
|||||||
rh, proxies={'all': f'socks4://{server_address}'})
|
rh, proxies={'all': f'socks4://{server_address}'})
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks4_auth(self, handler, ctx):
|
def test_socks4_auth(self, handler, ctx):
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
|
with ctx.socks_server(Socks4ProxyHandler, user_id='user') as server_address:
|
||||||
@ -281,7 +305,7 @@ def test_socks4_auth(self, handler, ctx):
|
|||||||
rh, proxies={'all': f'socks4://user:@{server_address}'})
|
rh, proxies={'all': f'socks4://user:@{server_address}'})
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks4a_ipv4_target(self, handler, ctx):
|
def test_socks4a_ipv4_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
||||||
@ -289,7 +313,7 @@ def test_socks4a_ipv4_target(self, handler, ctx):
|
|||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
|
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks4a_domain_target(self, handler, ctx):
|
def test_socks4a_domain_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
|
||||||
@ -298,7 +322,7 @@ def test_socks4a_domain_target(self, handler, ctx):
|
|||||||
assert response['ipv4_address'] is None
|
assert response['ipv4_address'] is None
|
||||||
assert response['domain_address'] == 'localhost'
|
assert response['domain_address'] == 'localhost'
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_ipv4_client_source_address(self, handler, ctx):
|
def test_ipv4_client_source_address(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler) as server_address:
|
||||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
@ -308,7 +332,7 @@ def test_ipv4_client_source_address(self, handler, ctx):
|
|||||||
assert response['client_address'][0] == source_address
|
assert response['client_address'][0] == source_address
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
@pytest.mark.parametrize('reply_code', [
|
@pytest.mark.parametrize('reply_code', [
|
||||||
Socks4CD.REQUEST_REJECTED_OR_FAILED,
|
Socks4CD.REQUEST_REJECTED_OR_FAILED,
|
||||||
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
|
Socks4CD.REQUEST_REJECTED_CANNOT_CONNECT_TO_IDENTD,
|
||||||
@ -320,7 +344,7 @@ def test_socks4_errors(self, handler, ctx, reply_code):
|
|||||||
with pytest.raises(ProxyError):
|
with pytest.raises(ProxyError):
|
||||||
ctx.socks_info_request(rh)
|
ctx.socks_info_request(rh)
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_ipv6_socks4_proxy(self, handler, ctx):
|
def test_ipv6_socks4_proxy(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
|
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
|
||||||
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
|
||||||
@ -329,7 +353,7 @@ def test_ipv6_socks4_proxy(self, handler, ctx):
|
|||||||
assert response['ipv4_address'] == '127.0.0.1'
|
assert response['ipv4_address'] == '127.0.0.1'
|
||||||
assert response['version'] == 4
|
assert response['version'] == 4
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_timeout(self, handler, ctx):
|
def test_timeout(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
|
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
|
||||||
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
|
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
|
||||||
@ -339,7 +363,7 @@ def test_timeout(self, handler, ctx):
|
|||||||
|
|
||||||
class TestSocks5Proxy:
|
class TestSocks5Proxy:
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5_no_auth(self, handler, ctx):
|
def test_socks5_no_auth(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
@ -347,7 +371,7 @@ def test_socks5_no_auth(self, handler, ctx):
|
|||||||
assert response['auth_methods'] == [0x0]
|
assert response['auth_methods'] == [0x0]
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5_user_pass(self, handler, ctx):
|
def test_socks5_user_pass(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler, auth=('test', 'testpass')) as server_address:
|
||||||
with handler() as rh:
|
with handler() as rh:
|
||||||
@ -360,7 +384,7 @@ def test_socks5_user_pass(self, handler, ctx):
|
|||||||
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
|
assert response['auth_methods'] == [Socks5Auth.AUTH_NONE, Socks5Auth.AUTH_USER_PASS]
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5_ipv4_target(self, handler, ctx):
|
def test_socks5_ipv4_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
@ -368,7 +392,7 @@ def test_socks5_ipv4_target(self, handler, ctx):
|
|||||||
assert response['ipv4_address'] == '127.0.0.1'
|
assert response['ipv4_address'] == '127.0.0.1'
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5_domain_target(self, handler, ctx):
|
def test_socks5_domain_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
@ -376,7 +400,7 @@ def test_socks5_domain_target(self, handler, ctx):
|
|||||||
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
|
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5h_domain_target(self, handler, ctx):
|
def test_socks5h_domain_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
||||||
@ -385,7 +409,7 @@ def test_socks5h_domain_target(self, handler, ctx):
|
|||||||
assert response['domain_address'] == 'localhost'
|
assert response['domain_address'] == 'localhost'
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5h_ip_target(self, handler, ctx):
|
def test_socks5h_ip_target(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5h://{server_address}'}) as rh:
|
||||||
@ -394,7 +418,7 @@ def test_socks5h_ip_target(self, handler, ctx):
|
|||||||
assert response['domain_address'] is None
|
assert response['domain_address'] is None
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_socks5_ipv6_destination(self, handler, ctx):
|
def test_socks5_ipv6_destination(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
@ -402,7 +426,7 @@ def test_socks5_ipv6_destination(self, handler, ctx):
|
|||||||
assert response['ipv6_address'] == '::1'
|
assert response['ipv6_address'] == '::1'
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_ipv6_socks5_proxy(self, handler, ctx):
|
def test_ipv6_socks5_proxy(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
|
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
|
||||||
@ -413,7 +437,7 @@ def test_ipv6_socks5_proxy(self, handler, ctx):
|
|||||||
|
|
||||||
# XXX: is there any feasible way of testing IPv6 source addresses?
|
# XXX: is there any feasible way of testing IPv6 source addresses?
|
||||||
# Same would go for non-proxy source_address test...
|
# Same would go for non-proxy source_address test...
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_ipv4_client_source_address(self, handler, ctx):
|
def test_ipv4_client_source_address(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler) as server_address:
|
||||||
source_address = f'127.0.0.{random.randint(5, 255)}'
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
@ -422,7 +446,7 @@ def test_ipv4_client_source_address(self, handler, ctx):
|
|||||||
assert response['client_address'][0] == source_address
|
assert response['client_address'][0] == source_address
|
||||||
assert response['version'] == 5
|
assert response['version'] == 5
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Requests', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
@pytest.mark.parametrize('reply_code', [
|
@pytest.mark.parametrize('reply_code', [
|
||||||
Socks5Reply.GENERAL_FAILURE,
|
Socks5Reply.GENERAL_FAILURE,
|
||||||
Socks5Reply.CONNECTION_NOT_ALLOWED,
|
Socks5Reply.CONNECTION_NOT_ALLOWED,
|
||||||
@ -439,7 +463,7 @@ def test_socks5_errors(self, handler, ctx, reply_code):
|
|||||||
with pytest.raises(ProxyError):
|
with pytest.raises(ProxyError):
|
||||||
ctx.socks_info_request(rh)
|
ctx.socks_info_request(rh)
|
||||||
|
|
||||||
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
|
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http'), ('Websockets', 'ws')], indirect=True)
|
||||||
def test_timeout(self, handler, ctx):
|
def test_timeout(self, handler, ctx):
|
||||||
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
|
with ctx.socks_server(Socks5ProxyHandler, sleep=2) as server_address:
|
||||||
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
|
with handler(proxies={'all': f'socks5://{server_address}'}, timeout=1) as rh:
|
||||||
|
380
test/test_websockets.py
Normal file
380
test/test_websockets.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Allow direct execution
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
import http.client
|
||||||
|
import http.cookiejar
|
||||||
|
import http.server
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import ssl
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from yt_dlp import socks
|
||||||
|
from yt_dlp.cookies import YoutubeDLCookieJar
|
||||||
|
from yt_dlp.dependencies import websockets
|
||||||
|
from yt_dlp.networking import Request
|
||||||
|
from yt_dlp.networking.exceptions import (
|
||||||
|
CertificateVerifyError,
|
||||||
|
HTTPError,
|
||||||
|
ProxyError,
|
||||||
|
RequestError,
|
||||||
|
SSLError,
|
||||||
|
TransportError,
|
||||||
|
)
|
||||||
|
from yt_dlp.utils.networking import HTTPHeaderDict
|
||||||
|
|
||||||
|
from test.conftest import validate_and_send
|
||||||
|
|
||||||
|
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def websocket_handler(websocket):
|
||||||
|
for message in websocket:
|
||||||
|
if isinstance(message, bytes):
|
||||||
|
if message == b'bytes':
|
||||||
|
return websocket.send('2')
|
||||||
|
elif isinstance(message, str):
|
||||||
|
if message == 'headers':
|
||||||
|
return websocket.send(json.dumps(dict(websocket.request.headers)))
|
||||||
|
elif message == 'path':
|
||||||
|
return websocket.send(websocket.request.path)
|
||||||
|
elif message == 'source_address':
|
||||||
|
return websocket.send(websocket.remote_address[0])
|
||||||
|
elif message == 'str':
|
||||||
|
return websocket.send('1')
|
||||||
|
return websocket.send(message)
|
||||||
|
|
||||||
|
|
||||||
|
def process_request(self, request):
|
||||||
|
if request.path.startswith('/gen_'):
|
||||||
|
status = http.HTTPStatus(int(request.path[5:]))
|
||||||
|
if 300 <= status.value <= 300:
|
||||||
|
return websockets.http11.Response(
|
||||||
|
status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
|
||||||
|
return self.protocol.reject(status.value, status.phrase)
|
||||||
|
return self.protocol.accept(request)
|
||||||
|
|
||||||
|
|
||||||
|
def create_websocket_server(**ws_kwargs):
|
||||||
|
import websockets.sync.server
|
||||||
|
wsd = websockets.sync.server.serve(websocket_handler, '127.0.0.1', 0, process_request=process_request, **ws_kwargs)
|
||||||
|
ws_port = wsd.socket.getsockname()[1]
|
||||||
|
ws_server_thread = threading.Thread(target=wsd.serve_forever)
|
||||||
|
ws_server_thread.daemon = True
|
||||||
|
ws_server_thread.start()
|
||||||
|
return ws_server_thread, ws_port
|
||||||
|
|
||||||
|
|
||||||
|
def create_ws_websocket_server():
|
||||||
|
return create_websocket_server()
|
||||||
|
|
||||||
|
|
||||||
|
def create_wss_websocket_server():
|
||||||
|
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
||||||
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
sslctx.load_cert_chain(certfn, None)
|
||||||
|
return create_websocket_server(ssl_context=sslctx)
|
||||||
|
|
||||||
|
|
||||||
|
MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
|
||||||
|
|
||||||
|
|
||||||
|
def create_mtls_wss_websocket_server():
|
||||||
|
certfn = os.path.join(TEST_DIR, 'testcert.pem')
|
||||||
|
cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
|
||||||
|
|
||||||
|
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
sslctx.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
sslctx.load_verify_locations(cafile=cacertfn)
|
||||||
|
sslctx.load_cert_chain(certfn, None)
|
||||||
|
|
||||||
|
return create_websocket_server(ssl_context=sslctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
|
||||||
|
class TestWebsSocketRequestHandlerConformance:
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
cls.ws_thread, cls.ws_port = create_ws_websocket_server()
|
||||||
|
cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
|
||||||
|
|
||||||
|
cls.wss_thread, cls.wss_port = create_wss_websocket_server()
|
||||||
|
cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
|
||||||
|
|
||||||
|
cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
|
||||||
|
cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
|
||||||
|
|
||||||
|
cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
|
||||||
|
cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_basic_websockets(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||||
|
assert 'upgrade' in ws.headers
|
||||||
|
assert ws.status == 101
|
||||||
|
ws.send('foo')
|
||||||
|
assert ws.recv() == 'foo'
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
# https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
|
||||||
|
@pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_send_types(self, handler, msg, opcode):
|
||||||
|
with handler() as rh:
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||||
|
ws.send(msg)
|
||||||
|
assert int(ws.recv()) == opcode
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_verify_cert(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
with pytest.raises(CertificateVerifyError):
|
||||||
|
validate_and_send(rh, Request(self.wss_base_url))
|
||||||
|
|
||||||
|
with handler(verify=False) as rh:
|
||||||
|
ws = validate_and_send(rh, Request(self.wss_base_url))
|
||||||
|
assert ws.status == 101
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_ssl_error(self, handler):
|
||||||
|
with handler(verify=False) as rh:
|
||||||
|
with pytest.raises(SSLError, match='sslv3 alert handshake failure') as exc_info:
|
||||||
|
validate_and_send(rh, Request(self.bad_wss_host))
|
||||||
|
assert not issubclass(exc_info.type, CertificateVerifyError)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
@pytest.mark.parametrize('path,expected', [
|
||||||
|
# Unicode characters should be encoded with uppercase percent-encoding
|
||||||
|
('/中文', '/%E4%B8%AD%E6%96%87'),
|
||||||
|
# don't normalize existing percent encodings
|
||||||
|
('/%c7%9f', '/%c7%9f'),
|
||||||
|
])
|
||||||
|
def test_percent_encode(self, handler, path, expected):
|
||||||
|
with handler() as rh:
|
||||||
|
ws = validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
|
||||||
|
ws.send('path')
|
||||||
|
assert ws.recv() == expected
|
||||||
|
assert ws.status == 101
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_remove_dot_segments(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
# This isn't a comprehensive test,
|
||||||
|
# but it should be enough to check whether the handler is removing dot segments
|
||||||
|
ws = validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
|
||||||
|
assert ws.status == 101
|
||||||
|
ws.send('path')
|
||||||
|
assert ws.recv() == '/test'
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
# We are restricted to known HTTP status codes in http.HTTPStatus
|
||||||
|
# Redirects are not supported for websockets
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
@pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
|
||||||
|
def test_raise_http_error(self, handler, status):
|
||||||
|
with handler() as rh:
|
||||||
|
with pytest.raises(HTTPError) as exc_info:
|
||||||
|
validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
|
||||||
|
assert exc_info.value.status == status
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
@pytest.mark.parametrize('params,extensions', [
|
||||||
|
({'timeout': 0.00001}, {}),
|
||||||
|
({}, {'timeout': 0.00001}),
|
||||||
|
])
|
||||||
|
def test_timeout(self, handler, params, extensions):
|
||||||
|
with handler(**params) as rh:
|
||||||
|
with pytest.raises(TransportError):
|
||||||
|
validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_cookies(self, handler):
|
||||||
|
cookiejar = YoutubeDLCookieJar()
|
||||||
|
cookiejar.set_cookie(http.cookiejar.Cookie(
|
||||||
|
version=0, name='test', value='ytdlp', port=None, port_specified=False,
|
||||||
|
domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
|
||||||
|
path_specified=True, secure=False, expires=None, discard=False, comment=None,
|
||||||
|
comment_url=None, rest={}))
|
||||||
|
|
||||||
|
with handler(cookiejar=cookiejar) as rh:
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||||
|
ws.send('headers')
|
||||||
|
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
with handler() as rh:
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||||
|
ws.send('headers')
|
||||||
|
assert 'cookie' not in json.loads(ws.recv())
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
|
||||||
|
ws.send('headers')
|
||||||
|
assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_source_address(self, handler):
|
||||||
|
source_address = f'127.0.0.{random.randint(5, 255)}'
|
||||||
|
with handler(source_address=source_address) as rh:
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||||
|
ws.send('source_address')
|
||||||
|
assert source_address == ws.recv()
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_response_url(self, handler):
|
||||||
|
with handler() as rh:
|
||||||
|
url = f'{self.ws_base_url}/something'
|
||||||
|
ws = validate_and_send(rh, Request(url))
|
||||||
|
assert ws.url == url
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_request_headers(self, handler):
|
||||||
|
with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
|
||||||
|
# Global Headers
|
||||||
|
ws = validate_and_send(rh, Request(self.ws_base_url))
|
||||||
|
ws.send('headers')
|
||||||
|
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
||||||
|
assert headers['test1'] == 'test'
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
# Per request headers, merged with global
|
||||||
|
ws = validate_and_send(rh, Request(
|
||||||
|
self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
|
||||||
|
ws.send('headers')
|
||||||
|
headers = HTTPHeaderDict(json.loads(ws.recv()))
|
||||||
|
assert headers['test1'] == 'test'
|
||||||
|
assert headers['test2'] == 'changed'
|
||||||
|
assert headers['test3'] == 'test3'
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('client_cert', (
|
||||||
|
{'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
|
||||||
|
{
|
||||||
|
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
|
||||||
|
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
|
||||||
|
'client_certificate_password': 'foobar',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
|
||||||
|
'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
|
||||||
|
'client_certificate_password': 'foobar',
|
||||||
|
}
|
||||||
|
))
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
def test_mtls(self, handler, client_cert):
|
||||||
|
with handler(
|
||||||
|
# Disable client-side validation of unacceptable self-signed testcert.pem
|
||||||
|
# The test is of a check on the server side, so unaffected
|
||||||
|
verify=False,
|
||||||
|
client_cert=client_cert
|
||||||
|
) as rh:
|
||||||
|
validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_fake_ws_connection(raised):
|
||||||
|
import websockets.sync.client
|
||||||
|
|
||||||
|
class FakeWsConnection(websockets.sync.client.ClientConnection):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
class FakeResponse:
|
||||||
|
body = b''
|
||||||
|
headers = {}
|
||||||
|
status_code = 101
|
||||||
|
reason_phrase = 'test'
|
||||||
|
|
||||||
|
self.response = FakeResponse()
|
||||||
|
|
||||||
|
def send(self, *args, **kwargs):
|
||||||
|
raise raised()
|
||||||
|
|
||||||
|
def recv(self, *args, **kwargs):
|
||||||
|
raise raised()
|
||||||
|
|
||||||
|
def close(self, *args, **kwargs):
|
||||||
|
return
|
||||||
|
|
||||||
|
return FakeWsConnection()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
|
||||||
|
class TestWebsocketsRequestHandler:
|
||||||
|
@pytest.mark.parametrize('raised,expected', [
|
||||||
|
# https://websockets.readthedocs.io/en/stable/reference/exceptions.html
|
||||||
|
(lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
|
||||||
|
# Requires a response object. Should be covered by HTTP error tests.
|
||||||
|
# (lambda: websockets.exceptions.InvalidStatus(), TransportError),
|
||||||
|
(lambda: websockets.exceptions.InvalidHandshake(), TransportError),
|
||||||
|
# These are subclasses of InvalidHandshake
|
||||||
|
(lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
|
||||||
|
(lambda: websockets.exceptions.NegotiationError(), TransportError),
|
||||||
|
# Catch-all
|
||||||
|
(lambda: websockets.exceptions.WebSocketException(), TransportError),
|
||||||
|
(lambda: TimeoutError(), TransportError),
|
||||||
|
# These may be raised by our create_connection implementation, which should also be caught
|
||||||
|
(lambda: OSError(), TransportError),
|
||||||
|
(lambda: ssl.SSLError(), SSLError),
|
||||||
|
(lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
|
||||||
|
(lambda: socks.ProxyError(), ProxyError),
|
||||||
|
])
|
||||||
|
def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
|
||||||
|
import websockets.sync.client
|
||||||
|
|
||||||
|
import yt_dlp.networking._websockets
|
||||||
|
with handler() as rh:
|
||||||
|
def fake_connect(*args, **kwargs):
|
||||||
|
raise raised()
|
||||||
|
monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
|
||||||
|
with pytest.raises(expected) as exc_info:
|
||||||
|
rh.send(Request('ws://fake-url'))
|
||||||
|
assert exc_info.type is expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('raised,expected,match', [
|
||||||
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
|
||||||
|
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
|
||||||
|
(lambda: RuntimeError(), TransportError, None),
|
||||||
|
(lambda: TimeoutError(), TransportError, None),
|
||||||
|
(lambda: TypeError(), RequestError, None),
|
||||||
|
(lambda: socks.ProxyError(), ProxyError, None),
|
||||||
|
# Catch-all
|
||||||
|
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
|
||||||
|
])
|
||||||
|
def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
|
||||||
|
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
|
||||||
|
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
|
||||||
|
with pytest.raises(expected, match=match) as exc_info:
|
||||||
|
ws.send('test')
|
||||||
|
assert exc_info.type is expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('raised,expected,match', [
|
||||||
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
|
||||||
|
(lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
|
||||||
|
(lambda: RuntimeError(), TransportError, None),
|
||||||
|
(lambda: TimeoutError(), TransportError, None),
|
||||||
|
(lambda: socks.ProxyError(), ProxyError, None),
|
||||||
|
# Catch-all
|
||||||
|
(lambda: websockets.exceptions.WebSocketException(), TransportError, None),
|
||||||
|
])
|
||||||
|
def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
|
||||||
|
from yt_dlp.networking._websockets import WebsocketsResponseAdapter
|
||||||
|
ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
|
||||||
|
with pytest.raises(expected, match=match) as exc_info:
|
||||||
|
ws.recv()
|
||||||
|
assert exc_info.type is expected
|
@ -4052,6 +4052,7 @@ def urlopen(self, req):
|
|||||||
return self._request_director.send(req)
|
return self._request_director.send(req)
|
||||||
except NoSupportingHandlers as e:
|
except NoSupportingHandlers as e:
|
||||||
for ue in e.unsupported_errors:
|
for ue in e.unsupported_errors:
|
||||||
|
# FIXME: This depends on the order of errors.
|
||||||
if not (ue.handler and ue.msg):
|
if not (ue.handler and ue.msg):
|
||||||
continue
|
continue
|
||||||
if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
|
if ue.handler.RH_KEY == 'Urllib' and 'unsupported url scheme: "file"' in ue.msg.lower():
|
||||||
@ -4061,6 +4062,15 @@ def urlopen(self, req):
|
|||||||
if 'unsupported proxy type: "https"' in ue.msg.lower():
|
if 'unsupported proxy type: "https"' in ue.msg.lower():
|
||||||
raise RequestError(
|
raise RequestError(
|
||||||
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests')
|
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests')
|
||||||
|
|
||||||
|
elif (
|
||||||
|
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
|
||||||
|
and 'websockets' not in self._request_director.handlers
|
||||||
|
):
|
||||||
|
raise RequestError(
|
||||||
|
'This request requires WebSocket support. '
|
||||||
|
'Ensure one of the following dependencies are installed: websockets',
|
||||||
|
cause=ue) from ue
|
||||||
raise
|
raise
|
||||||
except SSLError as e:
|
except SSLError as e:
|
||||||
if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):
|
if 'UNSAFE_LEGACY_RENEGOTIATION_DISABLED' in str(e):
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
from .common import FileDownloader
|
from .common import FileDownloader
|
||||||
from .external import FFmpegFD
|
from .external import FFmpegFD
|
||||||
from ..networking import Request
|
from ..networking import Request
|
||||||
from ..utils import DownloadError, WebSocketsWrapper, str_or_none, try_get
|
from ..utils import DownloadError, str_or_none, try_get
|
||||||
|
|
||||||
|
|
||||||
class NiconicoDmcFD(FileDownloader):
|
class NiconicoDmcFD(FileDownloader):
|
||||||
@ -64,7 +64,6 @@ def real_download(self, filename, info_dict):
|
|||||||
ws_url = info_dict['url']
|
ws_url = info_dict['url']
|
||||||
ws_extractor = info_dict['ws']
|
ws_extractor = info_dict['ws']
|
||||||
ws_origin_host = info_dict['origin']
|
ws_origin_host = info_dict['origin']
|
||||||
cookies = info_dict.get('cookies')
|
|
||||||
live_quality = info_dict.get('live_quality', 'high')
|
live_quality = info_dict.get('live_quality', 'high')
|
||||||
live_latency = info_dict.get('live_latency', 'high')
|
live_latency = info_dict.get('live_latency', 'high')
|
||||||
dl = FFmpegFD(self.ydl, self.params or {})
|
dl = FFmpegFD(self.ydl, self.params or {})
|
||||||
@ -76,12 +75,7 @@ def real_download(self, filename, info_dict):
|
|||||||
|
|
||||||
def communicate_ws(reconnect):
|
def communicate_ws(reconnect):
|
||||||
if reconnect:
|
if reconnect:
|
||||||
ws = WebSocketsWrapper(ws_url, {
|
ws = self.ydl.urlopen(Request(ws_url, headers={'Origin': f'https://{ws_origin_host}'}))
|
||||||
'Cookies': str_or_none(cookies) or '',
|
|
||||||
'Origin': f'https://{ws_origin_host}',
|
|
||||||
'Accept': '*/*',
|
|
||||||
'User-Agent': self.params['http_headers']['User-Agent'],
|
|
||||||
})
|
|
||||||
if self.ydl.params.get('verbose', False):
|
if self.ydl.params.get('verbose', False):
|
||||||
self.to_screen('[debug] Sending startWatching request')
|
self.to_screen('[debug] Sending startWatching request')
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
|
@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
from .common import InfoExtractor
|
from .common import InfoExtractor
|
||||||
from ..compat import compat_parse_qs
|
from ..compat import compat_parse_qs
|
||||||
from ..dependencies import websockets
|
|
||||||
from ..networking import Request
|
from ..networking import Request
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
ExtractorError,
|
ExtractorError,
|
||||||
WebSocketsWrapper,
|
|
||||||
js_to_json,
|
js_to_json,
|
||||||
traverse_obj,
|
traverse_obj,
|
||||||
update_url_query,
|
update_url_query,
|
||||||
@ -167,8 +165,6 @@ class FC2LiveIE(InfoExtractor):
|
|||||||
}]
|
}]
|
||||||
|
|
||||||
def _real_extract(self, url):
|
def _real_extract(self, url):
|
||||||
if not websockets:
|
|
||||||
raise ExtractorError('websockets library is not available. Please install it.', expected=True)
|
|
||||||
video_id = self._match_id(url)
|
video_id = self._match_id(url)
|
||||||
webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id)
|
webpage = self._download_webpage('https://live.fc2.com/%s/' % video_id, video_id)
|
||||||
|
|
||||||
@ -199,13 +195,9 @@ def _real_extract(self, url):
|
|||||||
ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']})
|
ws_url = update_url_query(control_server['url'], {'control_token': control_server['control_token']})
|
||||||
playlist_data = None
|
playlist_data = None
|
||||||
|
|
||||||
self.to_screen('%s: Fetching HLS playlist info via WebSocket' % video_id)
|
ws = self._request_webpage(Request(ws_url, headers={
|
||||||
ws = WebSocketsWrapper(ws_url, {
|
|
||||||
'Cookie': str(self._get_cookies('https://live.fc2.com/'))[12:],
|
|
||||||
'Origin': 'https://live.fc2.com',
|
'Origin': 'https://live.fc2.com',
|
||||||
'Accept': '*/*',
|
}), video_id, note='Fetching HLS playlist info via WebSocket')
|
||||||
'User-Agent': self.get_param('http_headers')['User-Agent'],
|
|
||||||
})
|
|
||||||
|
|
||||||
self.write_debug('Sending HLS server request')
|
self.write_debug('Sending HLS server request')
|
||||||
|
|
||||||
|
@ -8,12 +8,11 @@
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from .common import InfoExtractor, SearchInfoExtractor
|
from .common import InfoExtractor, SearchInfoExtractor
|
||||||
from ..dependencies import websockets
|
from ..networking import Request
|
||||||
from ..networking.exceptions import HTTPError
|
from ..networking.exceptions import HTTPError
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
ExtractorError,
|
ExtractorError,
|
||||||
OnDemandPagedList,
|
OnDemandPagedList,
|
||||||
WebSocketsWrapper,
|
|
||||||
bug_reports_message,
|
bug_reports_message,
|
||||||
clean_html,
|
clean_html,
|
||||||
float_or_none,
|
float_or_none,
|
||||||
@ -934,8 +933,6 @@ class NiconicoLiveIE(InfoExtractor):
|
|||||||
_KNOWN_LATENCY = ('high', 'low')
|
_KNOWN_LATENCY = ('high', 'low')
|
||||||
|
|
||||||
def _real_extract(self, url):
|
def _real_extract(self, url):
|
||||||
if not websockets:
|
|
||||||
raise ExtractorError('websockets library is not available. Please install it.', expected=True)
|
|
||||||
video_id = self._match_id(url)
|
video_id = self._match_id(url)
|
||||||
webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id)
|
webpage, urlh = self._download_webpage_handle(f'https://live.nicovideo.jp/watch/{video_id}', video_id)
|
||||||
|
|
||||||
@ -950,17 +947,13 @@ def _real_extract(self, url):
|
|||||||
})
|
})
|
||||||
|
|
||||||
hostname = remove_start(urlparse(urlh.url).hostname, 'sp.')
|
hostname = remove_start(urlparse(urlh.url).hostname, 'sp.')
|
||||||
cookies = try_get(urlh.url, self._downloader._calc_cookies)
|
|
||||||
latency = try_get(self._configuration_arg('latency'), lambda x: x[0])
|
latency = try_get(self._configuration_arg('latency'), lambda x: x[0])
|
||||||
if latency not in self._KNOWN_LATENCY:
|
if latency not in self._KNOWN_LATENCY:
|
||||||
latency = 'high'
|
latency = 'high'
|
||||||
|
|
||||||
ws = WebSocketsWrapper(ws_url, {
|
ws = self._request_webpage(
|
||||||
'Cookies': str_or_none(cookies) or '',
|
Request(ws_url, headers={'Origin': f'https://{hostname}'}),
|
||||||
'Origin': f'https://{hostname}',
|
video_id=video_id, note='Connecting to WebSocket server')
|
||||||
'Accept': '*/*',
|
|
||||||
'User-Agent': self.get_param('http_headers')['User-Agent'],
|
|
||||||
})
|
|
||||||
|
|
||||||
self.write_debug('[debug] Sending HLS server request')
|
self.write_debug('[debug] Sending HLS server request')
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
@ -1034,7 +1027,6 @@ def _real_extract(self, url):
|
|||||||
'protocol': 'niconico_live',
|
'protocol': 'niconico_live',
|
||||||
'ws': ws,
|
'ws': ws,
|
||||||
'video_id': video_id,
|
'video_id': video_id,
|
||||||
'cookies': cookies,
|
|
||||||
'live_latency': latency,
|
'live_latency': latency,
|
||||||
'origin': hostname,
|
'origin': hostname,
|
||||||
})
|
})
|
||||||
|
@ -21,3 +21,11 @@
|
|||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
|
warnings.warn(f'Failed to import "requests" request handler: {e}' + bug_reports_message())
|
||||||
|
|
||||||
|
try:
|
||||||
|
from . import _websockets
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
warnings.warn(f'Failed to import "websockets" request handler: {e}' + bug_reports_message())
|
||||||
|
|
||||||
|
159
yt_dlp/networking/_websockets.py
Normal file
159
yt_dlp/networking/_websockets.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import ssl
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from ._helper import create_connection, select_proxy, make_socks_proxy_opts, create_socks_proxy_socket
|
||||||
|
from .common import Response, register_rh, Features
|
||||||
|
from .exceptions import (
|
||||||
|
CertificateVerifyError,
|
||||||
|
HTTPError,
|
||||||
|
RequestError,
|
||||||
|
SSLError,
|
||||||
|
TransportError, ProxyError,
|
||||||
|
)
|
||||||
|
from .websocket import WebSocketRequestHandler, WebSocketResponse
|
||||||
|
from ..compat import functools
|
||||||
|
from ..dependencies import websockets
|
||||||
|
from ..utils import int_or_none
|
||||||
|
from ..socks import ProxyError as SocksProxyError
|
||||||
|
|
||||||
|
if not websockets:
|
||||||
|
raise ImportError('websockets is not installed')
|
||||||
|
|
||||||
|
import websockets.version
|
||||||
|
|
||||||
|
websockets_version = tuple(map(int_or_none, websockets.version.version.split('.')))
|
||||||
|
if websockets_version < (12, 0):
|
||||||
|
raise ImportError('Only websockets>=12.0 is supported')
|
||||||
|
|
||||||
|
import websockets.sync.client
|
||||||
|
from websockets.uri import parse_uri
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketsResponseAdapter(WebSocketResponse):
|
||||||
|
|
||||||
|
def __init__(self, wsw: websockets.sync.client.ClientConnection, url):
|
||||||
|
super().__init__(
|
||||||
|
fp=io.BytesIO(wsw.response.body or b''),
|
||||||
|
url=url,
|
||||||
|
headers=wsw.response.headers,
|
||||||
|
status=wsw.response.status_code,
|
||||||
|
reason=wsw.response.reason_phrase,
|
||||||
|
)
|
||||||
|
self.wsw = wsw
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.wsw.close()
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
def send(self, message):
|
||||||
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
|
||||||
|
try:
|
||||||
|
return self.wsw.send(message)
|
||||||
|
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
|
||||||
|
raise TransportError(cause=e) from e
|
||||||
|
except SocksProxyError as e:
|
||||||
|
raise ProxyError(cause=e) from e
|
||||||
|
except TypeError as e:
|
||||||
|
raise RequestError(cause=e) from e
|
||||||
|
|
||||||
|
def recv(self):
|
||||||
|
# https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
|
||||||
|
try:
|
||||||
|
return self.wsw.recv()
|
||||||
|
except SocksProxyError as e:
|
||||||
|
raise ProxyError(cause=e) from e
|
||||||
|
except (websockets.exceptions.WebSocketException, RuntimeError, TimeoutError) as e:
|
||||||
|
raise TransportError(cause=e) from e
|
||||||
|
|
||||||
|
|
||||||
|
@register_rh
|
||||||
|
class WebsocketsRH(WebSocketRequestHandler):
|
||||||
|
"""
|
||||||
|
Websockets request handler
|
||||||
|
https://websockets.readthedocs.io
|
||||||
|
https://github.com/python-websockets/websockets
|
||||||
|
"""
|
||||||
|
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
|
||||||
|
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
|
||||||
|
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
|
||||||
|
RH_NAME = 'websockets'
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
for name in ('websockets.client', 'websockets.server'):
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
handler = logging.StreamHandler(stream=sys.stdout)
|
||||||
|
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
|
||||||
|
logger.addHandler(handler)
|
||||||
|
if self.verbose:
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
def _check_extensions(self, extensions):
|
||||||
|
super()._check_extensions(extensions)
|
||||||
|
extensions.pop('timeout', None)
|
||||||
|
extensions.pop('cookiejar', None)
|
||||||
|
|
||||||
|
def _send(self, request):
|
||||||
|
timeout = float(request.extensions.get('timeout') or self.timeout)
|
||||||
|
headers = self._merge_headers(request.headers)
|
||||||
|
if 'cookie' not in headers:
|
||||||
|
cookiejar = request.extensions.get('cookiejar') or self.cookiejar
|
||||||
|
cookie_header = cookiejar.get_cookie_header(request.url)
|
||||||
|
if cookie_header:
|
||||||
|
headers['cookie'] = cookie_header
|
||||||
|
|
||||||
|
wsuri = parse_uri(request.url)
|
||||||
|
create_conn_kwargs = {
|
||||||
|
'source_address': (self.source_address, 0) if self.source_address else None,
|
||||||
|
'timeout': timeout
|
||||||
|
}
|
||||||
|
proxy = select_proxy(request.url, request.proxies or self.proxies or {})
|
||||||
|
try:
|
||||||
|
if proxy:
|
||||||
|
socks_proxy_options = make_socks_proxy_opts(proxy)
|
||||||
|
sock = create_connection(
|
||||||
|
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
|
||||||
|
_create_socket_func=functools.partial(
|
||||||
|
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
|
||||||
|
**create_conn_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sock = create_connection(
|
||||||
|
address=(wsuri.host, wsuri.port),
|
||||||
|
**create_conn_kwargs
|
||||||
|
)
|
||||||
|
conn = websockets.sync.client.connect(
|
||||||
|
sock=sock,
|
||||||
|
uri=request.url,
|
||||||
|
additional_headers=headers,
|
||||||
|
open_timeout=timeout,
|
||||||
|
user_agent_header=None,
|
||||||
|
ssl_context=self._make_sslcontext() if wsuri.secure else None,
|
||||||
|
close_timeout=0, # not ideal, but prevents yt-dlp hanging
|
||||||
|
)
|
||||||
|
return WebsocketsResponseAdapter(conn, url=request.url)
|
||||||
|
|
||||||
|
# Exceptions as per https://websockets.readthedocs.io/en/stable/reference/sync/client.html
|
||||||
|
except SocksProxyError as e:
|
||||||
|
raise ProxyError(cause=e) from e
|
||||||
|
except websockets.exceptions.InvalidURI as e:
|
||||||
|
raise RequestError(cause=e) from e
|
||||||
|
except ssl.SSLCertVerificationError as e:
|
||||||
|
raise CertificateVerifyError(cause=e) from e
|
||||||
|
except ssl.SSLError as e:
|
||||||
|
raise SSLError(cause=e) from e
|
||||||
|
except websockets.exceptions.InvalidStatus as e:
|
||||||
|
raise HTTPError(
|
||||||
|
Response(
|
||||||
|
fp=io.BytesIO(e.response.body),
|
||||||
|
url=request.url,
|
||||||
|
headers=e.response.headers,
|
||||||
|
status=e.response.status_code,
|
||||||
|
reason=e.response.reason_phrase),
|
||||||
|
) from e
|
||||||
|
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
|
||||||
|
raise TransportError(cause=e) from e
|
23
yt_dlp/networking/websocket.py
Normal file
23
yt_dlp/networking/websocket.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
|
||||||
|
from .common import Response, RequestHandler
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketResponse(Response):
|
||||||
|
|
||||||
|
def send(self, message: bytes | str):
|
||||||
|
"""
|
||||||
|
Send a message to the server.
|
||||||
|
|
||||||
|
@param message: The message to send. A string (str) is sent as a text frame, bytes is sent as a binary frame.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def recv(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketRequestHandler(RequestHandler, abc.ABC):
|
||||||
|
pass
|
@ -1,4 +1,6 @@
|
|||||||
"""No longer used and new code should not use. Exists only for API compat."""
|
"""No longer used and new code should not use. Exists only for API compat."""
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
import platform
|
import platform
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
@ -32,6 +34,77 @@
|
|||||||
has_websockets = bool(websockets)
|
has_websockets = bool(websockets)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketsWrapper:
|
||||||
|
"""Wraps websockets module to use in non-async scopes"""
|
||||||
|
pool = None
|
||||||
|
|
||||||
|
def __init__(self, url, headers=None, connect=True, **ws_kwargs):
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
# XXX: "loop" is deprecated
|
||||||
|
self.conn = websockets.connect(
|
||||||
|
url, extra_headers=headers, ping_interval=None,
|
||||||
|
close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'), **ws_kwargs)
|
||||||
|
if connect:
|
||||||
|
self.__enter__()
|
||||||
|
atexit.register(self.__exit__, None, None, None)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if not self.pool:
|
||||||
|
self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def send(self, *args):
|
||||||
|
self.run_with_loop(self.pool.send(*args), self.loop)
|
||||||
|
|
||||||
|
def recv(self, *args):
|
||||||
|
return self.run_with_loop(self.pool.recv(*args), self.loop)
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
try:
|
||||||
|
return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
|
||||||
|
finally:
|
||||||
|
self.loop.close()
|
||||||
|
self._cancel_all_tasks(self.loop)
|
||||||
|
|
||||||
|
# taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
|
||||||
|
# for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
|
||||||
|
@staticmethod
|
||||||
|
def run_with_loop(main, loop):
|
||||||
|
if not asyncio.iscoroutine(main):
|
||||||
|
raise ValueError(f'a coroutine was expected, got {main!r}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(main)
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
if hasattr(loop, 'shutdown_default_executor'):
|
||||||
|
loop.run_until_complete(loop.shutdown_default_executor())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cancel_all_tasks(loop):
|
||||||
|
to_cancel = asyncio.all_tasks(loop)
|
||||||
|
|
||||||
|
if not to_cancel:
|
||||||
|
return
|
||||||
|
|
||||||
|
for task in to_cancel:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# XXX: "loop" is removed in python 3.10+
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
|
||||||
|
|
||||||
|
for task in to_cancel:
|
||||||
|
if task.cancelled():
|
||||||
|
continue
|
||||||
|
if task.exception() is not None:
|
||||||
|
loop.call_exception_handler({
|
||||||
|
'message': 'unhandled exception during asyncio.run() shutdown',
|
||||||
|
'exception': task.exception(),
|
||||||
|
'task': task,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def load_plugins(name, suffix, namespace):
|
def load_plugins(name, suffix, namespace):
|
||||||
from ..plugins import load_plugins
|
from ..plugins import load_plugins
|
||||||
ret = load_plugins(name, suffix)
|
ret = load_plugins(name, suffix)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import atexit
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import calendar
|
import calendar
|
||||||
@ -54,7 +52,7 @@
|
|||||||
compat_os_name,
|
compat_os_name,
|
||||||
compat_shlex_quote,
|
compat_shlex_quote,
|
||||||
)
|
)
|
||||||
from ..dependencies import websockets, xattr
|
from ..dependencies import xattr
|
||||||
|
|
||||||
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module
|
__name__ = __name__.rsplit('.', 1)[0] # Pretend to be the parent module
|
||||||
|
|
||||||
@ -4923,77 +4921,6 @@ def parse_args(self):
|
|||||||
return self.parser.parse_args(self.all_args)
|
return self.parser.parse_args(self.all_args)
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsWrapper:
|
|
||||||
"""Wraps websockets module to use in non-async scopes"""
|
|
||||||
pool = None
|
|
||||||
|
|
||||||
def __init__(self, url, headers=None, connect=True):
|
|
||||||
self.loop = asyncio.new_event_loop()
|
|
||||||
# XXX: "loop" is deprecated
|
|
||||||
self.conn = websockets.connect(
|
|
||||||
url, extra_headers=headers, ping_interval=None,
|
|
||||||
close_timeout=float('inf'), loop=self.loop, ping_timeout=float('inf'))
|
|
||||||
if connect:
|
|
||||||
self.__enter__()
|
|
||||||
atexit.register(self.__exit__, None, None, None)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if not self.pool:
|
|
||||||
self.pool = self.run_with_loop(self.conn.__aenter__(), self.loop)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def send(self, *args):
|
|
||||||
self.run_with_loop(self.pool.send(*args), self.loop)
|
|
||||||
|
|
||||||
def recv(self, *args):
|
|
||||||
return self.run_with_loop(self.pool.recv(*args), self.loop)
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
try:
|
|
||||||
return self.run_with_loop(self.conn.__aexit__(type, value, traceback), self.loop)
|
|
||||||
finally:
|
|
||||||
self.loop.close()
|
|
||||||
self._cancel_all_tasks(self.loop)
|
|
||||||
|
|
||||||
# taken from https://github.com/python/cpython/blob/3.9/Lib/asyncio/runners.py with modifications
|
|
||||||
# for contributors: If there's any new library using asyncio needs to be run in non-async, move these function out of this class
|
|
||||||
@staticmethod
|
|
||||||
def run_with_loop(main, loop):
|
|
||||||
if not asyncio.iscoroutine(main):
|
|
||||||
raise ValueError(f'a coroutine was expected, got {main!r}')
|
|
||||||
|
|
||||||
try:
|
|
||||||
return loop.run_until_complete(main)
|
|
||||||
finally:
|
|
||||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
||||||
if hasattr(loop, 'shutdown_default_executor'):
|
|
||||||
loop.run_until_complete(loop.shutdown_default_executor())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _cancel_all_tasks(loop):
|
|
||||||
to_cancel = asyncio.all_tasks(loop)
|
|
||||||
|
|
||||||
if not to_cancel:
|
|
||||||
return
|
|
||||||
|
|
||||||
for task in to_cancel:
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
# XXX: "loop" is removed in python 3.10+
|
|
||||||
loop.run_until_complete(
|
|
||||||
asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
|
|
||||||
|
|
||||||
for task in to_cancel:
|
|
||||||
if task.cancelled():
|
|
||||||
continue
|
|
||||||
if task.exception() is not None:
|
|
||||||
loop.call_exception_handler({
|
|
||||||
'message': 'unhandled exception during asyncio.run() shutdown',
|
|
||||||
'exception': task.exception(),
|
|
||||||
'task': task,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def merge_headers(*dicts):
|
def merge_headers(*dicts):
|
||||||
"""Merge dicts of http headers case insensitively, prioritizing the latter ones"""
|
"""Merge dicts of http headers case insensitively, prioritizing the latter ones"""
|
||||||
return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
|
return {k.title(): v for k, v in itertools.chain.from_iterable(map(dict.items, dicts))}
|
||||||
|
Loading…
Reference in New Issue
Block a user