Skip to content

Commit

Permalink
Fix vectorized shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 24, 2023
1 parent 861f95c commit 7ecb9f8
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 16 deletions.
17 changes: 15 additions & 2 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
Expand Down Expand Up @@ -155,7 +155,20 @@ def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim


_vectorize_node.register(Shape, _vectorize_not_needed)
@_vectorize_node.register(Shape)
def vectorize_shape(op, node, batched_x):
from pytensor.tensor.extra_ops import broadcast_to

[old_x] = node.inputs
core_ndims = old_x.type.ndim
batch_ndims = batched_x.type.ndim - core_ndims
batched_x_shape = shape(batched_x)
if not batch_ndims:
return batched_x_shape.owner
else:
batch_shape = batched_x_shape[:batch_ndims]
core_shape = batched_x_shape[batch_ndims:]
return broadcast_to(core_shape, (*batch_shape, core_ndims)).owner


def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
Expand Down
65 changes: 51 additions & 14 deletions tests/tensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,29 +709,66 @@ def test_shape_tuple():


class TestVectorize:
@pytensor.config.change_flags(cxx="") # For faster eval
def test_shape(self):
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))

vec = tensor(shape=(None,), dtype="float64")
mat = tensor(shape=(None, None), dtype="float64")
node = shape(vec).owner
vect_node = vectorize_node(node, mat)
assert equal_computations(vect_node.outputs, [shape(mat)])

[vect_out] = vectorize_node(node, mat).outputs
assert equal_computations(
[vect_out], [broadcast_to(mat.shape[1:], (*mat.shape[:1], 1))]
)

mat_test_value = np.ones((5, 3))
ref_fn = np.vectorize(lambda vec: np.asarray(vec.shape), signature="(vec)->(1)")
np.testing.assert_array_equal(
vect_out.eval({mat: mat_test_value}),
ref_fn(mat_test_value),
)

mat = tensor(shape=(None, None), dtype="float64")
tns = tensor(shape=(None, None, None, None), dtype="float64")
node = shape(mat).owner
[vect_out] = vectorize_node(node, tns).outputs
assert equal_computations(
[vect_out], [broadcast_to(tns.shape[2:], (*tns.shape[:2], 2))]
)

tns_test_value = np.ones((4, 6, 5, 3))
ref_fn = np.vectorize(
lambda vec: np.asarray(vec.shape), signature="(m1,m2)->(2)"
)
np.testing.assert_array_equal(
vect_out.eval({tns: tns_test_value}),
ref_fn(tns_test_value),
)

@pytensor.config.change_flags(cxx="") # For faster eval
def test_reshape(self):
x = scalar("x", dtype=int)
vec = tensor(shape=(None,))
mat = tensor(shape=(None, None))
vec = tensor(shape=(None,), dtype="float64")
mat = tensor(shape=(None, None), dtype="float64")

shape = (2, x)
shape = (-1, x)
node = reshape(vec, shape).owner
vect_node = vectorize_node(node, mat, shape)
assert equal_computations(
vect_node.outputs, [reshape(mat, (*mat.shape[:1], 2, x))]

[vect_out] = vectorize_node(node, mat, shape).outputs
assert equal_computations([vect_out], [reshape(mat, (*mat.shape[:1], -1, x))])

x_test_value = 2
mat_test_value = np.ones((5, 6))
ref_fn = np.vectorize(
lambda x, vec: vec.reshape(-1, x), signature="(),(vec1)->(mat1,mat2)"
)
np.testing.assert_array_equal(
vect_out.eval({x: x_test_value, mat: mat_test_value}),
ref_fn(x_test_value, mat_test_value),
)

new_shape = (5, 2, x)
vect_node = vectorize_node(node, mat, new_shape)
assert equal_computations(vect_node.outputs, [reshape(mat, new_shape)])
new_shape = (5, -1, x)
[vect_out] = vectorize_node(node, mat, new_shape).outputs
assert equal_computations([vect_out], [reshape(mat, new_shape)])

with pytest.raises(NotImplementedError):
vectorize_node(node, vec, broadcast_to(as_tensor([5, 2, x]), (2, 3)))
Expand Down

0 comments on commit 7ecb9f8

Please sign in to comment.