diff --git a/relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py b/relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py index d506233..67becd0 100644 --- a/relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py +++ b/relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py @@ -14,6 +14,7 @@ """Operation level interface to the RelationalAI within Snowflake Snowpark""" +import json import logging from snowflake.snowpark import Session from snowflake.snowpark.dataframe import DataFrame @@ -40,46 +41,106 @@ def create_database(session: Session, database: str) -> bool: return create_db_res[0][0] == '"ok"' + def delete_database(session: Session, database: str) -> bool: delete_db_res = session.sql(f"SELECT RAI.DELETE_RAI_DATABASE('{database}')").collect() return delete_db_res[0][0] == '"ok"' + def get_database(session: Session, database: str) -> DataFrame: - return session.sql(f"SELECT RAI.GET_RAI_DATABASE('{database}')") + return session.sql(f""" + select + res:id::string as id, + res:account_name::string as account_name, + res:created_by::string as created_by, + res:name::string as name, + res:region::string as region, + res:state::string as state + from + (select RAI.GET_RAI_DATABASE('{database}') as res); + """) + def list_databases(session: Session) -> DataFrame: - return session.sql("""SELECT - value:id::string as ID - ,value:account_name::string as ACCOUNT_NAME - ,value:created_by::string as CREATED_BY - ,value:name::string as NAME - ,value:region::string as REGION - ,value:state::string as STATE - FROM - (SELECT RAI.LIST_RAI_DATABASES() as LRD) - ,LATERAL FLATTEN (input => LRD)""") + return session.sql(""" + select + value:id::string as id, + value:account_name::string as account_name, + value:created_by::string as created_by, + value:name::string as name, + value:region::string as region, + value:state::string as state + from + (select RAI.LIST_RAI_DATABASES() as lrd), + lateral flatten (input => lrd) + """) + def use_database(session: Session, database: str) -> DataFrame: - return session.sql(f"CALL RAI.USE_RAI_DATABASE('{database}')") + res = session.sql(f"CALL RAI.USE_RAI_DATABASE('{database}')").collect() + + if res[0][0] != database: + rsp = json.loads(res[0][0]) + + if not rsp['success']: + raise Exception(rsp['message']) + def create_engine(session: Session, engine: str, size: str = 'S') -> DataFrame: - return session.sql(f"SELECT RAI.CREATE_RAI_ENGINE('{engine}', '{size}')") + res = session.sql(f"SELECT RAI.CREATE_RAI_ENGINE('{engine}', '{size}')").collect() + + return res[0][0] == '"ok"' + def delete_engine(session: Session, engine: str) -> DataFrame: - return session.sql(f"SELECT RAI.DELETE_RAI_ENGINE('{engine}')") + res = session.sql(f"SELECT RAI.DELETE_RAI_ENGINE('{engine}')").collect() + + return res[0][0] == '"ok"' + def get_engine(session: Session, engine: str) -> DataFrame: - return session.sql(f"SELECT RAI.GET_RAI_ENGINE('{engine}')") + return session.sql(f""" + select + res:id::string as id, + res:account_name::string as account_name, + res:created_by::string as created_by, + res:created_on::string as created_on, + res:name::string as name, + res:region::string as region, + res:size::string as size, + res:state::string as state + from + (select RAI.GET_RAI_ENGINE('{engine}') as res); + """) + def list_engines(session: Session) -> DataFrame: - return session.sql(f"SELECT RAI.LIST_RAI_ENGINES()") + return session.sql(f""" + select + value:id::string as id, + value:account_name::string as account_name, + value:created_by::string as created_by, + value:created_on::string as created_on, + value:name::string as name, + value:region::string as region, + value:size::string as size, + value:state::string as state + from + (select RAI.LIST_RAI_ENGINES() as res), + lateral flatten (input => res); +""") + def use_engine(session: Session, engine: str) -> DataFrame: - return session.sql(f"CALL RAI.USE_RAI_ENGINE('{engine}')") + res = session.sql(f"CALL RAI.USE_RAI_ENGINE('{engine}')").collect() + + if res[0][0] != engine: + rsp = json.loads(res[0][0]) + + if not rsp['success']: + raise Exception(rsp['message']) + def exec(session: Session, database: str, engine: str, query: str) -> DataFrame: return session.sql(f"SELECT RAI.EXEC('{database}', '{engine}', '{query}', null, true)") - -# def exec(session: Session, query: str) -> DataFrame: -# return session.sql(f"SELECT RAI.EXEC('{query}')")