This commit is contained in:
coletdjnz 2024-07-27 18:52:48 +12:00 committed by GitHub
commit f1fdb20476
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 320 additions and 45 deletions

View File

@ -23,7 +23,7 @@ class HandlerWrapper(handler):
RH_KEY = handler.RH_KEY
def __init__(self, **kwargs):
super().__init__(logger=FakeLogger, **kwargs)
super().__init__(logger=FakeLogger(), **kwargs)
return HandlerWrapper

View File

@ -7,10 +7,12 @@
import random
import ssl
import threading
import time
from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer
from socketserver import BaseRequestHandler, ThreadingTCPServer
import pytest
import platform
from test.helper import http_server_port, verify_address_availability
from test.test_networking import TEST_DIR
@ -46,6 +48,11 @@ def do_proxy_auth(self, username, password):
except Exception:
return self.proxy_auth_error()
if auth_username == 'http_error':
self.send_response(404)
self.end_headers()
return False
if auth_username != (username or '') or auth_password != (password or ''):
return self.proxy_auth_error()
return True
@ -119,6 +126,16 @@ def _io_refs(self, value):
def shutdown(self, *args, **kwargs):
self.socket.shutdown(*args, **kwargs)
def _wrap_ssl_read(self, *args, **kwargs):
res = super()._wrap_ssl_read(*args, **kwargs)
if res == 0:
# Websockets does not treat 0 as an EOF, rather only b''
return b''
return res
def getsockname(self):
return self.socket.getsockname()
else:
SSLTransport = None
@ -128,7 +145,40 @@ def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
if isinstance(request, ssl.SSLSocket):
if SSLTransport:
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else:
request = sslctx.wrap_socket(request, server_side=True)
super().__init__(request, *args, **kwargs)
class WebSocketProxyHandler(BaseRequestHandler):
def __init__(self, *args, proxy_info=None, **kwargs):
self.proxy_info = proxy_info
super().__init__(*args, **kwargs)
def handle(self):
import websockets.sync.server
self.request.settimeout(None)
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=10)
try:
connection.handshake(timeout=5.0)
for message in connection:
if message == 'proxy_info':
connection.send(json.dumps(self.proxy_info))
except Exception as e:
print(f'Error in websocket proxy: {e}')
finally:
connection.close(code=1001)
class WebSocketSecureProxyHandler(WebSocketProxyHandler):
def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None)
if isinstance(request, ssl.SSLSocket) and SSLTransport:
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else:
request = sslctx.wrap_socket(request, server_side=True)
@ -197,7 +247,7 @@ def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_serv
finally:
server.shutdown()
server.server_close()
server_thread.join(2.0)
server_thread.join()
class HTTPProxyTestContext(abc.ABC):
@ -205,7 +255,9 @@ class HTTPProxyTestContext(abc.ABC):
REQUEST_PROTO = None
def http_server(self, server_class, *args, **kwargs):
return proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
server = proxy_server(server_class, self.REQUEST_HANDLER_CLASS, *args, **kwargs)
time.sleep(1)
return server
@abc.abstractmethod
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs) -> dict:
@ -234,9 +286,30 @@ def proxy_info_request(self, handler, target_domain=None, target_port=None, **re
return json.loads(handler.send(request).read().decode())
class HTTPProxyWebSocketTestContext(HTTPProxyTestContext):
REQUEST_HANDLER_CLASS = WebSocketProxyHandler
REQUEST_PROTO = 'ws'
def proxy_info_request(self, handler, target_domain=None, target_port=None, **req_kwargs):
request = Request(f'{self.REQUEST_PROTO}://{target_domain or "127.0.0.1"}:{target_port or "40000"}', **req_kwargs)
handler.validate(request)
ws = handler.send(request)
ws.send('proxy_info')
proxy_info = ws.recv()
ws.close()
return json.loads(proxy_info)
class HTTPProxyWebSocketSecureTestContext(HTTPProxyWebSocketTestContext):
REQUEST_HANDLER_CLASS = WebSocketSecureProxyHandler
REQUEST_PROTO = 'wss'
CTX_MAP = {
'http': HTTPProxyHTTPTestContext,
'https': HTTPProxyHTTPSTestContext,
'ws': HTTPProxyWebSocketTestContext,
'wss': HTTPProxyWebSocketSecureTestContext,
}
@ -272,6 +345,14 @@ def test_http_bad_auth(self, handler, ctx):
assert exc_info.value.response.status == 407
exc_info.value.response.close()
def test_http_error(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler, username='http_error', password='test') as server_address:
with handler(proxies={ctx.REQUEST_PROTO: f'http://http_error:test@{server_address}'}) as rh:
with pytest.raises(HTTPError) as exc_info:
ctx.proxy_info_request(rh)
assert exc_info.value.response.status == 404
exc_info.value.response.close()
def test_http_source_address(self, handler, ctx):
with ctx.http_server(HTTPProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
@ -314,7 +395,13 @@ def test_http_with_idn(self, handler, ctx):
'handler,ctx', [
('Requests', 'https'),
('CurlCFFI', 'https'),
('Websockets', 'ws'),
('Websockets', 'wss'),
], indirect=True)
@pytest.mark.skip_handler_if(
'Websockets', lambda request:
platform.python_implementation() == 'PyPy',
'Tests are flaky with PyPy, unknown reason')
class TestHTTPConnectProxy:
def test_http_connect_no_auth(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler) as server_address:
@ -341,6 +428,16 @@ def test_http_connect_bad_auth(self, handler, ctx):
with pytest.raises(ProxyError):
ctx.proxy_info_request(rh)
@pytest.mark.skip_handler(
'Requests',
'bug in urllib3 causes unclosed socket: https://github.com/urllib3/urllib3/issues/3374',
)
def test_http_connect_http_error(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler, username='http_error', password='test') as server_address:
with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://http_error:test@{server_address}'}) as rh:
with pytest.raises(ProxyError):
ctx.proxy_info_request(rh)
def test_http_connect_source_address(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'

View File

@ -407,7 +407,7 @@ def test_percent_encode(self, handler):
'/redirect_dotsegments_absolute',
])
def test_remove_dot_segments(self, handler, path):
with handler(verbose=True) as rh:
with handler() as rh:
# This isn't a comprehensive test,
# but it should be enough to check whether the handler is removing dot segments in required scenarios
res = validate_and_send(rh, Request(f'http://127.0.0.1:{self.http_port}{path}'))
@ -1224,8 +1224,8 @@ class HTTPSupportedRH(ValidationRH):
('socks5h', False),
]),
('Websockets', 'ws', [
('http', UnsupportedRequest),
('https', UnsupportedRequest),
('http', False),
('https', False),
('socks4', False),
('socks4a', False),
('socks5', False),
@ -1318,8 +1318,8 @@ class HTTPSupportedRH(ValidationRH):
('Websockets', False, 'ws'),
], indirect=['handler'])
def test_no_proxy(self, handler, fail, scheme):
run_validation(handler, fail, Request(f'{scheme}://', proxies={'no': '127.0.0.1,github.com'}))
run_validation(handler, fail, Request(f'{scheme}://'), proxies={'no': '127.0.0.1,github.com'})
run_validation(handler, fail, Request(f'{scheme}://example.com', proxies={'no': '127.0.0.1,github.com'}))
run_validation(handler, fail, Request(f'{scheme}://example.com'), proxies={'no': '127.0.0.1,github.com'})
@pytest.mark.parametrize('handler,scheme', [
('Urllib', 'http'),

View File

@ -216,6 +216,8 @@ def handle(self):
protocol = websockets.ServerProtocol()
connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0)
connection.handshake()
for message in connection:
if message == 'socks_info':
connection.send(json.dumps(self.socks_info))
connection.close()

View File

@ -4174,15 +4174,15 @@ def urlopen(self, req):
'Use --enable-file-urls to enable at your own risk.', cause=ue) from ue
if (
'unsupported proxy type: "https"' in ue.msg.lower()
and 'requests' not in self._request_director.handlers
and 'curl_cffi' not in self._request_director.handlers
and 'Requests' not in self._request_director.handlers
and 'CurlCFFI' not in self._request_director.handlers
):
raise RequestError(
'To use an HTTPS proxy for this request, one of the following dependencies needs to be installed: requests, curl_cffi')
elif (
re.match(r'unsupported url scheme: "wss?"', ue.msg.lower())
and 'websockets' not in self._request_director.handlers
and 'Websockets' not in self._request_director.handlers
):
raise RequestError(
'This request requires WebSocket support. '

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import base64
import contextlib
import functools
import os
@ -9,8 +10,9 @@
import typing
import urllib.parse
import urllib.request
from http.client import HTTPConnection, HTTPResponse
from .exceptions import RequestError, UnsupportedRequest
from .exceptions import ProxyError, RequestError, UnsupportedRequest
from ..dependencies import certifi
from ..socks import ProxyType, sockssocket
from ..utils import format_field, traverse_obj
@ -285,3 +287,65 @@ def create_connection(
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None
class NoCloseHTTPResponse(HTTPResponse):
def begin(self):
super().begin()
# Revert the default behavior of closing the connection after reading the response
if not self._check_close() and not self.chunked and self.length is None:
self.will_close = False
def create_http_connect_connection(
proxy_host,
proxy_port,
connect_host,
connect_port,
timeout=None,
ssl_context=None,
source_address=None,
username=None,
password=None,
debug=False,
):
proxy_headers = dict()
if username is not None or password is not None:
proxy_headers['Proxy-Authorization'] = 'Basic ' + base64.b64encode(
f'{username or ""}:{password or ""}'.encode()).decode('utf-8')
conn = HTTPConnection(proxy_host, port=proxy_port, timeout=timeout)
conn.set_debuglevel(int(debug))
conn.response_class = NoCloseHTTPResponse
if hasattr(conn, '_create_connection'):
conn._create_connection = create_connection
if source_address is not None:
conn.source_address = (source_address, 0)
try:
conn.connect()
if ssl_context:
conn.sock = ssl_context.wrap_socket(conn.sock, server_hostname=proxy_host)
conn.request(
method='CONNECT',
url=f'{connect_host}:{connect_port}',
headers=proxy_headers)
response = conn.getresponse()
except OSError as e:
conn.close()
raise ProxyError('Unable to connect to proxy', cause=e) from e
if response.status == 200:
sock = conn.sock
conn.sock = None
response.fp = None
return sock
else:
conn.close()
response.close()
raise ProxyError(f'Got HTTP Error {response.status} with CONNECT: {response.reason}')

View File

@ -243,14 +243,14 @@ def __init__(self, logger, *args, **kwargs):
def emit(self, record):
try:
msg = self.format(record)
except Exception:
self.handleError(record)
else:
if record.levelno >= logging.ERROR:
self._logger.error(msg)
else:
self._logger.stdout(msg)
except Exception:
self.handleError(record)
@register_rh
class RequestsRH(RequestHandler, InstanceStoreMixin):

View File

@ -5,10 +5,11 @@
import io
import logging
import ssl
import sys
import urllib.parse
from ._helper import (
create_connection,
create_http_connect_connection,
create_socks_proxy_socket,
make_socks_proxy_opts,
select_proxy,
@ -21,9 +22,10 @@
RequestError,
SSLError,
TransportError,
UnsupportedRequest,
)
from .websocket import WebSocketRequestHandler, WebSocketResponse
from ..dependencies import websockets
from ..dependencies import urllib3, websockets
from ..socks import ProxyError as SocksProxyError
from ..utils import int_or_none
@ -36,6 +38,20 @@
if websockets_version < (12, 0):
raise ImportError('Only websockets>=12.0 is supported')
urllib3_supported = False
urllib3_version = tuple(int_or_none(x, default=0) for x in urllib3.__version__.split('.')) if urllib3 else None
if urllib3_version and urllib3_version >= (1, 26, 17):
urllib3_supported = True
# Disable apply_mask C implementation
# Seems to help reduce "Fatal Python error: Aborted" in CI
with contextlib.suppress(Exception):
import websockets.frames
import websockets.legacy.framing
import websockets.utils
websockets.frames.apply_mask = websockets.legacy.framing = websockets.utils.apply_mask
import websockets.sync.client
from websockets.uri import parse_uri
@ -53,6 +69,22 @@
websockets.sync.connection.Connection.recv_events_exc = None
class WebsocketsLoggingHandler(logging.Handler):
"""Redirect websocket logs to our logger"""
def __init__(self, logger, *args, **kwargs):
super().__init__(*args, **kwargs)
self._logger = logger
def emit(self, record):
try:
msg = self.format(record)
except Exception:
self.handleError(record)
else:
self._logger.stdout(msg)
class WebsocketsResponseAdapter(WebSocketResponse):
def __init__(self, ws: websockets.sync.client.ClientConnection, url):
@ -98,7 +130,7 @@ class WebsocketsRH(WebSocketRequestHandler):
https://github.com/python-websockets/websockets
"""
_SUPPORTED_URL_SCHEMES = ('wss', 'ws')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h')
_SUPPORTED_PROXY_SCHEMES = ('socks4', 'socks4a', 'socks5', 'socks5h', 'http', 'https')
_SUPPORTED_FEATURES = (Features.ALL_PROXY, Features.NO_PROXY)
RH_NAME = 'websockets'
@ -107,13 +139,24 @@ def __init__(self, *args, **kwargs):
self.__logging_handlers = {}
for name in ('websockets.client', 'websockets.server'):
logger = logging.getLogger(name)
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s'))
handler = WebsocketsLoggingHandler(logger=self._logger)
handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: [{name}] %(message)s'))
self.__logging_handlers[name] = handler
logger.addHandler(handler)
if self.verbose:
logger.setLevel(logging.DEBUG)
def _validate(self, request):
super()._validate(request)
proxy = select_proxy(request.url, self._get_proxies(request))
if (
proxy
and urllib.parse.urlparse(proxy).scheme.lower() == 'https'
and urllib.parse.urlparse(request.url).scheme.lower() == 'wss'
and not urllib3_supported
):
raise UnsupportedRequest('WSS over HTTPS proxy requires a supported version of urllib3')
def _check_extensions(self, extensions):
super()._check_extensions(extensions)
extensions.pop('timeout', None)
@ -126,6 +169,41 @@ def close(self):
for name, handler in self.__logging_handlers.items():
logging.getLogger(name).removeHandler(handler)
def _make_sock(self, proxy, url, timeout):
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout,
}
parsed_url = parse_uri(url)
parsed_proxy_url = urllib.parse.urlparse(proxy)
if proxy:
if parsed_proxy_url.scheme.startswith('socks'):
socks_proxy_options = make_socks_proxy_opts(proxy)
return create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (parsed_url.host, parsed_url.port), socks_proxy_options),
**create_conn_kwargs,
)
elif parsed_proxy_url.scheme in ('http', 'https'):
return create_http_connect_connection(
proxy_port=parsed_proxy_url.port,
proxy_host=parsed_proxy_url.hostname,
connect_port=parsed_url.port,
connect_host=parsed_url.host,
timeout=timeout,
ssl_context=self._make_sslcontext() if parsed_proxy_url.scheme == 'https' else None,
source_address=self.source_address,
username=parsed_proxy_url.username,
password=parsed_proxy_url.password,
debug=self.verbose,
)
return create_connection(
address=(parsed_url.host, parsed_url.port),
**create_conn_kwargs,
)
def _send(self, request):
timeout = self._calculate_timeout(request)
headers = self._merge_headers(request.headers)
@ -135,35 +213,21 @@ def _send(self, request):
if cookie_header:
headers['cookie'] = cookie_header
wsuri = parse_uri(request.url)
create_conn_kwargs = {
'source_address': (self.source_address, 0) if self.source_address else None,
'timeout': timeout,
}
proxy = select_proxy(request.url, self._get_proxies(request))
try:
if proxy:
socks_proxy_options = make_socks_proxy_opts(proxy)
sock = create_connection(
address=(socks_proxy_options['addr'], socks_proxy_options['port']),
_create_socket_func=functools.partial(
create_socks_proxy_socket, (wsuri.host, wsuri.port), socks_proxy_options),
**create_conn_kwargs,
)
else:
sock = create_connection(
address=(wsuri.host, wsuri.port),
**create_conn_kwargs,
)
ssl_ctx = self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl'))
ssl_context = None
sock = self._make_sock(proxy, request.url, timeout)
if parse_uri(request.url).secure:
ssl_context = WebsocketsSSLContext(self._make_sslcontext(legacy_ssl_support=request.extensions.get('legacy_ssl')))
conn = websockets.sync.client.connect(
sock=sock,
uri=request.url,
additional_headers=headers,
open_timeout=timeout,
user_agent_header=None,
ssl_context=ssl_ctx if wsuri.secure else None,
close_timeout=0, # not ideal, but prevents yt-dlp hanging
ssl_context=ssl_context,
close_timeout=0.1, # not ideal, but prevents yt-dlp hanging
)
return WebsocketsResponseAdapter(conn, url=request.url)
@ -187,3 +251,43 @@ def _send(self, request):
) from e
except (OSError, TimeoutError, websockets.exceptions.WebSocketException) as e:
raise TransportError(cause=e) from e
if urllib3_supported:
from urllib3.util.ssltransport import SSLTransport
class WebsocketsSSLTransport(SSLTransport):
"""
Modified version of urllib3 SSLTransport to support additional operations used by websockets
"""
def setsockopt(self, *args, **kwargs):
self.socket.setsockopt(*args, **kwargs)
def shutdown(self, *args, **kwargs):
self.unwrap()
self.socket.shutdown(*args, **kwargs)
def _wrap_ssl_read(self, *args, **kwargs):
res = super()._wrap_ssl_read(*args, **kwargs)
if res == 0:
# Websockets does not treat 0 as an EOF, rather only b''
return b''
return res
else:
WebsocketsSSLTransport = None
class WebsocketsSSLContext:
"""
Dummy SSL Context for websockets which returns a WebsocketsSSLTransport instance
for wrap socket when using TLS-in-TLS.
"""
def __init__(self, ssl_context: ssl.SSLContext):
self.ssl_context = ssl_context
def wrap_socket(self, sock, server_hostname=None):
if isinstance(sock, ssl.SSLSocket) and WebsocketsSSLTransport:
return WebsocketsSSLTransport(sock, self.ssl_context, server_hostname=server_hostname)
return self.ssl_context.wrap_socket(sock, server_hostname=server_hostname)

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import abc
import urllib.parse
from .common import RequestHandler, Response
from .common import RequestHandler, Response, register_preference
class WebSocketResponse(Response):
@ -21,3 +22,10 @@ def recv(self):
class WebSocketRequestHandler(RequestHandler, abc.ABC):
pass
@register_preference(WebSocketRequestHandler)
def websocket_preference(_, request):
if urllib.parse.urlparse(request.url).scheme in ('ws', 'wss'):
return 200
return 0