Skip to content

Commit

Permalink
make_2d_sym: support flat input and shape
Browse files Browse the repository at this point in the history
Fixes #1236.  For now, just add a new underscore version that takes a
shape kwarg.
  • Loading branch information
cbm755 committed Sep 13, 2022
1 parent e74ce50 commit 353718d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
6 changes: 3 additions & 3 deletions inst/@sym/private/elementwise_op.m
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@
'# dbout(f"at least one matrix param, shape={q.shape}")'
'assert len(q.shape) == 2, "non-2D arrays/tensors not yet supported"'
'm, n = q.shape'
'g = [[0]*n for i in range(m)]'
'g = []'
'for i in range(m):'
' for j in range(n):'
' g[i][j] = _op(*[k[i, j] if isinstance(k, (MatrixBase, NDimArray)) else k for k in _ins])'
'return make_2d_sym(g)' ];
' g.append(_op(*[k[i, j] if isinstance(k, (MatrixBase, NDimArray)) else k for k in _ins]))'
'return _make_2d_sym(g, shape=q.shape)' ];

z = pycall_sympy__ (cmd, varargin{:});

Expand Down
15 changes: 15 additions & 0 deletions inst/private/python_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,21 @@ def make_2d_sym(it_of_it, dbg_matrix_only=False):
else:
dbout(f"make_2d_sym: constructing 2D sym...")
return Array(ls_of_ls)
def _make_2d_sym(flat, shape, dbg_matrix_only=False):
"""
If all elements of FLAT are Expr, construct the
corresponding Matrix. Otherwise, construct the
corresponding non-Matrix 2D sym.
"""
flat = list(flat)
if Version(spver) <= Version("1.11.1"):
# never use Array on older SymPy
dbg_matrix_only = True
if (dbg_matrix_only
or all(isinstance(elt, Expr) for elt in flat)):
return Matrix(*shape, flat)
dbout(f"make_2d_sym: constructing 2D sym...")
return Array(flat, shape)
except:
echo_exception_stdout("in python_header defining fcns block 5")
raise
Expand Down
15 changes: 15 additions & 0 deletions inst/private/python_ipc_native.m
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,21 @@
' else:'
' dbout(f"make_2d_sym: constructing 2D sym...")'
' return Array(ls_of_ls)'
' def _make_2d_sym(flat, shape, dbg_matrix_only=False):'
' """'
' If all elements of FLAT are Expr, construct the'
' corresponding Matrix. Otherwise, construct the'
' corresponding non-Matrix 2D sym.'
' """'
' flat = list(flat)'
' if Version(spver) <= Version("1.11.1"):'
' # never use Array on older SymPy'
' dbg_matrix_only = True'
' if (dbg_matrix_only'
' or all(isinstance(elt, Expr) for elt in flat)):'
' return Matrix(*shape, flat)'
' dbout(f"make_2d_sym: constructing 2D sym...")'
' return Array(flat, shape)'
}, newl))
have_headers = true;
end
Expand Down

0 comments on commit 353718d

Please sign in to comment.