Skip to content

Commit

Permalink
updated create/delete db methods, added a few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
antikus committed Oct 17, 2023
1 parent 32099d0 commit 9575acb
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 7 deletions.
15 changes: 8 additions & 7 deletions relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,20 @@
]


def create_database(session: Session, database: str) -> DataFrame:
return session.sql(f"SELECT RAI.CREATE_RAI_DATABASE('{database}')")
def create_database(session: Session, database: str) -> bool:
create_db_res = session.sql(f"SELECT RAI.CREATE_RAI_DATABASE('{database}')").collect()

def delete_database(session: Session, database: str) -> DataFrame:
return session.sql(f"SELECT RAI.DELETE_RAI_DATABASE('{database}')")
return create_db_res[0][0] == '"ok"'

def delete_database(session: Session, database: str) -> bool:
delete_db_res = session.sql(f"SELECT RAI.DELETE_RAI_DATABASE('{database}')").collect()

return delete_db_res[0][0] == '"ok"'

def get_database(session: Session, database: str) -> DataFrame:
return session.sql(f"SELECT RAI.GET_RAI_DATABASE('{database}')")

def list_databases(session: Session) -> DataFrame:
return session.sql(f"SELECT RAI.LIST_RAI_DATABASES()")

def tabular_list_databases(session: Session) -> DataFrame:
return session.sql("""SELECT
value:id::string as ID
,value:account_name::string as ACCOUNT_NAME
Expand Down
60 changes: 60 additions & 0 deletions test/test_integration_snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
import uuid

from snowflake.snowpark import Session

from logging.config import fileConfig
from relationalai.snowflake_sdk import api


connection_parameters = {
"user": "[email protected]",
"authenticator": "externalbrowser",
"account": "NDSOEBE-DR75630",
"database": "SNOWFLAKE_INTEGRATION_SANDBOX",
"schema": "ANATOLI",
}



suffix = uuid.uuid4()
engine = f"sf-python-sdk-{suffix}"
dbname = f"sf-python-sdk-{suffix}"

# init "rai" logger
fileConfig("./test/logger.config")


class TestDatabase(unittest.TestCase):
def setUp(self):
self.session = Session.builder.configs(connection_parameters).create()
# api.delete_database(session, dbname)

def test_create_delete_database_api(self):
create_database_res = api.create_database(self.session, dbname)
self.assertTrue(create_database_res)

delete_database_res = api.delete_database(self.session, dbname)
self.assertTrue(delete_database_res)

def test_list_databases(self):
df_list_dbs = api.list_databases(self.session)
list_dbs_res = df_list_dbs.collect()

self.assertTrue(len(list_dbs_res) > 0)

fist_el = list_dbs_res[0]

self.assertTrue(hasattr(fist_el, 'ID'))
self.assertTrue(hasattr(fist_el, 'ACCOUNT_NAME'))
self.assertTrue(hasattr(fist_el, 'CREATED_BY'))
self.assertTrue(hasattr(fist_el, 'NAME'))
self.assertTrue(hasattr(fist_el, 'REGION'))
self.assertTrue(hasattr(fist_el, 'STATE'))

def tearDown(self):
api.delete_database(self.session, dbname)


if __name__ == '__main__':
unittest.main()

0 comments on commit 9575acb

Please sign in to comment.