diff --git a/test/conftest.py b/test/conftest.py index a8b92f811..96b86c328 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -23,7 +23,7 @@ class HandlerWrapper(handler): RH_KEY = handler.RH_KEY def __init__(self, **kwargs): - super().__init__(logger=FakeLogger, **kwargs) + super().__init__(logger=FakeLogger(), **kwargs) return HandlerWrapper diff --git a/test/test_http_proxy.py b/test/test_http_proxy.py index 2435c878a..1d17a8d00 100644 --- a/test/test_http_proxy.py +++ b/test/test_http_proxy.py @@ -7,10 +7,12 @@ import random import ssl import threading +import time from http.server import BaseHTTPRequestHandler -from socketserver import ThreadingTCPServer +from socketserver import BaseRequestHandler, ThreadingTCPServer import pytest +import platform from test.helper import http_server_port, verify_address_availability from test.test_networking import TEST_DIR @@ -46,6 +48,11 @@ def do_proxy_auth(self, username, password): except Exception: return self.proxy_auth_error() + if auth_username == 'http_error': + self.send_response(404) + self.end_headers() + return False + if auth_username != (username or '') or auth_password != (password or ''): return self.proxy_auth_error() return True @@ -119,6 +126,16 @@ def _io_refs(self, value): def shutdown(self, *args, **kwargs): self.socket.shutdown(*args, **kwargs) + + def _wrap_ssl_read(self, *args, **kwargs): + res = super()._wrap_ssl_read(*args, **kwargs) + if res == 0: + # Websockets does not treat 0 as an EOF, rather only b'' + return b'' + return res + + def getsockname(self): + return self.socket.getsockname() else: SSLTransport = None @@ -128,7 +145,40 @@ def __init__(self, request, *args, **kwargs): certfn = os.path.join(TEST_DIR, 'testcert.pem') sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx.load_cert_chain(certfn, None) - if isinstance(request, ssl.SSLSocket): + if SSLTransport: + request = SSLTransport(request, ssl_context=sslctx, server_side=True) + else: + request = sslctx.wrap_socket(request, server_side=True) + super().__init__(request, *args, **kwargs) + + +class WebSocketProxyHandler(BaseRequestHandler): + def __init__(self, *args, proxy_info=None, **kwargs): + self.proxy_info = proxy_info + super().__init__(*args, **kwargs) + + def handle(self): + import websockets.sync.server + self.request.settimeout(None) + protocol = websockets.ServerProtocol() + connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=10) + try: + connection.handshake(timeout=5.0) + for message in connection: + if message == 'proxy_info': + connection.send(json.dumps(self.proxy_info)) + except Exception as e: + print(f'Error in websocket proxy: {e}') + finally: + connection.close(code=1001) + + +class WebSocketSecureProxyHandler(WebSocketProxyHandler): + def __init__(self, request, *args, **kwargs): + certfn = os.path.join(TEST_DIR, 'testcert.pem') + sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + sslctx.load_cert_chain(certfn, None) + if isinstance(request, ssl.SSLSocket) and SSLTransport: request = SSLTransport(request, ssl_context=sslctx, server_side=True) else: request = sslctx.wrap_socket(request, server_side=True) @@ -197,7 +247,7 @@ def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_serv finally: server.shutdown() server.server_close() - server_thread.join(2.0) + server_thread.join() class HTTPProxyTestContext(abc.ABC): @@ -205,7 +255,9 @@ class HTTPProxyTestContext(abc.ABC): REQUEST_PROTO = None def http_server(self, server_class, *args, **kwargs): - return proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs) + server = proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs) + time.sleep(1) + return server @abc.abstractmethod def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict: @@ -234,9 +286,30 @@ def proxy_info_request(self, handler, target_domain=None, target_port=None, **re return json.loads(handler.send(request).read().decode()) +class HTTPProxyWebSocketTestContext(HTTPProxyTestContext): + REQUEST_HANDLER_CLASS = WebSocketProxyHandler + REQUEST_PROTO = 'ws' + + def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs): + request = Request(f'{self.REQUEST_PROTO}://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs) + handler.validate(request) + ws = handler.send(request) + ws.send('proxy_info') + proxy_info = ws.recv() + ws.close() + return json.loads(proxy_info) + + +class HTTPProxyWebSocketSecureTestContext(HTTPProxyWebSocketTestContext): + REQUEST_HANDLER_CLASS = WebSocketSecureProxyHandler + REQUEST_PROTO = 'wss' + + CTX_MAP = { 'http': HTTPProxyHTTPTestContext, 'https': HTTPProxyHTTPSTestContext, + 'ws': HTTPProxyWebSocketTestContext, + 'wss': HTTPProxyWebSocketSecureTestContext, } @@ -272,6 +345,14 @@ def test_http_bad_auth(self, handler, ctx): assert exc_info.value.response.status == 407 exc_info.value.response.close() + def test_http_error(self, handler, ctx): + with ctx.http_server(HTTPProxyHandler, username='http_error', password='test') as server_address: + with handler(proxies={ctx.REQUEST_PROTO: f'http://http_error:test@{server_address}'}) as rh: + with pytest.raises(HTTPError) as exc_info: + ctx.proxy_info_request(rh) + assert exc_info.value.response.status == 404 + exc_info.value.response.close() + def test_http_source_address(self, handler, ctx): with ctx.http_server(HTTPProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' @@ -314,7 +395,13 @@ def test_http_with_idn(self, handler, ctx): 'handler,ctx', [ ('Requests', 'https'), ('CurlCFFI', 'https'), + ('Websockets', 'ws'), + ('Websockets', 'wss'), ], indirect=True) +@pytest.mark.skip_handler_if( + 'Websockets', lambda request: + platform.python_implementation() == 'PyPy', + 'Tests are flaky with PyPy, unknown reason') class TestHTTPConnectProxy: def test_http_connect_no_auth(self, handler, ctx): with ctx.http_server(HTTPConnectProxyHandler) as server_address: @@ -341,6 +428,16 @@ def test_http_connect_bad_auth(self, handler, ctx): with pytest.raises(ProxyError): ctx.proxy_info_request(rh) + @pytest.mark.skip_handler( + 'Requests', + 'bug in urllib3 causes unclosed socket: https://github.com/urllib3/urllib3/issues/3374', + ) + def test_http_connect_http_error(self, handler, ctx): + with ctx.http_server(HTTPConnectProxyHandler, username='http_error', password='test') as server_address: + with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://http_error:test@{server_address}'}) as rh: + with pytest.raises(ProxyError): + ctx.proxy_info_request(rh) + def test_http_connect_source_address(self, handler, ctx): with ctx.http_server(HTTPConnectProxyHandler) as server_address: source_address = f'127.0.0.{random.randint(5, 255)}' diff --git a/test/test_networking.py b/test/test_networking.py index 826f11a56..c12878a80 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -407,7 +407,7 @@ def test_percent_encode(self, handler): '/redirect_dotsegments_absolute', ]) def test_remove_dot_segments(self, handler, path): - with handler(verbose=True) as rh: + with handler() as rh: # This isn't a comprehensive test, # but it should be enough to check whether the handler is removing dot segments in required scenarios res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}{path}')) @@ -1224,8 +1224,8 @@ class HTTPSupportedRH(ValidationRH): ('socks5h', False), ]), ('Websockets', 'ws', [ - ('http', UnsupportedRequest), - ('https', UnsupportedRequest), + ('http', False), + ('https', False), ('socks4', False), ('socks4a', False), ('socks5', False), @@ -1318,8 +1318,8 @@ class HTTPSupportedRH(ValidationRH): ('Websockets', False, 'ws'), ], indirect=['handler']) def test_no_proxy(self, handler, fail, scheme): - run_validation(handler, fail, Request(f'{scheme}://', proxies={'no': '127.0.0.1,github.com'})) - run_validation(handler, fail, Request(f'{scheme}://'), proxies={'no': '127.0.0.1,github.com'}) + run_validation(handler, fail, Request(f'{scheme}://example.com', proxies={'no': '127.0.0.1,github.com'})) + run_validation(handler, fail, Request(f'{scheme}://example.com'), proxies={'no': '127.0.0.1,github.com'}) @pytest.mark.parametrize('handler,scheme', [ ('Urllib', 'http'), diff --git a/test/test_socks.py b/test/test_socks.py index 68af19d0c..f601fc8a5 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -216,7 +216,9 @@ def handle(self): protocol = websockets.ServerProtocol() connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) connection.handshake() - connection.send(json.dumps(self.socks_info)) + for message in connection: + if message == 'socks_info': + connection.send(json.dumps(self.socks_info)) connection.close() diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 9691a1ea7..61763bce6 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -4174,15 +4174,15 @@ def urlopen(self, req): 'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue if ( 'unsupported proxy type: "https"' in ue.msg.lower() - and 'requests' not in self._request_director.handlers - and 'curl_cffi' not in self._request_director.handlers + and 'Requests' not in self._request_director.handlers + and 'CurlCFFI' not in self._request_director.handlers ): raise RequestError( 'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests, curl_cffi') elif ( re.match(r'unsupported url scheme: "wss?"', ue.msg.lower()) - and 'websockets' not in self._request_director.handlers + and 'Websockets' not in self._request_director.handlers ): raise RequestError( 'This request requires WebSocket support. ' diff --git a/yt_dlp/networking/_helper.py b/yt_dlp/networking/_helper.py index fe3354ea2..8ee0aa460 100644 --- a/yt_dlp/networking/_helper.py +++ b/yt_dlp/networking/_helper.py @@ -1,5 +1,6 @@ from __future__ import annotations +import base64 import contextlib import functools import os @@ -9,8 +10,9 @@ import typing import urllib.parse import urllib.request +from http.client import HTTPConnection, HTTPResponse -from .exceptions import RequestError, UnsupportedRequest +from .exceptions import ProxyError, RequestError, UnsupportedRequest from ..dependencies import certifi from ..socks import ProxyType, sockssocket from ..utils import format_field, traverse_obj @@ -285,3 +287,65 @@ def create_connection( # Explicitly break __traceback__ reference cycle # https://bugs.python.org/issue36820 err = None + + +class NoCloseHTTPResponse(HTTPResponse): + def begin(self): + super().begin() + # Revert the default behavior of closing the connection after reading the response + if not self._check_close() and not self.chunked and self.length is None: + self.will_close = False + + +def create_http_connect_connection( + proxy_host, + proxy_port, + connect_host, + connect_port, + timeout=None, + ssl_context=None, + source_address=None, + username=None, + password=None, + debug=False, +): + + proxy_headers = dict() + + if username is not None or password is not None: + proxy_headers['Proxy-Authorization'] = 'Basic ' + base64.b64encode( + f'{username or ""}:{password or ""}'.encode()).decode('utf-8') + + conn = HTTPConnection(proxy_host, port=proxy_port, timeout=timeout) + conn.set_debuglevel(int(debug)) + + conn.response_class = NoCloseHTTPResponse + + if hasattr(conn, '_create_connection'): + conn._create_connection = create_connection + + if source_address is not None: + conn.source_address = (source_address, 0) + + try: + conn.connect() + if ssl_context: + conn.sock = ssl_context.wrap_socket(conn.sock, server_hostname=proxy_host) + conn.request( + method='CONNECT', + url=f'{connect_host}:{connect_port}', + headers=proxy_headers) + response = conn.getresponse() + except OSError as e: + conn.close() + raise ProxyError('Unable to connect to proxy', cause=e) from e + + if response.status == 200: + sock = conn.sock + conn.sock = None + response.fp = None + return sock + else: + conn.close() + response.close() + raise ProxyError(f'Got HTTP Error {response.status} with CONNECT: {response.reason}') diff --git a/yt_dlp/networking/_requests.py b/yt_dlp/networking/_requests.py index 7de95ab3b..fcb52451a 100644 --- a/yt_dlp/networking/_requests.py +++ b/yt_dlp/networking/_requests.py @@ -243,14 +243,14 @@ def __init__(self, logger, *args, **kwargs): def emit(self, record): try: msg = self.format(record) + except Exception: + self.handleError(record) + else: if record.levelno >= logging.ERROR: self._logger.error(msg) else: self._logger.stdout(msg) - except Exception: - self.handleError(record) - @register_rh class RequestsRH(RequestHandler, InstanceStoreMixin): diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index 492af1154..b2918af77 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -5,10 +5,11 @@ import io import logging import ssl -import sys +import urllib.parse from ._helper import ( create_connection, + create_http_connect_connection, create_socks_proxy_socket, make_socks_proxy_opts, select_proxy, @@ -21,9 +22,10 @@ RequestError, SSLError, TransportError, + UnsupportedRequest, ) from .websocket import WebSocketRequestHandler, WebSocketResponse -from ..dependencies import websockets +from ..dependencies import urllib3, websockets from ..socks import ProxyError as SocksProxyError from ..utils import int_or_none @@ -36,6 +38,20 @@ if websockets_version < (12, 0): raise ImportError('Only websockets>=12.0 is supported') +urllib3_supported = False +urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.')) if urllib3 else None +if urllib3_version and urllib3_version >= (1, 26, 17): + urllib3_supported = True + + +# Disable apply_mask C implementation +# Seems to help reduce "Fatal Python error: Aborted" in CI +with contextlib.suppress(Exception): + import websockets.frames + import websockets.legacy.framing + import websockets.utils + websockets.frames.apply_mask = websockets.legacy.framing = websockets.utils.apply_mask + import websockets.sync.client from websockets.uri import parse_uri @@ -53,6 +69,22 @@ websockets.sync.connection.Connection.recv_events_exc = None +class WebsocketsLoggingHandler(logging.Handler): + """Redirect websocket logs to our logger""" + + def __init__(self, logger, *args, **kwargs): + super().__init__(*args, **kwargs) + self._logger = logger + + def emit(self, record): + try: + msg = self.format(record) + except Exception: + self.handleError(record) + else: + self._logger.stdout(msg) + + class WebsocketsResponseAdapter(WebSocketResponse): def __init__(self, ws: websockets.sync.client.ClientConnection, url): @@ -98,7 +130,7 @@ class WebsocketsRH(WebSocketRequestHandler): https://github.com/python-websockets/websockets """ _SUPPORTED_URL_SCHEMES = ('wss', 'ws') - _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h') + _SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h', 'http', 'https') _SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY) RH_NAME = 'websockets' @@ -107,13 +139,24 @@ def __init__(self, *args, **kwargs): self.__logging_handlers = {} for name in ('websockets.client', 'websockets.server'): logger = logging.getLogger(name) - handler = logging.StreamHandler(stream=sys.stdout) - handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) + handler = WebsocketsLoggingHandler(logger=self._logger) + handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: [{name}] %(message)s')) self.__logging_handlers[name] = handler logger.addHandler(handler) if self.verbose: logger.setLevel(logging.DEBUG) + def _validate(self, request): + super()._validate(request) + proxy = select_proxy(request.url, self._get_proxies(request)) + if ( + proxy + and urllib.parse.urlparse(proxy).scheme.lower() == 'https' + and urllib.parse.urlparse(request.url).scheme.lower() == 'wss' + and not urllib3_supported + ): + raise UnsupportedRequest('WSS over HTTPS proxy requires a supported version of urllib3') + def _check_extensions(self, extensions): super()._check_extensions(extensions) extensions.pop('timeout', None) @@ -126,6 +169,41 @@ def close(self): for name, handler in self.__logging_handlers.items(): logging.getLogger(name).removeHandler(handler) + def _make_sock(self, proxy, url, timeout): + create_conn_kwargs = { + 'source_address': (self.source_address, 0) if self.source_address else None, + 'timeout': timeout, + } + parsed_url = parse_uri(url) + parsed_proxy_url = urllib.parse.urlparse(proxy) + if proxy: + if parsed_proxy_url.scheme.startswith('socks'): + socks_proxy_options = make_socks_proxy_opts(proxy) + return create_connection( + address=(socks_proxy_options['addr'], socks_proxy_options['port']), + _create_socket_func=functools.partial( + create_socks_proxy_socket, (parsed_url.host, parsed_url.port), socks_proxy_options), + **create_conn_kwargs, + ) + + elif parsed_proxy_url.scheme in ('http', 'https'): + return create_http_connect_connection( + proxy_port=parsed_proxy_url.port, + proxy_host=parsed_proxy_url.hostname, + connect_port=parsed_url.port, + connect_host=parsed_url.host, + timeout=timeout, + ssl_context=self._make_sslcontext() if parsed_proxy_url.scheme == 'https' else None, + source_address=self.source_address, + username=parsed_proxy_url.username, + password=parsed_proxy_url.password, + debug=self.verbose, + ) + return create_connection( + address=(parsed_url.host, parsed_url.port), + **create_conn_kwargs, + ) + def _send(self, request): timeout = self._calculate_timeout(request) headers = self._merge_headers(request.headers) @@ -135,35 +213,21 @@ def _send(self, request): if cookie_header: headers['cookie'] = cookie_header - wsuri = parse_uri(request.url) - create_conn_kwargs = { - 'source_address': (self.source_address, 0) if self.source_address else None, - 'timeout': timeout, - } proxy = select_proxy(request.url, self._get_proxies(request)) + try: - if proxy: - socks_proxy_options = make_socks_proxy_opts(proxy) - sock = create_connection( - address=(socks_proxy_options['addr'], socks_proxy_options['port']), - _create_socket_func=functools.partial( - create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options), - **create_conn_kwargs, - ) - else: - sock = create_connection( - address=(wsuri.host, wsuri.port), - **create_conn_kwargs, - ) - ssl_ctx = self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl')) + ssl_context = None + sock = self._make_sock(proxy, request.url, timeout) + if parse_uri(request.url).secure: + ssl_context = WebsocketsSSLContext(self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl'))) conn = websockets.sync.client.connect( sock=sock, uri=request.url, additional_headers=headers, open_timeout=timeout, user_agent_header=None, - ssl_context=ssl_ctx if wsuri.secure else None, - close_timeout=0, # not ideal, but prevents yt-dlp hanging + ssl_context=ssl_context, + close_timeout=0.1, # not ideal, but prevents yt-dlp hanging ) return WebsocketsResponseAdapter(conn, url=request.url) @@ -187,3 +251,43 @@ def _send(self, request): ) from e except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e: raise TransportError(cause=e) from e + + +if urllib3_supported: + from urllib3.util.ssltransport import SSLTransport + + class WebsocketsSSLTransport(SSLTransport): + """ + Modified version of urllib3 SSLTransport to support additional operations used by websockets + """ + + def setsockopt(self, *args, **kwargs): + self.socket.setsockopt(*args, **kwargs) + + def shutdown(self, *args, **kwargs): + self.unwrap() + self.socket.shutdown(*args, **kwargs) + + def _wrap_ssl_read(self, *args, **kwargs): + res = super()._wrap_ssl_read(*args, **kwargs) + if res == 0: + # Websockets does not treat 0 as an EOF, rather only b'' + return b'' + return res +else: + WebsocketsSSLTransport = None + + +class WebsocketsSSLContext: + """ + Dummy SSL Context for websockets which returns a WebsocketsSSLTransport instance + for wrap socket when using TLS-in-TLS. + """ + + def __init__(self, ssl_context: ssl.SSLContext): + self.ssl_context = ssl_context + + def wrap_socket(self, sock, server_hostname=None): + if isinstance(sock, ssl.SSLSocket) and WebsocketsSSLTransport: + return WebsocketsSSLTransport(sock, self.ssl_context, server_hostname=server_hostname) + return self.ssl_context.wrap_socket(sock, server_hostname=server_hostname) diff --git a/yt_dlp/networking/websocket.py b/yt_dlp/networking/websocket.py index 0e7e73c9e..d407cadad 100644 --- a/yt_dlp/networking/websocket.py +++ b/yt_dlp/networking/websocket.py @@ -1,8 +1,9 @@ from __future__ import annotations import abc +import urllib.parse -from .common import RequestHandler, Response +from .common import RequestHandler, Response, register_preference class WebSocketResponse(Response): @@ -21,3 +22,10 @@ def recv(self): class WebSocketRequestHandler(RequestHandler, abc.ABC): pass + + +@register_preference(WebSocketRequestHandler) +def websocket_preference(_, request): + if urllib.parse.urlparse(request.url).scheme in ('ws', 'wss'): + return 200 + return 0