diff --git a/test/test_http_proxy.py b/test/test_http_proxy.py index 191b6baf8..c72edc472 100644 --- a/test/test_http_proxy.py +++ b/test/test_http_proxy.py @@ -60,7 +60,9 @@ def __init__(self, *args, proxy_info=None, username=None, password=None, request super().__init__(*args, **kwargs) def do_GET(self): - self.do_proxy_auth(self.username, self.password) + if not self.do_proxy_auth(self.username, self.password): + self.server.close_request(self.request) + return if self.path.endswith('/proxy_info'): payload = json.dumps(self.proxy_info or { 'client_address': self.client_address, @@ -76,6 +78,11 @@ def do_GET(self): self.send_header('Content-Length', str(len(payload))) self.end_headers() self.wfile.write(payload.encode()) + else: + self.send_response(404) + self.end_headers() + + self.server.close_request(self.request) if urllib3: @@ -160,7 +167,9 @@ def __init__(self, *args, username=None, password=None, request_handler=None, ** super().__init__(*args, **kwargs) def do_CONNECT(self): - self.do_proxy_auth(self.username, self.password) + if not self.do_proxy_auth(self.username, self.password): + self.server.close_request(self.request) + return self.send_response(200) self.end_headers() proxy_info = { @@ -173,6 +182,7 @@ def do_CONNECT(self): 'proxy': ':'.join(str(y) for y in self.connection.getsockname()), } self.request_handler(self.request, self.client_address, self.server, proxy_info=proxy_info) + self.server.close_request(self.request) class HTTPSConnectProxyHandler(HTTPConnectProxyHandler): @@ -181,8 +191,13 @@ def __init__(self, request, *args, **kwargs): sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) sslctx.load_cert_chain(certfn, None) request = sslctx.wrap_socket(request, server_side=True) + self._original_request = request super().__init__(request, *args, **kwargs) + def do_CONNECT(self): + super().do_CONNECT() + self.server.close_request(self._original_request) + @contextlib.contextmanager def proxy_server(proxy_server_class, request_handler, bind_ip=None, **proxy_server_kwargs): @@ -295,6 +310,7 @@ def test_http_bad_auth(self, handler, ctx): with pytest.raises(HTTPError) as exc_info: ctx.proxy_info_request(rh) assert exc_info.value.response.status == 407 + exc_info.value.response.close() def test_http_source_address(self, handler, ctx): with ctx.http_server(HTTPProxyHandler) as server_address: