Skip to content

Commit

Permalink
updated apis
Browse files Browse the repository at this point in the history
  • Loading branch information
antikus committed Oct 18, 2023
1 parent 9575acb commit 9538e4c
Showing 1 changed file with 81 additions and 20 deletions.
101 changes: 81 additions & 20 deletions relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Operation level interface to the RelationalAI within Snowflake Snowpark"""

import json
import logging
from snowflake.snowpark import Session
from snowflake.snowpark.dataframe import DataFrame
Expand All @@ -40,46 +41,106 @@ def create_database(session: Session, database: str) -> bool:

return create_db_res[0][0] == '"ok"'


def delete_database(session: Session, database: str) -> bool:
delete_db_res = session.sql(f"SELECT RAI.DELETE_RAI_DATABASE('{database}')").collect()

return delete_db_res[0][0] == '"ok"'


def get_database(session: Session, database: str) -> DataFrame:
return session.sql(f"SELECT RAI.GET_RAI_DATABASE('{database}')")
return session.sql(f"""
select
res:id::string as id,
res:account_name::string as account_name,
res:created_by::string as created_by,
res:name::string as name,
res:region::string as region,
res:state::string as state
from
(select RAI.GET_RAI_DATABASE('{database}') as res);
""")


def list_databases(session: Session) -> DataFrame:
return session.sql("""SELECT
value:id::string as ID
,value:account_name::string as ACCOUNT_NAME
,value:created_by::string as CREATED_BY
,value:name::string as NAME
,value:region::string as REGION
,value:state::string as STATE
FROM
(SELECT RAI.LIST_RAI_DATABASES() as LRD)
,LATERAL FLATTEN (input => LRD)""")
return session.sql("""
select
value:id::string as id,
value:account_name::string as account_name,
value:created_by::string as created_by,
value:name::string as name,
value:region::string as region,
value:state::string as state
from
(select RAI.LIST_RAI_DATABASES() as lrd),
lateral flatten (input => lrd)
""")


def use_database(session: Session, database: str) -> DataFrame:
return session.sql(f"CALL RAI.USE_RAI_DATABASE('{database}')")
res = session.sql(f"CALL RAI.USE_RAI_DATABASE('{database}')").collect()

if res[0][0] != database:
rsp = json.loads(res[0][0])

if not rsp['success']:
raise Exception(rsp['message'])


def create_engine(session: Session, engine: str, size: str = 'S') -> DataFrame:
return session.sql(f"SELECT RAI.CREATE_RAI_ENGINE('{engine}', '{size}')")
res = session.sql(f"SELECT RAI.CREATE_RAI_ENGINE('{engine}', '{size}')").collect()

return res[0][0] == '"ok"'


def delete_engine(session: Session, engine: str) -> DataFrame:
return session.sql(f"SELECT RAI.DELETE_RAI_ENGINE('{engine}')")
res = session.sql(f"SELECT RAI.DELETE_RAI_ENGINE('{engine}')").collect()

return res[0][0] == '"ok"'


def get_engine(session: Session, engine: str) -> DataFrame:
return session.sql(f"SELECT RAI.GET_RAI_ENGINE('{engine}')")
return session.sql(f"""
select
res:id::string as id,
res:account_name::string as account_name,
res:created_by::string as created_by,
res:created_on::string as created_on,
res:name::string as name,
res:region::string as region,
res:size::string as size,
res:state::string as state
from
(select RAI.GET_RAI_ENGINE('{engine}') as res);
""")


def list_engines(session: Session) -> DataFrame:
return session.sql(f"SELECT RAI.LIST_RAI_ENGINES()")
return session.sql(f"""
select
value:id::string as id,
value:account_name::string as account_name,
value:created_by::string as created_by,
value:created_on::string as created_on,
value:name::string as name,
value:region::string as region,
value:size::string as size,
value:state::string as state
from
(select RAI.LIST_RAI_ENGINES() as res),
lateral flatten (input => res);
""")


def use_engine(session: Session, engine: str) -> DataFrame:
return session.sql(f"CALL RAI.USE_RAI_ENGINE('{engine}')")
res = session.sql(f"CALL RAI.USE_RAI_ENGINE('{engine}')").collect()

if res[0][0] != engine:
rsp = json.loads(res[0][0])

if not rsp['success']:
raise Exception(rsp['message'])


def exec(session: Session, database: str, engine: str, query: str) -> DataFrame:
return session.sql(f"SELECT RAI.EXEC('{database}', '{engine}', '{query}', null, true)")

# def exec(session: Session, query: str) -> DataFrame:
# return session.sql(f"SELECT RAI.EXEC('{query}')")

0 comments on commit 9538e4c

Please sign in to comment.