From 33078aa9924a7e9c0a31be2f6c9277c0400c2ac7 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 15 Jan 2025 19:01:53 +0530 Subject: [PATCH 01/22] init commit --- pyproject.toml | 1 + .../remorph/connections/credential_manager.py | 20 +++++ .../remorph/connections/database_manager.py | 73 +++++++++++++++++++ .../remorph/resources/config/credentials.yml | 29 ++++++++ 4 files changed, 123 insertions(+) create mode 100644 src/databricks/labs/remorph/connections/credential_manager.py create mode 100644 src/databricks/labs/remorph/connections/database_manager.py create mode 100644 src/databricks/labs/remorph/resources/config/credentials.yml diff --git a/pyproject.toml b/pyproject.toml index d1f75b696a..a7acc6c5db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "databricks-labs-blueprint[yaml]>=0.2.3", "databricks-labs-lsql>=0.7.5,<0.14.0", # TODO: Limit the LSQL version until dependencies are correct. "cryptography>=41.0.3", + "SQLAlchemy" ] [project.urls] diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py new file mode 100644 index 0000000000..cdd299f8f8 --- /dev/null +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -0,0 +1,20 @@ +from pathlib import Path +import yaml + + +class Credentials: + + def __init__(self, product_info, source): + self._product_info = product_info + self._credentials = self._load_credentials(self._get_local_version_file_path()) + + def _get_local_version_file_path(self): + user_home = f"{Path(__file__).home()}" + return Path(f"{user_home}/.databricks/labs/{self._product_info.product_name()}/credentials.yml") + + def _load_credentials(self, file_path): + with open(file_path, encoding="utf-8") as f: + return yaml.safe_load(f) + + + diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py new file mode 100644 index 0000000000..1f5cb21cfa --- /dev/null +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -0,0 +1,73 @@ +from abc import ABC, abstractmethod +from pathlib import Path +import yaml +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +class ISourceSystemConnector(ABC): + @abstractmethod + def connect(self): + pass + + @abstractmethod + def execute_query(self, query: str): + pass + +class SnowflakeConnector(ISourceSystemConnector): + def __init__(self, config): + self.config = config + self.engine = None + + def connect(self): + connection_string = ( + f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" + f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" + ) + self.engine = create_engine(connection_string) + + def execute_query(self, query: str): + if not self.engine: + raise ConnectionError("Not connected to the database.") + Session = sessionmaker(bind=self.engine) + session = Session() + try: + result = session.execute(query) + return [dict(row) for row in result] + finally: + session.close() + +class MSSQLConnector(ISourceSystemConnector): + def __init__(self, config): + self.config = config + self.engine = None + + def connect(self): + connection_string = ( + f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" + f"{self.config['database']}?driver={self.config['driver']}" + ) + self.engine = create_engine(connection_string) + + def execute_query(self, query: str): + if not self.engine: + raise ConnectionError("Not connected to the database.") + Session = sessionmaker(bind=self.engine) + session = Session() + try: + result = session.execute(query) + return [dict(row) for row in result] + finally: + session.close() + + +#TODO move this factory to Application Context +class SourceSystemConnectorFactory: + @staticmethod + def create_connector(db_type: str, config: dict) -> ISourceSystemConnector: + if db_type == "snowflake": + return SnowflakeConnector(config) + elif db_type == "mssql": + return MSSQLConnector(config) + else: + raise ValueError(f"Unsupported database type: {db_type}") + diff --git a/src/databricks/labs/remorph/resources/config/credentials.yml b/src/databricks/labs/remorph/resources/config/credentials.yml new file mode 100644 index 0000000000..68fedc86f8 --- /dev/null +++ b/src/databricks/labs/remorph/resources/config/credentials.yml @@ -0,0 +1,29 @@ +credentials: + snowflake: + account: example_account + connect_retries: 1 + connect_timeout: null + host: null + insecure_mode: false + oauth_client_id: null + oauth_client_secret: null + password: null + port: null + private_key: null + private_key_passphrase: null + private_key_path: null + role: null + token: null + user: null + warehouse: null + + msssql: + database: example_database + driver: ODBC Driver 18 for SQL Server + server: example_host + port: null + user: null + password: null + + + From e12db4d55c55e9442a6fbf208b70ef5840da5ad5 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 15 Jan 2025 19:06:52 +0530 Subject: [PATCH 02/22] Added Base Connector --- .../remorph/connections/credential_manager.py | 8 ++- .../remorph/connections/database_manager.py | 49 +++++++------------ 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index cdd299f8f8..3f784ccf9a 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -4,7 +4,7 @@ class Credentials: - def __init__(self, product_info, source): + def __init__(self, product_info): self._product_info = product_info self._credentials = self._load_credentials(self._get_local_version_file_path()) @@ -16,5 +16,9 @@ def _load_credentials(self, file_path): with open(file_path, encoding="utf-8") as f: return yaml.safe_load(f) - + def get(self, source): + if source in self._credentials: + return self._credentials[source] + else: + raise KeyError(f"source system: {source} credentials not found not in credentials: {key}") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 1f5cb21cfa..f4e26c1bb0 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -3,29 +3,26 @@ import yaml from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from typing import Any, Dict class ISourceSystemConnector(ABC): @abstractmethod - def connect(self): + def connect(self) -> None: pass @abstractmethod - def execute_query(self, query: str): + def execute_query(self, query: str) -> list[Dict[str, Any]]: pass -class SnowflakeConnector(ISourceSystemConnector): - def __init__(self, config): +class BaseConnector(ISourceSystemConnector): + def __init__(self, config: Dict[str, Any]): self.config = config self.engine = None - def connect(self): - connection_string = ( - f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" - f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" - ) - self.engine = create_engine(connection_string) + def connect(self) -> None: + raise NotImplementedError("Subclasses should implement this method") - def execute_query(self, query: str): + def execute_query(self, query: str) -> list[Dict[str, Any]]: if not self.engine: raise ConnectionError("Not connected to the database.") Session = sessionmaker(bind=self.engine) @@ -36,38 +33,28 @@ def execute_query(self, query: str): finally: session.close() -class MSSQLConnector(ISourceSystemConnector): - def __init__(self, config): - self.config = config - self.engine = None +class SnowflakeConnector(BaseConnector): + def connect(self) -> None: + connection_string = ( + f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" + f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" + ) + self.engine = create_engine(connection_string) - def connect(self): +class MSSQLConnector(BaseConnector): + def connect(self) -> None: connection_string = ( f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" ) self.engine = create_engine(connection_string) - def execute_query(self, query: str): - if not self.engine: - raise ConnectionError("Not connected to the database.") - Session = sessionmaker(bind=self.engine) - session = Session() - try: - result = session.execute(query) - return [dict(row) for row in result] - finally: - session.close() - - -#TODO move this factory to Application Context class SourceSystemConnectorFactory: @staticmethod - def create_connector(db_type: str, config: dict) -> ISourceSystemConnector: + def create_connector(db_type: str, config: Dict[str, Any]) -> ISourceSystemConnector: if db_type == "snowflake": return SnowflakeConnector(config) elif db_type == "mssql": return MSSQLConnector(config) else: raise ValueError(f"Unsupported database type: {db_type}") - From 7fc59ffd88c21d03efdf54de29ce589af8e84154 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 15 Jan 2025 19:45:31 +0530 Subject: [PATCH 03/22] Moved the Abstract class to private --- .../remorph/connections/credential_manager.py | 15 +++++------ .../remorph/connections/database_manager.py | 27 ++++++++++--------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 3f784ccf9a..2180a4732c 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,24 +1,21 @@ from pathlib import Path import yaml - class Credentials: - - def __init__(self, product_info): + def __init__(self, product_info: str) -> None: self._product_info = product_info - self._credentials = self._load_credentials(self._get_local_version_file_path()) + self._credentials: dict[str, Any] = self._load_credentials(self._get_local_version_file_path()) - def _get_local_version_file_path(self): + def _get_local_version_file_path(self) -> Path: user_home = f"{Path(__file__).home()}" return Path(f"{user_home}/.databricks/labs/{self._product_info.product_name()}/credentials.yml") - def _load_credentials(self, file_path): + def _load_credentials(self, file_path: Path) -> dict[str, str]: with open(file_path, encoding="utf-8") as f: return yaml.safe_load(f) - def get(self, source): + def get(self, source: str) -> dict[str, str]: if source in self._credentials: return self._credentials[source] else: - raise KeyError(f"source system: {source} credentials not found not in credentials: {key}") - + raise KeyError(f"source system: {source} credentials not found not in credentials: {source}") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index f4e26c1bb0..2b7dd66164 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -3,26 +3,26 @@ import yaml from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from typing import Any, Dict +from typing import Any -class ISourceSystemConnector(ABC): +class _ISourceSystemConnector(ABC): @abstractmethod - def connect(self) -> None: + def connect(self) -> Engine: pass @abstractmethod - def execute_query(self, query: str) -> list[Dict[str, Any]]: + def execute_query(self, query: str) -> list[dict[str, Any]]: pass -class BaseConnector(ISourceSystemConnector): - def __init__(self, config: Dict[str, Any]): +class _BaseConnector(_ISourceSystemConnector): + def __init__(self, config: dict[str, Any]): self.config = config self.engine = None - def connect(self) -> None: + def connect(self) -> Engine: raise NotImplementedError("Subclasses should implement this method") - def execute_query(self, query: str) -> list[Dict[str, Any]]: + def execute_query(self, query: str) -> list[dict[str, Any]]: if not self.engine: raise ConnectionError("Not connected to the database.") Session = sessionmaker(bind=self.engine) @@ -33,25 +33,26 @@ def execute_query(self, query: str) -> list[Dict[str, Any]]: finally: session.close() -class SnowflakeConnector(BaseConnector): - def connect(self) -> None: +class SnowflakeConnector(_BaseConnector): + def connect(self) -> Engine: connection_string = ( f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" ) self.engine = create_engine(connection_string) -class MSSQLConnector(BaseConnector): - def connect(self) -> None: +class MSSQLConnector(_BaseConnector): + def connect(self) -> Engine: connection_string = ( f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" ) self.engine = create_engine(connection_string) +#TODO: Move this application context class SourceSystemConnectorFactory: @staticmethod - def create_connector(db_type: str, config: Dict[str, Any]) -> ISourceSystemConnector: + def create_connector(db_type: str, config: dict[str, str]) -> ISourceSystemConnector: if db_type == "snowflake": return SnowflakeConnector(config) elif db_type == "mssql": From 978f8bcd5200fb3db002f298fcf893aee6da2552 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 15 Jan 2025 20:46:26 +0530 Subject: [PATCH 04/22] Added pyodbc dependency --- pyproject.toml | 4 +- .../remorph/connections/credential_manager.py | 5 +- .../remorph/connections/database_manager.py | 11 +++-- .../remorph/resources/config/credentials.yml | 49 +++++++++---------- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a7acc6c5db..207945ae8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,9 @@ dependencies = [ "databricks-labs-blueprint[yaml]>=0.2.3", "databricks-labs-lsql>=0.7.5,<0.14.0", # TODO: Limit the LSQL version until dependencies are correct. "cryptography>=41.0.3", - "SQLAlchemy" + "pyodbc", + "SQLAlchemy", + ] [project.urls] diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 2180a4732c..3a770608e5 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,8 +1,9 @@ from pathlib import Path import yaml +from databricks.labs.blueprint.wheels import ProductInfo class Credentials: - def __init__(self, product_info: str) -> None: + def __init__(self, product_info: ProductInfo) -> None: self._product_info = product_info self._credentials: dict[str, Any] = self._load_credentials(self._get_local_version_file_path()) @@ -18,4 +19,4 @@ def get(self, source: str) -> dict[str, str]: if source in self._credentials: return self._credentials[source] else: - raise KeyError(f"source system: {source} credentials not found not in credentials: {source}") + raise KeyError(f"source system: {source} credentials not found not in file credentials.yml") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 2b7dd66164..45651175eb 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -2,6 +2,7 @@ from pathlib import Path import yaml from sqlalchemy import create_engine +from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from typing import Any @@ -34,7 +35,7 @@ def execute_query(self, query: str) -> list[dict[str, Any]]: session.close() class SnowflakeConnector(_BaseConnector): - def connect(self) -> Engine: + def connect(self) -> _ISourceSystemConnector: connection_string = ( f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" @@ -42,7 +43,7 @@ def connect(self) -> Engine: self.engine = create_engine(connection_string) class MSSQLConnector(_BaseConnector): - def connect(self) -> Engine: + def connect(self) -> _ISourceSystemConnector: connection_string = ( f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" @@ -52,10 +53,10 @@ def connect(self) -> Engine: #TODO: Move this application context class SourceSystemConnectorFactory: @staticmethod - def create_connector(db_type: str, config: dict[str, str]) -> ISourceSystemConnector: + def create_connector(db_type: str, config: dict[str, str]) -> _ISourceSystemConnector: if db_type == "snowflake": - return SnowflakeConnector(config) + return SnowflakeConnector(config).connect() elif db_type == "mssql": - return MSSQLConnector(config) + return MSSQLConnector(config).connect() else: raise ValueError(f"Unsupported database type: {db_type}") diff --git a/src/databricks/labs/remorph/resources/config/credentials.yml b/src/databricks/labs/remorph/resources/config/credentials.yml index 68fedc86f8..c6017de1f5 100644 --- a/src/databricks/labs/remorph/resources/config/credentials.yml +++ b/src/databricks/labs/remorph/resources/config/credentials.yml @@ -1,29 +1,28 @@ -credentials: - snowflake: - account: example_account - connect_retries: 1 - connect_timeout: null - host: null - insecure_mode: false - oauth_client_id: null - oauth_client_secret: null - password: null - port: null - private_key: null - private_key_passphrase: null - private_key_path: null - role: null - token: null - user: null - warehouse: null +snowflake: + account: example_account + connect_retries: 1 + connect_timeout: null + host: null + insecure_mode: false + oauth_client_id: null + oauth_client_secret: null + password: null + port: null + private_key: null + private_key_passphrase: null + private_key_path: null + role: null + token: null + user: null + warehouse: null - msssql: - database: example_database - driver: ODBC Driver 18 for SQL Server - server: example_host - port: null - user: null - password: null +msssql: + database: example_database + driver: ODBC Driver 18 for SQL Server + server: example_host + port: null + user: null + password: null From 39cdc3eaf6ceaf2467739663f3cb127761766ae4 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 15 Jan 2025 22:59:25 +0530 Subject: [PATCH 05/22] fmt fixes --- .../remorph/connections/credential_manager.py | 12 +++-- .../remorph/connections/database_manager.py | 50 ++++++++++--------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 3a770608e5..c0e6a830b4 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -2,10 +2,11 @@ import yaml from databricks.labs.blueprint.wheels import ProductInfo + class Credentials: def __init__(self, product_info: ProductInfo) -> None: self._product_info = product_info - self._credentials: dict[str, Any] = self._load_credentials(self._get_local_version_file_path()) + self._credentials: dict[str, str] = self._load_credentials(self._get_local_version_file_path()) def _get_local_version_file_path(self) -> Path: user_home = f"{Path(__file__).home()}" @@ -16,7 +17,10 @@ def _load_credentials(self, file_path: Path) -> dict[str, str]: return yaml.safe_load(f) def get(self, source: str) -> dict[str, str]: + error_msg = f"source system: {source} credentials not found not in file credentials.yml" if source in self._credentials: - return self._credentials[source] - else: - raise KeyError(f"source system: {source} credentials not found not in file credentials.yml") + value = self._credentials[source] + if isinstance(value, dict): + return value + raise KeyError(error_msg) + raise KeyError(error_msg) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 45651175eb..c6dbce9ae3 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -1,62 +1,64 @@ from abc import ABC, abstractmethod -from pathlib import Path -import yaml +from typing import Any + from sqlalchemy import create_engine -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Engine, Result from sqlalchemy.orm import sessionmaker -from typing import Any +from sqlalchemy import text + class _ISourceSystemConnector(ABC): @abstractmethod - def connect(self) -> Engine: + def _connect(self) -> Engine: pass @abstractmethod - def execute_query(self, query: str) -> list[dict[str, Any]]: + def execute_query(self, query: str) -> Result[Any]: pass + class _BaseConnector(_ISourceSystemConnector): def __init__(self, config: dict[str, Any]): self.config = config - self.engine = None + self.engine: Engine = self._connect() - def connect(self) -> Engine: + def _connect(self) -> Engine: raise NotImplementedError("Subclasses should implement this method") - def execute_query(self, query: str) -> list[dict[str, Any]]: + def execute_query(self, query: str) -> Result[Any]: if not self.engine: raise ConnectionError("Not connected to the database.") - Session = sessionmaker(bind=self.engine) - session = Session() - try: - result = session.execute(query) - return [dict(row) for row in result] - finally: - session.close() + session = sessionmaker(bind=self.engine) + connection = session() + return connection.execute(text(query)) + class SnowflakeConnector(_BaseConnector): - def connect(self) -> _ISourceSystemConnector: + def _connect(self) -> Engine: connection_string = ( f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" ) self.engine = create_engine(connection_string) + return self.engine + class MSSQLConnector(_BaseConnector): - def connect(self) -> _ISourceSystemConnector: + def _connect(self) -> Engine: connection_string = ( f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" ) self.engine = create_engine(connection_string) + return self.engine + -#TODO: Move this application context class SourceSystemConnectorFactory: @staticmethod def create_connector(db_type: str, config: dict[str, str]) -> _ISourceSystemConnector: if db_type == "snowflake": - return SnowflakeConnector(config).connect() - elif db_type == "mssql": - return MSSQLConnector(config).connect() - else: - raise ValueError(f"Unsupported database type: {db_type}") + return SnowflakeConnector(config) + if db_type == "mssql": + return MSSQLConnector(config) + + raise ValueError(f"Unsupported database type: {db_type}") From d1263b54badb718301e8db5fbd36c0e1bfdf2021 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Thu, 16 Jan 2025 15:13:54 +0530 Subject: [PATCH 06/22] Added Vault Manager --- .../remorph/connections/database_manager.py | 20 ++++++++++++++----- .../remorph/resources/config/credentials.yml | 4 +++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index c6dbce9ae3..2bf7342077 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -35,10 +35,20 @@ def execute_query(self, query: str) -> Result[Any]: class SnowflakeConnector(_BaseConnector): def _connect(self) -> Engine: - connection_string = ( - f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" - f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}&role={self.config['role']}" - ) + if self.config['private_key_path'] is not None: + connection_string = ( + f"snowflake://{self.config['user']}@{self.config['account']}/" + f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}" + f"&role={self.config['role']}" + f"&authenticator=externalbrowser&private_key_path={self.config['private_key_path']}" + f"&private_key_passphrase={self.config['private_key_passphrase']}" + ) + else: + connection_string = ( + f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" + f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}" + f"&role={self.config['role']}" + ) self.engine = create_engine(connection_string) return self.engine @@ -49,7 +59,7 @@ def _connect(self) -> Engine: f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" ) - self.engine = create_engine(connection_string) + self.engine = create_engine(connection_string, echo = True) return self.engine diff --git a/src/databricks/labs/remorph/resources/config/credentials.yml b/src/databricks/labs/remorph/resources/config/credentials.yml index c6017de1f5..8ace5b4256 100644 --- a/src/databricks/labs/remorph/resources/config/credentials.yml +++ b/src/databricks/labs/remorph/resources/config/credentials.yml @@ -1,3 +1,5 @@ +secret_vault_type: local | databricks | env +secret_vault_name: null snowflake: account: example_account connect_retries: 1 @@ -17,7 +19,7 @@ snowflake: warehouse: null msssql: - database: example_database + database: DB_NAME driver: ODBC Driver 18 for SQL Server server: example_host port: null From 0c40eca89876d2f22059293f961bf948408e9b8c Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 17 Jan 2025 10:31:15 +0530 Subject: [PATCH 07/22] Added TODO --- src/databricks/labs/remorph/connections/database_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 2bf7342077..d3d44aa4f6 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -62,7 +62,7 @@ def _connect(self) -> Engine: self.engine = create_engine(connection_string, echo = True) return self.engine - +#TODO Refactor into application context class SourceSystemConnectorFactory: @staticmethod def create_connector(db_type: str, config: dict[str, str]) -> _ISourceSystemConnector: From d7eed0832b1a100d8b6d1f1cdce6aa9df11dc09d Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 20 Jan 2025 15:55:23 +0530 Subject: [PATCH 08/22] Adding credential manager for multiple secret --- .../remorph/connections/credential_manager.py | 17 +++++++++++--- .../remorph/connections/database_manager.py | 22 ++++++++++++------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index c0e6a830b4..7c476c1693 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,7 +1,7 @@ from pathlib import Path import yaml from databricks.labs.blueprint.wheels import ProductInfo - +import os class Credentials: def __init__(self, product_info: ProductInfo) -> None: @@ -17,10 +17,21 @@ def _load_credentials(self, file_path: Path) -> dict[str, str]: return yaml.safe_load(f) def get(self, source: str) -> dict[str, str]: - error_msg = f"source system: {source} credentials not found not in file credentials.yml" + error_msg = f"source system: {source} credentials not found in file credentials.yml" if source in self._credentials: value = self._credentials[source] if isinstance(value, dict): - return value + return {k: self.get_secret_value(v) for k, v in value.items()} raise KeyError(error_msg) raise KeyError(error_msg) + + def get_secret_value(self, key: str) -> str: + secret_vault_type = self._credentials.get('secret_vault_type', 'local') + if secret_vault_type == 'local': + return key + elif secret_vault_type == 'env': + return os.getenv(key, f"Environment variable {key} not found") + elif secret_vault_type == 'databricks': + return NotImplemented + else: + raise ValueError(f"Unsupported secret vault type: {secret_vault_type}") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index d3d44aa4f6..2c00bfc6a2 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -63,12 +63,18 @@ def _connect(self) -> Engine: return self.engine #TODO Refactor into application context -class SourceSystemConnectorFactory: - @staticmethod - def create_connector(db_type: str, config: dict[str, str]) -> _ISourceSystemConnector: - if db_type == "snowflake": - return SnowflakeConnector(config) - if db_type == "mssql": - return MSSQLConnector(config) +class DatabaseManager: + def __init__(self, db_type: str, config: dict[str, str]): + self.db_type = db_type + self.config = config + self.connector = self._create_connector() + + def _create_connector(self) -> _ISourceSystemConnector: + if self.db_type == "snowflake": + return SnowflakeConnector(self.config) + if self.db_type == "mssql": + return MSSQLConnector(self.config) + raise ValueError(f"Unsupported database type: {self.db_type}") - raise ValueError(f"Unsupported database type: {db_type}") + def execute_query(self, query: str) -> Result[Any]: + return self.connector.execute_query(query) From bd0d012f11e21f042fbd9faf32fc77c8a374d717 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 20 Jan 2025 16:27:59 +0530 Subject: [PATCH 09/22] Added reading credentials from env and then falling back to key itself --- .../labs/remorph/connections/credential_manager.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 7c476c1693..30b6d06890 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -30,7 +30,13 @@ def get_secret_value(self, key: str) -> str: if secret_vault_type == 'local': return key elif secret_vault_type == 'env': - return os.getenv(key, f"Environment variable {key} not found") + print(f"key: {key}") + v = os.getenv(str(key)) # Port numbers can be int + print(v) + if v is None: + print(f"Environment variable {key} not found Failing back to actual strings") + return key + return v elif secret_vault_type == 'databricks': return NotImplemented else: From c7f91f52fbb2dab0de05dff7ba4830e3d5671ede Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 20 Jan 2025 16:29:04 +0530 Subject: [PATCH 10/22] fixed case agnostic connection creation. --- src/databricks/labs/remorph/connections/database_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 2c00bfc6a2..304204d6b9 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -70,9 +70,9 @@ def __init__(self, db_type: str, config: dict[str, str]): self.connector = self._create_connector() def _create_connector(self) -> _ISourceSystemConnector: - if self.db_type == "snowflake": + if self.db_type.lower() == "snowflake": return SnowflakeConnector(self.config) - if self.db_type == "mssql": + if self.db_type.lower() == "mssql": return MSSQLConnector(self.config) raise ValueError(f"Unsupported database type: {self.db_type}") From 1dcff154b920af827877308a7392f24ac2310337 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 20 Jan 2025 19:40:04 +0530 Subject: [PATCH 11/22] Added UT --- .../remorph/connections/credential_manager.py | 7 +- .../remorph/connections/database_manager.py | 28 ++---- .../remorph/resources/config/credentials.yml | 5 +- .../connections/test_credential_manager.py | 90 +++++++++++++++++++ .../unit/connections/test_database_manager.py | 0 5 files changed, 107 insertions(+), 23 deletions(-) create mode 100644 tests/unit/connections/test_credential_manager.py create mode 100644 tests/unit/connections/test_database_manager.py diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 30b6d06890..1993e23ea9 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -3,6 +3,7 @@ from databricks.labs.blueprint.wheels import ProductInfo import os + class Credentials: def __init__(self, product_info: ProductInfo) -> None: self._product_info = product_info @@ -26,18 +27,18 @@ def get(self, source: str) -> dict[str, str]: raise KeyError(error_msg) def get_secret_value(self, key: str) -> str: - secret_vault_type = self._credentials.get('secret_vault_type', 'local') + secret_vault_type = self._credentials.get('secret_vault_type', 'local').lower() if secret_vault_type == 'local': return key elif secret_vault_type == 'env': print(f"key: {key}") - v = os.getenv(str(key)) # Port numbers can be int + v = os.getenv(str(key)) # Port numbers can be int print(v) if v is None: print(f"Environment variable {key} not found Failing back to actual strings") return key return v elif secret_vault_type == 'databricks': - return NotImplemented + raise NotImplementedError("Databricks secret vault not implemented") else: raise ValueError(f"Unsupported secret vault type: {secret_vault_type}") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 304204d6b9..cc81a31a50 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -34,23 +34,9 @@ def execute_query(self, query: str) -> Result[Any]: class SnowflakeConnector(_BaseConnector): + # TODO: Not Implemented def _connect(self) -> Engine: - if self.config['private_key_path'] is not None: - connection_string = ( - f"snowflake://{self.config['user']}@{self.config['account']}/" - f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}" - f"&role={self.config['role']}" - f"&authenticator=externalbrowser&private_key_path={self.config['private_key_path']}" - f"&private_key_passphrase={self.config['private_key_passphrase']}" - ) - else: - connection_string = ( - f"snowflake://{self.config['user']}:{self.config['password']}@{self.config['account']}/" - f"{self.config['database']}/{self.config['schema']}?warehouse={self.config['warehouse']}" - f"&role={self.config['role']}" - ) - self.engine = create_engine(connection_string) - return self.engine + raise NotImplementedError("Snowflake connector not implemented") class MSSQLConnector(_BaseConnector): @@ -59,10 +45,13 @@ def _connect(self) -> Engine: f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" ) - self.engine = create_engine(connection_string, echo = True) + # TODO: Add support for other connection parameters through a custom dictionary or config + self.engine = create_engine(connection_string, echo=True, connect_args=None) return self.engine -#TODO Refactor into application context + + +# TODO Refactor into application context class DatabaseManager: def __init__(self, db_type: str, config: dict[str, str]): self.db_type = db_type @@ -72,9 +61,10 @@ def __init__(self, db_type: str, config: dict[str, str]): def _create_connector(self) -> _ISourceSystemConnector: if self.db_type.lower() == "snowflake": return SnowflakeConnector(self.config) - if self.db_type.lower() == "mssql": + if self.db_type.lower() in ("mssql", "tsql", "synapse"): return MSSQLConnector(self.config) raise ValueError(f"Unsupported database type: {self.db_type}") + def execute_query(self, query: str) -> Result[Any]: return self.connector.execute_query(query) diff --git a/src/databricks/labs/remorph/resources/config/credentials.yml b/src/databricks/labs/remorph/resources/config/credentials.yml index 8ace5b4256..72ffbedfd8 100644 --- a/src/databricks/labs/remorph/resources/config/credentials.yml +++ b/src/databricks/labs/remorph/resources/config/credentials.yml @@ -18,7 +18,8 @@ snowflake: user: null warehouse: null -msssql: +mssql: + #TODO Expand to support sqlpools, and legacy dwh database: DB_NAME driver: ODBC Driver 18 for SQL Server server: example_host @@ -28,3 +29,5 @@ msssql: + + diff --git a/tests/unit/connections/test_credential_manager.py b/tests/unit/connections/test_credential_manager.py new file mode 100644 index 0000000000..be280b44ee --- /dev/null +++ b/tests/unit/connections/test_credential_manager.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path +from databricks.labs.remorph.connections.credential_manager import Credentials +from databricks.labs.blueprint.wheels import ProductInfo + + +@pytest.fixture +def product_info(): + mock_product_info = MagicMock(spec=ProductInfo) + mock_product_info.product_name.return_value = "test_product" + return mock_product_info + + +@pytest.fixture +def local_credentials(): + return { + 'secret_vault_type': 'local', + 'mssql': { + 'database': 'DB_NAME', + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'example_host', + 'user': 'local_user', + 'password': 'local_password', + }, + } + + +@pytest.fixture +def env_credentials(): + return { + 'secret_vault_type': 'env', + 'mssql': { + 'database': 'DB_NAME', + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'example_host', + 'user': 'MSSQL_USER_ENV', + 'password': 'MSSQL_PASSWORD_ENV', + }, + } + + +@pytest.fixture +def databricks_credentials(): + return { + 'secret_vault_type': 'databricks', + 'secret_vault_name': 'databricks_vault_name', + 'mssql': { + 'database': 'DB_NAME', + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'example_host', + 'user': 'databricks_user', + 'password': 'databricks_password', + }, + } + + +@patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path') +@patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials') +def test_local_credentials(mock_load_credentials, mock_get_local_version_file_path, product_info, local_credentials): + mock_load_credentials.return_value = local_credentials + mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml") + credentials = Credentials(product_info) + creds = credentials.get('mssql') + assert creds['user'] == 'local_user' + assert creds['password'] == 'local_password' + + +@patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path') +@patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials') +@patch.dict('os.environ', {'MSSQL_USER_ENV': 'env_user', 'MSSQL_PASSWORD_ENV': 'env_password'}) +def test_env_credentials(mock_load_credentials, mock_get_local_version_file_path, product_info, env_credentials): + mock_load_credentials.return_value = env_credentials + mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml") + credentials = Credentials(product_info) + creds = credentials.get('mssql') + assert creds['user'] == 'env_user' + assert creds['password'] == 'env_password' + + +@patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path') +@patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials') +def test_databricks_credentials( + mock_load_credentials, mock_get_local_version_file_path, product_info, databricks_credentials +): + mock_load_credentials.return_value = databricks_credentials + mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml") + credentials = Credentials(product_info) + with pytest.raises(NotImplementedError): + credentials.get('mssql') diff --git a/tests/unit/connections/test_database_manager.py b/tests/unit/connections/test_database_manager.py new file mode 100644 index 0000000000..e69de29bb2 From 211944e683ca0105f395b828667a1b3128ffcf03 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 20 Jan 2025 19:43:24 +0530 Subject: [PATCH 12/22] fmt fixes --- .../remorph/connections/credential_manager.py | 20 +++++++++---------- .../remorph/connections/database_manager.py | 4 +--- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 1993e23ea9..5a21f918e6 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,7 +1,9 @@ from pathlib import Path +import os import yaml + + from databricks.labs.blueprint.wheels import ProductInfo -import os class Credentials: @@ -30,15 +32,13 @@ def get_secret_value(self, key: str) -> str: secret_vault_type = self._credentials.get('secret_vault_type', 'local').lower() if secret_vault_type == 'local': return key - elif secret_vault_type == 'env': - print(f"key: {key}") - v = os.getenv(str(key)) # Port numbers can be int - print(v) - if v is None: + if secret_vault_type == 'env': + value = os.getenv(str(key)) # Port numbers can be int + if value is None: print(f"Environment variable {key} not found Failing back to actual strings") return key - return v - elif secret_vault_type == 'databricks': + return value + if secret_vault_type == 'databricks': raise NotImplementedError("Databricks secret vault not implemented") - else: - raise ValueError(f"Unsupported secret vault type: {secret_vault_type}") + + raise ValueError(f"Unsupported secret vault type: {secret_vault_type}") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index cc81a31a50..71860ac7f3 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -50,7 +50,6 @@ def _connect(self) -> Engine: return self.engine - # TODO Refactor into application context class DatabaseManager: def __init__(self, db_type: str, config: dict[str, str]): @@ -61,10 +60,9 @@ def __init__(self, db_type: str, config: dict[str, str]): def _create_connector(self) -> _ISourceSystemConnector: if self.db_type.lower() == "snowflake": return SnowflakeConnector(self.config) - if self.db_type.lower() in ("mssql", "tsql", "synapse"): + if self.db_type.lower() in {"mssql", "tsql", "synapse"}: return MSSQLConnector(self.config) raise ValueError(f"Unsupported database type: {self.db_type}") - def execute_query(self, query: str) -> Result[Any]: return self.connector.execute_query(query) From a445abaf065cf1638fae3b8d08d4dc0f087d8b84 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 20 Jan 2025 19:54:04 +0530 Subject: [PATCH 13/22] initial test case setup --- .../unit/connections/test_database_manager.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/unit/connections/test_database_manager.py b/tests/unit/connections/test_database_manager.py index e69de29bb2..ff6150d8df 100644 --- a/tests/unit/connections/test_database_manager.py +++ b/tests/unit/connections/test_database_manager.py @@ -0,0 +1,52 @@ +import pytest +from unittest.mock import create_autospec, patch + +from databricks.labs.remorph.connections.database_manager import DatabaseManager, MSSQLConnector + +def test_unsupported_database_type(): + config = {'user': 'user', 'password': 'password', 'server': 'server', 'database': 'database', 'driver': 'driver'} + with pytest.raises(ValueError, match="Unsupported database type: invalid_db"): + DatabaseManager('invalid_db', config) + +def test_mssql_connector_connection(): + config = { + 'user': 'valid_user', + 'password': 'valid_password', + 'server': 'valid_server', + 'database': 'valid_database', + 'driver': 'ODBC Driver 17 for SQL Server' + } + + with patch('sqlalchemy.create_engine') as mock_create_engine, \ + patch.object(MSSQLConnector, '_connect', return_value=None): # Prevent actual connection + connector = MSSQLConnector(config) + engine = connector._connect() + + assert engine is not None + mock_create_engine.assert_called_once_with( + f"mssql+pyodbc://{config['user']}:{config['password']}@{config['server']}/{config['database']}?driver={config['driver']}", + echo=True, + connect_args=None + ) + +def test_execute_successful_query(): + config = { + 'user': 'valid_user', + 'password': 'valid_password', + 'server': 'valid_server', + 'database': 'valid_database', + 'driver': 'ODBC Driver 17 for SQL Server' + } + + with patch('sqlalchemy.create_engine'), \ + patch('sqlalchemy.orm.sessionmaker') as mock_session, \ + patch.object(MSSQLConnector, '_connect', return_value='result'): + connector = MSSQLConnector(config) + connector._connect() + + mock_connection = mock_session.return_value.__enter__.return_value + mock_connection.execute.return_value = "result" + + result = connector.execute_query("SELECT * FROM valid_table") + + assert result == "result" From 3cb9c05113f276e91a31ab9bae598dc20a141fea Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Tue, 21 Jan 2025 18:26:18 +0530 Subject: [PATCH 14/22] test case setup --- .../remorph/connections/database_manager.py | 38 ++++--- .../unit/connections/test_database_manager.py | 105 +++++++++--------- 2 files changed, 74 insertions(+), 69 deletions(-) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 71860ac7f3..4df0b6777d 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -29,12 +29,11 @@ def execute_query(self, query: str) -> Result[Any]: if not self.engine: raise ConnectionError("Not connected to the database.") session = sessionmaker(bind=self.engine) - connection = session() - return connection.execute(text(query)) + with session() as connection: + return connection.execute(text(query)) class SnowflakeConnector(_BaseConnector): - # TODO: Not Implemented def _connect(self) -> Engine: raise NotImplementedError("Snowflake connector not implemented") @@ -45,24 +44,27 @@ def _connect(self) -> Engine: f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" f"{self.config['database']}?driver={self.config['driver']}" ) - # TODO: Add support for other connection parameters through a custom dictionary or config - self.engine = create_engine(connection_string, echo=True, connect_args=None) - return self.engine + return create_engine(connection_string, echo=True) -# TODO Refactor into application context class DatabaseManager: - def __init__(self, db_type: str, config: dict[str, str]): - self.db_type = db_type - self.config = config - self.connector = self._create_connector() - - def _create_connector(self) -> _ISourceSystemConnector: - if self.db_type.lower() == "snowflake": - return SnowflakeConnector(self.config) - if self.db_type.lower() in {"mssql", "tsql", "synapse"}: - return MSSQLConnector(self.config) - raise ValueError(f"Unsupported database type: {self.db_type}") + def __init__(self, db_type: str, config: dict[str, Any]): + self.connector = self._create_connector(db_type, config) + + def _create_connector(self, db_type: str, config: dict[str, Any]) -> _ISourceSystemConnector: + connectors = { + "snowflake": SnowflakeConnector, + "mssql": MSSQLConnector, + "tsql": MSSQLConnector, + "synapse": MSSQLConnector, + } + + connector_class = connectors.get(db_type.lower()) + + if connector_class is None: + raise ValueError(f"Unsupported database type: {db_type}") + + return connector_class(config) def execute_query(self, query: str) -> Result[Any]: return self.connector.execute_query(query) diff --git a/tests/unit/connections/test_database_manager.py b/tests/unit/connections/test_database_manager.py index ff6150d8df..4d1e21770d 100644 --- a/tests/unit/connections/test_database_manager.py +++ b/tests/unit/connections/test_database_manager.py @@ -1,52 +1,55 @@ import pytest -from unittest.mock import create_autospec, patch - -from databricks.labs.remorph.connections.database_manager import DatabaseManager, MSSQLConnector - -def test_unsupported_database_type(): - config = {'user': 'user', 'password': 'password', 'server': 'server', 'database': 'database', 'driver': 'driver'} - with pytest.raises(ValueError, match="Unsupported database type: invalid_db"): - DatabaseManager('invalid_db', config) - -def test_mssql_connector_connection(): - config = { - 'user': 'valid_user', - 'password': 'valid_password', - 'server': 'valid_server', - 'database': 'valid_database', - 'driver': 'ODBC Driver 17 for SQL Server' - } - - with patch('sqlalchemy.create_engine') as mock_create_engine, \ - patch.object(MSSQLConnector, '_connect', return_value=None): # Prevent actual connection - connector = MSSQLConnector(config) - engine = connector._connect() - - assert engine is not None - mock_create_engine.assert_called_once_with( - f"mssql+pyodbc://{config['user']}:{config['password']}@{config['server']}/{config['database']}?driver={config['driver']}", - echo=True, - connect_args=None - ) - -def test_execute_successful_query(): - config = { - 'user': 'valid_user', - 'password': 'valid_password', - 'server': 'valid_server', - 'database': 'valid_database', - 'driver': 'ODBC Driver 17 for SQL Server' - } - - with patch('sqlalchemy.create_engine'), \ - patch('sqlalchemy.orm.sessionmaker') as mock_session, \ - patch.object(MSSQLConnector, '_connect', return_value='result'): - connector = MSSQLConnector(config) - connector._connect() - - mock_connection = mock_session.return_value.__enter__.return_value - mock_connection.execute.return_value = "result" - - result = connector.execute_query("SELECT * FROM valid_table") - - assert result == "result" +from unittest.mock import MagicMock, patch +from databricks.labs.remorph.connections.database_manager import DatabaseManager + +sample_config = { + 'user': 'test_user', + 'password': 'test_pass', + 'server': 'test_server', + 'database': 'test_db', + 'driver': 'ODBC Driver 17 for SQL Server', +} + + +def test_create_connector_unsupported_db_type(): + with pytest.raises(ValueError, match="Unsupported database type: unsupported_db"): + DatabaseManager("unsupported_db", sample_config) + + +# Test case for MSSQLConnector +@patch('databricks.labs.remorph.connections.database_manager.MSSQLConnector') +def test_mssql_connector(mock_mssql_connector): + mock_connector_instance = MagicMock() + mock_mssql_connector.return_value = mock_connector_instance + + db_manager = DatabaseManager("mssql", sample_config) + + assert db_manager.connector == mock_connector_instance + mock_mssql_connector.assert_called_once_with(sample_config) + + +@patch('databricks.labs.remorph.connections.database_manager.MSSQLConnector') +def test_execute_query(mock_mssql_connector): + mock_connector_instance = MagicMock() + mock_mssql_connector.return_value = mock_connector_instance + + db_manager = DatabaseManager("mssql", sample_config) + + query = "SELECT * FROM users" + mock_result = MagicMock() + mock_connector_instance.execute_query.return_value = mock_result + + result = db_manager.execute_query(query) + + assert result == mock_result + mock_connector_instance.execute_query.assert_called_once_with(query) + + +def test_execute_query_without_connection(): + db_manager = DatabaseManager("mssql", sample_config) + + # Simulating that the engine is not connected + db_manager.connector.engine = None + + with pytest.raises(ConnectionError, match="Not connected to the database."): + db_manager.execute_query("SELECT * FROM users") From 8b1c254033312202c4c09eb1b19991a3f89cb935 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Tue, 21 Jan 2025 19:42:06 +0530 Subject: [PATCH 15/22] Refactored to better --- .../remorph/connections/database_manager.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 4df0b6777d..ed79ca0dbd 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -29,8 +29,24 @@ def execute_query(self, query: str) -> Result[Any]: if not self.engine: raise ConnectionError("Not connected to the database.") session = sessionmaker(bind=self.engine) - with session() as connection: - return connection.execute(text(query)) + connection = session() + return connection.execute(text(query)) + + +def _create_connector(db_type: str, config: dict[str, Any]) -> _ISourceSystemConnector: + connectors = { + "snowflake": SnowflakeConnector, + "mssql": MSSQLConnector, + "tsql": MSSQLConnector, + "synapse": MSSQLConnector, + } + + connector_class = connectors.get(db_type.lower()) + + if connector_class is None: + raise ValueError(f"Unsupported database type: {db_type}") + + return connector_class(config) class SnowflakeConnector(_BaseConnector): @@ -49,22 +65,7 @@ def _connect(self) -> Engine: class DatabaseManager: def __init__(self, db_type: str, config: dict[str, Any]): - self.connector = self._create_connector(db_type, config) - - def _create_connector(self, db_type: str, config: dict[str, Any]) -> _ISourceSystemConnector: - connectors = { - "snowflake": SnowflakeConnector, - "mssql": MSSQLConnector, - "tsql": MSSQLConnector, - "synapse": MSSQLConnector, - } - - connector_class = connectors.get(db_type.lower()) - - if connector_class is None: - raise ValueError(f"Unsupported database type: {db_type}") - - return connector_class(config) + self.connector = _create_connector(db_type, config) def execute_query(self, query: str) -> Result[Any]: return self.connector.execute_query(query) From 74030d3545fc3b7b7e301c6849801fdedfa40be9 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 24 Jan 2025 12:20:04 +0530 Subject: [PATCH 16/22] Added Integration Test --- .../remorph/connections/credential_manager.py | 17 ++++--- .../remorph/connections/database_manager.py | 24 +++++++-- .../labs/remorph/connections/env_getter.py | 24 +++++++++ .../connections/test_mssql_connector.py | 50 +++++++++++++++++++ .../connections/test_credential_manager.py | 30 +++++++---- 5 files changed, 124 insertions(+), 21 deletions(-) create mode 100644 src/databricks/labs/remorph/connections/env_getter.py create mode 100644 tests/integration/connections/test_mssql_connector.py diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 5a21f918e6..6593f3fdff 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,14 +1,15 @@ from pathlib import Path -import os import yaml from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.remorph.connections.env_getter import EnvGetter class Credentials: - def __init__(self, product_info: ProductInfo) -> None: + def __init__(self, product_info: ProductInfo, env: EnvGetter) -> None: self._product_info = product_info + self._env = env self._credentials: dict[str, str] = self._load_credentials(self._get_local_version_file_path()) def _get_local_version_file_path(self) -> Path: @@ -19,25 +20,27 @@ def _load_credentials(self, file_path: Path) -> dict[str, str]: with open(file_path, encoding="utf-8") as f: return yaml.safe_load(f) - def get(self, source: str) -> dict[str, str]: + def load(self, source: str) -> dict[str, str]: error_msg = f"source system: {source} credentials not found in file credentials.yml" if source in self._credentials: value = self._credentials[source] if isinstance(value, dict): - return {k: self.get_secret_value(v) for k, v in value.items()} + return {k: self._get_secret_value(v) for k, v in value.items()} raise KeyError(error_msg) raise KeyError(error_msg) - def get_secret_value(self, key: str) -> str: + def _get_secret_value(self, key: str) -> str: secret_vault_type = self._credentials.get('secret_vault_type', 'local').lower() if secret_vault_type == 'local': return key if secret_vault_type == 'env': - value = os.getenv(str(key)) # Port numbers can be int - if value is None: + try: + value = self._env.get(str(key)) # Port numbers can be int + except KeyError: print(f"Environment variable {key} not found Failing back to actual strings") return key return value + if secret_vault_type == 'databricks': raise NotImplementedError("Databricks secret vault not implemented") diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index ed79ca0dbd..72c6433631 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -2,7 +2,7 @@ from typing import Any from sqlalchemy import create_engine -from sqlalchemy.engine import Engine, Result +from sqlalchemy.engine import Engine, Result, URL from sqlalchemy.orm import sessionmaker from sqlalchemy import text @@ -56,11 +56,21 @@ def _connect(self) -> Engine: class MSSQLConnector(_BaseConnector): def _connect(self) -> Engine: - connection_string = ( - f"mssql+pyodbc://{self.config['user']}:{self.config['password']}@{self.config['server']}/" - f"{self.config['database']}?driver={self.config['driver']}" + query_params = {"driver": self.config['driver']} + + for key, value in self.config.items(): + if key not in ["user", "password", "server", "database", "port"]: + query_params[key] = value + connection_string = URL.create( + "mssql+pyodbc", + username=self.config['user'], + password=self.config['password'], + host=self.config['server'], + port=self.config.get('port', 1433), + database=self.config['database'], + query=query_params, ) - return create_engine(connection_string, echo=True) + return create_engine(connection_string) class DatabaseManager: @@ -69,3 +79,7 @@ def __init__(self, db_type: str, config: dict[str, Any]): def execute_query(self, query: str) -> Result[Any]: return self.connector.execute_query(query) + + # query to ORM + # ORM model to some inbuilt daabase/DUCKDB + # push that ducbk db to rsult to databricks for sharing. diff --git a/src/databricks/labs/remorph/connections/env_getter.py b/src/databricks/labs/remorph/connections/env_getter.py new file mode 100644 index 0000000000..2b64c247ff --- /dev/null +++ b/src/databricks/labs/remorph/connections/env_getter.py @@ -0,0 +1,24 @@ +import os +import json +import logging + + +class EnvGetter: + def __init__(self, is_debug: bool = False): + self.env = self._get_debug_env() if is_debug else dict(os.environ) + + def get(self, key: str) -> str: + if key in self.env: + return self.env[key] + raise KeyError(f"not in env: {key}") + + def _get_debug_env(self) -> dict: + try: + debug_env_file = f"{os.path.expanduser('~')}/.databricks/debug-env.json" + with open(debug_env_file, 'r', encoding='utf-8') as file: + contents = file.read() + logging.debug(f"Found debug env file: {debug_env_file}") + raw = json.loads(contents) + return raw.get("ucws", {}) + except FileNotFoundError: + return dict(os.environ) diff --git a/tests/integration/connections/test_mssql_connector.py b/tests/integration/connections/test_mssql_connector.py new file mode 100644 index 0000000000..6cd379f003 --- /dev/null +++ b/tests/integration/connections/test_mssql_connector.py @@ -0,0 +1,50 @@ +from databricks.labs.remorph.connections.credential_manager import Credentials +from databricks.labs.remorph.connections.database_manager import DatabaseManager, MSSQLConnector +from databricks.labs.blueprint.wheels import ProductInfo +from databricks.labs.remorph.config import RemorphConfigs +from databricks.labs.remorph.connections.env_getter import EnvGetter +import pytest +from unittest.mock import patch +from urllib.parse import urlparse + +@pytest.fixture(scope="module") +def mock_credentials(): + with patch.object(Credentials, '_load_credentials', return_value={ + 'secret_vault_type': 'env', + 'secret_vault_name': '', + 'mssql': { + 'user': 'TEST_TSQL_USER', + 'password': 'TEST_TSQL_PASS', + 'server': 'TEST_TSQL_JDBC', + 'database': 'TEST_TSQL_JDBC', + 'driver': 'ODBC Driver 18 for SQL Server' + } + }): + yield + +@pytest.fixture(scope="module") +def db_manager(mock_credentials): + config = Credentials(ProductInfo.from_class(RemorphConfigs),EnvGetter(True)).load("mssql") + # since the kv has only URL so added explicit parse rules + base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1) + url_parts = urlparse(base_url) + server = url_parts.hostname + query_params = dict( + param.split("=", 1) for param in params.split(";") if "=" in param + ) + database = query_params.get("database", "") + config['server'] = '' #server + config['database'] = database + + return DatabaseManager("mssql", config) + + +def test_mssql_connector_connection(db_manager): + assert isinstance(db_manager.connector, MSSQLConnector) + +def test_mssql_connector_execute_query(db_manager): + # Test executing a query + query = "SELECT 101 AS test_column" + result = db_manager.execute_query(query) + row = result.fetchone() + assert row[0] == 101 diff --git a/tests/unit/connections/test_credential_manager.py b/tests/unit/connections/test_credential_manager.py index be280b44ee..f8e46d88bd 100644 --- a/tests/unit/connections/test_credential_manager.py +++ b/tests/unit/connections/test_credential_manager.py @@ -2,7 +2,9 @@ from unittest.mock import patch, MagicMock from pathlib import Path from databricks.labs.remorph.connections.credential_manager import Credentials +from databricks.labs.remorph.connections.env_getter import EnvGetter from databricks.labs.blueprint.wheels import ProductInfo +import os @pytest.fixture @@ -12,6 +14,11 @@ def product_info(): return mock_product_info +@pytest.fixture +def env_getter(): + return MagicMock(spec=EnvGetter) + + @pytest.fixture def local_credentials(): return { @@ -57,11 +64,13 @@ def databricks_credentials(): @patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path') @patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials') -def test_local_credentials(mock_load_credentials, mock_get_local_version_file_path, product_info, local_credentials): +def test_local_credentials( + mock_load_credentials, mock_get_local_version_file_path, product_info, local_credentials, env_getter +): mock_load_credentials.return_value = local_credentials mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml") - credentials = Credentials(product_info) - creds = credentials.get('mssql') + credentials = Credentials(product_info, env_getter) + creds = credentials.load('mssql') assert creds['user'] == 'local_user' assert creds['password'] == 'local_password' @@ -69,11 +78,14 @@ def test_local_credentials(mock_load_credentials, mock_get_local_version_file_pa @patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path') @patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials') @patch.dict('os.environ', {'MSSQL_USER_ENV': 'env_user', 'MSSQL_PASSWORD_ENV': 'env_password'}) -def test_env_credentials(mock_load_credentials, mock_get_local_version_file_path, product_info, env_credentials): +def test_env_credentials( + mock_load_credentials, mock_get_local_version_file_path, product_info, env_credentials, env_getter +): mock_load_credentials.return_value = env_credentials mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml") - credentials = Credentials(product_info) - creds = credentials.get('mssql') + env_getter.get.side_effect = lambda key: os.environ[key] + credentials = Credentials(product_info, env_getter) + creds = credentials.load('mssql') assert creds['user'] == 'env_user' assert creds['password'] == 'env_password' @@ -81,10 +93,10 @@ def test_env_credentials(mock_load_credentials, mock_get_local_version_file_path @patch('databricks.labs.remorph.connections.credential_manager.Credentials._get_local_version_file_path') @patch('databricks.labs.remorph.connections.credential_manager.Credentials._load_credentials') def test_databricks_credentials( - mock_load_credentials, mock_get_local_version_file_path, product_info, databricks_credentials + mock_load_credentials, mock_get_local_version_file_path, product_info, databricks_credentials, env_getter ): mock_load_credentials.return_value = databricks_credentials mock_get_local_version_file_path.return_value = Path("/fake/path/to/credentials.yml") - credentials = Credentials(product_info) + credentials = Credentials(product_info, env_getter) with pytest.raises(NotImplementedError): - credentials.get('mssql') + credentials.load('mssql') From 29b14be09a76d6a36d7d4e353384ae599864bcc8 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 24 Jan 2025 12:42:30 +0530 Subject: [PATCH 17/22] Added Integration Test --- Makefile | 4 +- .../remorph/connections/credential_manager.py | 6 ++- .../remorph/connections/database_manager.py | 17 +++++---- .../connections/test_mssql_connector.py | 37 +++++++++++-------- 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index 0855b199a5..cb8a497d06 100644 --- a/Makefile +++ b/Makefile @@ -17,10 +17,10 @@ fmt: setup_spark_remote: .github/scripts/setup_spark_remote.sh -test: setup_spark_remote +test: hatch run test -integration: +integration: setup_spark_remote hatch run integration coverage: diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 6593f3fdff..4910f02dc1 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,10 +1,12 @@ from pathlib import Path import yaml - +import logging from databricks.labs.blueprint.wheels import ProductInfo from databricks.labs.remorph.connections.env_getter import EnvGetter +logger = logging.getLogger(__name__) + class Credentials: def __init__(self, product_info: ProductInfo, env: EnvGetter) -> None: @@ -37,7 +39,7 @@ def _get_secret_value(self, key: str) -> str: try: value = self._env.get(str(key)) # Port numbers can be int except KeyError: - print(f"Environment variable {key} not found Failing back to actual strings") + logger.debug(f"Environment variable {key} not found Failing back to actual string value") return key return value diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index 72c6433631..b08eb39d0b 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -1,3 +1,4 @@ +import logging from abc import ABC, abstractmethod from typing import Any @@ -5,6 +6,9 @@ from sqlalchemy.engine import Engine, Result, URL from sqlalchemy.orm import sessionmaker from sqlalchemy import text +from pyodbc import OperationalError + +logger = logging.getLogger(__name__) class _ISourceSystemConnector(ABC): @@ -28,9 +32,12 @@ def _connect(self) -> Engine: def execute_query(self, query: str) -> Result[Any]: if not self.engine: raise ConnectionError("Not connected to the database.") - session = sessionmaker(bind=self.engine) - connection = session() - return connection.execute(text(query)) + try: + session = sessionmaker(bind=self.engine) + connection = session() + return connection.execute(text(query)) + except OperationalError: + raise ConnectionError("Error connecting to the database check credentials") def _create_connector(db_type: str, config: dict[str, Any]) -> _ISourceSystemConnector: @@ -79,7 +86,3 @@ def __init__(self, db_type: str, config: dict[str, Any]): def execute_query(self, query: str) -> Result[Any]: return self.connector.execute_query(query) - - # query to ORM - # ORM model to some inbuilt daabase/DUCKDB - # push that ducbk db to rsult to databricks for sharing. diff --git a/tests/integration/connections/test_mssql_connector.py b/tests/integration/connections/test_mssql_connector.py index 6cd379f003..f89ca7940b 100644 --- a/tests/integration/connections/test_mssql_connector.py +++ b/tests/integration/connections/test_mssql_connector.py @@ -7,33 +7,37 @@ from unittest.mock import patch from urllib.parse import urlparse + @pytest.fixture(scope="module") def mock_credentials(): - with patch.object(Credentials, '_load_credentials', return_value={ - 'secret_vault_type': 'env', - 'secret_vault_name': '', - 'mssql': { - 'user': 'TEST_TSQL_USER', - 'password': 'TEST_TSQL_PASS', - 'server': 'TEST_TSQL_JDBC', - 'database': 'TEST_TSQL_JDBC', - 'driver': 'ODBC Driver 18 for SQL Server' - } - }): + with patch.object( + Credentials, + '_load_credentials', + return_value={ + 'secret_vault_type': 'env', + 'secret_vault_name': '', + 'mssql': { + 'user': 'TEST_TSQL_USER', + 'password': 'TEST_TSQL_PASS', + 'server': 'TEST_TSQL_JDBC', + 'database': 'TEST_TSQL_JDBC', + 'driver': 'ODBC Driver 18 for SQL Server', + }, + }, + ): yield + @pytest.fixture(scope="module") def db_manager(mock_credentials): - config = Credentials(ProductInfo.from_class(RemorphConfigs),EnvGetter(True)).load("mssql") + config = Credentials(ProductInfo.from_class(RemorphConfigs), EnvGetter(True)).load("mssql") # since the kv has only URL so added explicit parse rules base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1) url_parts = urlparse(base_url) server = url_parts.hostname - query_params = dict( - param.split("=", 1) for param in params.split(";") if "=" in param - ) + query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param) database = query_params.get("database", "") - config['server'] = '' #server + config['server'] = server config['database'] = database return DatabaseManager("mssql", config) @@ -42,6 +46,7 @@ def db_manager(mock_credentials): def test_mssql_connector_connection(db_manager): assert isinstance(db_manager.connector, MSSQLConnector) + def test_mssql_connector_execute_query(db_manager): # Test executing a query query = "SELECT 101 AS test_column" From 0aa457b72d6ced4e81cfc107fc67553e1de51a7e Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 24 Jan 2025 13:53:35 +0530 Subject: [PATCH 18/22] fmt fixes --- src/databricks/labs/remorph/connections/credential_manager.py | 2 +- src/databricks/labs/remorph/connections/database_manager.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py index 4910f02dc1..dac58eab07 100644 --- a/src/databricks/labs/remorph/connections/credential_manager.py +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -1,6 +1,6 @@ from pathlib import Path -import yaml import logging +import yaml from databricks.labs.blueprint.wheels import ProductInfo from databricks.labs.remorph.connections.env_getter import EnvGetter diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py index b08eb39d0b..852eb76c48 100644 --- a/src/databricks/labs/remorph/connections/database_manager.py +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -6,7 +6,7 @@ from sqlalchemy.engine import Engine, Result, URL from sqlalchemy.orm import sessionmaker from sqlalchemy import text -from pyodbc import OperationalError +from sqlalchemy.exc import OperationalError logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ def execute_query(self, query: str) -> Result[Any]: connection = session() return connection.execute(text(query)) except OperationalError: - raise ConnectionError("Error connecting to the database check credentials") + raise ConnectionError("Error connecting to the database check credentials") from None def _create_connector(db_type: str, config: dict[str, Any]) -> _ISourceSystemConnector: From 9e1f7fd99b699915cb68b3497b0cbfb0b9efdd6d Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 24 Jan 2025 17:59:29 +0530 Subject: [PATCH 19/22] added fixture --- tests/integration/connections/__init__.py | 0 tests/integration/connections/conftest.py | 20 +++++++++++++++++++ .../connections/test_mssql_connector.py | 8 +++++--- 3 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 tests/integration/connections/__init__.py create mode 100644 tests/integration/connections/conftest.py diff --git a/tests/integration/connections/__init__.py b/tests/integration/connections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/connections/conftest.py b/tests/integration/connections/conftest.py new file mode 100644 index 0000000000..992a65b9aa --- /dev/null +++ b/tests/integration/connections/conftest.py @@ -0,0 +1,20 @@ +import logging +import pytest + +from databricks.labs.remorph.__about__ import __version__ + + +logging.getLogger("tests").setLevel("DEBUG") +logging.getLogger("databricks.labs.remorph").setLevel("DEBUG") + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def debug_env_name(): + return "ucws" + + +@pytest.fixture +def product_info(): + return "remorph", __version__ diff --git a/tests/integration/connections/test_mssql_connector.py b/tests/integration/connections/test_mssql_connector.py index f89ca7940b..e309b1f5b8 100644 --- a/tests/integration/connections/test_mssql_connector.py +++ b/tests/integration/connections/test_mssql_connector.py @@ -1,11 +1,13 @@ +from unittest.mock import patch +from urllib.parse import urlparse + +import pytest + from databricks.labs.remorph.connections.credential_manager import Credentials from databricks.labs.remorph.connections.database_manager import DatabaseManager, MSSQLConnector from databricks.labs.blueprint.wheels import ProductInfo from databricks.labs.remorph.config import RemorphConfigs from databricks.labs.remorph.connections.env_getter import EnvGetter -import pytest -from unittest.mock import patch -from urllib.parse import urlparse @pytest.fixture(scope="module") From 79e3a861f8b176157b6c5a99d81273b10f4ecfac Mon Sep 17 00:00:00 2001 From: SundarShankar89 <72757199+sundarshankar89@users.noreply.github.com> Date: Tue, 28 Jan 2025 05:23:25 +0000 Subject: [PATCH 20/22] add acceptance (#1428) Add Integration Test --- .github/scripts/setup_mssql_odbc.sh | 13 +++++ .github/workflows/acceptance.yml | 58 +++++++++++++++++++ pyproject.toml | 3 +- .../labs/remorph/transpiler/execute.py | 8 +-- tests/integration/conftest.py | 32 ++++++++++ tests/integration/connections/conftest.py | 20 ------- .../connections/test_mssql_connector.py | 3 +- 7 files changed, 110 insertions(+), 27 deletions(-) create mode 100644 .github/scripts/setup_mssql_odbc.sh create mode 100644 .github/workflows/acceptance.yml delete mode 100644 tests/integration/connections/conftest.py diff --git a/.github/scripts/setup_mssql_odbc.sh b/.github/scripts/setup_mssql_odbc.sh new file mode 100644 index 0000000000..09971ee597 --- /dev/null +++ b/.github/scripts/setup_mssql_odbc.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -xve +#Repurposed from https://github.com/Yarden-zamir/install-mssql-odbc + +curl -sSL -O https://packages.microsoft.com/config/ubuntu/$(grep VERSION_ID /etc/os-release | cut -d '"' -f 2)/packages-microsoft-prod.deb + +sudo dpkg -i packages-microsoft-prod.deb +#rm packages-microsoft-prod.deb + +sudo apt-get update +sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 + diff --git a/.github/workflows/acceptance.yml b/.github/workflows/acceptance.yml new file mode 100644 index 0000000000..97fec0367f --- /dev/null +++ b/.github/workflows/acceptance.yml @@ -0,0 +1,58 @@ +name: acceptance + +on: + pull_request: + types: [ opened, synchronize, ready_for_review ] + merge_group: + types: [ checks_requested ] + push: + branches: + - main + +permissions: + id-token: write + contents: read + pull-requests: write + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + integration: + if: github.event_name == 'pull_request' && github.event.pull_request.draft == false + environment: tool + runs-on: larger + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Install hatch + run: pip install hatch==1.9.4 + + - name: Install MSSQL ODBC Driver + run: | + chmod +x $GITHUB_WORKSPACE/.github/scripts/setup_mssql_odbc.sh + $GITHUB_WORKSPACE/.github/scripts/setup_mssql_odbc.sh + + - name: Run integration tests + uses: databrickslabs/sandbox/acceptance@acceptance/v0.4.2 + with: + vault_uri: ${{ secrets.VAULT_URI }} + directory: ${{ github.workspace }} + timeout: 2h + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ARM_CLIENT_ID: ${{ secrets.ARM_CLIENT_ID }} + ARM_TENANT_ID: ${{ secrets.ARM_TENANT_ID }} + TEST_ENV: 'ACCEPTANCE' + diff --git a/pyproject.toml b/pyproject.toml index a432a91b94..160ea36221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "pytest", "pytest-cov>=5.0.0,<6.0.0", "pytest-asyncio>=0.24.0", + "pytest-xdist~=3.5.0", "black>=23.1.0", "ruff>=0.0.243", "databricks-connect==15.1", @@ -69,7 +70,7 @@ reconcile = "databricks.labs.remorph.reconcile.execute:main" [tool.hatch.envs.default.scripts] test = "pytest --cov src --cov-report=xml tests/unit" coverage = "pytest --cov src tests/unit --cov-report=html" -integration = "pytest --cov src tests/integration --durations 20" +integration = "pytest --cov src tests/integration --durations 20 --ignore=tests/integration/connections" fmt = ["black .", "ruff check . --fix", "mypy --disable-error-code 'annotation-unchecked' .", diff --git a/src/databricks/labs/remorph/transpiler/execute.py b/src/databricks/labs/remorph/transpiler/execute.py index b15f74bea4..1e234ae78d 100644 --- a/src/databricks/labs/remorph/transpiler/execute.py +++ b/src/databricks/labs/remorph/transpiler/execute.py @@ -209,17 +209,15 @@ async def _do_transpile( def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClient: - # pylint: disable=protected-access """ [Private] Verifies and updates the workspace client configuration. TODO: In future refactor this function so it can be used for reconcile module without cross access. """ - product_info = workspace_client.config._product_info + # Using reflection to set right value for _product_info as dqx for telemetry + product_info = getattr(workspace_client.config, '_product_info') if product_info[0] != "remorph": - product_info[0] = "remorph" - if product_info[1] != __version__: - product_info[1] = __version__ + setattr(workspace_client.config, '_product_info', ('remorph', __version__)) return workspace_client diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2f655d78c9..202e8f5215 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,6 +1,38 @@ +import os +import logging import pytest from pyspark.sql import SparkSession +from databricks.labs.remorph.__about__ import __version__ + +logging.getLogger("tests").setLevel("DEBUG") +logging.getLogger("databricks.labs.remorph").setLevel("DEBUG") + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def debug_env_name(): + return "ucws" + + +@pytest.fixture +def product_info(): + return "remorph", __version__ + + +def pytest_collection_modifyitems(config, items): + if os.getenv('TEST_ENV') == 'ACCEPTANCE': + selected_items = [] + deselected_items = [] + for item in items: + if 'tests/integration/connections' in str(item.fspath): + selected_items.append(item) + else: + deselected_items.append(item) + items[:] = selected_items + config.hook.pytest_deselected(items=deselected_items) + @pytest.fixture(scope="session") def mock_spark() -> SparkSession: diff --git a/tests/integration/connections/conftest.py b/tests/integration/connections/conftest.py deleted file mode 100644 index 992a65b9aa..0000000000 --- a/tests/integration/connections/conftest.py +++ /dev/null @@ -1,20 +0,0 @@ -import logging -import pytest - -from databricks.labs.remorph.__about__ import __version__ - - -logging.getLogger("tests").setLevel("DEBUG") -logging.getLogger("databricks.labs.remorph").setLevel("DEBUG") - -logger = logging.getLogger(__name__) - - -@pytest.fixture -def debug_env_name(): - return "ucws" - - -@pytest.fixture -def product_info(): - return "remorph", __version__ diff --git a/tests/integration/connections/test_mssql_connector.py b/tests/integration/connections/test_mssql_connector.py index e309b1f5b8..9944126ea4 100644 --- a/tests/integration/connections/test_mssql_connector.py +++ b/tests/integration/connections/test_mssql_connector.py @@ -35,10 +35,11 @@ def db_manager(mock_credentials): config = Credentials(ProductInfo.from_class(RemorphConfigs), EnvGetter(True)).load("mssql") # since the kv has only URL so added explicit parse rules base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1) + url_parts = urlparse(base_url) server = url_parts.hostname query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param) - database = query_params.get("database", "") + database = query_params.get("database", "" "") config['server'] = server config['database'] = database From f44a09e94f344a1dfb8f2a9ba786772dea9e5edc Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 29 Jan 2025 23:09:20 +0530 Subject: [PATCH 21/22] fmt fixes --- src/databricks/labs/remorph/transpiler/execute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/labs/remorph/transpiler/execute.py b/src/databricks/labs/remorph/transpiler/execute.py index 8863bff9e0..fa49bb164e 100644 --- a/src/databricks/labs/remorph/transpiler/execute.py +++ b/src/databricks/labs/remorph/transpiler/execute.py @@ -214,7 +214,7 @@ def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClien TODO: In future refactor this function so it can be used for reconcile module without cross access. """ - + # Using reflection to set right value for _product_info for telemetry product_info = getattr(workspace_client.config, '_product_info') if product_info[0] != "remorph": From 355b76d4404e0a8630e4695bbd9ee5b7fde5fd75 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Mon, 3 Feb 2025 15:08:21 +0530 Subject: [PATCH 22/22] Simplified installation journey --- labs.yml | 10 +++-- src/databricks/labs/remorph/base_install.py | 9 ++++ src/databricks/labs/remorph/cli.py | 48 +++++++++++++++++++++ src/databricks/labs/remorph/install.py | 33 +++----------- tests/unit/test_install.py | 45 +++++++------------ 5 files changed, 85 insertions(+), 60 deletions(-) create mode 100644 src/databricks/labs/remorph/base_install.py diff --git a/labs.yml b/labs.yml index d11592f0fb..a1bc5cfd24 100644 --- a/labs.yml +++ b/labs.yml @@ -3,9 +3,7 @@ name: remorph description: Code Transpiler and Data Reconciliation tool for Accelerating Data onboarding to Databricks from EDW, CDW and other ETL sources. install: min_runtime_version: 13.3 - require_running_cluster: false - require_databricks_connect: false - script: src/databricks/labs/remorph/install.py + script: src/databricks/labs/remorph/base_install.py uninstall: script: src/databricks/labs/remorph/uninstall.py entrypoint: src/databricks/labs/remorph/cli.py @@ -66,3 +64,9 @@ commands: description: Utility to setup Scope and Secrets on Databricks Workspace - name: debug-me description: "[INTERNAL] Debug SDK connectivity" + - name: install-assessment + description: "Install Assessment" + - name: install-transpile + description: "Install Transpile" + - name: install-reconcile + description: "Install Reconcile" diff --git a/src/databricks/labs/remorph/base_install.py b/src/databricks/labs/remorph/base_install.py new file mode 100644 index 0000000000..b0b4d199dd --- /dev/null +++ b/src/databricks/labs/remorph/base_install.py @@ -0,0 +1,9 @@ +from databricks.labs.blueprint.entrypoint import get_logger, is_in_debug + +if __name__ == "__main__": + logger = get_logger(__file__) + logger.setLevel("INFO") + if is_in_debug(): + logger.getLogger("databricks").setLevel(logger.setLevel("DEBUG")) + + logger.info("Successfully Setup Remorph Components Locally") diff --git a/src/databricks/labs/remorph/cli.py b/src/databricks/labs/remorph/cli.py index 98fd863dcf..ca001539a2 100644 --- a/src/databricks/labs/remorph/cli.py +++ b/src/databricks/labs/remorph/cli.py @@ -8,6 +8,8 @@ from databricks.labs.remorph.config import TranspileConfig from databricks.labs.remorph.contexts.application import ApplicationContext from databricks.labs.remorph.helpers.recon_config_utils import ReconConfigPrompts +from databricks.labs.remorph.__about__ import __version__ +from databricks.labs.remorph.install import WorkspaceInstaller from databricks.labs.remorph.reconcile.runner import ReconcileRunner from databricks.labs.remorph.lineage import lineage_generator from databricks.labs.remorph.transpiler.execute import transpile as do_transpile @@ -34,6 +36,32 @@ def raise_validation_exception(msg: str) -> Exception: proxy_command(remorph, "debug-bundle") +def _installer(ws: WorkspaceClient) -> WorkspaceInstaller: + app_context = ApplicationContext(_verify_workspace_client(ws)) + return WorkspaceInstaller( + app_context.workspace_client, + app_context.prompts, + app_context.installation, + app_context.install_state, + app_context.product_info, + app_context.resource_configurator, + app_context.workspace_installation, + ) + + +def _verify_workspace_client(ws: WorkspaceClient) -> WorkspaceClient: + """ + [Private] Verifies and updates the workspace client configuration. + """ + + # Using reflection to set right value for _product_info for telemetry + product_info = getattr(ws.config, '_product_info') + if product_info[0] != "remorph": + setattr(ws.config, '_product_info', ('remorph', __version__)) + + return ws + + @remorph.command def transpile( w: WorkspaceClient, @@ -168,5 +196,25 @@ def configure_secrets(w: WorkspaceClient): recon_conf.prompt_and_save_connection_details() +@remorph.command(is_unauthenticated=True) +def install_assessment(): + """Install the Remorph Assessment package""" + raise NotImplementedError("Assessment package is not available yet.") + + +@remorph.command() +def install_transpile(w: WorkspaceClient): + """Install the Remorph Transpile package""" + installer = _installer(w) + installer.run(module="transpile") + + +@remorph.command(is_unauthenticated=False) +def install_reconcile(w: WorkspaceClient): + """Install the Remorph Reconcile package""" + installer = _installer(w) + installer.run(module="reconcile") + + if __name__ == "__main__": remorph() diff --git a/src/databricks/labs/remorph/install.py b/src/databricks/labs/remorph/install.py index 32968c77e7..4b911c8b1d 100644 --- a/src/databricks/labs/remorph/install.py +++ b/src/databricks/labs/remorph/install.py @@ -3,7 +3,6 @@ import os import webbrowser -from databricks.labs.blueprint.entrypoint import get_logger, is_in_debug from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.installation import SerdeError from databricks.labs.blueprint.installer import InstallState @@ -12,7 +11,6 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound, PermissionDenied -from databricks.labs.remorph.__about__ import __version__ from databricks.labs.remorph.config import ( TranspileConfig, ReconcileConfig, @@ -20,7 +18,6 @@ RemorphConfigs, ReconcileMetadataConfig, ) -from databricks.labs.remorph.contexts.application import ApplicationContext from databricks.labs.remorph.deployment.configurator import ResourceConfigurator from databricks.labs.remorph.deployment.installation import WorkspaceInstallation from databricks.labs.remorph.reconcile.constants import ReconReportType, ReconSourceType @@ -29,7 +26,6 @@ logger = logging.getLogger(__name__) TRANSPILER_WAREHOUSE_PREFIX = "Remorph Transpiler Validation" -MODULES = sorted({"transpile", "reconcile", "all"}) class WorkspaceInstaller: @@ -61,20 +57,20 @@ def __init__( def run( self, + module: str, config: RemorphConfigs | None = None, ) -> RemorphConfigs: logger.info(f"Installing Remorph v{self._product_info.version()}") if not config: - config = self.configure() + config = self.configure(module) if self._is_testing(): return config self._ws_installation.install(config) logger.info("Installation completed successfully! Please refer to the documentation for the next steps.") return config - def configure(self, module: str | None = None) -> RemorphConfigs: - selected_module = module or self._prompts.choice("Select a module to configure:", MODULES) - match selected_module: + def configure(self, module: str) -> RemorphConfigs: + match module: case "transpile": logger.info("Configuring remorph `transpile`.") return RemorphConfigs(self._configure_transpile(), None) @@ -88,7 +84,7 @@ def configure(self, module: str | None = None) -> RemorphConfigs: self._configure_reconcile(), ) case _: - raise ValueError(f"Invalid input: {selected_module}") + raise ValueError(f"Invalid input: {module}") def _is_testing(self): return self._product_info.product_name() != "remorph" @@ -282,22 +278,3 @@ def _save_config(self, config: TranspileConfig | ReconcileConfig): def _has_necessary_access(self, catalog_name: str, schema_name: str, volume_name: str | None = None): self._resource_configurator.has_necessary_access(catalog_name, schema_name, volume_name) - - -if __name__ == "__main__": - logger = get_logger(__file__) - logger.setLevel("INFO") - if is_in_debug(): - logging.getLogger("databricks").setLevel(logging.DEBUG) - - app_context = ApplicationContext(WorkspaceClient(product="remorph", product_version=__version__)) - installer = WorkspaceInstaller( - app_context.workspace_client, - app_context.prompts, - app_context.installation, - app_context.install_state, - app_context.product_info, - app_context.resource_configurator, - app_context.workspace_installation, - ) - installer.run() diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index 7f93a96973..da2942a394 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -9,7 +9,7 @@ from databricks.labs.remorph.contexts.application import ApplicationContext from databricks.labs.remorph.deployment.configurator import ResourceConfigurator from databricks.labs.remorph.deployment.installation import WorkspaceInstallation -from databricks.labs.remorph.install import WorkspaceInstaller, MODULES +from databricks.labs.remorph.install import WorkspaceInstaller from databricks.labs.remorph.config import TranspileConfig from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 from databricks.labs.remorph.reconcile.constants import ReconSourceType, ReconReportType @@ -63,7 +63,7 @@ def test_workspace_installer_run_install_not_called_in_test(ws): ctx.resource_configurator, ctx.workspace_installation, ) - returned_config = workspace_installer.run(config=provided_config) + returned_config = workspace_installer.run(module="transpile", config=provided_config) assert returned_config == provided_config ws_installation.install.assert_not_called() @@ -85,7 +85,7 @@ def test_workspace_installer_run_install_called_with_provided_config(ws): ctx.resource_configurator, ctx.workspace_installation, ) - returned_config = workspace_installer.run(config=provided_config) + returned_config = workspace_installer.run(module="transpile", config=provided_config) assert returned_config == provided_config ws_installation.install.assert_called_once_with(provided_config) @@ -113,7 +113,6 @@ def test_configure_error_if_invalid_module_selected(ws): def test_workspace_installer_run_install_called_with_generated_config(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Do you want to override the existing installation?": "no", r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source dialect": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), @@ -142,7 +141,7 @@ def test_workspace_installer_run_install_called_with_generated_config(ws): ctx.resource_configurator, ctx.workspace_installation, ) - workspace_installer.run() + workspace_installer.run("transpile") installation.assert_file_written( "config.yml", { @@ -163,7 +162,6 @@ def test_workspace_installer_run_install_called_with_generated_config(ws): def test_configure_transpile_no_existing_installation(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Do you want to override the existing installation?": "no", r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), @@ -191,7 +189,7 @@ def test_configure_transpile_no_existing_installation(ws): ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("transpile") expected_morph_config = TranspileConfig( transpiler_config_path="sqlglot", source_dialect="snowflake", @@ -225,7 +223,6 @@ def test_configure_transpile_no_existing_installation(ws): def test_configure_transpile_installation_no_override(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Do you want to override the existing installation?": "no", } ) @@ -262,13 +259,12 @@ def test_configure_transpile_installation_no_override(ws): ctx.workspace_installation, ) with pytest.raises(SystemExit): - workspace_installer.configure() + workspace_installer.configure("transpile") def test_configure_transpile_installation_config_error_continue_install(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Do you want to override the existing installation?": "no", r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), @@ -312,7 +308,7 @@ def test_configure_transpile_installation_config_error_continue_install(ws): ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("transpile") expected_morph_config = TranspileConfig( transpiler_config_path="sqlglot", source_dialect="snowflake", @@ -347,7 +343,6 @@ def test_configure_transpile_installation_config_error_continue_install(ws): def test_configure_transpile_installation_with_no_validation(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source dialect": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), r"Enter input SQL path.*": "/tmp/queries/snow", @@ -375,7 +370,7 @@ def test_configure_transpile_installation_with_no_validation(ws): ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("transpile") expected_morph_config = TranspileConfig( transpiler_config_path="sqlglot", source_dialect="snowflake", @@ -409,7 +404,6 @@ def test_configure_transpile_installation_with_no_validation(ws): def test_configure_transpile_installation_with_validation_and_cluster_id_in_config(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), r"Enter input SQL path.*": "/tmp/queries/snow", @@ -444,7 +438,7 @@ def test_configure_transpile_installation_with_validation_and_cluster_id_in_conf ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("transpile") expected_config = RemorphConfigs( transpile=TranspileConfig( transpiler_config_path="sqlglot", @@ -479,7 +473,6 @@ def test_configure_transpile_installation_with_validation_and_cluster_id_in_conf def test_configure_transpile_installation_with_validation_and_cluster_id_from_prompt(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), r"Enter input SQL path.*": "/tmp/queries/snow", @@ -515,7 +508,7 @@ def test_configure_transpile_installation_with_validation_and_cluster_id_from_pr ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("transpile") expected_config = RemorphConfigs( transpile=TranspileConfig( transpiler_config_path="sqlglot", @@ -550,7 +543,6 @@ def test_configure_transpile_installation_with_validation_and_cluster_id_from_pr def test_configure_transpile_installation_with_validation_and_warehouse_id_from_prompt(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), r"Enter input SQL path.*": "/tmp/queries/snow", @@ -584,7 +576,7 @@ def test_configure_transpile_installation_with_validation_and_warehouse_id_from_ ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("transpile") expected_config = RemorphConfigs( transpile=TranspileConfig( transpiler_config_path="sqlglot", @@ -619,7 +611,6 @@ def test_configure_transpile_installation_with_validation_and_warehouse_id_from_ def test_configure_reconcile_installation_no_override(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("reconcile"), r"Do you want to override the existing installation?": "no", } ) @@ -660,13 +651,12 @@ def test_configure_reconcile_installation_no_override(ws): ctx.workspace_installation, ) with pytest.raises(SystemExit): - workspace_installer.configure() + workspace_installer.configure("reconcile") def test_configure_reconcile_installation_config_error_continue_install(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("reconcile"), r"Select the Data Source": RECONCILE_DATA_SOURCES.index("oracle"), r"Select the report type": RECONCILE_REPORT_TYPES.index("all"), r"Enter Secret scope name to store .* connection details / secrets": "remorph_oracle", @@ -719,7 +709,7 @@ def test_configure_reconcile_installation_config_error_continue_install(ws): ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("reconcile") expected_config = RemorphConfigs( reconcile=ReconcileConfig( data_source="oracle", @@ -763,7 +753,6 @@ def test_configure_reconcile_installation_config_error_continue_install(ws): def test_configure_reconcile_no_existing_installation(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("reconcile"), r"Select the Data Source": RECONCILE_DATA_SOURCES.index("snowflake"), r"Select the report type": RECONCILE_REPORT_TYPES.index("all"), r"Enter Secret scope name to store .* connection details / secrets": "remorph_snowflake", @@ -797,7 +786,7 @@ def test_configure_reconcile_no_existing_installation(ws): ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("reconcile") expected_config = RemorphConfigs( reconcile=ReconcileConfig( data_source="snowflake", @@ -842,7 +831,6 @@ def test_configure_reconcile_no_existing_installation(ws): def test_configure_all_override_installation(ws): prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("all"), r"Do you want to override the existing installation?": "yes", r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), @@ -917,7 +905,7 @@ def test_configure_all_override_installation(ws): ctx.resource_configurator, ctx.workspace_installation, ) - config = workspace_installer.configure() + config = workspace_installer.configure("all") expected_transpile_config = TranspileConfig( transpiler_config_path="sqlglot", source_dialect="snowflake", @@ -1015,7 +1003,6 @@ def test_runs_upgrades_on_more_recent_version(ws): ctx = ApplicationContext(ws) prompts = MockPrompts( { - r"Select a module to configure:": MODULES.index("transpile"), r"Do you want to override the existing installation?": "yes", r"Enter path to the transpiler configuration file": "sqlglot", r"Select the source": sorted(SQLGLOT_DIALECTS.keys()).index("snowflake"), @@ -1048,7 +1035,7 @@ def test_runs_upgrades_on_more_recent_version(ws): ctx.workspace_installation, ) - workspace_installer.run() + workspace_installer.run("transpile") mock_workspace_installation.install.assert_called_once_with( RemorphConfigs(