Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Jan 20, 2024
1 parent 72aa4be commit 38df7be
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 15 deletions.
72 changes: 57 additions & 15 deletions src/databricks/labs/blueprint/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
logger = logging.getLogger(__name__)

Resources = dict[str, str]
Json = dict[str, Any] | list[dict[str, Any]]
Json = dict[str, Any]

# @dataclass
# class ConnectConfig:
Expand Down Expand Up @@ -210,21 +210,63 @@ def save(self) -> None:

def _overwrite(self, filename: str, raw: bytes):
with self._lock:
self._ws.workspace.upload(
f"{self.install_folder()}/{filename}",
raw, # type: ignore[arg-type]
format=ImportFormat.AUTO,
overwrite=True,
)
dst = f"{self.install_folder()}/{filename}"
attempt = partial(self._ws.workspace.upload, dst, raw, format=ImportFormat.AUTO, overwrite=True)
try:
attempt()
except NotFound:
self._ws.workspace.mkdirs(self.install_folder())
attempt()
return dst

T = typing.TypeVar("T")

def load_typed_file(self, type_ref: typing.Type[T]) -> T:
def load_typed_file(self, type_ref: typing.Type[T], *, filename: str = None) -> T:
# TODO: load with type_ref, convert JSON/YAML into a dataclass instance, discover format migrations from methods
# TODO: detect databricks config and allow using it as part of dataclass instance
# TODO: MockInstallState to get JSON/YAML created/loaded as dict-per-filename
if not filename and hasattr(type_ref, "__file__"):
filename = getattr(type_ref, "__file__")
elif not filename:
filename = f"{type_ref.__name__}.json"
expected_version = None
if hasattr(type_ref, "__version__"):
expected_version = getattr(type_ref, "__version__")
as_dict = self._load_content(filename)
if expected_version:
actual_version = as_dict.pop("$version", 1)
while actual_version < expected_version:
migrate = getattr(type_ref, f'v{actual_version}_migrate', None)
if not migrate:
break
as_dict = migrate(as_dict)
actual_version = as_dict.pop("$version", 1)
if actual_version != expected_version:
raise IllegalState(f"expected state $version={expected_version}, got={actual_version}")
raise NotImplementedError

def _load_yaml(self, raw: typing.BinaryIO) -> Json:
try:
try:
return yaml.safe_load(raw)
except yaml.YAMLError as err:
raise JSONDecodeError(str(err), '<yaml>', 0)
except ImportError:
raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]")

def _load_content(self, filename: str) -> Json:
converters = {"json": json.load, "yml": self._load_yaml}
extension = filename.split(".")[-1]
if extension not in converters:
raise KeyError(f"Unknown extension: {extension}")
try:
with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f:
return converters[extension](f)
except JSONDecodeError:
return {}
except NotFound:
return {}

def _dump_yaml(self, raw: Json) -> bytes:
try:
return yaml.dump(raw).encode("utf8")
Expand Down Expand Up @@ -344,12 +386,9 @@ def upload_to_dbfs(self) -> str:
self._ws.dbfs.upload(self._remote_wheel, f, overwrite=True)
return self._remote_wheel

def upload_to_wsfs(self) -> str:
with self._local_wheel.open("rb") as f:
self._ws.workspace.mkdirs(self._remote_dir_name)
logger.info(f"Uploading wheel to /Workspace{self._remote_wheel}")
self._ws.workspace.upload(self._remote_wheel, f, overwrite=True, format=ImportFormat.AUTO)
return self._remote_wheel
def upload_to_wsfs(self, filename: str, raw: bytes) -> str:
# TODO: use in wheels
return self._overwrite(filename, raw)

def _load_versioned_json(
self,
Expand Down Expand Up @@ -448,7 +487,10 @@ def install_folder(self) -> str:
def _overwrite_content(self, filename: str, as_dict: Json):
self._overwrites[filename] = as_dict

def assert_file_written(self, filename: str, expected: Json):
def _load_content(self, filename: str) -> Json:
return self._overwrites[filename]

def assert_file_written(self, filename: str, expected: Any):
assert filename in self._overwrites, f"{filename} had no writes"
actual = self._overwrites[filename]
assert expected == actual, f"{filename} content missmatch"
17 changes: 17 additions & 0 deletions tests/unit/test_installer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -122,6 +123,22 @@ def test_save_typed_file():
overwrite=True,
)


def test_load_typed_file():
ws = create_autospec(WorkspaceClient)
ws.current_user.me().user_name = "foo"
ws.workspace.download.return_value = io.StringIO(yaml.dump({
"$version": 2,
"num_threads": 20,
"inventory_database": "some_blueprint",
}))
state = InstallState(ws, "blueprint")

cfg = state.load_typed_file(WorkspaceConfig)

assert 20 == cfg.num_threads


def test_save_typed_file_array():
state = MockInstallState()

Expand Down

0 comments on commit 38df7be

Please sign in to comment.