Skip to content

Commit

Permalink
WIP sympy integration via ufunc dispatch
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Paine <[email protected]>
  • Loading branch information
timkpaine committed Feb 10, 2024
1 parent f4b8ac9 commit b8c03e8
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 61 deletions.
190 changes: 147 additions & 43 deletions csp/baselib.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,28 @@
"min",
"gate",
"floordiv",
"mod",
"pow",
"abs",
"ln",
"log2",
"log10",
"exp",
"abs",
"exp2",
"sqrt",
"erf",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"unroll",
"collect",
"flatten",
Expand Down Expand Up @@ -670,11 +688,44 @@ def or_(*inputs):
# Math/comparison binary operators are supported in C++ only for (int,int) and
# (float, float) arguments. For all other types, the Python implementation is used.

MATH_OPS = ["add", "sub", "multiply", "divide", "pow", "max", "min"]
MATH_OPS = [
# binary
"add",
"sub",
"multiply",
"divide",
"pow",
"max",
"min",
"floordiv",
"mod",
# unary
"abs",
"ln",
"log2",
"log10",
"exp",
"exp2",
"sqrt",
"erf",
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
]

COMP_OPS = ["eq", "ne", "lt", "gt", "le", "ge"]

MATH_COMP_OPS_CPP = {
# binary math
("add", "float"): _cspbaselibimpl.add_f,
("add", "int"): _cspbaselibimpl.add_i,
("sub", "float"): _cspbaselibimpl.sub_f,
Expand All @@ -689,6 +740,11 @@ def or_(*inputs):
("max", "int"): _cspbaselibimpl.max_i,
("min", "float"): _cspbaselibimpl.min_f,
("min", "int"): _cspbaselibimpl.min_i,
# unary math
("abs", "float"): _cspbaselibimpl.abs,
("ln", "float"): _cspbaselibimpl.ln,
("exp", "float"): _cspbaselibimpl.exp,
# binary comparator
("eq", "float"): _cspbaselibimpl.eq_f,
("eq", "int"): _cspbaselibimpl.eq_i,
("ne", "float"): _cspbaselibimpl.ne_f,
Expand All @@ -705,7 +761,7 @@ def or_(*inputs):


@lru_cache(maxsize=512)
def define_op(name, op_lambda):
def define_binary_op(name, op_lambda):
float_out_type, int_out_type, generic_out_type = [None] * 3
if name in COMP_OPS:
float_out_type = bool
Expand All @@ -722,12 +778,12 @@ def define_op(name, op_lambda):

from csp.impl.wiring.node import _node_internal_use

@_node_internal_use(cppimpl=MATH_COMP_OPS_CPP[(name, "float")], name=name)
@_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "float"), None), name=name)
def float_type(x: ts[float], y: ts[float]) -> ts[float_out_type]:
if csp.valid(x, y):
return op_lambda(x, y)

@_node_internal_use(cppimpl=MATH_COMP_OPS_CPP[(name, "int")], name=name)
@_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "int"), None), name=name)
def int_type(x: ts[int], y: ts[int]) -> ts[int_out_type]:
if csp.valid(x, y):
return op_lambda(x, y)
Expand Down Expand Up @@ -759,32 +815,98 @@ def comp(x: ts["T"], y: ts["U"]):
return comp


# Math operators
@lru_cache(maxsize=512)
def define_unary_op(name, op_lambda):
float_out_type, int_out_type, generic_out_type = [None] * 3
if name in COMP_OPS:
float_out_type = bool
int_out_type = bool
generic_out_type = bool
elif name in MATH_OPS:
float_out_type = float
if name in ("abs",):
int_out_type = int
generic_out_type = "T"
else:
int_out_type = float
generic_out_type = float

add = define_op("add", lambda x, y: x + y)
sub = define_op("sub", lambda x, y: x - y)
multiply = define_op("multiply", lambda x, y: x * y)
pow = define_op("pow", lambda x, y: x**y)
divide = define_op("divide", lambda x, y: x / y)
min = define_op("min", lambda x, y: x if x < y else y)
max = define_op("max", lambda x, y: x if x > y else y)
from csp.impl.wiring.node import _node_internal_use

# Comparison operators
@_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "float"), None), name=name)
def float_type(x: ts[float]) -> ts[float_out_type]:
if csp.valid(x):
return op_lambda(x)

eq = define_op("eq", lambda x, y: x == y)
ne = define_op("ne", lambda x, y: x != y)
gt = define_op("gt", lambda x, y: x > y)
lt = define_op("lt", lambda x, y: x < y)
ge = define_op("ge", lambda x, y: x >= y)
le = define_op("le", lambda x, y: x <= y)
@_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "int"), None), name=name)
def int_type(x: ts[int]) -> ts[int_out_type]:
if csp.valid(x):
return op_lambda(x)

# Other math ops
@_node_internal_use(name=name)
def numpy_type(x: ts["T"]) -> ts[np.ndarray]:
if csp.valid(x):
return op_lambda(x)

@_node_internal_use(name=name)
def generic_type(x: ts["T"]) -> ts[generic_out_type]:
if csp.valid(x):
return op_lambda(x)

def comp(x: ts["T"]):
if x.tstype.typ in [Numpy1DArray[float], NumpyNDArray[float]]:
return numpy_type(x)
elif x.tstype.typ is float:
return float_type(x)
elif x.tstype.typ is int:
return int_type(x)
return generic_type(x)

@node
def floordiv(x: ts["T"], y: ts["T"]) -> ts["T"]:
if csp.ticked(x, y) and csp.valid(x, y):
return x // y
comp.__name__ = name
return comp


# Math operators
add = define_binary_op("add", lambda x, y: x + y)
sub = define_binary_op("sub", lambda x, y: x - y)
multiply = define_binary_op("multiply", lambda x, y: x * y)
divide = define_binary_op("divide", lambda x, y: x / y)
pow = define_binary_op("pow", lambda x, y: x**y)
min = define_binary_op("min", lambda x, y: x if x < y else y)
max = define_binary_op("max", lambda x, y: x if x > y else y)
floordiv = define_binary_op("floordiv", lambda x, y: x // y)
mod = define_binary_op("mod", lambda x, y: x % y)

# Other math ops
_python_abs = abs
abs = define_unary_op("abs", lambda x: _python_abs(x))
ln = define_unary_op("ln", lambda x: math.log(x))
log2 = define_unary_op("log2", lambda x: math.log2(x))
log10 = define_unary_op("log10", lambda x: math.log10(x))
exp = define_unary_op("exp", lambda x: math.exp(x))
exp2 = define_unary_op("exp2", lambda x: math.exp2(x))
sqrt = define_unary_op("sqrt", lambda x: math.sqrt(x))
erf = define_unary_op("erf", lambda x: math.erf(x))
sin = define_unary_op("sin", lambda x: math.sin(x))
cos = define_unary_op("cos", lambda x: math.cos(x))
tan = define_unary_op("tan", lambda x: math.tan(x))
arcsin = define_unary_op("arcsin", lambda x: math.asin(x))
arccos = define_unary_op("arccos", lambda x: math.acos(x))
arctan = define_unary_op("arctan", lambda x: math.atan(x))
sinh = define_unary_op("sinh", lambda x: math.sinh(x))
cosh = define_unary_op("cosh", lambda x: math.cosh(x))
tanh = define_unary_op("tanh", lambda x: math.tanh(x))
arcsinh = define_unary_op("arcsinh", lambda x: math.asinh(x))
arccosh = define_unary_op("arccosh", lambda x: math.acosh(x))
arctanh = define_unary_op("arctanh", lambda x: math.atanh(x))

# Comparison operators
eq = define_binary_op("eq", lambda x, y: x == y)
ne = define_binary_op("ne", lambda x, y: x != y)
gt = define_binary_op("gt", lambda x, y: x > y)
lt = define_binary_op("lt", lambda x, y: x < y)
ge = define_binary_op("ge", lambda x, y: x >= y)
le = define_binary_op("le", lambda x, y: x <= y)


@node
Expand All @@ -797,24 +919,6 @@ def accum(x: ts["T"], start: "~T" = 0) -> ts["T"]:
return s_accum


@node(cppimpl=_cspbaselibimpl.ln)
def ln(x: ts[float]) -> ts[float]:
if csp.ticked(x):
return math.log(x)


@node(cppimpl=_cspbaselibimpl.exp)
def exp(x: ts[float]) -> ts[float]:
if csp.ticked(x):
return math.exp(x)


@node(cppimpl=_cspbaselibimpl.abs)
def abs(x: ts[float]) -> ts[float]:
if csp.ticked(x):
return abs(x)


@node(cppimpl=_cspbaselibimpl.exprtk_impl)
def _csp_exprtk_impl(
expression_str: str,
Expand Down
Loading

0 comments on commit b8c03e8

Please sign in to comment.