This commit is contained in:
coletdjnz 2024-05-16 11:59:57 +00:00 committed by GitHub
commit fa546b46bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 231 additions and 33 deletions

View File

@ -8,7 +8,7 @@ import random
import ssl
import threading
from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer
from socketserver import BaseRequestHandler, ThreadingTCPServer
import pytest
@ -134,6 +134,34 @@ class HTTPSProxyHandler(HTTPProxyHandler):
super().__init__(request, *args, **kwargs)
class WebSocketProxyHandler(BaseRequestHandler):
def __init__(self, *args, proxy_info=None, **kwargs):
self.proxy_info = proxy_info
super().__init__(*args, **kwargs)
def handle(self):
import websockets.sync.server
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
for message in connection:
if message == 'proxy_info':
connection.send(json.dumps(self.proxy_info))
connection.close()
class WebSocketSecureProxyHandler(WebSocketProxyHandler):
def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
if SSLTransport:
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else:
request = sslctx.wrap_socket(request, server_side=True)
super().__init__(request, *args, **kwargs)
class HTTPConnectProxyHandler(BaseHTTPRequestHandler, HTTPProxyAuthMixin):
protocol_version = 'HTTP/1.1'
default_request_version = 'HTTP/1.1'
@ -233,9 +261,30 @@ class HTTPProxyHTTPSTestContext(HTTPProxyTestContext):
return json.loads(handler.send(request).read().decode())
class HTTPProxyWebSocketTestContext(HTTPProxyTestContext):
REQUEST_HANDLER_CLASS = WebSocketProxyHandler
REQUEST_PROTO = 'ws'
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'{self.REQUEST_PROTO}://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
handler.validate(request)
ws = handler.send(request)
ws.send('proxy_info')
socks_info = ws.recv()
ws.close()
return json.loads(socks_info)
class HTTPProxyWebSocketSecureTestContext(HTTPProxyWebSocketTestContext):
REQUEST_HANDLER_CLASS = WebSocketSecureProxyHandler
REQUEST_PROTO = 'wss'
CTX_MAP = {
'http': HTTPProxyHTTPTestContext,
'https': HTTPProxyHTTPSTestContext,
'ws': HTTPProxyWebSocketTestContext,
'wss': HTTPProxyWebSocketSecureTestContext,
}
@ -313,6 +362,8 @@ class TestHTTPProxy:
'handler,ctx', [
('Requests', 'https'),
('CurlCFFI', 'https'),
('Websockets', 'ws'),
('Websockets', 'wss')
], indirect=True)
class TestHTTPConnectProxy:
def test_http_connect_no_auth(self, handler, ctx):

View File

@ -1159,8 +1159,8 @@ class TestRequestHandlerValidation:
('socks5h', False),
]),
('Websockets', 'ws', [
('http', UnsupportedRequest),
('https', UnsupportedRequest),
('http', False),
('https', False),
('socks4', False),
('socks4a', False),
('socks5', False),
@ -1241,8 +1241,8 @@ class TestRequestHandlerValidation:
('Websockets', False, 'ws')
], indirect=['handler'])
def test_no_proxy(self, handler, fail, scheme):
run_validation(handler, fail, Request(f'{scheme}://', proxies={'no': '127.0.0.1,github.com'}))
run_validation(handler, fail, Request(f'{scheme}://'), proxies={'no': '127.0.0.1,github.com'})
run_validation(handler, fail, Request(f'{scheme}://example.com', proxies={'no': '127.0.0.1,github.com'}))
run_validation(handler, fail, Request(f'{scheme}://example.com'), proxies={'no': '127.0.0.1,github.com'})
@pytest.mark.parametrize('handler,scheme', [
('Urllib', 'http'),

View File

@ -216,7 +216,9 @@ class SocksWebSocketTestRequestHandler(SocksTestRequestHandler):
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
connection.send(json.dumps(self.socks_info))
for message in connection:
if message == 'socks_info':
connection.send(json.dumps(self.socks_info))
connection.close()

View File

@ -4151,15 +4151,15 @@ class YoutubeDL:
'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue
if (
'unsupported proxy type: "https"' in ue.msg.lower()
and 'requests' not in self._request_director.handlers
and 'curl_cffi' not in self._request_director.handlers
and 'Requests' not in self._request_director.handlers
and 'CurlCFFI' not in self._request_director.handlers
):
raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests, curl_cffi')
elif (
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
and 'websockets' not in self._request_director.handlers
and 'Websockets' not in self._request_director.handlers
):
raise RequestError(
'This request requires WebSocket support. '

View File

@ -1,10 +1,13 @@
from __future__ import annotations
import base64
import contextlib
import io
import logging
import ssl
import sys
import urllib.parse
from http.client import HTTPConnection, HTTPResponse
from ._helper import (
create_connection,
@ -20,12 +23,14 @@ from .exceptions import (
RequestError,
SSLError,
TransportError,
UnsupportedRequest,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..compat import functools
from ..dependencies import websockets
from ..dependencies import urllib3, websockets
from ..socks import ProxyError as SocksProxyError
from ..utils import int_or_none
from ..utils.networking import HTTPHeaderDict
if not websockets:
raise ImportError('websockets is not installed')
@ -36,6 +41,11 @@ websockets_version = tuple(map(int_or_none, websockets.version.version.split('.'
if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported')
urllib3_supported = False
urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.')) if urllib3 else None
if urllib3_version and urllib3_version >= (1, 26, 17):
urllib3_supported = True
import websockets.sync.client
from websockets.uri import parse_uri
@ -98,7 +108,7 @@ class WebsocketsRH(WebSocketRequestHandler):
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h', 'http', 'https')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets'
@ -108,12 +118,23 @@ class WebsocketsRH(WebSocketRequestHandler):
for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: [{name}] %(message)s'))
self.__logging_handlers[name] = handler
logger.addHandler(handler)
if self.verbose:
logger.setLevel(logging.DEBUG)
def _validate(self, request):
super()._validate(request)
proxy = select_proxy(request.url, self._get_proxies(request))
if (
proxy
and urllib.parse.urlparse(proxy).scheme.lower() == 'https'
and urllib.parse.urlparse(request.url).scheme.lower() == 'wss'
and not urllib3_supported
):
raise UnsupportedRequest('WSS over HTTPS proxies requires a supported version of urllib3')
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('timeout', None)
@ -125,6 +146,38 @@ class WebsocketsRH(WebSocketRequestHandler):
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
def _make_sock(self, proxy, url, timeout):
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout
}
parsed_url = parse_uri(url)
parsed_proxy_url = urllib.parse.urlparse(proxy)
if proxy:
if parsed_proxy_url.scheme.startswith('socks'):
socks_proxy_options = make_socks_proxy_opts(proxy)
return create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (parsed_url.host, parsed_url.port), socks_proxy_options),
**create_conn_kwargs
)
elif parsed_proxy_url.scheme in ('http', 'https'):
return create_http_connect_conn(
proxy_url=proxy,
url=url,
timeout=timeout,
ssl_context=self._make_sslcontext() if parsed_proxy_url.scheme == 'https' else None,
source_address=self.source_address,
username=parsed_proxy_url.username,
password=parsed_proxy_url.password,
)
return create_connection(
address=(parsed_url.host, parsed_url.port),
**create_conn_kwargs
)
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._merge_headers(request.headers)
@ -134,33 +187,22 @@ class WebsocketsRH(WebSocketRequestHandler):
if cookie_header:
headers['cookie'] = cookie_header
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout
}
proxy = select_proxy(request.url, self._get_proxies(request))
try:
if proxy:
socks_proxy_options = make_socks_proxy_opts(proxy)
sock = create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs
)
ssl_context = None
if parse_uri(request.url).secure:
if WebsocketsSSLContext is not None:
ssl_context = WebsocketsSSLContext(self._make_sslcontext())
else:
sock = create_connection(
address=(wsuri.host, wsuri.port),
**create_conn_kwargs
)
ssl_context = self._make_sslcontext()
try:
conn = websockets.sync.client.connect(
sock=sock,
sock=self._make_sock(proxy, request.url, timeout),
uri=request.url,
additional_headers=headers,
open_timeout=timeout,
user_agent_header=None,
ssl_context=self._make_sslcontext() if wsuri.secure else None,
ssl_context=ssl_context,
close_timeout=0, # not ideal, but prevents yt-dlp hanging
)
return WebsocketsResponseAdapter(conn, url=request.url)
@ -185,3 +227,98 @@ class WebsocketsRH(WebSocketRequestHandler):
) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
raise TransportError(cause=e) from e
class NoCloseHTTPResponse(HTTPResponse):
def begin(self):
super().begin()
# Revert the default behavior of closing the connection after reading the response
if not self._check_close() and not self.chunked and self.length is None:
self.will_close = False
if urllib3_supported:
from urllib3.util.ssltransport import SSLTransport
class WebsocketsSSLTransport(SSLTransport):
"""
Modified version of urllib3 SSLTransport to support additional operations used by websockets
"""
def setsockopt(self, *args, **kwargs):
self.socket.setsockopt(*args, **kwargs)
def shutdown(self, *args, **kwargs):
self.unwrap()
self.socket.shutdown(*args, **kwargs)
else:
WebsocketsSSLTransport = None
class WebsocketsSSLContext:
"""
Dummy SSL Context for websockets which returns a WebsocketsSSLTransport instance
for wrap socket when using TLS-in-TLS.
"""
def __init__(self, ssl_context: ssl.SSLContext):
self.ssl_context = ssl_context
def wrap_socket(self, sock, server_hostname=None):
if isinstance(sock, ssl.SSLSocket):
return WebsocketsSSLTransport(sock, self.ssl_context, server_hostname=server_hostname)
return self.ssl_context.wrap_socket(sock, server_hostname=server_hostname)
def create_http_connect_conn(
proxy_url,
url,
timeout=None,
ssl_context=None,
source_address=None,
username=None,
password=None,
):
proxy_headers = HTTPHeaderDict()
if username is not None or password is not None:
proxy_headers['Proxy-Authorization'] = 'Basic ' + base64.b64encode(
f'{username or ""}:{password or ""}'.encode('utf-8')).decode('utf-8')
proxy_url_parsed = urllib.parse.urlparse(proxy_url)
request_url_parsed = parse_uri(url)
conn = HTTPConnection(proxy_url_parsed.hostname, port=proxy_url_parsed.port, timeout=timeout)
conn.response_class = NoCloseHTTPResponse
if hasattr(conn, '_create_connection'):
conn._create_connection = create_connection
if source_address is not None:
conn.source_address = (source_address, 0)
try:
conn.connect()
if ssl_context:
conn.sock = ssl_context.wrap_socket(conn.sock, server_hostname=proxy_url_parsed.hostname)
conn.request(
method='CONNECT',
url=f'{request_url_parsed.host}:{request_url_parsed.port}',
headers=proxy_headers)
response = conn.getresponse()
except OSError as e:
conn.close()
raise ProxyError('Unable to connect to proxy', cause=e) from e
if response.status == 200:
return conn.sock
elif response.status == 407:
conn.close()
raise ProxyError('Got HTTP Error 407 with CONNECT: Proxy Authentication Required')
else:
conn.close()
res_adapter = Response(
fp=io.BytesIO(b''),
url=proxy_url, headers=response.headers,
status=response.status,
reason=response.reason)
raise HTTPError(response=res_adapter)

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import abc
import urllib.parse
from .common import RequestHandler, Response
from .common import RequestHandler, Response, register_preference
class WebSocketResponse(Response):
@ -21,3 +22,10 @@ class WebSocketResponse(Response):
class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass
@register_preference(WebSocketRequestHandler)
def websocket_preference(_, request):
if urllib.parse.urlparse(request.url).scheme in ('ws', 'wss'):
return 200
return 0