Only use SSLTransport where tls-in-tls will be used

This commit is contained in:
coletdjnz 2024-05-18 17:18:21 +12:00
parent 82cceaed31
commit d274eb1f53
No known key found for this signature in database
GPG Key ID: 91984263BB39894A
2 changed files with 14 additions and 17 deletions

View File

@ -172,7 +172,7 @@ def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None) sslctx.load_cert_chain(certfn, None)
if SSLTransport: if isinstance(request, ssl.SSLSocket) and SSLTransport:
request = SSLTransport(request, ssl_context=sslctx, server_side=True) request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else: else:
request = sslctx.wrap_socket(request, server_side=True) request = sslctx.wrap_socket(request, server_side=True)
@ -213,10 +213,7 @@ def __init__(self, request, *args, **kwargs):
certfn = os.path.join(TEST_DIR, 'testcert.pem') certfn = os.path.join(TEST_DIR, 'testcert.pem')
sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
sslctx.load_cert_chain(certfn, None) sslctx.load_cert_chain(certfn, None)
if SSLTransport: request = sslctx.wrap_socket(request, server_side=True)
request = SSLTransport(request, ssl_context=sslctx, server_side=True)
else:
request = sslctx.wrap_socket(request, server_side=True)
self._original_request = request self._original_request = request
super().__init__(request, *args, **kwargs) super().__init__(request, *args, **kwargs)
@ -244,7 +241,6 @@ def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_serv
finally: finally:
server.shutdown() server.shutdown()
server.server_close() server.server_close()
server_thread.join(2.0)
class HTTPProxyTestContext(abc.ABC): class HTTPProxyTestContext(abc.ABC):
@ -393,11 +389,11 @@ def test_http_with_idn(self, handler, ctx):
('Websockets', 'ws'), ('Websockets', 'ws'),
('Websockets', 'wss') ('Websockets', 'wss')
], indirect=True) ], indirect=True)
@pytest.mark.skip_handler_if( # @pytest.mark.skip_handler_if(
'Websockets', lambda request: # 'Websockets', lambda request:
(platform.python_implementation() == 'PyPy' # (platform.python_implementation() == 'PyPy'
and request.getfixturevalue('ctx').REQUEST_PROTO == 'wss'), # and request.getfixturevalue('ctx').REQUEST_PROTO == 'wss'),
'PyPy sometimes fails with wss tests, unknown reason') # 'PyPy sometimes fails with wss tests, unknown reason')
class TestHTTPConnectProxy: class TestHTTPConnectProxy:
def test_http_connect_no_auth(self, handler, ctx): def test_http_connect_no_auth(self, handler, ctx):
with ctx.http_server(HTTPConnectProxyHandler) as server_address: with ctx.http_server(HTTPConnectProxyHandler) as server_address:

View File

@ -199,20 +199,21 @@ def _send(self, request):
proxy = select_proxy(request.url, self._get_proxies(request)) proxy = select_proxy(request.url, self._get_proxies(request))
ssl_context = None ssl_context = None
sock = self._make_sock(proxy, request.url, timeout)
if parse_uri(request.url).secure: if parse_uri(request.url).secure:
if WebsocketsSSLContext is not None: ssl_context = self._make_sslcontext()
ssl_context = WebsocketsSSLContext(self._make_sslcontext()) if isinstance(sock, ssl.SSLSocket) and WebsocketsSSLContext: # tls in tls
else: ssl_context = WebsocketsSSLContext(ssl_context)
ssl_context = self._make_sslcontext()
try: try:
conn = websockets.sync.client.connect( conn = websockets.sync.client.connect(
sock=self._make_sock(proxy, request.url, timeout), sock=sock,
uri=request.url, uri=request.url,
additional_headers=headers, additional_headers=headers,
open_timeout=timeout, open_timeout=timeout,
user_agent_header=None, user_agent_header=None,
ssl_context=ssl_context, ssl_context=ssl_context,
close_timeout=0.1, # not ideal, but prevents yt-dlp hanging close_timeout=0, # not ideal, but prevents yt-dlp hanging
) )
return WebsocketsResponseAdapter(conn, url=request.url) return WebsocketsResponseAdapter(conn, url=request.url)