Skip to content

Commit

Permalink
Remove need to unwrap Placeholder (#225)
Browse files Browse the repository at this point in the history
Use the `wrapt` package to implement a proxy placeholder used for `Commit`s in a transaction.  Now we do not need to unwrap a placeholder after it has been filled.
  • Loading branch information
maxmynter authored Dec 15, 2023
1 parent 3d51aaa commit 90f76a9
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 30 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ readme = "README.md"
license = { text = "Apache-2.0" }
authors = [{ name = "appliedAI Initiative", email = "[email protected]" }]
maintainers = [
{ name = "Nicholas Junge", email = "[email protected]" },
{ name = "Max Mynter", email = "[email protected]" },
{ name = "Adrian Rumpold", email = "[email protected]" },
{ name = "Nicholas Junge", email = "n.junge@appliedai-institute.de" },
{ name = "Max Mynter", email = "m.mynter@appliedai-institute.de" },
{ name = "Adrian Rumpold", email = "a.rumpold@appliedai-institute.de" },
]
classifiers = [
"Development Status :: 3 - Alpha",
Expand All @@ -32,7 +32,7 @@ classifiers = [
"Typing :: Typed",
]

dependencies = ["fsspec>=2023.6.0", "lakefs-sdk>=1.0.0", "pyyaml>=6.0.1"]
dependencies = ["fsspec>=2023.6.0", "lakefs-sdk>=1.0.0", "pyyaml>=6.0.1", "wrapt"]

dynamic = ["version"]

Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ six==1.16.0
typing-extensions==4.8.0
urllib3==2.0.7
virtualenv==20.25.0
wrapt==1.16.0

# The following packages are considered to be unsafe in a requirements file:
# setuptools
1 change: 1 addition & 0 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,5 @@ webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
widgetsnbextension==4.0.9
wrapt==1.16.0
zipp==3.17.0
36 changes: 15 additions & 21 deletions src/lakefs_spec/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar

import wrapt
from fsspec.spec import AbstractBufferedFile
from fsspec.transaction import Transaction
from lakefs_sdk.client import LakeFSClient
Expand All @@ -25,30 +26,25 @@


@dataclass
class Placeholder(Generic[T]):
class Placeholder(Generic[T], wrapt.ObjectProxy):
"""A generic placeholder for a value computed by the lakeFS server in a versioning operation during a transaction."""

value: T | None = None
"""The abstract value. Set only on completion of the versioning operation during the defining transaction."""
def __init__(self, wrapped: T | None = None):
super().__init__(wrapped)

def available(self):
@property
def available(self) -> bool:
"""Whether the wrapped value is available, i.e. already computed."""
return self.value is not None
return self.__wrapped__ is not None

def set_value(self, value: T) -> None:
"""Fill in the placeholder. Not meant to be called directly except in the completion of the transaction."""
self.value = value

def unwrap(self) -> T:
"""Return the placeholder's value after it has been filled."""
if self.value is None:
raise RuntimeError("placeholder unfilled")
return self.value
@property
def value(self):
return self.__wrapped__


def unwrap_placeholders(kwargs: dict[str, Any]) -> dict[str, Any]:
"""Unwrap any placeholder values passed in a dictionary of keyword arguments."""
return {k: v.unwrap() if isinstance(v, Placeholder) else v for k, v in kwargs.items()}
@value.setter
def value(self, val: T) -> None:
"""Fill in the placeholder. Not meant to be called directly except in the completion of the transaction."""
self.__wrapped__ = val


class LakeFSTransaction(Transaction):
Expand Down Expand Up @@ -133,7 +129,7 @@ def complete(self, commit: bool = True) -> None:
# if the transaction member returns a placeholder,
# fill it with the result of the client helper.
if isinstance(retval, Placeholder):
retval.set_value(result)
retval.value = result

self.fs._intrans = False

Expand Down Expand Up @@ -228,7 +224,6 @@ def rev_parse(
"""

def rev_parse_op(client: LakeFSClient, **kwargs: Any) -> Commit:
kwargs = unwrap_placeholders(kwargs)
return rev_parse(client, **kwargs)

p: Placeholder[Commit] = Placeholder()
Expand Down Expand Up @@ -256,7 +251,6 @@ def tag(self, repository: str, ref: str | Placeholder[Commit], tag: str) -> str:
"""

def tag_op(client: LakeFSClient, **kwargs: Any) -> Ref:
kwargs = unwrap_placeholders(kwargs)
return create_tag(client, **kwargs)

self.files.append((partial(tag_op, repository=repository, ref=ref, tag=tag), tag))
Expand Down
36 changes: 31 additions & 5 deletions tests/test_transactions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from lakefs_sdk.models import Commit

from lakefs_spec import LakeFSFileSystem
from tests.util import RandomFileFactory, with_counter

Expand All @@ -24,9 +26,9 @@ def test_transaction_commit(
sha = tx.commit(repository, temp_branch, message=message)
# stack contains the file to upload, and the commit op.
assert len(tx.files) == 2
assert not sha.available()
assert not sha.available

assert sha.available()
assert sha.available

commits = fs.client.refs_api.log_commits(
repository=repository,
Expand All @@ -35,7 +37,7 @@ def test_transaction_commit(
latest_commit = commits.results[0]

assert latest_commit.message == message
assert latest_commit.id == sha.value.id
assert latest_commit.id == sha.id


def test_transaction_tag(fs: LakeFSFileSystem, repository: str) -> None:
Expand All @@ -45,11 +47,11 @@ def test_transaction_tag(fs: LakeFSFileSystem, repository: str) -> None:
sha = tx.rev_parse(repository, "main")
tag = tx.tag(repository=repository, ref=sha, tag="v2")

assert sha.available()
assert sha.available

tags = fs.client.tags_api.list_tags(repository).results
assert tags[0].id == tag
assert tags[0].commit_id == sha.value.id
assert tags[0].commit_id == sha.id
finally:
fs.client.tags_api.delete_tag(repository=repository, tag=tag)

Expand Down Expand Up @@ -166,3 +168,27 @@ def test_transaction_failure(

# assert that no commit happens because of the exception.
assert counter.count("commits_api.commit") == 0


def test_placeholder_representations(
random_file_factory: RandomFileFactory,
fs: LakeFSFileSystem,
repository: str,
temp_branch: str,
) -> None:
random_file = random_file_factory.make()

lpath = str(random_file)
rpath = f"{repository}/{temp_branch}/{random_file.name}"

message = f"Add file {random_file.name}"

with fs.transaction as tx:
fs.put_file(lpath, rpath)
sha = tx.commit(repository, temp_branch, message=message)
latest_commit = fs.client.refs_api.log_commits(repository=repository, ref=temp_branch).results[
0
]
assert isinstance(sha, Commit)
assert sha.id == latest_commit.id
assert repr(sha.id) == repr(latest_commit.id)

0 comments on commit 90f76a9

Please sign in to comment.