From 3999a510f7c77fab1cb785964c665acacf1a061a Mon Sep 17 00:00:00 2001 From: coletdjnz Date: Sat, 6 Apr 2024 15:14:59 +1300 Subject: [PATCH] Working websockets HTTP/S proxy --- test/test_http_proxy.py | 29 ++++++++++++++------ test/test_socks.py | 4 ++- yt_dlp/networking/_websockets.py | 45 +++++++++++++++++++++++++------- 3 files changed, 60 insertions(+), 18 deletions(-) diff --git a/test/test_http_proxy.py b/test/test_http_proxy.py index c72edc472..34dee4ab5 100644 --- a/test/test_http_proxy.py +++ b/test/test_http_proxy.py @@ -116,6 +116,9 @@ def _io_refs(self): @_io_refs.setter def _io_refs(self, value): self.socket._io_refs = value + + def shutdown(self, *args, **kwargs): + self.socket.shutdown(*args, **kwargs) else: SSLTransport = None @@ -142,13 +145,14 @@ def handle(self): protocol = websockets.ServerProtocol() connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) connection.handshake() - connection.send(json.dumps(self.proxy_info)) + for message in connection: + if message == 'proxy_info': + connection.send(json.dumps(self.proxy_info)) connection.close() class WebSocketSecureProxyHandler(WebSocketProxyHandler): - def __init__(self, request, *args, proxy_info=None, **kwargs): - self.proxy_info = proxy_info + def __init__(self, request, *args, **kwargs): certfn = os.path.join(TEST_DIR, 'testcert.pem') sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx.load_cert_chain(certfn, None) @@ -218,7 +222,7 @@ def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_serv finally: server.shutdown() server.server_close() - server_thread.join(2.0) + server_thread.join() class HTTPProxyTestContext(abc.ABC): @@ -297,6 +301,7 @@ def test_http_no_auth(self, handler, ctx): proxy_info = ctx.proxy_info_request(rh) assert proxy_info['connect'] is False assert 'Proxy-Authorization' not in proxy_info['headers'] + assert proxy_info['proxy'] == server_address def test_http_auth(self, handler, ctx): with ctx.http_server(HTTPProxyHandler, username='test', password='test') as server_address: @@ -318,8 +323,9 @@ def test_http_source_address(self, handler, ctx): verify_address_availability(source_address) with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}, source_address=source_address) as rh: - response = ctx.proxy_info_request(rh) - assert response['client_address'][0] == source_address + proxy_info = ctx.proxy_info_request(rh) + assert proxy_info['client_address'][0] == source_address + assert proxy_info['proxy'] == server_address @pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies') def test_https(self, handler, ctx): @@ -328,6 +334,7 @@ def test_https(self, handler, ctx): proxy_info = ctx.proxy_info_request(rh) assert proxy_info['connect'] is False assert 'Proxy-Authorization' not in proxy_info['headers'] + assert proxy_info['proxy'] == server_address @pytest.mark.skip_handler('Urllib', 'urllib does not support https proxies') def test_https_verify_failed(self, handler, ctx): @@ -345,6 +352,7 @@ def test_http_with_idn(self, handler, ctx): proxy_info = ctx.proxy_info_request(rh, target_domain='中文.tw') assert proxy_info['path'].startswith('http://xn--fiq228c.tw') assert proxy_info['headers']['Host'].split(':', 1)[0] == 'xn--fiq228c.tw' + assert proxy_info['proxy'] == server_address @pytest.mark.parametrize( @@ -361,12 +369,14 @@ def test_http_connect_no_auth(self, handler, ctx): proxy_info = ctx.proxy_info_request(rh) assert proxy_info['connect'] is True assert 'Proxy-Authorization' not in proxy_info['headers'] + assert proxy_info['proxy'] == server_address def test_http_connect_auth(self, handler, ctx): with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address: with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'http://test:test@{server_address}'}) as rh: proxy_info = ctx.proxy_info_request(rh) assert 'Proxy-Authorization' in proxy_info['headers'] + assert proxy_info['proxy'] == server_address def test_http_connect_bad_auth(self, handler, ctx): with ctx.http_server(HTTPConnectProxyHandler, username='test', password='test') as server_address: @@ -381,8 +391,9 @@ def test_http_connect_source_address(self, handler, ctx): with handler(proxies={ctx.REQUEST_PROTO: f'http://{server_address}'}, source_address=source_address, verify=False) as rh: - response = ctx.proxy_info_request(rh) - assert response['client_address'][0] == source_address + proxy_info = ctx.proxy_info_request(rh) + assert proxy_info['client_address'][0] == source_address + assert proxy_info['proxy'] == server_address @pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test') def test_https_connect_proxy(self, handler, ctx): @@ -391,6 +402,7 @@ def test_https_connect_proxy(self, handler, ctx): proxy_info = ctx.proxy_info_request(rh) assert proxy_info['connect'] is True assert 'Proxy-Authorization' not in proxy_info['headers'] + assert proxy_info['proxy'] == server_address @pytest.mark.skipif(urllib3 is None, reason='requires urllib3 to test') def test_https_connect_verify_failed(self, handler, ctx): @@ -408,3 +420,4 @@ def test_https_connect_proxy_auth(self, handler, ctx): with handler(verify=False, proxies={ctx.REQUEST_PROTO: f'https://test:test@{server_address}'}) as rh: proxy_info = ctx.proxy_info_request(rh) assert 'Proxy-Authorization' in proxy_info['headers'] + assert proxy_info['proxy'] == server_address diff --git a/test/test_socks.py b/test/test_socks.py index 43d612d85..20237dc76 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -216,7 +216,9 @@ def handle(self): protocol = websockets.ServerProtocol() connection = websockets.sync.server.ServerConnection(socket=self.request, protocol=protocol, close_timeout=0) connection.handshake() - connection.send(json.dumps(self.socks_info)) + for message in connection: + if message == 'socks_info': + connection.send(json.dumps(self.socks_info)) connection.close() diff --git a/yt_dlp/networking/_websockets.py b/yt_dlp/networking/_websockets.py index 776662c3a..37da6b102 100644 --- a/yt_dlp/networking/_websockets.py +++ b/yt_dlp/networking/_websockets.py @@ -118,7 +118,7 @@ def __init__(self, *args, **kwargs): for name in ('websockets.client', 'websockets.server'): logger = logging.getLogger(name) handler = logging.StreamHandler(stream=sys.stdout) - handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: %(message)s')) + handler.setFormatter(logging.Formatter(f'{self.RH_NAME}: [{name}] %(message)s')) self.__logging_handlers[name] = handler logger.addHandler(handler) if self.verbose: @@ -152,7 +152,7 @@ def _make_sock(self, proxy, url, timeout): **create_conn_kwargs ) - elif parsed_proxy_url.scheme.startswith('http'): + elif parsed_proxy_url.scheme in ('http', 'https'): return create_http_connect_conn( proxy_url=proxy, url=url, @@ -177,6 +177,7 @@ def _send(self, request): headers['cookie'] = cookie_header proxy = select_proxy(request.url, self._get_proxies(request)) + try: conn = websockets.sync.client.connect( sock=self._make_sock(proxy, request.url, timeout), @@ -184,7 +185,10 @@ def _send(self, request): additional_headers=headers, open_timeout=timeout, user_agent_header=None, - ssl_context=self._make_sslcontext() if parse_uri(request.url).secure else None, + ssl_context=( + WebsocketsSSLContext(self._make_sslcontext()) + if parse_uri(request.url).secure else None + ), close_timeout=0, # not ideal, but prevents yt-dlp hanging ) return WebsocketsResponseAdapter(conn, url=request.url) @@ -218,12 +222,34 @@ def begin(self): if not self._check_close() and not self.chunked and self.length is None: self.will_close = False -class CustomSSLTransport(SSLTransport): + +# todo: only define if urllib3 is available +class WebsocketsSSLTransport(SSLTransport): + """ + Modified version of urllib3 SSLTransport to support additional operations used by websockets + """ def setsockopt(self, *args, **kwargs): self.socket.setsockopt(*args, **kwargs) + def shutdown(self, *args, **kwargs): + self.unwrap() self.socket.shutdown(*args, **kwargs) + +class WebsocketsSSLContext: + """ + Dummy SSL Context for websockets which returns a WebsocketsSSLTransport instance + for wrap socket when using TLS-in-TLS. + """ + def __init__(self, ssl_context: ssl.SSLContext): + self.ssl_context = ssl_context + + def wrap_socket(self, sock, server_hostname=None): + if isinstance(sock, ssl.SSLSocket): + return WebsocketsSSLTransport(sock, self.ssl_context, server_hostname=server_hostname) + return self.ssl_context.wrap_socket(sock, server_hostname=server_hostname) + + def create_http_connect_conn( proxy_url, url, @@ -256,17 +282,18 @@ def create_http_connect_conn( if source_address is not None: conn.source_address = (source_address, 0) - conn.debuglevel=2 try: conn.connect() if ssl_context: - conn.sock = CustomSSLTransport(conn.sock, ssl_context, server_hostname=proxy_url_parsed.hostname) - - conn.request(method='CONNECT', url=f'{request_url_parsed.host}:{request_url_parsed.port}', headers=proxy_headers) + conn.sock = ssl_context.wrap_socket(conn.sock, server_hostname=proxy_url_parsed.hostname) + conn.request( + method='CONNECT', + url=f'{request_url_parsed.host}:{request_url_parsed.port}', + headers=proxy_headers) response = conn.getresponse() except OSError as e: conn.close() - raise TransportError('Unable to connect to proxy', cause=e) from e + raise ProxyError('Unable to connect to proxy', cause=e) from e if response.status == 200: return conn.sock