diff --git a/.github/scripts/setup_mssql_odbc.sh b/.github/scripts/setup_mssql_odbc.sh new file mode 100644 index 0000000000..09971ee597 --- /dev/null +++ b/.github/scripts/setup_mssql_odbc.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -xve +#Repurposed from https://github.com/Yarden-zamir/install-mssql-odbc + +curl -sSL -O https://packages.microsoft.com/config/ubuntu/$(grep VERSION_ID /etc/os-release | cut -d '"' -f 2)/packages-microsoft-prod.deb + +sudo dpkg -i packages-microsoft-prod.deb +#rm packages-microsoft-prod.deb + +sudo apt-get update +sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 + diff --git a/.github/workflows/acceptance.yml b/.github/workflows/acceptance.yml new file mode 100644 index 0000000000..97fec0367f --- /dev/null +++ b/.github/workflows/acceptance.yml @@ -0,0 +1,58 @@ +name: acceptance + +on: + pull_request: + types: [ opened, synchronize, ready_for_review ] + merge_group: + types: [ checks_requested ] + push: + branches: + - main + +permissions: + id-token: write + contents: read + pull-requests: write + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + integration: + if: github.event_name == 'pull_request' && github.event.pull_request.draft == false + environment: tool + runs-on: larger + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install Python + uses: actions/setup-python@v5 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Install hatch + run: pip install hatch==1.9.4 + + - name: Install MSSQL ODBC Driver + run: | + chmod +x $GITHUB_WORKSPACE/.github/scripts/setup_mssql_odbc.sh + $GITHUB_WORKSPACE/.github/scripts/setup_mssql_odbc.sh + + - name: Run integration tests + uses: databrickslabs/sandbox/acceptance@acceptance/v0.4.2 + with: + vault_uri: ${{ secrets.VAULT_URI }} + directory: ${{ github.workspace }} + timeout: 2h + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ARM_CLIENT_ID: ${{ secrets.ARM_CLIENT_ID }} + ARM_TENANT_ID: ${{ secrets.ARM_TENANT_ID }} + TEST_ENV: 'ACCEPTANCE' + diff --git a/Makefile b/Makefile index 0855b199a5..cb8a497d06 100644 --- a/Makefile +++ b/Makefile @@ -17,10 +17,10 @@ fmt: setup_spark_remote: .github/scripts/setup_spark_remote.sh -test: setup_spark_remote +test: hatch run test -integration: +integration: setup_spark_remote hatch run integration coverage: diff --git a/pyproject.toml b/pyproject.toml index 23d0b4838d..b3ec477d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "databricks-labs-blueprint[yaml]>=0.2.3", "databricks-labs-lsql>=0.7.5,<0.14.0", # TODO: Limit the LSQL version until dependencies are correct. "cryptography>=41.0.3", + "pyodbc", + "SQLAlchemy", "pygls>=2.0.0a2", ] @@ -50,6 +52,7 @@ dependencies = [ "pytest", "pytest-cov>=5.0.0,<6.0.0", "pytest-asyncio>=0.24.0", + "pytest-xdist~=3.5.0", "black>=23.1.0", "ruff>=0.0.243", "databricks-connect==15.1", @@ -65,8 +68,8 @@ reconcile = "databricks.labs.remorph.reconcile.execute:main" [tool.hatch.envs.default.scripts] test = "pytest --cov src --cov-report=xml tests/unit" -coverage = "pytest --cov src tests/unit --cov-report=html" -integration = "pytest --cov src tests/integration --durations 20" +coverage = "pytest --cov src tests --cov-report=html --ignore=tests/integration/connections" +integration = "pytest --cov src tests/integration --durations 20 --ignore=tests/integration/connections" fmt = ["black .", "ruff check . --fix", "mypy --disable-error-code 'annotation-unchecked' .", diff --git a/src/databricks/labs/remorph/connections/__init__.py b/src/databricks/labs/remorph/connections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/remorph/connections/credential_manager.py b/src/databricks/labs/remorph/connections/credential_manager.py new file mode 100644 index 0000000000..7ec476eff5 --- /dev/null +++ b/src/databricks/labs/remorph/connections/credential_manager.py @@ -0,0 +1,85 @@ +from pathlib import Path +import logging +from typing import Protocol + +import yaml + +from databricks.labs.remorph.connections.env_getter import EnvGetter + + +logger = logging.getLogger(__name__) + + +class SecretProvider(Protocol): + def get_secret(self, key: str) -> str: + pass + + +class LocalSecretProvider: + def get_secret(self, key: str) -> str: + return key + + +class EnvSecretProvider: + def __init__(self, env_getter: EnvGetter): + self._env_getter = env_getter + + def get_secret(self, key: str) -> str: + try: + return self._env_getter.get(str(key)) + except KeyError: + logger.debug(f"Environment variable {key} not found. Falling back to actual value") + return key + + +class DatabricksSecretProvider: + def get_secret(self, key: str) -> str: + raise NotImplementedError("Databricks secret vault not implemented") + + +class CredentialManager: + def __init__(self, credential_loader: dict, secret_providers: dict): + self._credentials = credential_loader + self._secret_providers = secret_providers + self._default_vault = self._credentials.get('secret_vault_type', 'local').lower() + + def get_credentials(self, source: str) -> dict: + if source not in self._credentials: + raise KeyError(f"Source system: {source} credentials not found") + + value = self._credentials[source] + if not isinstance(value, dict): + raise KeyError(f"Invalid credential format for source: {source}") + + return {k: self._get_secret_value(v) for k, v in value.items()} + + def _get_secret_value(self, key: str) -> str: + provider = self._secret_providers.get(self._default_vault) + if not provider: + raise ValueError(f"Unsupported secret vault type: {self._default_vault}") + return provider.get_secret(key) + + +def _get_home() -> Path: + return Path(__file__).home() + + +def _load_credentials(path: Path) -> dict: + try: + with open(path, encoding="utf-8") as f: + return yaml.safe_load(f) + except FileNotFoundError as e: + raise FileNotFoundError(f"Credentials file not found at {path}") from e + + +def create_credential_manager(product_name: str, env_getter: EnvGetter): + file_path = Path(f"{_get_home()}/.databricks/labs/{product_name}/.credentials.yml") + + secret_providers = { + 'local': LocalSecretProvider(), + 'env': EnvSecretProvider(env_getter), + 'databricks': DatabricksSecretProvider(), + } + + loader = _load_credentials(file_path) + return CredentialManager(loader, secret_providers) diff --git a/src/databricks/labs/remorph/connections/database_manager.py b/src/databricks/labs/remorph/connections/database_manager.py new file mode 100644 index 0000000000..fb65281dfd --- /dev/null +++ b/src/databricks/labs/remorph/connections/database_manager.py @@ -0,0 +1,96 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any + +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine, Result, URL +from sqlalchemy.orm import sessionmaker +from sqlalchemy import text +from sqlalchemy.exc import OperationalError + +logger = logging.getLogger(__name__) + + +class DatabaseConnector(ABC): + @abstractmethod + def _connect(self) -> Engine: + pass + + @abstractmethod + def execute_query(self, query: str) -> Result[Any]: + pass + + +class _BaseConnector(DatabaseConnector): + def __init__(self, config: dict[str, Any]): + self.config = config + self.engine: Engine = self._connect() + + def _connect(self) -> Engine: + raise NotImplementedError("Subclasses should implement this method") + + def execute_query(self, query: str) -> Result[Any]: + if not self.engine: + raise ConnectionError("Not connected to the database.") + session = sessionmaker(bind=self.engine) + connection = session() + return connection.execute(text(query)) + + +def _create_connector(db_type: str, config: dict[str, Any]) -> DatabaseConnector: + connectors = { + "snowflake": SnowflakeConnector, + "mssql": MSSQLConnector, + "tsql": MSSQLConnector, + "synapse": MSSQLConnector, + } + + connector_class = connectors.get(db_type.lower()) + + if connector_class is None: + raise ValueError(f"Unsupported database type: {db_type}") + + return connector_class(config) + + +class SnowflakeConnector(_BaseConnector): + def _connect(self) -> Engine: + raise NotImplementedError("Snowflake connector not implemented") + + +class MSSQLConnector(_BaseConnector): + def _connect(self) -> Engine: + query_params = {"driver": self.config['driver']} + + for key, value in self.config.items(): + if key not in ["user", "password", "server", "database", "port"]: + query_params[key] = value + connection_string = URL.create( + "mssql+pyodbc", + username=self.config['user'], + password=self.config['password'], + host=self.config['server'], + port=self.config.get('port', 1433), + database=self.config['database'], + query=query_params, + ) + return create_engine(connection_string) + + +class DatabaseManager: + def __init__(self, db_type: str, config: dict[str, Any]): + self.connector = _create_connector(db_type, config) + + def execute_query(self, query: str) -> Result[Any]: + try: + return self.connector.execute_query(query) + except OperationalError: + raise ConnectionError("Error connecting to the database check credentials") from None + + def check_connection(self) -> bool: + query = "SELECT 101 AS test_column" + result = self.execute_query(query) + row = result.fetchone() + if row is None: + return False + return row[0] == 101 diff --git a/src/databricks/labs/remorph/connections/env_getter.py b/src/databricks/labs/remorph/connections/env_getter.py new file mode 100644 index 0000000000..0f6c45098e --- /dev/null +++ b/src/databricks/labs/remorph/connections/env_getter.py @@ -0,0 +1,13 @@ +import os + + +class EnvGetter: + """Standardised inorder to support testing Capabilities, check debug_envgetter.py""" + + def __init__(self): + self.env = dict(os.environ) + + def get(self, key: str) -> str: + if key in self.env: + return self.env[key] + raise KeyError(f"not in env: {key}") diff --git a/src/databricks/labs/remorph/resources/config/credentials.yml b/src/databricks/labs/remorph/resources/config/credentials.yml new file mode 100644 index 0000000000..72ffbedfd8 --- /dev/null +++ b/src/databricks/labs/remorph/resources/config/credentials.yml @@ -0,0 +1,33 @@ +secret_vault_type: local | databricks | env +secret_vault_name: null +snowflake: + account: example_account + connect_retries: 1 + connect_timeout: null + host: null + insecure_mode: false + oauth_client_id: null + oauth_client_secret: null + password: null + port: null + private_key: null + private_key_passphrase: null + private_key_path: null + role: null + token: null + user: null + warehouse: null + +mssql: + #TODO Expand to support sqlpools, and legacy dwh + database: DB_NAME + driver: ODBC Driver 18 for SQL Server + server: example_host + port: null + user: null + password: null + + + + + diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2f655d78c9..202e8f5215 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,6 +1,38 @@ +import os +import logging import pytest from pyspark.sql import SparkSession +from databricks.labs.remorph.__about__ import __version__ + +logging.getLogger("tests").setLevel("DEBUG") +logging.getLogger("databricks.labs.remorph").setLevel("DEBUG") + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def debug_env_name(): + return "ucws" + + +@pytest.fixture +def product_info(): + return "remorph", __version__ + + +def pytest_collection_modifyitems(config, items): + if os.getenv('TEST_ENV') == 'ACCEPTANCE': + selected_items = [] + deselected_items = [] + for item in items: + if 'tests/integration/connections' in str(item.fspath): + selected_items.append(item) + else: + deselected_items.append(item) + items[:] = selected_items + config.hook.pytest_deselected(items=deselected_items) + @pytest.fixture(scope="session") def mock_spark() -> SparkSession: diff --git a/tests/integration/connections/__init__.py b/tests/integration/connections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/connections/debug_envgetter.py b/tests/integration/connections/debug_envgetter.py new file mode 100644 index 0000000000..06d1230c8a --- /dev/null +++ b/tests/integration/connections/debug_envgetter.py @@ -0,0 +1,24 @@ +import os +import json +import logging + + +class EnvGetter: + def __init__(self, is_debug: bool = True): + self.env = self._get_debug_env() if is_debug else dict(os.environ) + + def get(self, key: str) -> str: + if key in self.env: + return self.env[key] + raise KeyError(f"not in env: {key}") + + def _get_debug_env(self) -> dict: + try: + debug_env_file = f"{os.path.expanduser('~')}/.databricks/debug-env.json" + with open(debug_env_file, 'r', encoding='utf-8') as file: + contents = file.read() + logging.debug(f"Found debug env file: {debug_env_file}") + raw = json.loads(contents) + return raw.get("ucws", {}) + except FileNotFoundError: + return dict(os.environ) diff --git a/tests/integration/connections/test_mssql_connector.py b/tests/integration/connections/test_mssql_connector.py new file mode 100644 index 0000000000..9b2dac9145 --- /dev/null +++ b/tests/integration/connections/test_mssql_connector.py @@ -0,0 +1,59 @@ +from unittest.mock import patch +from urllib.parse import urlparse + +import pytest + +from databricks.labs.remorph.connections.credential_manager import create_credential_manager +from databricks.labs.remorph.connections.database_manager import DatabaseManager, MSSQLConnector +from .debug_envgetter import EnvGetter + + +@pytest.fixture(scope="module") +def mock_credentials(): + with patch( + 'databricks.labs.remorph.connections.credential_manager._load_credentials', + return_value={ + 'secret_vault_type': 'env', + 'secret_vault_name': '', + 'mssql': { + 'user': 'TEST_TSQL_USER', + 'password': 'TEST_TSQL_PASS', + 'server': 'TEST_TSQL_JDBC', + 'database': 'TEST_TSQL_JDBC', + 'driver': 'ODBC Driver 18 for SQL Server', + }, + }, + ): + yield + + +@pytest.fixture(scope="module") +def db_manager(mock_credentials): + config = create_credential_manager("remorph", EnvGetter(True)).get_credentials("mssql") + # since the kv has only URL so added explicit parse rules + base_url, params = config['server'].replace("jdbc:", "", 1).split(";", 1) + + url_parts = urlparse(base_url) + server = url_parts.hostname + query_params = dict(param.split("=", 1) for param in params.split(";") if "=" in param) + database = query_params.get("database", "" "") + config['server'] = server + config['database'] = database + + return DatabaseManager("mssql", config) + + +def test_mssql_connector_connection(db_manager): + assert isinstance(db_manager.connector, MSSQLConnector) + + +def test_mssql_connector_execute_query(db_manager): + # Test executing a query + query = "SELECT 101 AS test_column" + result = db_manager.execute_query(query) + row = result.fetchone() + assert row[0] == 101 + + +def test_connection_test(db_manager): + assert db_manager.check_connection() diff --git a/tests/unit/connections/test_credential_manager.py b/tests/unit/connections/test_credential_manager.py new file mode 100644 index 0000000000..71a6e7133e --- /dev/null +++ b/tests/unit/connections/test_credential_manager.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import patch, MagicMock +from pathlib import Path +from databricks.labs.remorph.connections.credential_manager import create_credential_manager +from databricks.labs.remorph.connections.env_getter import EnvGetter +import os + +product_name = "remorph" + + +@pytest.fixture +def env_getter(): + return MagicMock(spec=EnvGetter) + + +@pytest.fixture +def local_credentials(): + return { + 'secret_vault_type': 'local', + 'mssql': { + 'database': 'DB_NAME', + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'example_host', + 'user': 'local_user', + 'password': 'local_password', + }, + } + + +@pytest.fixture +def env_credentials(): + return { + 'secret_vault_type': 'env', + 'mssql': { + 'database': 'DB_NAME', + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'example_host', + 'user': 'MSSQL_USER_ENV', + 'password': 'MSSQL_PASSWORD_ENV', + }, + } + + +@pytest.fixture +def databricks_credentials(): + return { + 'secret_vault_type': 'databricks', + 'secret_vault_name': 'databricks_vault_name', + 'mssql': { + 'database': 'DB_NAME', + 'driver': 'ODBC Driver 18 for SQL Server', + 'server': 'example_host', + 'user': 'databricks_user', + 'password': 'databricks_password', + }, + } + + +@patch('databricks.labs.remorph.connections.credential_manager._load_credentials') +@patch('databricks.labs.remorph.connections.credential_manager._get_home') +def test_local_credentials(mock_get_home, mock_load_credentials, local_credentials, env_getter): + mock_load_credentials.return_value = local_credentials + mock_get_home.return_value = Path("/fake/home") + credentials = create_credential_manager(product_name, env_getter) + creds = credentials.get_credentials('mssql') + assert creds['user'] == 'local_user' + assert creds['password'] == 'local_password' + + +@patch('databricks.labs.remorph.connections.credential_manager._load_credentials') +@patch('databricks.labs.remorph.connections.credential_manager._get_home') +@patch.dict('os.environ', {'MSSQL_USER_ENV': 'env_user', 'MSSQL_PASSWORD_ENV': 'env_password'}) +def test_env_credentials(mock_get_home, mock_load_credentials, env_credentials, env_getter): + mock_load_credentials.return_value = env_credentials + mock_get_home.return_value = Path("/fake/home") + env_getter.get.side_effect = lambda key: os.environ[key] + credentials = create_credential_manager(product_name, env_getter) + creds = credentials.get_credentials('mssql') + assert creds['user'] == 'env_user' + assert creds['password'] == 'env_password' + + +@patch('databricks.labs.remorph.connections.credential_manager._load_credentials') +@patch('databricks.labs.remorph.connections.credential_manager._get_home') +def test_databricks_credentials(mock_get_home, mock_load_credentials, databricks_credentials, env_getter): + mock_load_credentials.return_value = databricks_credentials + mock_get_home.return_value = Path("/fake/home") + credentials = create_credential_manager(product_name, env_getter) + with pytest.raises(NotImplementedError): + credentials.get_credentials('mssql') diff --git a/tests/unit/connections/test_database_manager.py b/tests/unit/connections/test_database_manager.py new file mode 100644 index 0000000000..4d1e21770d --- /dev/null +++ b/tests/unit/connections/test_database_manager.py @@ -0,0 +1,55 @@ +import pytest +from unittest.mock import MagicMock, patch +from databricks.labs.remorph.connections.database_manager import DatabaseManager + +sample_config = { + 'user': 'test_user', + 'password': 'test_pass', + 'server': 'test_server', + 'database': 'test_db', + 'driver': 'ODBC Driver 17 for SQL Server', +} + + +def test_create_connector_unsupported_db_type(): + with pytest.raises(ValueError, match="Unsupported database type: unsupported_db"): + DatabaseManager("unsupported_db", sample_config) + + +# Test case for MSSQLConnector +@patch('databricks.labs.remorph.connections.database_manager.MSSQLConnector') +def test_mssql_connector(mock_mssql_connector): + mock_connector_instance = MagicMock() + mock_mssql_connector.return_value = mock_connector_instance + + db_manager = DatabaseManager("mssql", sample_config) + + assert db_manager.connector == mock_connector_instance + mock_mssql_connector.assert_called_once_with(sample_config) + + +@patch('databricks.labs.remorph.connections.database_manager.MSSQLConnector') +def test_execute_query(mock_mssql_connector): + mock_connector_instance = MagicMock() + mock_mssql_connector.return_value = mock_connector_instance + + db_manager = DatabaseManager("mssql", sample_config) + + query = "SELECT * FROM users" + mock_result = MagicMock() + mock_connector_instance.execute_query.return_value = mock_result + + result = db_manager.execute_query(query) + + assert result == mock_result + mock_connector_instance.execute_query.assert_called_once_with(query) + + +def test_execute_query_without_connection(): + db_manager = DatabaseManager("mssql", sample_config) + + # Simulating that the engine is not connected + db_manager.connector.engine = None + + with pytest.raises(ConnectionError, match="Not connected to the database."): + db_manager.execute_query("SELECT * FROM users")