Skip to content

Commit

Permalink
fix #20 for now by ignoring closures in versioning;
Browse files Browse the repository at this point in the history
also, remove a leftover static method to compute globals
  • Loading branch information
amakelov committed Aug 20, 2024
1 parent b2af6d2 commit fb5bfd2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
45 changes: 28 additions & 17 deletions mandala/deps/tracers/dec_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,28 @@ def set_active_trace_obj(trace_obj: Optional["DecTracer"]):
TracerState.tracer = trace_obj

def get_globals(self, func: Callable) -> List[GlobalVarNode]:
result = []
code_obj = extract_code(obj=func)
global_scope = extract_func_obj(obj=func, strict=self.strict).__globals__
for name in get_global_names_candidates(code=code_obj):
# names used by the function; not all of them are global variables
if name in global_scope.keys():
global_val = global_scope[name]
if not is_global_val(global_val):
continue
node = GlobalVarNode.from_obj(
obj=global_val, dep_key=(func.__module__, name)
)
result.append(node)
return result
"""
Get the global variables available to the function as a list of
GlobalVarNode objects.
Currently, this is not used, because it doesn't really track accesses
to globals, and can thus over-estimate the dependencies of a function.
"""
# result = []
# code_obj = extract_code(obj=func)
# global_scope = extract_func_obj(obj=func, strict=self.strict).__globals__
# for name in get_global_names_candidates(code=code_obj):
# # names used by the function; not all of them are global variables
# if name in global_scope.keys():
# global_val = global_scope[name]
# if not is_global_val(global_val):
# continue
# node = GlobalVarNode.from_obj(
# obj=global_val, dep_key=(func.__module__, name)
# )
# result.append(node)
# return result
return []

def register_call(self, func: Callable) -> CallableNode:
module_name = func.__module__
Expand All @@ -216,7 +224,7 @@ def register_call(self, func: Callable) -> CallableNode:
)
if len(closure_names) > 0:
msg = f"Found closure variables accessed by function {module_name}.{qualname}:\n{closure_names}"
self._process_failure(msg)
self._process_failure(msg, level='debug')
### get call node
node = CallableNode.from_runtime(
module_name=module_name, obj_name=qualname, code_obj=extract_code(obj=func)
Expand Down Expand Up @@ -278,8 +286,11 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
DecTracer.set_active_trace_obj(None)

def _process_failure(self, msg: str):
def _process_failure(self, msg: str, level: str = 'warning'):
if self.strict:
raise RuntimeError(msg)
else:
logger.warning(msg)
if level == 'warning':
logger.warning(msg)
elif level == 'debug':
logger.debug(msg)
6 changes: 4 additions & 2 deletions mandala/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .deps.versioner import Versioner, CodeState
from .deps.utils import get_dep_key_from_func, extract_func_obj
from .deps.tracers import DecTracer, SysTracer, TracerABC
from .deps.tracers.dec_impl import track

from .storage_utils import (
DBAdapter,
Expand All @@ -25,7 +26,7 @@ class Storage:
def __init__(self, db_path: str = ":memory:",
deps_path: Optional[Union[str, Path]] = None,
tracer_impl: Optional[type] = None,
strict_tracing: bool = True,
strict_tracing: bool = False,
deps_package: Optional[str] = None,
):
self.db = DBAdapter(db_path=db_path)
Expand Down Expand Up @@ -805,7 +806,8 @@ def call_internal(
tracer = tracer_option
with tracer:
if isinstance(tracer, DecTracer):
node = tracer.register_call(func=op.f)
f = track(op.f)
node = tracer.register_call(func=f)
returns = f(*args, **kwargs)
if isinstance(tracer, DecTracer):
tracer.register_return(node=node)
Expand Down

0 comments on commit fb5bfd2

Please sign in to comment.