From 8a3da59fdf27c14306b1e11b000742ab1efab287 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Fri, 6 Dec 2024 17:08:04 +0100 Subject: [PATCH] test --- .../{test_vector.mlir => batched_scalar.mlir} | 8 +++--- .../test/MLIR/ForwardMode/batched_tensor.mlir | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) rename enzyme/test/MLIR/ForwardMode/{test_vector.mlir => batched_scalar.mlir} (74%) create mode 100644 enzyme/test/MLIR/ForwardMode/batched_tensor.mlir diff --git a/enzyme/test/MLIR/ForwardMode/test_vector.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir similarity index 74% rename from enzyme/test/MLIR/ForwardMode/test_vector.mlir rename to enzyme/test/MLIR/ForwardMode/batched_scalar.mlir index 1aa1f9621fd..09f85c5f68a 100644 --- a/enzyme/test/MLIR/ForwardMode/test_vector.mlir +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -11,14 +11,14 @@ module { } } -// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<2xf64>, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> // CHECK-NEXT: return %[[i0]] : tensor<2xf64> // CHECK-NEXT: } // CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { -// CHECK-NEXT: %[[s0:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64> // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> -// CHECK-NEXT: %[[s1:.+]] = tensor.splat %[[arg0]] : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{width = 2 : i64}> : f64 -> tensor<2xf64> // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir new file mode 100644 index 00000000000..5895a8bad24 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = arith.mulf %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) + return %r : tensor<2x10xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{width = 2 : i64}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> +// CHECK-NEXT: } \ No newline at end of file