diff --git a/src/hl_ops/movement.rs b/src/hl_ops/movement.rs index 410b90f9..e0a1e0e0 100644 --- a/src/hl_ops/movement.rs +++ b/src/hl_ops/movement.rs @@ -17,7 +17,14 @@ impl GraphTensor { S: BroadcastShapeTo, { let new_dims = Dst::realized_shape(); - if !new_dims.is_empty() { + let src_dims = S::realized_shape(); + let is_noop = src_dims.len() == new_dims.len() + && src_dims + .iter() + .zip(new_dims.iter()) + .all(|(src_dim, new_dim)| src_dim.as_num() == new_dim.as_num()); + + if !new_dims.is_empty() && !is_noop { for (i, dim) in Ax::as_array().into_iter().map(|i| (i, new_dims[i])) { self.shape.expand(i, dim); } @@ -514,4 +521,14 @@ mod tests { assert_close(&c.data(), &d_c.as_vec()); } + + #[test] + fn test_noop_expand() { + type S = R1<1>; + type Tensor = GraphTensor; + let mut cx = Graph::new(); + let a: Tensor = cx.tensor(); + let noop_expanded: Tensor = a.expand::>(); + assert_eq!(a.shape, noop_expanded.shape); + } } diff --git a/src/hl_ops/reduction.rs b/src/hl_ops/reduction.rs index 604a5aff..6a76f3f9 100644 --- a/src/hl_ops/reduction.rs +++ b/src/hl_ops/reduction.rs @@ -96,7 +96,7 @@ mod tests { let a_data = random_vec(6); let a = cx.tensor::>(); a.set(a_data.clone()); - let b = a.sum_reduce::<_, LAxis<1>>(); + let b = a.sum_reduce::, LAxis<1>>(); b.retrieve(); cx.execute(); @@ -114,7 +114,7 @@ mod tests { let a_data = random_vec(6); let a = cx.tensor::>(); a.set(a_data.clone()); - let b = a.max_reduce::<_, LAxis<1>>(); + let b = a.max_reduce::, LAxis<1>>(); b.retrieve(); cx.execute(); @@ -132,7 +132,7 @@ mod tests { let a_data = random_vec(6); let a = cx.tensor::>(); a.set(a_data.clone()); - let b = a.mean_reduce::<_, LAxis<1>>(); + let b = a.mean_reduce::, LAxis<1>>(); b.retrieve(); cx.execute(); diff --git a/src/shape/broadcast.rs b/src/shape/broadcast.rs index 08b8649a..c8fd65eb 100644 --- a/src/shape/broadcast.rs +++ b/src/shape/broadcast.rs @@ -12,7 +12,7 @@ pub trait ReduceShape: Sized + HasAxes + ReduceShapeTo; } -impl ReduceShapeTo<(), Axis<0>> for () {} +impl, AnyAxes> ReduceShapeTo for S {} impl ReduceShape> for () { type Reduced = (); } diff --git a/src/tests/test_prim.rs b/src/tests/test_prim.rs index 33116c62..5e895879 100644 --- a/src/tests/test_prim.rs +++ b/src/tests/test_prim.rs @@ -267,9 +267,15 @@ fn test_sum_reduce() { let a = cx .tensor::>() .set([[[1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.]]]); - let b = a.sum_reduce::<_, crate::prelude::Axis<1>>().retrieve(); - let c = a.sum_reduce::<_, crate::prelude::Axis<0>>().retrieve(); - let d = a.sum_reduce::<_, crate::prelude::Axis<2>>().retrieve(); + let b = a + .sum_reduce::, crate::prelude::Axis<1>>() + .retrieve(); + let c = a + .sum_reduce::, crate::prelude::Axis<0>>() + .retrieve(); + let d = a + .sum_reduce::, crate::prelude::Axis<2>>() + .retrieve(); cx.execute(); let d_dev = Cpu::default(); @@ -290,7 +296,9 @@ fn test_sum_reduce2() { [[34.4, -96.0, 144.0], [43.0, 560.0, 180.0]], [[39.6, -120.0, 180.0], [49.5, 700.0, 225.0]], ]]); - let b = a.sum_reduce::<_, crate::prelude::Axis<3>>().retrieve(); + let b = a + .sum_reduce::, crate::prelude::Axis<3>>() + .retrieve(); cx.execute(); let d_dev = Cpu::default(); @@ -309,9 +317,15 @@ fn test_max_reduce() { let a = cx .tensor::>() .set([[[1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.]]]); - let b = a.max_reduce::<_, crate::prelude::Axis<1>>().retrieve(); - let c = a.max_reduce::<_, crate::prelude::Axis<0>>().retrieve(); - let d = a.max_reduce::<_, crate::prelude::Axis<2>>().retrieve(); + let b = a + .max_reduce::, crate::prelude::Axis<1>>() + .retrieve(); + let c = a + .max_reduce::, crate::prelude::Axis<0>>() + .retrieve(); + let d = a + .max_reduce::, crate::prelude::Axis<2>>() + .retrieve(); cx.execute(); let d_dev = Cpu::default();