From 5fc7288745cca8dc187a82f431853e641f9cdf34 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Fri, 9 Feb 2024 22:25:18 -0500 Subject: [PATCH] WIP sympy integration via ufunc dispatch Signed-off-by: Tim Paine <3105306+timkpaine@users.noreply.github.com> --- csp/baselib.py | 190 +++++++++++++++++++++++++++++--------- csp/impl/wiring/edge.py | 163 +++++++++++++++++++++++++++++--- csp/tests/test_baselib.py | 61 +++++++++++- csp/tests/test_engine.py | 1 + 4 files changed, 357 insertions(+), 58 deletions(-) diff --git a/csp/baselib.py b/csp/baselib.py index 0aac1f52e..102609231 100644 --- a/csp/baselib.py +++ b/csp/baselib.py @@ -44,10 +44,28 @@ "min", "gate", "floordiv", + "mod", "pow", + "abs", "ln", + "log2", + "log10", "exp", - "abs", + "exp2", + "sqrt", + "erf", + "sin", + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", "unroll", "collect", "flatten", @@ -670,11 +688,44 @@ def or_(*inputs): # Math/comparison binary operators are supported in C++ only for (int,int) and # (float, float) arguments. For all other types, the Python implementation is used. -MATH_OPS = ["add", "sub", "multiply", "divide", "pow", "max", "min"] +MATH_OPS = [ + # binary + "add", + "sub", + "multiply", + "divide", + "pow", + "max", + "min", + "floordiv", + "mod", + # unary + "abs", + "ln", + "log2", + "log10", + "exp", + "exp2", + "sqrt", + "erf", + "sin", + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", +] COMP_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] MATH_COMP_OPS_CPP = { + # binary math ("add", "float"): _cspbaselibimpl.add_f, ("add", "int"): _cspbaselibimpl.add_i, ("sub", "float"): _cspbaselibimpl.sub_f, @@ -689,6 +740,11 @@ def or_(*inputs): ("max", "int"): _cspbaselibimpl.max_i, ("min", "float"): _cspbaselibimpl.min_f, ("min", "int"): _cspbaselibimpl.min_i, + # unary math + ("abs", "float"): _cspbaselibimpl.abs, + ("ln", "float"): _cspbaselibimpl.ln, + ("exp", "float"): _cspbaselibimpl.exp, + # binary comparator ("eq", "float"): _cspbaselibimpl.eq_f, ("eq", "int"): _cspbaselibimpl.eq_i, ("ne", "float"): _cspbaselibimpl.ne_f, @@ -705,7 +761,7 @@ def or_(*inputs): @lru_cache(maxsize=512) -def define_op(name, op_lambda): +def define_binary_op(name, op_lambda): float_out_type, int_out_type, generic_out_type = [None] * 3 if name in COMP_OPS: float_out_type = bool @@ -722,12 +778,12 @@ def define_op(name, op_lambda): from csp.impl.wiring.node import _node_internal_use - @_node_internal_use(cppimpl=MATH_COMP_OPS_CPP[(name, "float")], name=name) + @_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "float"), None), name=name) def float_type(x: ts[float], y: ts[float]) -> ts[float_out_type]: if csp.valid(x, y): return op_lambda(x, y) - @_node_internal_use(cppimpl=MATH_COMP_OPS_CPP[(name, "int")], name=name) + @_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "int"), None), name=name) def int_type(x: ts[int], y: ts[int]) -> ts[int_out_type]: if csp.valid(x, y): return op_lambda(x, y) @@ -759,32 +815,98 @@ def comp(x: ts["T"], y: ts["U"]): return comp -# Math operators +@lru_cache(maxsize=512) +def define_unary_op(name, op_lambda): + float_out_type, int_out_type, generic_out_type = [None] * 3 + if name in COMP_OPS: + float_out_type = bool + int_out_type = bool + generic_out_type = bool + elif name in MATH_OPS: + float_out_type = float + if name in ("abs",): + int_out_type = int + generic_out_type = "T" + else: + int_out_type = float + generic_out_type = float -add = define_op("add", lambda x, y: x + y) -sub = define_op("sub", lambda x, y: x - y) -multiply = define_op("multiply", lambda x, y: x * y) -pow = define_op("pow", lambda x, y: x**y) -divide = define_op("divide", lambda x, y: x / y) -min = define_op("min", lambda x, y: x if x < y else y) -max = define_op("max", lambda x, y: x if x > y else y) + from csp.impl.wiring.node import _node_internal_use -# Comparison operators + @_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "float"), None), name=name) + def float_type(x: ts[float]) -> ts[float_out_type]: + if csp.valid(x): + return op_lambda(x) -eq = define_op("eq", lambda x, y: x == y) -ne = define_op("ne", lambda x, y: x != y) -gt = define_op("gt", lambda x, y: x > y) -lt = define_op("lt", lambda x, y: x < y) -ge = define_op("ge", lambda x, y: x >= y) -le = define_op("le", lambda x, y: x <= y) + @_node_internal_use(cppimpl=MATH_COMP_OPS_CPP.get((name, "int"), None), name=name) + def int_type(x: ts[int]) -> ts[int_out_type]: + if csp.valid(x): + return op_lambda(x) -# Other math ops + @_node_internal_use(name=name) + def numpy_type(x: ts["T"]) -> ts[np.ndarray]: + if csp.valid(x): + return op_lambda(x) + @_node_internal_use(name=name) + def generic_type(x: ts["T"]) -> ts[generic_out_type]: + if csp.valid(x): + return op_lambda(x) + + def comp(x: ts["T"]): + if x.tstype.typ in [Numpy1DArray[float], NumpyNDArray[float]]: + return numpy_type(x) + elif x.tstype.typ is float: + return float_type(x) + elif x.tstype.typ is int: + return int_type(x) + return generic_type(x) -@node -def floordiv(x: ts["T"], y: ts["T"]) -> ts["T"]: - if csp.ticked(x, y) and csp.valid(x, y): - return x // y + comp.__name__ = name + return comp + + +# Math operators +add = define_binary_op("add", lambda x, y: x + y) +sub = define_binary_op("sub", lambda x, y: x - y) +multiply = define_binary_op("multiply", lambda x, y: x * y) +divide = define_binary_op("divide", lambda x, y: x / y) +pow = define_binary_op("pow", lambda x, y: x**y) +min = define_binary_op("min", lambda x, y: x if x < y else y) +max = define_binary_op("max", lambda x, y: x if x > y else y) +floordiv = define_binary_op("floordiv", lambda x, y: x // y) +mod = define_binary_op("mod", lambda x, y: x % y) + +# Other math ops +_python_abs = abs +abs = define_unary_op("abs", lambda x: _python_abs(x)) +ln = define_unary_op("ln", lambda x: math.log(x)) +log2 = define_unary_op("log2", lambda x: math.log2(x)) +log10 = define_unary_op("log10", lambda x: math.log10(x)) +exp = define_unary_op("exp", lambda x: math.exp(x)) +exp2 = define_unary_op("exp2", lambda x: math.exp2(x)) +sqrt = define_unary_op("sqrt", lambda x: math.sqrt(x)) +erf = define_unary_op("erf", lambda x: math.erf(x)) +sin = define_unary_op("sin", lambda x: math.sin(x)) +cos = define_unary_op("cos", lambda x: math.cos(x)) +tan = define_unary_op("tan", lambda x: math.tan(x)) +arcsin = define_unary_op("arcsin", lambda x: math.asin(x)) +arccos = define_unary_op("arccos", lambda x: math.acos(x)) +arctan = define_unary_op("arctan", lambda x: math.atan(x)) +sinh = define_unary_op("sinh", lambda x: math.sinh(x)) +cosh = define_unary_op("cosh", lambda x: math.cosh(x)) +tanh = define_unary_op("tanh", lambda x: math.tanh(x)) +arcsinh = define_unary_op("arcsinh", lambda x: math.asinh(x)) +arccosh = define_unary_op("arccosh", lambda x: math.acosh(x)) +arctanh = define_unary_op("arctanh", lambda x: math.atanh(x)) + +# Comparison operators +eq = define_binary_op("eq", lambda x, y: x == y) +ne = define_binary_op("ne", lambda x, y: x != y) +gt = define_binary_op("gt", lambda x, y: x > y) +lt = define_binary_op("lt", lambda x, y: x < y) +ge = define_binary_op("ge", lambda x, y: x >= y) +le = define_binary_op("le", lambda x, y: x <= y) @node @@ -797,24 +919,6 @@ def accum(x: ts["T"], start: "~T" = 0) -> ts["T"]: return s_accum -@node(cppimpl=_cspbaselibimpl.ln) -def ln(x: ts[float]) -> ts[float]: - if csp.ticked(x): - return math.log(x) - - -@node(cppimpl=_cspbaselibimpl.exp) -def exp(x: ts[float]) -> ts[float]: - if csp.ticked(x): - return math.exp(x) - - -@node(cppimpl=_cspbaselibimpl.abs) -def abs(x: ts[float]) -> ts[float]: - if csp.ticked(x): - return abs(x) - - @node(cppimpl=_cspbaselibimpl.exprtk_impl) def _csp_exprtk_impl( expression_str: str, diff --git a/csp/impl/wiring/edge.py b/csp/impl/wiring/edge.py index d285fe2ef..1626cce82 100644 --- a/csp/impl/wiring/edge.py +++ b/csp/impl/wiring/edge.py @@ -1,3 +1,6 @@ +import numpy as np + + class Edge: __slots__ = ["tstype", "nodedef", "output_idx", "basket_idx"] @@ -13,7 +16,7 @@ def __repr__(self): def __bool__(self): raise ValueError("boolean evaluation of an edge is not supported") - def __wrap_method(self, other, method): + def __wrap_binary_method(self, other, method): import csp if isinstance(other, Edge): @@ -30,7 +33,7 @@ def __hash__(self): def __add__(self, other): import csp - return self.__wrap_method(other, csp.add) + return self.__wrap_binary_method(other, csp.add) def __radd__(self, other): return self.__add__(other) @@ -38,7 +41,7 @@ def __radd__(self, other): def __sub__(self, other): import csp - return self.__wrap_method(other, csp.sub) + return self.__wrap_binary_method(other, csp.sub) def __rsub__(self, other): import csp @@ -48,7 +51,7 @@ def __rsub__(self, other): def __mul__(self, other): import csp - return self.__wrap_method(other, csp.multiply) + return self.__wrap_binary_method(other, csp.multiply) def __rmul__(self, other): return self.__mul__(other) @@ -56,7 +59,7 @@ def __rmul__(self, other): def __truediv__(self, other): import csp - return self.__wrap_method(other, csp.divide) + return self.__wrap_binary_method(other, csp.divide) def __rtruediv__(self, other): import csp @@ -66,7 +69,7 @@ def __rtruediv__(self, other): def __floordiv__(self, other): import csp - return self.__wrap_method(other, csp.floordiv) + return self.__wrap_binary_method(other, csp.floordiv) def __rfloordiv__(self, other): import csp @@ -76,42 +79,52 @@ def __rfloordiv__(self, other): def __pow__(self, other): import csp - return self.__wrap_method(other, csp.pow) + return self.__wrap_binary_method(other, csp.pow) def __rpow__(self, other): import csp return csp.pow(csp.const(other), self) + def __mod__(self, other): + import csp + + return self.__wrap_binary_method(other, csp.mod) + + def __mod__(self, other): + import csp + + return csp.mod(csp.const(other), self) + def __gt__(self, other): import csp - return self.__wrap_method(other, csp.gt) + return self.__wrap_binary_method(other, csp.gt) def __ge__(self, other): import csp - return self.__wrap_method(other, csp.ge) + return self.__wrap_binary_method(other, csp.ge) def __lt__(self, other): import csp - return self.__wrap_method(other, csp.lt) + return self.__wrap_binary_method(other, csp.lt) def __le__(self, other): import csp - return self.__wrap_method(other, csp.le) + return self.__wrap_binary_method(other, csp.le) def __eq__(self, other): import csp - return self.__wrap_method(other, csp.eq) + return self.__wrap_binary_method(other, csp.eq) def __ne__(self, other): import csp - return self.__wrap_method(other, csp.ne) + return self.__wrap_binary_method(other, csp.ne) def __invert__(self): import csp @@ -120,6 +133,130 @@ def __invert__(self): return csp.bitwise_not(self) raise TypeError(f"Cannot call invert with a ts[{self.tstype.typ.__name__}], not an integer type") + def abs(self): + import csp + + return csp.abs(self) + + def ln(self): + import csp + + return csp.ln(self) + + def log2(self): + import csp + + return csp.log2(self) + + def log10(self): + import csp + + return csp.log10(self) + + def exp(self): + import csp + + return csp.exp(self) + + def sin(self): + import csp + + return csp.sin(self) + + def cos(self): + import csp + + return csp.cos(self) + + def tan(self): + import csp + + return csp.tan(self) + + def arcsin(self): + import csp + + return csp.arcsin(self) + + def arccos(self): + import csp + + return csp.arccos(self) + + def arctan(self): + import csp + + return csp.arctan(self) + + def sqrt(self): + import csp + + return csp.sqrt(self) + + def erf(self): + import csp + + return csp.erf(self) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if ufunc == np.add: + if isinstance(inputs[0], Edge): + return inputs[0].__add__(inputs[1]) + else: + return inputs[1].__add__(inputs[0]) + elif ufunc == np.subtract: + if isinstance(inputs[0], Edge): + return inputs[0].__sub__(inputs[1]) + else: + return inputs[1].__sub__(inputs[0]) + elif ufunc == np.multiply: + if isinstance(inputs[0], Edge): + return inputs[0].__mul__(inputs[1]) + else: + return inputs[1].__mul__(inputs[0]) + elif ufunc == np.divide: + if isinstance(inputs[0], Edge): + return inputs[0].__truediv__(inputs[1]) + else: + return inputs[1].__truediv__(inputs[0]) + elif ufunc == np.floor_divide: + if isinstance(inputs[0], Edge): + return inputs[0].__floordiv__(inputs[1]) + else: + return inputs[1].__floordiv__(inputs[0]) + elif ufunc == np.power: + return inputs[0].pow(inputs[1]) + elif ufunc == np.abs: + return inputs[0].abs() + elif ufunc == np.log: + return inputs[0].ln() + elif ufunc == np.log2: + return inputs[0].log2() + elif ufunc == np.log10: + return inputs[0].log10() + elif ufunc == np.exp: + return inputs[0].exp() + elif ufunc == np.exp2: + return inputs[0].exp2() + elif ufunc == np.sin: + return inputs[0].sin() + elif ufunc == np.cos: + return inputs[0].cos() + elif ufunc == np.tan: + return inputs[0].tan() + elif ufunc == np.arcsin: + return inputs[0].asin() + elif ufunc == np.arccos: + return inputs[0].acos() + elif ufunc == np.arctan: + return inputs[0].atan() + elif ufunc == np.sqrt: + return inputs[0].sqrt() + elif ufunc.__name__ == "erf": + # TODO can we use name for all? + return inputs[0].erf() + raise NotImplementedError("Not Implemented for type csp.Edge: {}".format(ufunc)) + def __getattr__(self, key): from csp.impl.struct import Struct diff --git a/csp/tests/test_baselib.py b/csp/tests/test_baselib.py index 972b3f3f6..d4ac7e321 100644 --- a/csp/tests/test_baselib.py +++ b/csp/tests/test_baselib.py @@ -417,16 +417,16 @@ def test_exprtk(self): results[0], list(zip([start_time + timedelta(seconds=i) for i in range(5)], [0, 77, 154, 231, 308])) ) - def test_math_ops(self): + def test_math_binary_ops(self): OPS = { csp.add: lambda x, y: x + y, csp.sub: lambda x, y: x - y, csp.multiply: lambda x, y: x * y, csp.divide: lambda x, y: x / y, csp.pow: lambda x, y: x**y, - csp.floordiv: lambda x, y: x // y, csp.min: lambda x, y: min(x, y), csp.max: lambda x, y: max(x, y), + csp.floordiv: lambda x, y: x // y, } @csp.graph @@ -469,6 +469,63 @@ def graph(use_promotion: bool): [v[1] for v in results[op.__name__ + "-rev"]], [comp(y, x) for x, y in zip(xv, yv)], op.__name__ ) + def test_math_unary_ops(self): + OPS = { + csp.abs: lambda x: abs(x), + csp.ln: lambda x: math.log(x), + csp.log2: lambda x: math.log2(x), + csp.log10: lambda x: math.log10(x), + csp.exp: lambda x: math.exp(x), + csp.exp2: lambda x: math.exp2(x), + csp.sin: lambda x: math.sin(x), + csp.cos: lambda x: math.cos(x), + csp.tan: lambda x: math.tan(x), + csp.arctan: lambda x: math.atan(x), + csp.sinh: lambda x: math.sinh(x), + csp.cosh: lambda x: math.cosh(x), + csp.tanh: lambda x: math.tanh(x), + csp.arcsinh: lambda x: math.asinh(x), + csp.arccosh: lambda x: math.acosh(x), + csp.erf: lambda x: math.erf(x), + } + + @csp.graph + def graph(): + x = csp.count(csp.timer(timedelta(seconds=0.25))) + csp.add_graph_output("x", x) + + for op in OPS.keys(): + csp.add_graph_output(op.__name__, op(x)) + + st = datetime(2020, 1, 1) + results = csp.run(graph, starttime=st, endtime=st + timedelta(seconds=3)) + xv = [v[1] for v in results["x"]] + + for op, comp in OPS.items(): + self.assertEqual([v[1] for v in results[op.__name__]], [comp(x) for x in xv], op.__name__) + + def test_math_unary_ops_other_domain(self): + OPS = { + csp.arcsin: lambda x: math.asin(x), + csp.arccos: lambda x: math.acos(x), + csp.arctanh: lambda x: math.atanh(x), + } + + @csp.graph + def graph(): + x = 1 / (csp.count(csp.timer(timedelta(seconds=0.25))) * math.pi) + csp.add_graph_output("x", x) + + for op in OPS.keys(): + csp.add_graph_output(op.__name__, op(x)) + + st = datetime(2020, 1, 1) + results = csp.run(graph, starttime=st, endtime=st + timedelta(seconds=3)) + xv = [v[1] for v in results["x"]] + + for op, comp in OPS.items(): + self.assertEqual([v[1] for v in results[op.__name__]], [comp(x) for x in xv], op.__name__) + def test_comparisons(self): OPS = { csp.gt: lambda x, y: x > y, diff --git a/csp/tests/test_engine.py b/csp/tests/test_engine.py index 2b45c3197..bd604f2cb 100644 --- a/csp/tests/test_engine.py +++ b/csp/tests/test_engine.py @@ -488,6 +488,7 @@ def graph(): def test_bugreport_csp28(self): """bug where non-basket inputs after basket inputs were not being assigne dproperly in c++""" + @csp.node def buggy(basket: [ts[int]], x: ts[bool]) -> ts[bool]: if csp.ticked(x) and csp.valid(x):