diff --git a/cli-e2e-test/test_e2e.py b/cli-e2e-test/test_e2e.py index bcb0929..05c7d1c 100644 --- a/cli-e2e-test/test_e2e.py +++ b/cli-e2e-test/test_e2e.py @@ -4,9 +4,9 @@ import shutil import uuid -import test_query as q import workflow.manager import workflow.rai +from workflow.constants import RESOURCES_TO_DELETE_REL from csv_diff import load_csv, compare, human_text from subprocess import call @@ -49,7 +49,7 @@ def test_scenario2_model_no_data_changes(self): # then self.assertNotEqual(rsp, 1) rai_config = self.resource_manager.get_rai_config() - rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, q.RESOURCES_TO_DELETE) + rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, RESOURCES_TO_DELETE_REL) self.assertEqual(rsp_json, {}) def test_scenario2_model_force_reimport(self): @@ -65,7 +65,7 @@ def test_scenario2_model_force_reimport(self): # then self.assertNotEqual(rsp, 1) rai_config = self.resource_manager.get_rai_config() - rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, q.RESOURCES_TO_DELETE) + rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, RESOURCES_TO_DELETE_REL) self.assertEqual(rsp_json, [{'partition': 2023090800001, 'relation': 'city_data'}, {'partition': 2023090800002, 'relation': 'city_data'}, {'partition': 2023090900001, 'relation': 'city_data'}, @@ -84,7 +84,7 @@ def test_scenario2_model_force_reimport_chunk_partitioned(self): # then self.assertNotEqual(rsp, 1) rai_config = self.resource_manager.get_rai_config() - rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, q.RESOURCES_TO_DELETE) + rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, RESOURCES_TO_DELETE_REL) self.assertEqual(rsp_json, [{'relation': 'zip_city_state_master_data'}]) def test_scenario3_model_single_partition_change(self): @@ -106,7 +106,7 @@ def test_scenario3_model_single_partition_change(self): # then self.assertNotEqual(rsp, 1) rai_config = self.resource_manager.get_rai_config() - rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, q.RESOURCES_TO_DELETE) + rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, RESOURCES_TO_DELETE_REL) self.assertEqual(rsp_json, [{'partition': 2023090800001, 'relation': 'city_data'}]) def test_scenario3_model_two_partitions_overriden_by_one(self): @@ -129,7 +129,7 @@ def test_scenario3_model_two_partitions_overriden_by_one(self): # then self.assertNotEqual(rsp, 1) rai_config = self.resource_manager.get_rai_config() - rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, q.RESOURCES_TO_DELETE) + rsp_json = workflow.rai.execute_relation_json(self.logger, rai_config, RESOURCES_TO_DELETE_REL) self.assertEqual(rsp_json, [{'partition': 2023090800001, 'relation': 'city_data'}, {'partition': 2023090800002, 'relation': 'city_data'}]) diff --git a/cli-e2e-test/test_query.py b/cli-e2e-test/test_query.py deleted file mode 100644 index 8979066..0000000 --- a/cli-e2e-test/test_query.py +++ /dev/null @@ -1 +0,0 @@ -RESOURCES_TO_DELETE = "resources_data_to_delete_json" diff --git a/cli/README.md b/cli/README.md index 9ba3b65..ee98c78 100644 --- a/cli/README.md +++ b/cli/README.md @@ -6,7 +6,7 @@ This Command-Line Interface (CLI) is designed to provide an easy and interactive 1. Create a batch configuration (ex. `poc.json`) file using the syntax and structure outlined in the [RAI Workflow Framework README](../workflow/README.md). 2. Add `rai-workflow-manager` as dependency to your `requirements.txt` file: ```txt -rai-workflow-manager==0.0.13 +rai-workflow-manager==0.0.14 ``` 3. Build the project: ```bash @@ -50,18 +50,23 @@ where ``, `` are the names of some RAI resources to use ## `loader.toml` Configuration The `loader.toml` file is used to specify static properties for the RAI Workflow Framework. It contains the following properties: -| Description | Property | -|:---------------------------------------------------------------------------------|------------------------| -| RAI profile. | `rai_profile` | -| Path to RAI config. | `rai_profile_path` | -| HTTP retries for RAI sdk in case of errors. (Can be overridden by CLI argument) | `rai_sdk_http_retries` | -| A list of containers to use for loading and exporting data. | `container` | -| The name of the container. | `container.name` | -| The type of the container. Supported types: `local`, `azure` | `container.type` | -| The path in the container. | `container.data_path` | -| Remote container account | `container.account` | -| Remote container SAS token. | `container.sas` | -| Container for remote container SAS token. (Ex. Azure Blob container) | `container.container` | +| Description | Property | +|:--------------------------------------------------------------------------------------------|------------------------| +| RAI profile. | `rai_profile` | +| Path to RAI config. | `rai_profile_path` | +| HTTP retries for RAI sdk in case of errors. (Can be overridden by CLI argument) | `rai_sdk_http_retries` | +| A list of containers to use for loading and exporting data. | `container` | +| The name of the container. | `container.name` | +| The type of the container. Supported types: `local`, `azure`, `snowflake`(only data import) | `container.type` | +| The path in the container. | `container.data_path` | +| Remote container account | `container.account` | +| Remote container SAS token. | `container.sas` | +| User for remote container. | `container.user` | +| Password for remote container. | `container.password` | +| User role for remote container (e.g. Snowflake user role). | `container.role` | +| Database for remote container. | `container.database` | +| Schema for remote container. | `container.schema` | +| Warehouse for Snowflake container. | `container.warehouse` | ### Azure container example ```toml @@ -73,6 +78,19 @@ account="account_name" sas="sas_token" container="container_name" ``` +### Snowflake container example +```toml +[[container]] +name="input" +type="snowflake" +account="account" +user="use" +password="password" +role="snowflake role" +warehouse="warehouse" +database="database" +schema="schema" +``` ### Local container example ```toml [[container]] diff --git a/cli/logger.py b/cli/logger.py index 17ed186..15bbd49 100644 --- a/cli/logger.py +++ b/cli/logger.py @@ -5,6 +5,8 @@ def configure(level=logging.INFO) -> logging.Logger: # override default logging level for azure logger = logging.getLogger('azure.core') logger.setLevel(logging.ERROR) + logger = logging.getLogger('snowflake') + logger.setLevel(logging.ERROR) logger = logging.getLogger() logger.setLevel(level) diff --git a/rel/batch_config/workflow/steps/configure_sources.rel b/rel/batch_config/workflow/steps/configure_sources.rel index 105a1de..fe89490 100644 --- a/rel/batch_config/workflow/steps/configure_sources.rel +++ b/rel/batch_config/workflow/steps/configure_sources.rel @@ -4,6 +4,8 @@ module batch_workflow_step module configure_sources def config_files(st in BatchWorkflowConfigureSourcesStep, f) { batch_workflow_step:configFiles(st, f) } + def default_container(st in BatchWorkflowConfigureSourcesStep, c) { batch_workflow_step:defaultContainer(st, c) } + def sources = transpose[batch_source:step] end end @@ -28,8 +30,10 @@ end module batch_source def step(src, st) { batch_source_name:step_source_name_to_source(st, _, src) } def extensions(src, e) { extract_value:extensions(src, :[], _, e) } - def partitioned(src) { extract_value:partitioned(src, boolean_true) } + def date_partitioned(src) { extract_value:isDatePartitioned(src, boolean_true) } + def chunk_partitioned(src) { extract_value:isChunkPartitioned(src, boolean_true) } def relation = extract_value:relation + def container = extract_value:container def relative_path = extract_value:relativePath def input_format = extract_value:inputFormat def loads_number_of_days = extract_value:loadsNumberOfDays @@ -44,4 +48,8 @@ module batch_source } end -// TODO: declare ICs \ No newline at end of file +ic configure_sources_step_default_container_is_mandatory(s) { + BatchWorkflowConfigureSourcesStep(s) + implies + batch_workflow_step:configure_sources:default_container(s, _) +} diff --git a/rel/batch_config/workflow/workflow.rel b/rel/batch_config/workflow/workflow.rel index a165b9e..09fc20d 100644 --- a/rel/batch_config/workflow/workflow.rel +++ b/rel/batch_config/workflow/workflow.rel @@ -69,6 +69,7 @@ module batch_workflow_step end def batch_workflow_step(part, s, v) { batch_workflow_step:json_data(s, part, :[], _, v) } +def batch_workflow_step(part, s, v) { batch_workflow_step:json_data(s, part, v) } module batch_workflow_step_order def step_order_to_workflow(o, s, w) { diff --git a/rel/source_configs/config.rel b/rel/source_configs/config.rel index 4253024..4f1a857 100644 --- a/rel/source_configs/config.rel +++ b/rel/source_configs/config.rel @@ -3,6 +3,7 @@ bound multi_part_relation = String bound date_partitioned_source = String bound source_declares_resource = String, String, String bound source_has_input_format = String, String +bound source_has_container_type = String, String bound source_catalog bound simple_source_catalog bound part_resource_date_pattern = String @@ -36,6 +37,31 @@ end def input_format_code_to_string = transpose[^InputFormatCode] +/** + * Container types + */ + +value type ContainerTypeCode = String +entity type ContainerType = ContainerTypeCode + +def ContainerType(t) { container_type:id(t, _) } + +module container_type + def id = transpose[container_type_code:identifies] +end + +module container_type_code + def value = { "AZURE" ; "LOCAL" ; "SNOWFLAKE" } + + def identifies(c, t) { + value(v) and + ^ContainerTypeCode(v, c) and + ^ContainerType(c, t) + from v + } +end + +def container_type_code_to_string = transpose[^ContainerTypeCode] /** * Resources */ @@ -54,13 +80,6 @@ module resource def id = transpose[uri:identifies] def part_of = transpose[source:declares] - - def local(r) { - str = uri_to_string[resource:id[r]] and - not regex_match( "^azure://(.+)$", str ) and - not regex_match( "^https://(.+)$", str ) - from str - } end module part_resource @@ -167,24 +186,16 @@ ic multi_part_source_declares_part_resources(s) { forall(r in source:declares[s]: PartResource(r)) } -ic source_is_local_or_remote(s, rel) { - source:declares(s, _) and - source:populates(s, rel) - implies - source:local(s) or source:remote(s) -} - -ic not_local_and_remote(s, rel) { - source:populates(s, rel) and - source:local(s) +ic source_has_unique_input_format(s) { + Source(s) implies - not source:remote(s) + source:format(s, _) } -ic source_has_unique_input_format(s) { +ic source_has_container_type(s) { Source(s) implies - source:format(s, _) + source:container_type(s, _) } // currently we do not support chunk partitioning for not date partitioned source ic multi_part_source_is_date_partitioned(s) { @@ -205,15 +216,6 @@ module source def spans[s] = part_resource:hashes_to_date[declares[s]] - def local(s) { - declares(s, _) and - forall(r: declares(s, r) implies resource:local(r) ) - } - def remote(s) { - declares(s, _) and - forall(r: declares(s, r) implies not resource:local(r) ) - } - def container(s, c) { source_declares_resource(rel, c, _) and s = relation:identifies[ rel_name:identifies[ ^RelName[rel] ] ] @@ -235,6 +237,13 @@ module source CSVInputFormat(f) from rel } + + def container_type(s, typ) { + source_has_container_type(rel, raw_container_type) and + s = relation:identifies[ rel_name:identifies[ ^RelName[rel] ] ] and + typ = container_type_code:identifies[^ContainerTypeCode[raw_container_type]] + from raw_container_type, rel + } end /** @@ -344,24 +353,17 @@ def missing_resources_json(:[], n, :is_multi_part, "Y") { from s } -def missing_resources_json(:[], n, :is_remote, "Y") { - source:needs_resource(s) and - source:index[s] = n and - source:remote(s) - from s -} - -def missing_resources_json(:[], n, :is_local, "Y") { +def missing_resources_json(:[], n, :container, v) { source:needs_resource(s) and source:index[s] = n and - source:local(s) + source:container(s, v) from s } -def missing_resources_json(:[], n, :container, v) { +def missing_resources_json(:[], n, :container_type, typ) { source:needs_resource(s) and source:index[s] = n and - source:container(s, v) + typ = container_type_code_to_string[ container_type:id[ source:container_type[s] ] ] from s } diff --git a/requirements.txt b/requirements.txt index 27e78eb..3c6e967 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ requests-toolbelt==1.0.0 urllib3==1.26.6 more-itertools==10.1.0 azure-storage-blob==12.17.0 +snowflake-connector-python==3.2.0 csv-diff==1.1 \ No newline at end of file diff --git a/test/test_cfg_src_step.py b/test/test_cfg_src_step.py index 85146ba..51c0bb8 100644 --- a/test/test_cfg_src_step.py +++ b/test/test_cfg_src_step.py @@ -6,7 +6,7 @@ from unittest.mock import Mock from workflow import paths -from workflow.common import Source +from workflow.common import Source, Container, ContainerType from workflow.executor import ConfigureSourcesWorkflowStep, WorkflowStepState @@ -402,7 +402,7 @@ def _create_test_source(is_chunk_partitioned: bool = True, is_date_partitioned: loads_number_of_days: int = 1, offset_by_number_of_days: int = 0, snapshot_validity_days=None) -> Source: return Source( - container="default", + container=Container("default", ContainerType.LOCAL, {}), relation="test", relative_path="test", input_format="test", diff --git a/test/test_cfg_src_step_factory.py b/test/test_cfg_src_step_factory.py index 1f51906..6ba4c67 100644 --- a/test/test_cfg_src_step_factory.py +++ b/test/test_cfg_src_step_factory.py @@ -14,13 +14,16 @@ class TestConfigureSourcesWorkflowStepFactory(unittest.TestCase): def test_get_step(self): # Setup factory spy factory = ConfigureSourcesWorkflowStepFactory() + config = _create_wf_cfg(EnvConfig({"azure": _create_container("default", ContainerType.AZURE), + "local": _create_container("local", ContainerType.LOCAL)}), Mock()) spy = MagicMock(wraps=factory._parse_sources) - sources = [_create_test_source("azure", "src1"), _create_test_source("local", "src2"), - _create_test_source("local", "src3"), _create_test_source("local", "src4")] + sources = [_create_test_source(config.env.get_container("azure"), "src1"), + _create_test_source(config.env.get_container("local"), "src2"), + _create_test_source(config.env.get_container("local"), "src3"), + _create_test_source(config.env.get_container("local"), "src4")] spy.return_value = sources factory._parse_sources = spy - config = _create_wf_cfg(EnvConfig({"azure": _create_container("default", ContainerType.AZURE), - "local": _create_container("local", ContainerType.LOCAL)}), Mock()) + # When step = factory._get_step(self.logger, config, "1", "name", WorkflowStepState.INIT, 0, None, {"configFiles": []}) # Then @@ -34,14 +37,14 @@ def test_get_step(self): self.assertEqual(sources, step.sources) self.assertEqual(2, len(step.paths_builders.keys())) self.assertIsInstance(step.paths_builders.get("local"), paths.LocalPathsBuilder) - self.assertIsInstance(step.paths_builders.get("azure"), paths.AzurePathsBuilder) + self.assertIsInstance(step.paths_builders.get("default"), paths.AzurePathsBuilder) self.assertEqual("2021-01-01", step.start_date) self.assertEqual("2021-01-01", step.end_date) self.assertFalse(step.force_reimport) self.assertFalse(step.force_reimport_not_chunk_partitioned) -def _create_test_source(container: str, relation: str) -> Source: +def _create_test_source(container: Container, relation: str) -> Source: return Source( container=container, relation=relation, diff --git a/workflow/README.md b/workflow/README.md index 3c20c52..d009cf3 100644 --- a/workflow/README.md +++ b/workflow/README.md @@ -13,6 +13,7 @@ This framework is designed to simplify the process of managing and executing com - [Load Data](#load-data) - [Materialize](#materialize) - [Export](#export) + - [Snowflake integration](#snowflake-integration) - [Framework extension](#framework-extension) - [Custom workflow steps](#custom-workflow-steps) @@ -299,6 +300,18 @@ look like the one shown below: "metaKey": [ "Symbol", "Symbol" ] } ``` +## Snowflake Integration +Workflow manager supports Snowflake as a data source container only for data ingestion. +### Integration Details +Workflow Manager uses RAI integration for data sync from Snowflake. Workflow Manager creates a data stream for each source in batch config with Snowflake container. Integration Service creates an ingestion engine per rai account with prefix `ingestion-engine-*` and uses this engine for data ingestion. Relation for data ingestion: `simple_source_catalog`. Once data sync is completed, Workflow Manager deletes the data stream. + +**Note:** Workflow Manager is not responsible for creating and deleting an ingestion engine. The ingestion engine is not deleted automatically after data sync. +### Configure RAI Integration +To use Snowflake as a data source container, you need to configure Snowflake using following guides: +* [RAI Integration for Snowflake: Quick Start for Administrators](https://docs.relational.ai/preview/snowflake/quickstart-admin) +* [RAI Integration for Snowflake: Quick Start for Users](https://docs.relational.ai/preview/snowflake/quickstart-user) +### Configure Snowflake Container +Snowflake container configuration is defined in this section: [Snowflake container example](../cli/README.md#snowflake-container-example). # Framework extension ## Custom workflow steps diff --git a/workflow/__init__.py b/workflow/__init__.py index 250d4ad..29ac885 100644 --- a/workflow/__init__.py +++ b/workflow/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version_info__ = (0, 0, 13) +__version_info__ = (0, 0, 14) __version__ = ".".join(map(str, __version_info__)) diff --git a/workflow/blob.py b/workflow/blob.py index b6dafc5..458de51 100644 --- a/workflow/blob.py +++ b/workflow/blob.py @@ -11,6 +11,7 @@ def list_files_in_containers(logger: logging.Logger, config: AzureConfig, path_p blob_service_client = BlobServiceClient(account_url=f"https://{config.account}.blob.core.windows.net", credential=config.sas) container_client = blob_service_client.get_container_client(config.container) + logger = logger.getChild("blob") # Get a list of blobs in the folder logger.debug(f"Path prefix to list blob files: {path_prefix}") diff --git a/workflow/common.py b/workflow/common.py index 84511ba..4c704b0 100644 --- a/workflow/common.py +++ b/workflow/common.py @@ -4,8 +4,8 @@ from railib import api -from workflow.constants import AZURE_ACCOUNT, AZURE_CONTAINER, AZURE_DATA_PATH, AZURE_SAS, LOCAL_DATA_PATH, CONTAINER, \ - CONTAINER_TYPE, CONTAINER_NAME +from workflow.constants import ACCOUNT_PARAM, CONTAINER_PARAM, DATA_PATH_PARAM, AZURE_SAS, CONTAINER, CONTAINER_TYPE, \ + CONTAINER_NAME, USER_PARAM, PASSWORD_PARAM, SNOWFLAKE_ROLE, SNOWFLAKE_WAREHOUSE, DATABASE_PARAM, SCHEMA_PARAM class MetaEnum(EnumMeta): @@ -51,18 +51,17 @@ class FileType(str, BaseEnum): class ContainerType(str, BaseEnum): LOCAL = 'local' AZURE = 'azure' + SNOWFLAKE = 'snowflake' def __str__(self): return self.value @staticmethod def from_source(src): - if 'is_local' in src and src['is_local'] == 'Y': - return ContainerType.LOCAL - elif 'is_remote' in src and src['is_remote'] == 'Y': - return ContainerType.AZURE - else: - raise ValueError("Source is neither local nor remote.") + try: + return ContainerType[src["container_type"]] + except KeyError as ex: + raise ValueError(f"Container type is not supported: {ex}") @dataclasses.dataclass @@ -85,25 +84,48 @@ class LocalConfig: data_path: str +@dataclasses.dataclass +class SnowflakeConfig: + account: str + user: str + password: str + role: str + warehouse: str + database: str + schema: str + + class ConfigExtractor: @staticmethod def azure_from_env_vars(env_vars: dict[str, Any]): return AzureConfig( - account=env_vars.get(AZURE_ACCOUNT, ""), - container=env_vars.get(AZURE_CONTAINER, ""), - data_path=env_vars.get(AZURE_DATA_PATH, ""), + account=env_vars.get(ACCOUNT_PARAM, ""), + container=env_vars.get(CONTAINER_PARAM, ""), + data_path=env_vars.get(DATA_PATH_PARAM, ""), sas=env_vars.get(AZURE_SAS, "") ) + @staticmethod + def snowflake_from_env_vars(env_vars: dict[str, Any]): + return SnowflakeConfig( + account=env_vars.get(ACCOUNT_PARAM, ""), + user=env_vars.get(USER_PARAM, ""), + password=env_vars.get(PASSWORD_PARAM, ""), + role=env_vars.get(SNOWFLAKE_ROLE, ""), + warehouse=env_vars.get(SNOWFLAKE_WAREHOUSE, ""), + database=env_vars.get(DATABASE_PARAM, ""), + schema=env_vars.get(SCHEMA_PARAM, "") + ) + @staticmethod def local_from_env_vars(env_vars: dict[str, Any]): - return LocalConfig(data_path=env_vars.get(LOCAL_DATA_PATH, "")) + return LocalConfig(data_path=env_vars.get(DATA_PATH_PARAM, "")) @dataclasses.dataclass class Source: - container: str + container: Container relation: str relative_path: str input_format: str @@ -116,7 +138,7 @@ class Source: paths: List[str] = dataclasses.field(default_factory=list) def to_paths_csv(self) -> str: - return "\n".join([f"{self.relation},{self.container},{p}" for p in self.paths]) + return "\n".join([f"{self.relation},{self.container.name},{p}" for p in self.paths]) def to_chunk_partitioned_paths_csv(self) -> str: return "\n".join([f"{self.relation},{path},{self.is_chunk_partitioned}" for path in self.paths]) @@ -124,6 +146,9 @@ def to_chunk_partitioned_paths_csv(self) -> str: def to_formats_csv(self) -> str: return f"{self.relation},{self.input_format.upper()}" + def to_container_type_csv(self) -> str: + return f"{self.relation},{self.container.type.name}" + @dataclasses.dataclass class RaiConfig: @@ -136,16 +161,21 @@ class RaiConfig: class EnvConfig: containers: dict[str, Container] - EXTRACTORS = { + __EXTRACTORS = { ContainerType.AZURE: lambda env_vars: ConfigExtractor.azure_from_env_vars(env_vars), - ContainerType.LOCAL: lambda env_vars: ConfigExtractor.local_from_env_vars(env_vars) + ContainerType.LOCAL: lambda env_vars: ConfigExtractor.local_from_env_vars(env_vars), + ContainerType.SNOWFLAKE: lambda env_vars: ConfigExtractor.snowflake_from_env_vars(env_vars) } - def container_name_to_type(self) -> dict[str, ContainerType]: - return {container.name: container.type for container in self.containers.values()} - def get_container(self, name: str) -> Container: - return self.containers[name] + try: + return self.containers[name] + except KeyError: + raise ValueError(f"Container `{name}` is missed in Environment Config.") + + @staticmethod + def get_config(container: Container): + return EnvConfig.__EXTRACTORS[container.type](container.params) @staticmethod def from_env_vars(env_vars: dict[str, Any]): @@ -165,7 +195,7 @@ class Export: relative_path: str file_type: FileType snapshot_binding: str - container: str + container: Container offset_by_number_of_days: int = 0 diff --git a/workflow/constants.py b/workflow/constants.py index 47ab6bb..9c8a1d2 100644 --- a/workflow/constants.py +++ b/workflow/constants.py @@ -26,6 +26,7 @@ IMPORT_CONFIG_REL = "import_config" MISSED_RESOURCES_REL = "missing_resources_json" +RESOURCES_TO_DELETE_REL = "resources_data_to_delete_json" WORKFLOW_JSON_REL = "workflow_json" BATCH_CONFIG_REL = "batch:config" @@ -41,11 +42,18 @@ RAI_PROFILE = "rai_profile" RAI_PROFILE_PATH = "rai_profile_path" RAI_SDK_HTTP_RETRIES = "rai_sdk_http_retries" -AZURE_ACCOUNT = "account" -AZURE_CONTAINER = "container" -AZURE_DATA_PATH = "data_path" +# Generic container params +ACCOUNT_PARAM = "account" +USER_PARAM = "user" +PASSWORD_PARAM = "password" +SCHEMA_PARAM = "schema" +DATABASE_PARAM = "database" +CONTAINER_PARAM = "container" +DATA_PATH_PARAM = "data_path" +# Datasource specific params AZURE_SAS = "sas" -LOCAL_DATA_PATH = "data_path" +SNOWFLAKE_ROLE = "role" +SNOWFLAKE_WAREHOUSE = "warehouse" # Step parameters REL_CONFIG_DIR = "rel_config_dir" @@ -54,3 +62,13 @@ FORCE_REIMPORT = "force_reimport" FORCE_REIMPORT_NOT_CHUNK_PARTITIONED = "force_reimport_not_chunk_partitioned" COLLAPSE_PARTITIONS_ON_LOAD = "collapse_partitions_on_load" + +# Snowflake constants + +# Properties +SNOWFLAKE_SYNC_STATUS = "Data sync status" +SNOWFLAKE_STREAM_HEALTH_STATUS = "Data stream health" +SNOWFLAKE_TOTAL_ROWS = "Latest changes written to RAI - Total rows" +# Values +SNOWFLAKE_FINISHED_SYNC_STATUS = "\"Fully synced\"" +SNOWFLAKE_HEALTHY_STREAM_STATUS = "\"Healthy\"" diff --git a/workflow/executor.py b/workflow/executor.py index c5d4799..c324ca7 100644 --- a/workflow/executor.py +++ b/workflow/executor.py @@ -1,16 +1,17 @@ import dataclasses import logging import time +from workflow import snow from datetime import datetime from enum import Enum from itertools import groupby from types import MappingProxyType -from typing import List, Dict +from typing import List from more_itertools import peekable from workflow import query as q, paths, rai, constants -from workflow.common import EnvConfig, RaiConfig, Source, BatchConfig, Export, FileType, ContainerType +from workflow.common import EnvConfig, RaiConfig, Source, BatchConfig, Export, FileType, ContainerType, Container from workflow.manager import ResourceManager from workflow.utils import save_csv_output, format_duration, build_models, extract_date_range, build_relation_path, \ get_common_model_relative_path @@ -153,8 +154,9 @@ def _inflate_sources(self, logger: logging.Logger): for src in self.sources: logger.info(f"Inflating source: '{src.relation}'") days = self._get_date_range(logger, src) - inflated_paths = self.paths_builders[src.container].build(logger, days, src.relative_path, src.extensions, - src.is_date_partitioned) + inflated_paths = self.paths_builders[src.container.name].build(logger, days, src.relative_path, + src.extensions, + src.is_date_partitioned) if src.is_date_partitioned: # after inflating we take the last `src.loads_number_of_days` days and reduce into an array of paths inflated_paths.sort(key=lambda v: v.as_of_date) @@ -189,7 +191,7 @@ class ConfigureSourcesWorkflowStepFactory(WorkflowStepFactory): def _validate_params(self, config: WorkflowConfig, step: dict) -> None: super()._validate_params(config, step) end_date = config.step_params[constants.END_DATE] - sources = self._parse_sources(step) + sources = self._parse_sources(step, config.env) if not end_date: for s in sources: if s.is_date_partitioned: @@ -201,25 +203,23 @@ def _required_params(self, config: WorkflowConfig) -> List[str]: def _get_step(self, logger: logging.Logger, config: WorkflowConfig, idt, name, state, timing, engine_size, step: dict) -> ConfigureSourcesWorkflowStep: rel_config_dir = config.step_params[constants.REL_CONFIG_DIR] - sources = self._parse_sources(step) + sources = self._parse_sources(step, config.env) start_date = config.step_params[constants.START_DATE] end_date = config.step_params[constants.END_DATE] force_reimport = config.step_params.get(constants.FORCE_REIMPORT, False) force_reimport_not_chunk_partitioned = config.step_params.get(constants.FORCE_REIMPORT_NOT_CHUNK_PARTITIONED, False) - env_config = config.env paths_builders = {} for src in sources: - container_name = src.container - if container_name not in paths_builders: - container = env_config.get_container(container_name) - paths_builders[container_name] = paths.PathsBuilderFactory.get_path_builder(container) + container = src.container + if container.name not in paths_builders: + paths_builders[container.name] = paths.PathsBuilderFactory.get_path_builder(container) return ConfigureSourcesWorkflowStep(idt, name, state, timing, engine_size, step["configFiles"], rel_config_dir, sources, paths_builders, start_date, end_date, force_reimport, force_reimport_not_chunk_partitioned) @staticmethod - def _parse_sources(step: dict) -> List[Source]: + def _parse_sources(step: dict, env_config: EnvConfig) -> List[Source]: sources = step["sources"] default_container = step["defaultContainer"] result = [] @@ -234,9 +234,9 @@ def _parse_sources(step: dict) -> List[Source]: loads_number_of_days = source.get("loadsNumberOfDays") offset_by_number_of_days = source.get("offsetByNumberOfDays") snapshot_validity_days = source.get("snapshotValidityDays") - container = source.get("container", default_container) + container_name = source.get("container", default_container) result.append(Source( - container, + env_config.get_container(container_name), relation, relative_path, input_format, @@ -300,10 +300,11 @@ def _load_source(self, logger: logging.Logger, env_config: EnvConfig, rai_config def _load_resource(logger: logging.Logger, env_config: EnvConfig, rai_config: RaiConfig, resources, src) -> None: try: container = env_config.get_container(src["container"]) - rai.execute_query(logger, rai_config, - q.load_resources(logger, EnvConfig.EXTRACTORS[container.type](container.params), - resources, src), readonly=False) - + config = EnvConfig.get_config(container) + if ContainerType.LOCAL == container.type or ContainerType.AZURE == container.type: + rai.execute_query(logger, rai_config, q.load_resources(logger, config, resources, src), readonly=False) + elif ContainerType.SNOWFLAKE == container.type: + snow.sync_data(logger, config, rai_config, resources, src) except KeyError as e: logger.error(f"Unsupported file type: {src['file_type']}. Skip the source: {src}", e) except ValueError as e: @@ -358,12 +359,11 @@ class ExportWorkflowStep(WorkflowStep): ContainerType.LOCAL: lambda logger, rai_config, exports, end_date, date_format, container: save_csv_output( rai.execute_query_csv(logger, rai_config, q.export_relations_local(logger, exports)), - EnvConfig.EXTRACTORS[container.type](container.params)), + EnvConfig.get_config(container)), ContainerType.AZURE: lambda logger, rai_config, exports, end_date, date_format, container: rai.execute_query( logger, rai_config, - q.export_relations_to_azure(logger, EnvConfig.EXTRACTORS[container.type](container.params), exports, - end_date, date_format)) + q.export_relations_to_azure(logger, EnvConfig.get_config(container), exports, end_date, date_format)) } def __init__(self, idt, name, state, timing, engine_size, exports, export_jointly, date_format, end_date): @@ -377,19 +377,25 @@ def execute(self, logger: logging.Logger, env_config: EnvConfig, rai_config: Rai logger.info("Executing Export step..") exports = list(filter(lambda e: self._should_export(logger, rai_config, e), self.exports)) - name_to_type = env_config.container_name_to_type() if self.export_jointly: - exports.sort(key=lambda e: e.container) + exports.sort(key=lambda e: e.container.name) container_groups = {container: list(group) for container, group in - groupby(exports, key=lambda e: e.container)} + groupby(exports, key=lambda e: e.container.name)} for container, grouped_exports in container_groups.items(): - self.EXPORT_FUNCTION[name_to_type[container]](logger, rai_config, grouped_exports, self.end_date, - self.date_format, env_config.get_container(container)) + ExportWorkflowStep.get_export_function(container)(logger, rai_config, grouped_exports, self.end_date, + self.date_format, container) else: for export in exports: container = export.container - self.EXPORT_FUNCTION[name_to_type[container]](logger, rai_config, [export], self.end_date, - self.date_format, env_config.get_container(container)) + ExportWorkflowStep.get_export_function(container)(logger, rai_config, [export], self.end_date, + self.date_format, container) + + @staticmethod + def get_export_function(container: Container): + try: + return ExportWorkflowStep.EXPORT_FUNCTION[container.type] + except KeyError as ex: + raise ValueError(f"Container type is not supported: {ex}") def _should_export(self, logger: logging.Logger, rai_config: RaiConfig, export: Export) -> bool: if export.snapshot_binding is None: @@ -416,13 +422,13 @@ def _required_params(self, config: WorkflowConfig) -> List[str]: def _get_step(self, logger: logging.Logger, config: WorkflowConfig, idt, name, state, timing, engine_size, step: dict) -> WorkflowStep: - exports = self._load_exports(logger, step) + exports = self._load_exports(logger, config.env, step) end_date = config.step_params[constants.END_DATE] return ExportWorkflowStep(idt, name, state, timing, engine_size, exports, step["exportJointly"], step["dateFormat"], end_date) @staticmethod - def _load_exports(logger: logging.Logger, src) -> List[Export]: + def _load_exports(logger: logging.Logger, env_config: EnvConfig, src) -> List[Export]: exports_json = src["exports"] default_container = src["defaultContainer"] exports = [] @@ -434,7 +440,7 @@ def _load_exports(logger: logging.Logger, src) -> List[Export]: relative_path=e["relativePath"], file_type=FileType[e["type"].upper()], snapshot_binding=e.get("snapshotBinding"), - container=e.get("container", default_container), + container=env_config.get_container(e.get("container", default_container)), offset_by_number_of_days=e.get("offsetByNumberOfDays", 0))) except KeyError as ex: logger.warning(f"Unsupported FileType: {ex}. Skipping export: {e}.") diff --git a/workflow/paths.py b/workflow/paths.py index 5c33c85..fad49f5 100644 --- a/workflow/paths.py +++ b/workflow/paths.py @@ -6,7 +6,7 @@ from typing import List from workflow import blob, constants -from workflow.common import EnvConfig, AzureConfig, LocalConfig, Container, ContainerType +from workflow.common import EnvConfig, AzureConfig, LocalConfig, SnowflakeConfig, Container, ContainerType @dataclasses.dataclass @@ -86,14 +86,24 @@ def _build(self, logger: logging.Logger, days: List[str], relative_path, extensi return paths -class PathsBuilderFactory: +class SnowflakePathsBuilder(PathsBuilder): + config: SnowflakeConfig + + def __init__(self, config: SnowflakeConfig): + self.config = config + def _build(self, logger: logging.Logger, days: List[str], relative_path, extensions: List[str], + is_date_partitioned: bool) -> List[FilePath]: + return [FilePath(path=f"{self.config.database}.{self.config.schema}.{relative_path}")] + + +class PathsBuilderFactory: __CONTAINER_TYPE_TO_BUILDER = { - ContainerType.LOCAL: lambda container: LocalPathsBuilder(EnvConfig.EXTRACTORS[container.type](container.params)), - ContainerType.AZURE: lambda container: AzurePathsBuilder(EnvConfig.EXTRACTORS[container.type](container.params)) + ContainerType.LOCAL: lambda container: LocalPathsBuilder(EnvConfig.get_config(container)), + ContainerType.AZURE: lambda container: AzurePathsBuilder(EnvConfig.get_config(container)), + ContainerType.SNOWFLAKE: lambda container: SnowflakePathsBuilder(EnvConfig.get_config(container)), } @staticmethod def get_path_builder(container: Container) -> PathsBuilder: return PathsBuilderFactory.__CONTAINER_TYPE_TO_BUILDER[container.type](container) - diff --git a/workflow/query.py b/workflow/query.py index 93e0e0b..b570532 100644 --- a/workflow/query.py +++ b/workflow/query.py @@ -54,6 +54,7 @@ def install_model(models: dict) -> QueryWithInputs: def populate_source_configs(sources: List[Source]) -> str: source_config_csv = "\n".join([source.to_paths_csv() for source in sources]) data_formats_csv = "\n".join([source.to_formats_csv() for source in sources]) + container_types_csv = "\n".join([source.to_container_type_csv() for source in sources]) simple_sources = list(filter(lambda source: not source.is_chunk_partitioned, sources)) multipart_sources = list(filter(lambda source: source.is_chunk_partitioned, sources)) @@ -71,9 +72,7 @@ def resource_config[:syntax, :header] = (1, :Relation); (2, :Container); (3, :Pa def resource_config[:schema, :Relation] = "string" def resource_config[:schema, :Container] = "string" def resource_config[:schema, :Path] = "string" - def source_config_csv = load_csv[resource_config] - def insert:source_declares_resource(r, c, p) = exists(i : source_config_csv(:Relation, i, r) and @@ -82,16 +81,22 @@ def insert:source_declares_resource(r, c, p) = ) def input_format_config[:data] = \"\"\"{data_formats_csv}\"\"\" - def input_format_config[:syntax, :header_row] = -1 def input_format_config[:syntax, :header] = (1, :Relation); (2, :InputFormatCode) def input_format_config[:schema, :Relation] = "string" def input_format_config[:schema, :InputFormatCode] = "string" - def input_format_config_csv = load_csv[input_format_config] - def insert:source_has_input_format(r, p) = exists(i : input_format_config_csv(:Relation, i, r) and input_format_config_csv(:InputFormatCode, i, p)) + + def container_type_config[:data] = \"\"\"{container_types_csv}\"\"\" + def container_type_config[:syntax, :header_row] = -1 + def container_type_config[:syntax, :header] = (1, :Relation); (2, :ContainerType) + def container_type_config[:schema, :Relation] = "string" + def container_type_config[:schema, :ContainerType] = "string" + def container_type_config_csv = load_csv[container_type_config] + def insert:source_has_container_type(r, t) = + exists(i : container_type_config_csv(:Relation, i, r) and container_type_config_csv(:ContainerType, i, t)) {f"def insert:simple_relation = {_to_rel_literal_relation([source.relation for source in simple_sources])}" if len(simple_sources) > 0 else ""} {f"def insert:multi_part_relation = {_to_rel_literal_relation([source.relation for source in multipart_sources])}" if len(multipart_sources) > 0 else ""} diff --git a/workflow/snow.py b/workflow/snow.py new file mode 100644 index 0000000..7dee625 --- /dev/null +++ b/workflow/snow.py @@ -0,0 +1,81 @@ +import logging + +import snowflake.connector + +from workflow import constants +from workflow.utils import call_with_overhead +from workflow.common import SnowflakeConfig, RaiConfig + + +def sync_data(logger: logging.Logger, snowflake_config: SnowflakeConfig, rai_config: RaiConfig, resources, src): + conn = __get_connection(snowflake_config) + cursor = conn.cursor() + logger = logger.getChild("snowflake") + + destination_rel = src['source'] + source_table = resources[0]['uri'] + database = rai_config.database + engine = rai_config.engine + # Start data stream commands + commands = ( + f"CALL RAI.use_rai_database('{database}');", + f"CALL RAI.use_rai_engine('{engine}');", + f"CALL RAI.create_data_stream('{source_table}', '{database}', 'simple_source_catalog, :{destination_rel}');" + ) + try: + for command in commands: + logger.info(f"Executing command Snowflake command: `{command.strip()}`") + cursor.execute(command) + except Exception as e: + cursor.close() + conn.close() + raise e + + # Wait for data sync finish + try: + logger.info(f"Wait for Snowflake data sync finish for `{source_table}`...") + call_with_overhead( + f=lambda: sync_finished(logger, cursor, source_table), + logger=logger, + overhead_rate=0.5, + timeout=30 * 60, # 30 min + first_delay=10, # 10 sec since it can take some time to start Job on Snowflake by Ingestion Service + max_delay=55 # 55 sec since snowflake warehouse can be suspended after 60 sec of inactivity + ) + except Exception as e: + raise e + finally: + # Clean up data stream after sync + cursor.execute(f"CALL RAI.delete_data_stream('{source_table}')") + cursor.close() + conn.close() + + +def sync_finished(logger: logging.Logger, cursor, source_table: str): + cursor.execute(f"CALL RAI.get_data_stream_status('{source_table}');") + properties = cursor.fetchall() + key_value_pairs = {stream[0]: stream[1] for stream in properties} + health_status = key_value_pairs.get(constants.SNOWFLAKE_STREAM_HEALTH_STATUS) + if health_status != constants.SNOWFLAKE_HEALTHY_STREAM_STATUS: + raise Exception(f"Snowflake sync for `{source_table}` has failed. Health status: {health_status}") + if key_value_pairs.get(constants.SNOWFLAKE_SYNC_STATUS) == constants.SNOWFLAKE_FINISHED_SYNC_STATUS: + synced_rows = key_value_pairs.get(constants.SNOWFLAKE_TOTAL_ROWS) + logger.info(f"Snowflake sync finished for `{source_table}` has finished. Synced row: {synced_rows}") + return True + return False + + +def __get_connection(config: SnowflakeConfig): + """ + This function will get a connection to snowflake. + """ + # connect to snowflake + return snowflake.connector.connect( + user=config.user, + password=config.password, + account=config.account, + role=config.role, + warehouse=config.warehouse, + database=config.database, + schema=config.schema, + ) diff --git a/workflow/utils.py b/workflow/utils.py index 4eb05cb..60cea2d 100644 --- a/workflow/utils.py +++ b/workflow/utils.py @@ -1,5 +1,6 @@ import logging import os +import time from datetime import datetime, timedelta from typing import List, Dict @@ -122,3 +123,37 @@ def to_rai_date_format(date_format: str) -> str: for py_fmt, rai_fmt in fmt_part_map.items(): rai_date_format = rai_date_format.replace(py_fmt, rai_fmt) return rai_date_format + + +def call_with_overhead( + f, + logger: logging.Logger, + overhead_rate: float, + start_time: int = time.time(), + timeout: int = None, + max_tries: int = None, + first_delay: float = 0.5, + max_delay: int = 120, # 2 minutes +): + tries = 0 + max_time = time.time() + timeout if timeout else None + + while True: + logger.debug(f"Calling function. The number of try: {tries + 1}") + if f(): + break + + if max_tries is not None and tries >= max_tries: + raise Exception(f'max tries {max_tries} exhausted') + + if max_time is not None and time.time() >= max_time: + raise Exception(f'timed out after {timeout} seconds') + + tries += 1 + duration = min((time.time() - start_time) * overhead_rate, max_delay) + if tries == 1: + logger.debug(f"Sleep duration for the first try: {first_delay}s") + time.sleep(first_delay) + else: + logger.debug(f"Sleep duration for a try: {duration}s") + time.sleep(duration)