Skip to content

Commit

Permalink
Added common install state primitives with strong typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Jan 20, 2024
1 parent e01a2a9 commit 023368c
Showing 1 changed file with 246 additions and 1 deletion.
247 changes: 246 additions & 1 deletion src/databricks/labs/blueprint/installer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,140 @@
import abc
import dataclasses
import enum
import json
import logging
import threading
import types
import typing
from json import JSONDecodeError
from typing import TypedDict
from typing import TypedDict, Any, Callable, io

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.mixins import workspace
from databricks.sdk.service.workspace import ImportFormat

logger = logging.getLogger(__name__)

Resources = dict[str, str]
Json = dict[str, Any]

@dataclass
class ConnectConfig:
# Keep all the fields in sync with databricks.sdk.core.Config
host: str | None = None
account_id: str | None = None
token: str | None = None
client_id: str | None = None
client_secret: str | None = None
azure_client_id: str | None = None
azure_tenant_id: str | None = None
azure_client_secret: str | None = None
azure_environment: str | None = None
cluster_id: str | None = None
profile: str | None = None
debug_headers: bool | None = False
# Truncate JSON fields in HTTP requests and responses above this limit.
# If this occurs, the log message will include the text `... (XXX additional elements)`
debug_truncate_bytes: int | None = 250000
rate_limit: int | None = None
max_connections_per_pool: int | None = None
max_connection_pools: int | None = None

@staticmethod
def from_databricks_config(cfg: Config) -> "ConnectConfig":
return ConnectConfig(
host=cfg.host,
token=cfg.token,
client_id=cfg.client_id,
client_secret=cfg.client_secret,
azure_client_id=cfg.azure_client_id,
azure_tenant_id=cfg.azure_tenant_id,
azure_client_secret=cfg.azure_client_secret,
azure_environment=cfg.azure_environment,
cluster_id=cfg.cluster_id,
profile=cfg.profile,
debug_headers=cfg.debug_headers,
debug_truncate_bytes=cfg.debug_truncate_bytes,
rate_limit=cfg.rate_limit,
max_connection_pools=cfg.max_connection_pools,
max_connections_per_pool=cfg.max_connections_per_pool,
)

def to_databricks_config(self):
return Config(
host=self.host,
account_id=self.account_id,
token=self.token,
client_id=self.client_id,
client_secret=self.client_secret,
azure_client_id=self.azure_client_id,
azure_tenant_id=self.azure_tenant_id,
azure_client_secret=self.azure_client_secret,
azure_environment=self.azure_environment,
cluster_id=self.cluster_id,
profile=self.profile,
debug_headers=self.debug_headers,
debug_truncate_bytes=self.debug_truncate_bytes,
rate_limit=self.rate_limit,
max_connection_pools=self.max_connection_pools,
max_connections_per_pool=self.max_connections_per_pool,
product="ucx",
product_version=__version__,
)

@classmethod
def from_dict(cls, raw: dict):
return cls(**raw)

class _Config(Generic[T]):
connect: ConnectConfig | None = None

@classmethod
@abstractmethod
def from_dict(cls, raw: dict[str, Any]) -> T:
...

@classmethod
def from_bytes(cls, raw_str: str | bytes) -> T:
from yaml import safe_load

raw: dict[str, Any] = safe_load(raw_str)
empty: dict[str, Any] = {}
return cls.from_dict(empty if not raw else raw)

@classmethod
def from_file(cls, config_file: Path) -> T:
return cls.from_bytes(config_file.read_text())

def __post_init__(self):
if self.connect is None:
self.connect = ConnectConfig()

def to_databricks_config(self) -> Config:
connect = self.connect
if connect is None:
# default empty config
connect = ConnectConfig()
return connect.to_databricks_config()

def as_dict(self) -> dict[str, Any]:
from dataclasses import fields, is_dataclass

def inner(x):
if is_dataclass(x):
result = []
for f in fields(x):
value = inner(getattr(x, f.name))
if not value:
continue
result.append((f.name, value))
return dict(result)
return x

serialized = inner(self)
serialized["version"] = _CONFIG_VERSION
return serialized


class RawState(TypedDict):
Expand Down Expand Up @@ -86,3 +210,124 @@ def save(self) -> None:
format=ImportFormat.AUTO,
overwrite=True,
)

T = typing.TypeVar('T')
def load_typed_file(self, type_ref: typing.Type[T]) -> T:
# TODO: load with type_ref, convert JSON/YAML into a dataclass instance, discover format migrations from methods
# TODO: detect databricks config and allow using it as part of dataclass instance
# TODO: MockInstallState to get JSON/YAML created/loaded as dict-per-filename
raise NotImplementedError

def save_typed_file(self, inst: T):
# TODO: save JSON/YML, versioned
raise NotImplementedError

def load_csv(self, type_ref: typing.Type[T]) -> list[T]:
# TODO: load/save arrays in CSV format
# TODO: MockInstallState to get CSV file created/loaded as slice-of-dataclasses
raise NotImplementedError

def save_csv(self, records: list[T]) -> list[T]:
# TODO: load/save arrays in CSV format
raise NotImplementedError

def list_files(self) -> list[workspace.ObjectInfo]:
# TODO: list files under install folder
raise NotImplementedError

# TODO: add from_dict to databricks config (or make it temporary hack in unmarshaller)

def upload_to_dbfs(self) -> str:
# TODO: use this in Wheels to upload/download random files into correct prefix in WSFS/DBFS
with self._local_wheel.open("rb") as f:
self._ws.dbfs.mkdirs(self._remote_dir_name)
logger.info(f"Uploading wheel to dbfs:{self._remote_wheel}")
self._ws.dbfs.upload(self._remote_wheel, f, overwrite=True)
return self._remote_wheel

def upload_to_wsfs(self) -> str:
with self._local_wheel.open("rb") as f:
self._ws.workspace.mkdirs(self._remote_dir_name)
logger.info(f"Uploading wheel to /Workspace{self._remote_wheel}")
self._ws.workspace.upload(self._remote_wheel, f, overwrite=True, format=ImportFormat.AUTO)
return self._remote_wheel

def _load_versioned_json(self,
name: str,
expected_version: int,
parse_raw: Callable[[io.BinaryIO], Json] = json.load,
format_migrations: list[Callable[[int, Json], Json]] = None) -> Json:
if not format_migrations:
format_migrations = []
target_file = f'{self.install_folder()}/{name}'
try:
raw = parse_raw(self._ws.workspace.download(target_file))
version = raw.pop("$version", None)
if not version:
raise IllegalState('no $version found')
for migrate in format_migrations:
raw = migrate(version, raw)
if version != expected_version:
msg = f"expected state $version={self._config_version}, got={version}"
raise IllegalState(msg)
return raw
except NotFound:
return {}
except JSONDecodeError:
logger.warning(f"JSON state file corrupt: {self._state_file}")
return {}

def _is_assignable(self,
type_ref: type, raw: Any, path: list[str], name_transform: Callable[[str], str]
) -> tuple[bool, str | None]:
if dataclasses.is_dataclass(type_ref):
if not isinstance(raw, dict):
return False, self._explain_why(dict, raw, path)
for field, hint in typing.get_type_hints(type_ref).items():
field = name_transform(field)
valid, why_not = self._is_assignable(hint, raw.get(field), [*path, field], name_transform)
if not valid:
return False, why_not
return True, None
if isinstance(type_ref, types.GenericAlias):
if not isinstance(raw, list):
return False, self._explain_why(list, raw, path)
type_args = typing.get_args(type_ref)
if not type_args:
raise TypeError(f"Missing type arguments: {type_args}")
item_ref = type_args[0]
for i, v in enumerate(raw):
valid, why_not = self._is_assignable(item_ref, v, [*path, f"{i}"], name_transform)
if not valid:
return False, why_not
return True, None
if isinstance(type_ref, types.UnionType):
combo = []
for variant in typing.get_args(type_ref):
valid, why_not = self._is_assignable(variant, raw, [], name_transform)
if valid:
return True, None
if why_not:
combo.append(why_not)
return False, f'{".".join(path)}: union: {" or ".join(combo)}'
if isinstance(type_ref, abc.ABCMeta):
# until we generate method that returns subtypes
return True, None
if isinstance(type_ref, enum.EnumMeta):
if raw in type_ref._value2member_map_:
return True, None
return False, self._explain_why(type_ref, raw, path)
if type_ref == types.NoneType:
if raw is None:
return True, None
return False, None
if type_ref in (int, bool, float, str):
if type_ref == type(raw):
return True, None
return False, self._explain_why(type_ref, raw, path)
return False, f'{".".join(path)}: unknown: {raw}'

def _explain_why(self, type_ref: type, raw: Any, path: list[str]) -> str:
if raw is None:
raw = "value is missing"
return f'{".".join(path)}: not a {type_ref.__name__}: {raw}'

0 comments on commit 023368c

Please sign in to comment.