Skip to content

Commit

Permalink
Allow lazyexpr() to seek for operands if not passed
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescAlted committed Nov 4, 2024
1 parent 544d7d3 commit 9d5e481
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
59 changes: 52 additions & 7 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import pathlib
import re
import sys
import threading
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -2209,11 +2210,48 @@ def lazyudf(
return LazyUDF(func, inputs, dtype, chunked_eval, **kwargs)


def seek_operands(names, local_dict=None, global_dict=None, _frame_depth: int = 2):
"""
Get the arguments based on the names.
"""
call_frame = sys._getframe(_frame_depth)

clear_local_dict = False
if local_dict is None:
local_dict = call_frame.f_locals
clear_local_dict = True
try:
frame_globals = call_frame.f_globals
if global_dict is None:
global_dict = frame_globals

# If `call_frame` is the top frame of the interpreter we can't clear its
# `local_dict`, because it is actually the `global_dict`.
clear_local_dict = clear_local_dict and frame_globals is not local_dict

op_dict = {}
for name in names:
try:
a = local_dict[name]
except KeyError:
a = global_dict[name]
op_dict[name] = a
finally:
# If we generated local_dict via an explicit reference to f_locals,
# clear the dict to prevent creating extra ref counts in the caller's scope
if clear_local_dict and hasattr(local_dict, "clear"):
local_dict.clear()

return op_dict


def lazyexpr(
expression: str | bytes | LazyExpr,
operands: dict | None = None,
out: blosc2.NDArray | np.ndarray = None,
where: tuple | list | None = None,
local_dict: dict | None = None,
global_dict: dict | None = None,
) -> LazyExpr:
"""
Get a LazyExpr from an expression.
Expand All @@ -2227,17 +2265,18 @@ def lazyexpr(
operands: dict
The dictionary with operands. Supported values are NumPy.ndarray,
Python scalars, :ref:`NDArray`, :ref:`NDField` or :ref:`C2Array` instances.
If None, the operands will be seeked in the local and global dictionaries.
out: NDArray or np.ndarray, optional
The output array where the result will be stored. If not provided,
a new array will be created.
where: tuple, list, optional
A sequence of arguments for the where clause in the expression.
guess: bool, optional
Whether to guess the output dtype and shape. If False, the dtype and shape
will be computed producing temporary arrays in the process (e.g. for reductions).
If True, the dtype and shape will be guessed from the expression, but without
evaluating any part of it. Use True when you want to e.g. save the expression
but without evaluating it.
local_dict: dict, optional
The local dictionary to use when looking for operands in the expression.
If not provided, the local dictionary of the caller will be used.
global_dict: dict, optional
The global dictionary to use when looking for operands in the expression.
If not provided, the global dictionary of the caller will be used.
Returns
-------
Expand Down Expand Up @@ -2279,7 +2318,13 @@ def lazyexpr(
expression._where_args = where_args
return expression
if operands is None:
raise ValueError("`operands` must be provided for a string expression")
# Try to get operands from variables in the stack
operands = get_expr_operands(expression)
# If no operands are found, raise an error
if operands is None:
raise ValueError("No operands found in the expression")
# Look for operands in the stack
operands = seek_operands(operands, local_dict, global_dict)

return LazyExpr._new_expr(expression, operands, guess=True, out=out, where=where)

Expand Down
8 changes: 6 additions & 2 deletions tests/ndarray/test_lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,8 @@ def test_broadcasting(broadcast_fixture):
("numpy", "numpy"),
],
)
def test_lazyexpr(array_fixture, operand_mix):
@pytest.mark.parametrize("operand_guess", [True, False])
def test_lazyexpr(array_fixture, operand_mix, operand_guess):
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
if operand_mix[0] == "NDArray" and operand_mix[1] == "NDArray":
operands = {"a1": a1, "a2": a2, "a3": a3, "a4": a4}
Expand All @@ -795,7 +796,10 @@ def test_lazyexpr(array_fixture, operand_mix):
operands = {"a1": na1, "a2": na2, "a3": na3, "a4": na4}

# Check eval()
expr = blosc2.lazyexpr("a1 + a2 - a3 * a4", operands=operands)
if operand_guess:
expr = blosc2.lazyexpr("a1 + a2 - a3 * a4")
else:
expr = blosc2.lazyexpr("a1 + a2 - a3 * a4", operands=operands)
nres = ne.evaluate("na1 + na2 - na3 * na4")
res = expr.compute()
np.testing.assert_allclose(res[:], nres)
Expand Down

0 comments on commit 9d5e481

Please sign in to comment.