Skip to content

Commit

Permalink
initial support for "value pointers": large objects...
Browse files Browse the repository at this point in the history
identified by human-readable names, and not saved to storage.
  • Loading branch information
amakelov committed Dec 20, 2024
1 parent c154a68 commit efcf66f
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 18 deletions.
2 changes: 1 addition & 1 deletion mandala/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .storage import Storage
from .model import op, Ignore, NewArgDefault, wrap_atom
from .model import op, Ignore, NewArgDefault, wrap_atom, ValuePointer
from .tps import MList, MDict
from .deps.tracers.dec_impl import track

Expand Down
50 changes: 50 additions & 0 deletions mandala/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __repr__(self) -> str:
return "Atom" + super().__repr__()


################################################################################
### wrappers for values that should be treated in special ways
###############################################################################
class _Ignore:
"""
Used to mark values that should be ignored by the storage, but still
Expand All @@ -83,6 +86,41 @@ class _NewArgDefault(_Ignore):
"""
pass


class ValuePointer:
"""
Replace an object by a human-readable name from the point of view of the
storage.
A value wrapped in a `ValuePointer` will:
- be identified via its `id` when hashing (when computing both the content
and history IDs)
- will not be saved in the storage
- will be replaced by the underlying `obj` when passed to a memoized
function
This is useful for passing large, complex objects (e.g., a machine learning
dataset or model) that is immutable over the course of a project and is
either not serializable, or stored elsewhere and thus we don't need to
duplicate it in the storage.
"""
def __init__(self, id: str, obj: Any):
if not isinstance(id, str) or not id:
raise ValueError("The `id` must be a non-empty string.")
self.id = id
self.obj = obj

def __eq__(self, other: Any) -> bool:
raise NotImplementedError("ValuePointer objects should not be compared.")

def __hash__(self) -> int:
raise NotImplementedError("ValuePointer objects should not be hashed.")

def __repr__(self) -> str:
obj_repr = repr(self.obj)
return f"ValuePointer({self.id!r}, {obj_repr})"


T = TypeVar("T")
def Ignore(value: T = None) -> T:
Expand All @@ -92,6 +130,9 @@ def NewArgDefault(value: T = None) -> T:
return _NewArgDefault(value)


################################################################################
### ops and calls
################################################################################
class Op:
def __init__(
self,
Expand Down Expand Up @@ -244,10 +285,19 @@ def wrap_atom(obj: Any, history_id: Optional[str] = None) -> AtomRef:
it unchanged. If `history_id` is not provided, it will be initialized
from the object's content hash (thereby representing an object without any
history).
All operations that wrap a value (whether a collection or an atom) must
factor through this function.
"""
if isinstance(obj, Ref):
if not isinstance(obj, AtomRef):
raise ValueError(f"Expected an AtomRef, got {type(obj)}")
assert history_id is None
return obj
if isinstance(obj, ValuePointer):
# we never directly hash the object, but rather the id
uid = get_content_hash(obj.id)
return AtomRef(cid=uid, hid=uid, in_memory=True, obj=ValuePointer(id=obj.id, obj=None))
uid = get_content_hash(obj)
if history_id is None:
history_id = get_content_hash(uid)
Expand Down
27 changes: 11 additions & 16 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
from .model import *
import sqlite3
from .model import __make_list__, __list_getitem__, __make_dict__, __dict_getitem__, _Ignore, _NewArgDefault
from .model import __make_list__, __list_getitem__, __make_dict__, __dict_getitem__, _Ignore, _NewArgDefault, ValuePointer
from .utils import dataframe_to_prettytable, parse_returns, _conservative_equality_check
from .viz import _get_colorized_diff
from .deps.versioner import Versioner, CodeState
Expand Down Expand Up @@ -510,11 +510,6 @@ def get_struct_inputs(self, tp: Type, val: Any) -> Dict[str, Any]:
# the keys must be strings
assert all(isinstance(k, str) for k in val.keys())
res = val
# sorted_keys = sorted(val.keys())
# res = {}
# for i, k in enumerate(sorted_keys):
# res[f'key_{i}'] = k
# res[f'value_{i}'] = val[k]
return res
else:
raise NotImplementedError
Expand All @@ -532,19 +527,15 @@ def get_struct_tps(
result = {}
for input_name in struct_inputs.keys():
result[input_name] = tp.val
# if input_name.startswith("key_"):
# i = int(input_name.split("_")[-1])
# result[f"key_{i}"] = tp.key
# elif input_name.startswith("value_"):
# i = int(input_name.split("_")[-1])
# result[f"value_{i}"] = tp.val
# else:
# raise ValueError(f"Invalid input name {input_name}")
return result
else:
raise NotImplementedError

def construct(self, tp: Type, val: Any) -> Tuple[Ref, List[Call]]:
"""
Given a target type and a value, construct a `Ref` of the target type,
as well as the associated structural calls, if any.
"""
if isinstance(val, Ref):
return val, []
if isinstance(tp, AtomType):
Expand Down Expand Up @@ -833,10 +824,10 @@ def call_internal(
if op.__structural__:
returns = f(**wrapped_inputs)
else:
# #! guard against side effects
# # guard against side effects
# cids_before = {k: v.cid for k, v in wrapped_inputs.items()}
# raw_values = {k: self.unwrap(v) for k, v in wrapped_inputs.items()}
#! call the function
### we must run the function
kwargs = {}
if kwarg_keys is not None:
for k in kwarg_keys:
Expand All @@ -855,13 +846,17 @@ def call_internal(
args = self.unwrap(args)
kwargs.update(leftover_kwargs)
kwargs = self.unwrap(kwargs)
# replace any ValuePointer instances with their underlying objects
kwargs = {k: v.obj if isinstance(v, ValuePointer) else v for k, v in kwargs.items()}
args = tuple([v.obj if isinstance(v, ValuePointer) else v for v in args])

if tracer_option is not None:
tracer = tracer_option
with tracer:
if isinstance(tracer, DecTracer):
f = track(op.f)
node = tracer.register_call(func=f)
#! call the function
returns = f(*args, **kwargs)
if isinstance(tracer, DecTracer):
tracer.register_return(node=node)
Expand Down
22 changes: 21 additions & 1 deletion mandala/tests/test_memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,24 @@ def add_array(x:np.ndarray, y=NewArgDefault(None)):

# now test passing a wrapped value
with storage:
add_array(np.array([1, 2, 3]), y=wrap_atom(np.array([7, 8, 9])))
add_array(np.array([1, 2, 3]), y=wrap_atom(np.array([7, 8, 9])))




def test_value_pointer():
storage = Storage()

@op
def get_mean(x: np.ndarray) -> float:
return x.mean()

with storage:
X = np.array([1, 2, 3, 4, 5])
X_pointer = ValuePointer("X", X)
mean = get_mean(X_pointer)

assert storage.unwrap(mean) == 3.0
df = storage.cf(get_mean).df()
assert len(df) == 1
assert df['x'].item().id == "X"

0 comments on commit efcf66f

Please sign in to comment.