diff --git a/flojoy/utils.py b/flojoy/utils.py index 034e017..6dae05d 100644 --- a/flojoy/utils.py +++ b/flojoy/utils.py @@ -1,24 +1,72 @@ import decimal import json as _json +import os +import sys +from pathlib import Path +from typing import Any, Callable, Optional + import logging import numpy as np import pandas as pd -from pathlib import Path -import os -import yaml -from typing import Union, Any import requests +import yaml from dotenv import dotenv_values # type:ignore +from huggingface_hub import hf_hub_download as _hf_hub_download +from huggingface_hub import snapshot_download as _snapshot_download from .dao import Dao from .config import FlojoyConfig, logger +from .node_init import NodeInit, NodeInitService +import keyring + __all__ = [ "send_to_socket", - "get_frontier_api_key", - "set_frontier_api_key", - "set_frontier_s3_key", + "get_env_var", + "set_env_var", + "delete_env_var", + "hf_hub_download", + "snapshot_download", + "get_node_init_function", + "get_credentials", + "clear_flojoy_memory", ] +FLOJOY_DIR = ".flojoy" + + +if sys.platform == "win32": + FLOJOY_CACHE_DIR = os.path.join(os.environ["APPDATA"], FLOJOY_DIR) +else: + FLOJOY_CACHE_DIR = os.path.join(os.environ["HOME"], FLOJOY_DIR) + + +# Make as a function to mock at test-time +def _get_hf_hub_cache_path() -> str: + return os.path.join(FLOJOY_CACHE_DIR, "cache", "hf_hub") + + +def hf_hub_download(*args, **kwargs): + if "cache_dir" not in kwargs: + kwargs["cache_dir"] = _get_hf_hub_cache_path() + else: + if kwargs["cache_dir"] != _get_hf_hub_cache_path(): + raise ValueError( + f"Attempted to override cache_dir parameter, received {kwargs['cache_dir']} while the only alloed value is {_get_hf_hub_cache_path()}" + ) + return _hf_hub_download(*args, **kwargs) + + +def snapshot_download(*args, **kwargs): + if "cache_dir" not in kwargs: + kwargs["cache_dir"] = _get_hf_hub_cache_path() + else: + if kwargs["cache_dir"] != _get_hf_hub_cache_path(): + raise ValueError( + f"Attempted to override cache_dir parameter, received {kwargs['cache_dir']} while the only alloed value is {_get_hf_hub_cache_path()}" + ) + return _snapshot_download(*args, **kwargs) + + env_vars = dotenv_values("../.env") port = env_vars.get("VITE_BACKEND_PORT", "8000") BACKEND_URL = os.environ.get("BACKEND_URL", f"http://127.0.0.1:{port}") @@ -251,88 +299,74 @@ def get_flojoy_root_dir() -> str: stream = open(path, "r") yaml_dict = yaml.load(stream, Loader=yaml.FullLoader) root_dir = "" + if isinstance(yaml_dict, str): root_dir = yaml_dict.split(":")[1] else: root_dir = yaml_dict["PATH"] + return root_dir -def get_frontier_api_key() -> Union[str, None]: +def get_env_var(key: str) -> Optional[str]: + return keyring.get_password("flojoy", key) + + +def set_env_var(key: str, value: str): + keyring.set_password("flojoy", key, value) home = str(Path.home()) - api_key = None - path = os.path.join(home, ".flojoy/credentials") - if not os.path.exists(path): - return api_key + file_path = os.path.join(home, os.path.join(FLOJOY_DIR, "credentials.txt")) - stream = open(path, "r", encoding="utf-8") - yaml_dict = yaml.load(stream, Loader=yaml.FullLoader) - if yaml_dict is None: - return api_key - if isinstance(yaml_dict, str) == True: - split_by_line = yaml_dict.split("\n") - for line in split_by_line: - if "FRONTIER_API_KEY" in line: - api_key = line.split(":")[1] - else: - api_key = yaml_dict.get("FRONTIER_API_KEY", None) - return api_key + if not os.path.exists(file_path): + with open(file_path, "w") as f: + f.write(key) + return + with open(file_path, "r") as f: + keys = f.read().strip().split(",") + if key not in keys: + keys.append(key) -def set_frontier_api_key(api_key: str): - try: - home = str(Path.home()) - file_path = os.path.join(home, ".flojoy/credentials") + with open(file_path, "a") as f: + f.write(",".join(keys)) - if not os.path.exists(file_path): - # Create a new file and write the API_KEY to it - with open(file_path, "w") as file: - file.write(f"FRONTIER_API_KEY:{api_key}\n") - else: - # Read the contents of the file - with open(file_path, "r") as file: - lines = file.readlines() - # Update the API key if it exists, otherwise append a new line - updated = False - for i, line in enumerate(lines): - if line.startswith("FRONTIER_API_KEY:"): - lines[i] = f"FRONTIER_API_KEY:{api_key}\n" - updated = True - break +def delete_env_var(key: str): + home = str(Path.home()) + file_path = os.path.join(home, os.path.join(FLOJOY_DIR, "credentials.txt")) + + if not os.path.exists(file_path): + return + + with open(file_path, "r") as f: + keys = f.read().strip().split(",") - if not updated: - lines.append(f"FRONTIER_API_KEY:{api_key}\n") - # Write the updated contents to the file - with open(file_path, "w") as file: - file.writelines(lines) + if key not in keys: + return - except Exception as e: - raise e + keys.remove(key) + with open(file_path, "w") as f: + f.write(",".join(keys)) -def set_frontier_s3_key(s3_name: str, s3_access_key: str, s3_secret_key: str): + keyring.delete_password("flojoy", key) + + +def get_credentials() -> list[dict[str, str]]: home = str(Path.home()) - file_path = os.path.join(home, os.path.join(".flojoy", "credentials.yaml")) + file_path = os.path.join(home, os.path.join(FLOJOY_DIR, "credentials.txt")) - data = { - f"{s3_name}": s3_name, - f"{s3_name}accessKey": s3_access_key, - f"{s3_name}secretKey": s3_secret_key, - } - if not os.path.exists(file_path): - # Create a new file and write the ACCSS_KEY to it - with open(file_path, "w") as file: - yaml.dump(data, file) - return + with open(file_path, "r") as f: + keys = f.read().strip().split(",") + + credentials_list: list[dict[str, str]] = [] + for key in keys: + val = get_env_var(key) + if val: + credentials_list.append({"key": key, "value": val}) - # Read the contents of the file - with open(file_path, "r") as file: - load = yaml.safe_load(file) + return credentials_list - load[f"{s3_name}"] = s3_name - load[f"{s3_name}accessKey"] = s3_access_key - load[f"{s3_name}secretKey"] = s3_secret_key - with open(file_path, "w") as file: - yaml.dump(load, file) +def get_node_init_function(node_func: Callable) -> NodeInit: + return NodeInitService().get_node_init_function(node_func) diff --git a/requirements.txt b/requirements.txt index 066058b..570ca00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -58,3 +58,5 @@ uptime==3.0.1 urllib3==1.26.15 wrapt==1.15.0 zope.interface==6.0 +keyring==24.2.0 +huggingface_hub==0.16.4