Skip to content

Commit

Permalink
an option to make non-memoized calls raise an exception
Browse files Browse the repository at this point in the history
  • Loading branch information
amakelov committed Jan 14, 2025
1 parent af015e3 commit f9b7469
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
20 changes: 16 additions & 4 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,
# stack of modes
self._mode_stack = []
self._next_mode = 'run'
self._allow_new_calls = True

def dump_config(self) -> dict[str, Any]:
return {
Expand All @@ -144,17 +145,24 @@ def dump_config(self) -> dict[str, Any]:
"track_globals": self._track_globals,
}

@property
def mode(self) -> str:
return self._mode_stack[-1] if self._mode_stack else 'run'

def conn(self) -> sqlite3.Connection:
return self.db.conn()

def vacuum(self):
with self.conn() as conn:
conn.execute("VACUUM")


############################################################################
### runtime configuration options
############################################################################
@property
def mode(self) -> str:
return self._mode_stack[-1] if self._mode_stack else 'run'

def allow_new_calls(self, allow: bool):
self._allow_new_calls = allow

############################################################################
### managing the caches
############################################################################
Expand Down Expand Up @@ -816,6 +824,10 @@ def call_internal(
if not op.__structural__: logger.debug(f"Call to {op.name} with hid {call_hid} already exists.")
main_call = call_option
return main_call.outputs, main_call, input_calls

if not self._allow_new_calls:
# caller should decide how to handle this
raise RuntimeError(f"Call to {op.name} does not exist and new calls are not allowed.")

### execute the call if it doesn't exist
if not op.__structural__:
Expand Down
27 changes: 27 additions & 0 deletions mandala/tests/test_modes.py → mandala/tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,30 @@ def inc(x: int) -> int:
with noop():
x = inc(20)
assert x == 21


def test_no_new_calls():
@op
def inc(x: int) -> int:
return x + 1

storage = Storage()
with storage:
inc(20)

storage.allow_new_calls(False)

# memoized calls should still work
with storage:
inc(20)

try:
with storage:
inc(21)
except RuntimeError as e:
assert str(e) == "Call to inc does not exist and new calls are not allowed."
except Exception as e:
raise e
finally:
storage.allow_new_calls(True)

0 comments on commit f9b7469

Please sign in to comment.