Skip to content

Commit

Permalink
"added more functions and tests, initial README"
Browse files Browse the repository at this point in the history
  • Loading branch information
antikus committed Oct 25, 2023
1 parent 7c85644 commit e76d2ae
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 62 deletions.
40 changes: 40 additions & 0 deletions relationalai-snowflake-sdk/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# The RelationalAI Software Development Kit for Snowflake Snowpark

The RelationalAI (RAI) SDK for Python enables developers to access the RAI
REST APIs from Snowflake Snowpark.

* You can find RelationalAI Python SDK documentation at <https://docs.relational.ai/rkgms/sdk/python-sdk>
* You can find RelationalAI product documentation at <https://docs.relational.ai>
* You can learn more about RelationalAI at <https://relational.ai>

## Getting started

### Requirements

* Python 3.7+

### Installing the SDK

Install from source in `editable` mode.

```console
$ git clone [email protected]:RelationalAI/rai-sdk-python.git
$ cd rai-sdk-python
$ [sudo] pip install -e relationalai-snowflake-sdk
```

## Support

You can reach the RAI developer support team at `[email protected]`

## Contributing

We value feedback and contributions from our developer community. Feel free
to submit an issue or a PR here.

## License

The RelationalAI Software Development Kit for Python is licensed under the
Apache License 2.0. See:
https://github.com/RelationalAI/rai-sdk-python/blob/master/LICENSE

137 changes: 114 additions & 23 deletions relationalai-snowflake-sdk/relationalai/snowflake_sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

import json
import logging
from typing import List

from snowflake.snowpark import Session
from snowflake.snowpark.dataframe import DataFrame
from snowflake.snowpark.row import Row

# logger
logger = logging.getLogger(__package__)
Expand All @@ -33,19 +36,34 @@
"list_engines",
"use_database",
"use_engine",
"get_current_database",
"get_current_engine",
"exec",
"exec_into",
"load_data",
"load_model",
"load_model_code",
"load_model_query",
"create_data_stream",
"delete_data_stream",
"get_data_stream",
"get_data_stream_status",
"list_data_streams",
"ping",
]


def create_database(session: Session, database: str) -> bool:
create_db_res = session.sql(f"SELECT RAI.CREATE_RAI_DATABASE('{database}')").collect()
#################################
# DATABASE
#################################

return create_db_res[0][0] == '"ok"'

def create_database(session: Session, database: str) -> List[Row]:
return session.sql(f"select RAI.CREATE_RAI_DATABASE('{database}') as status").collect()

def delete_database(session: Session, database: str) -> bool:
delete_db_res = session.sql(f"SELECT RAI.DELETE_RAI_DATABASE('{database}')").collect()

return delete_db_res[0][0] == '"ok"'
def delete_database(session: Session, database: str) -> List[Row]:
return session.sql(f"select RAI.DELETE_RAI_DATABASE('{database}') as status").collect()


def get_database(session: Session, database: str) -> DataFrame:
Expand All @@ -72,31 +90,37 @@ def list_databases(session: Session) -> DataFrame:
value:region::string as region,
value:state::string as state
from
(select RAI.LIST_RAI_DATABASES() as lrd),
lateral flatten (input => lrd)
(select RAI.LIST_RAI_DATABASES() as res),
lateral flatten (input => res)
""")


def use_database(session: Session, database: str) -> DataFrame:
res = session.sql(f"CALL RAI.USE_RAI_DATABASE('{database}')").collect()
def use_database(session: Session, database: str) -> List[Row]:
res = session.sql(f"call RAI.USE_RAI_DATABASE('{database}')").collect()

if res[0][0] != database:
rsp = json.loads(res[0][0])

if not rsp['success']:
raise Exception(rsp['message'])
if not rsp["success"]:
raise Exception(rsp["message"])

return res


def get_current_database(session: Session) -> DataFrame:
return session.sql("select RAI.CURRENT_RAI_DATABASE() as current_database")

def create_engine(session: Session, engine: str, size: str = 'S') -> DataFrame:
res = session.sql(f"SELECT RAI.CREATE_RAI_ENGINE('{engine}', '{size}')").collect()

return res[0][0] == '"ok"'
#################################
# ENGINE
#################################

def create_engine(session: Session, engine: str, size: str = 'XS') -> List[Row]:
return session.sql(f"select RAI.CREATE_RAI_ENGINE('{engine}', '{size}') as status").collect()

def delete_engine(session: Session, engine: str) -> DataFrame:
res = session.sql(f"SELECT RAI.DELETE_RAI_ENGINE('{engine}')").collect()

return res[0][0] == '"ok"'
def delete_engine(session: Session, engine: str) -> List[Row]:
return session.sql(f"select RAI.DELETE_RAI_ENGINE('{engine}') as status").collect()


def get_engine(session: Session, engine: str) -> DataFrame:
Expand Down Expand Up @@ -132,15 +156,82 @@ def list_engines(session: Session) -> DataFrame:
""")


def use_engine(session: Session, engine: str) -> DataFrame:
def use_engine(session: Session, engine: str) -> List[Row]:
res = session.sql(f"CALL RAI.USE_RAI_ENGINE('{engine}')").collect()

if res[0][0] != engine:
rsp = json.loads(res[0][0])

if not rsp['success']:
raise Exception(rsp['message'])
if not rsp["success"]:
raise Exception(rsp["message"])

return res


def get_current_engine(session: Session) -> DataFrame:
return session.sql("select RAI.CURRENT_RAI_ENGINE() as current_engine")


#################################
# Transaction
#################################


def exec(session: Session, database: str, engine: str, query: str, data=None, readonly: bool = True) -> DataFrame:
return session.sql(f"select RAI.EXEC('{database}', '{engine}', '{query}', {data if data else 'null'}, {readonly})")


def exec_into(session: Session, database: str, engine: str, query: str, warehouse: str, target: str, data=None, readonly: bool = True) -> DataFrame:
return session.sql(f"select RAI.EXEC_INTO('{database}', '{engine}', '{query}', '{data if data else 'null'}', {readonly}, '{warehouse}', '{target}')")


#################################
# Model
#################################

def load_model(session: Session, database: str, engine: str, name: str, path: str) -> List[Row]:
return session.sql(f"select RAI.LOAD_MODEL('{database}', '{engine}', '{name}', '{path}')").collect()


def load_model_code(session: Session, database: str, engine: str, name: str, code: str) -> List[Row]:
return session.sql(f"select RAI.LOAD_MODEL_CODE('{database}', '{engine}', '{name}', '{code}')").collect()


def load_model_query(session: Session, name: str, path: str) -> List[Row]:
return session.sql(f"select RAI.LOAD_MODEL_QUERY('{name}', '{path}')").collect()

#################################
# Data Stream
#################################


def create_data_stream(session: Session, data_source: str, database: str, base_relation: str) -> List[Row]:
return session.sql(f"select RAI.CREATE_DATA_STREAM('{data_source}', '{database}', '{base_relation}') as status").collect()


def delete_data_stream(session: Session, data_source: str) -> List[Row]:
return session.sql(f"select RAI.DELETE_DATA_STREAM('{data_source}') as status").collect()


def get_data_stream(session: Session, data_source: str) -> DataFrame:
return session.sql(f"select RAI.GET_DATA_STREAM('{data_source}')")


def get_data_stream_status(session: Session, data_source: str) -> DataFrame:
return session.sql(f"select RAI.GET_DATA_STREAM_STATUS('{data_source}') as status")


def list_data_streams(session: Session) -> DataFrame:
return session.sql(f"select RAI.LIST_DATA_STREAMS()")

#################################
# Misc
#################################


def load_data(session: Session, database: str, relation: str, primary_key: str, query: str) -> List[Row]:
return session.sql(f"select RAI.LOAD_DATA('{database}', '{relation}', '{primary_key}', '{query}')").collect()


def exec(session: Session, database: str, engine: str, query: str) -> DataFrame:
return session.sql(f"SELECT RAI.EXEC('{database}', '{engine}', '{query}', null, true)")
def ping(session: Session) -> List[Row]:
return session.sql("select RAI.PING() as result").collect()
Loading

0 comments on commit e76d2ae

Please sign in to comment.