From e524b5d83a5c2d3d85f0a7ee0aaf239c5e87200d Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Thu, 6 Jun 2024 08:44:32 -0400 Subject: [PATCH 1/6] Add crashing test case --- test/error/CMakeLists.txt | 1 + test/error/rfactor_fused_var_and_rvar.cpp | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 test/error/rfactor_fused_var_and_rvar.cpp diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 52a2a01cd65e..d5d26894e032 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -94,6 +94,7 @@ tests(GROUPS error require_fail.cpp reuse_var_in_schedule.cpp reused_args.cpp + rfactor_fused_var_and_rvar.cpp rfactor_inner_dim_non_commutative.cpp round_up_and_blend_race.cpp run_with_large_stack_throws.cpp diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp new file mode 100644 index 000000000000..e121ad82c31f --- /dev/null +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -0,0 +1,17 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f; + RDom r(0, 100); + Var x, y; + f(x, y) = 0; + f(x, y) += r; + + RVar yr; + Var z; + f.update().fuse(y, r, yr).rfactor(yr, z); + + return 0; +} \ No newline at end of file From 932e43f752269c08fd33042cfa640df7772e47fc Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Thu, 6 Jun 2024 10:30:12 -0400 Subject: [PATCH 2/6] Add error when calling rfactor on a fused Var+RVar Fixes #7854 --- src/Func.cpp | 15 +++++++++++++++ test/error/rfactor_fused_var_and_rvar.cpp | 11 +++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 1f480c99983c..53aec89a4281 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -841,6 +841,21 @@ Func Stage::rfactor(vector> preserved) { << ", since it is already used in this Func's schedule elsewhere.\n" << dump_argument_list(); } + { + // Check that the RVar has not had a pure var fused into it + for (const Split &s : splits) { + if (s.is_fuse() && var_name_match(s.old_var, rv.name())) { + const auto &iter = std::find_if(pure_vars_used.begin(), pure_vars_used.end(), [&s](const std::string &v) { + return (var_name_match(v, s.outer) || var_name_match(v, s.inner)); + }); + user_assert(iter == pure_vars_used.end()) + << "In schedule for " << name() + << ", can't perform rfactor() on " << rv.name() + << " because the pure var " << (*iter) + << " is fused into it.\n"; + } + } + } } // If the operator is associative but non-commutative, rfactor() on inner diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp index e121ad82c31f..3eea850cbeae 100644 --- a/test/error/rfactor_fused_var_and_rvar.cpp +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -3,15 +3,18 @@ using namespace Halide; int main(int argc, char **argv) { - Func f; + Func f{"f"}; RDom r(0, 100); - Var x, y; + Var x{"x"}, y{"y"}; f(x, y) = 0; f(x, y) += r; - RVar yr; - Var z; + RVar yr{"yr"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() on yr because the pure var y is fused into it. f.update().fuse(y, r, yr).rfactor(yr, z); + printf("Success!\n"); return 0; } \ No newline at end of file From 3ff7c5c1bb9c36def17e734392769644365b3d13 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Thu, 6 Jun 2024 14:48:24 -0400 Subject: [PATCH 3/6] Update test cases --- test/correctness/rfactor.cpp | 18 ++++++++++++++++++ test/error/rfactor_fused_var_and_rvar.cpp | 14 ++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/correctness/rfactor.cpp b/test/correctness/rfactor.cpp index 02bcc9e0f74e..f5983a781842 100644 --- a/test/correctness/rfactor.cpp +++ b/test/correctness/rfactor.cpp @@ -992,6 +992,23 @@ int self_assignment_rfactor_test() { return 0; } +int rfactor_with_partial_pure_fusion_test() { + Func f{"f"}; + Var x{"x"}, y{"y"}, z{"z"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + RVar rxy{"rxy"}, yrz{"yrz"}; + + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + f.update() + .fuse(r.x, r.y, rxy) + .fuse(r.z, y, yrz) + .rfactor(rxy, z); + + return 0; +} + } // namespace int main(int argc, char **argv) { @@ -1032,6 +1049,7 @@ int main(int argc, char **argv) { {"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test}, {"complex multiply rfactor test", complex_multiply_rfactor_test}, {"argmin rfactor test", argmin_rfactor_test}, + {"rfactor a fused rvar+rvar in the presence of a fused var+rvar", rfactor_with_partial_pure_fusion_test}, }; using Sharder = Halide::Internal::Test::Sharder; diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp index 3eea850cbeae..ed1e94bd1cac 100644 --- a/test/error/rfactor_fused_var_and_rvar.cpp +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -4,16 +4,22 @@ using namespace Halide; int main(int argc, char **argv) { Func f{"f"}; - RDom r(0, 100); + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); Var x{"x"}, y{"y"}; f(x, y) = 0; - f(x, y) += r; + f(x, y) += r.x + r.y + r.z; - RVar yr{"yr"}; + RVar rxy{"rxy"}, yrz{"yrz"}, yr{"yr"}; Var z{"z"}; // Error: In schedule for f.update(0), can't perform rfactor() on yr because the pure var y is fused into it. - f.update().fuse(y, r, yr).rfactor(yr, z); + f.update() + .fuse(r.x, r.y, rxy) + .fuse(y, r.z, yrz) + .fuse(rxy, yrz, yr) + .rfactor(yr, z); + + f.print_loop_nest(); printf("Success!\n"); return 0; From ceb8502a21da27beac95a7544a27a5e7e7c47cc4 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Fri, 7 Jun 2024 12:37:57 -0400 Subject: [PATCH 4/6] Disallow calling rfactor() after fusing a pure var Fixes #7854 --- src/Func.cpp | 19 +++----------- test/correctness/rfactor.cpp | 18 ------------- test/error/CMakeLists.txt | 1 + .../rfactor_after_var_and_rvar_fusion.cpp | 25 +++++++++++++++++++ test/error/rfactor_fused_var_and_rvar.cpp | 2 +- 5 files changed, 31 insertions(+), 34 deletions(-) create mode 100644 test/error/rfactor_after_var_and_rvar_fusion.cpp diff --git a/src/Func.cpp b/src/Func.cpp index 53aec89a4281..36d9cb526d95 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -841,21 +841,6 @@ Func Stage::rfactor(vector> preserved) { << ", since it is already used in this Func's schedule elsewhere.\n" << dump_argument_list(); } - { - // Check that the RVar has not had a pure var fused into it - for (const Split &s : splits) { - if (s.is_fuse() && var_name_match(s.old_var, rv.name())) { - const auto &iter = std::find_if(pure_vars_used.begin(), pure_vars_used.end(), [&s](const std::string &v) { - return (var_name_match(v, s.outer) || var_name_match(v, s.inner)); - }); - user_assert(iter == pure_vars_used.end()) - << "In schedule for " << name() - << ", can't perform rfactor() on " << rv.name() - << " because the pure var " << (*iter) - << " is fused into it.\n"; - } - } - } } // If the operator is associative but non-commutative, rfactor() on inner @@ -885,6 +870,10 @@ Func Stage::rfactor(vector> preserved) { for (const Split &s : splits) { // If it's already applied, we should remove it from the split list. if (!apply_split_directive(s, rvars, predicates, args, values)) { + user_assert(!s.is_fuse()) + << "In schedule for " << name() + << ", can't perform rfactor() after fusing " << s.outer + << " and " << s.inner << "\n"; temp.push_back(s); } } diff --git a/test/correctness/rfactor.cpp b/test/correctness/rfactor.cpp index f5983a781842..02bcc9e0f74e 100644 --- a/test/correctness/rfactor.cpp +++ b/test/correctness/rfactor.cpp @@ -992,23 +992,6 @@ int self_assignment_rfactor_test() { return 0; } -int rfactor_with_partial_pure_fusion_test() { - Func f{"f"}; - Var x{"x"}, y{"y"}, z{"z"}; - RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); - RVar rxy{"rxy"}, yrz{"yrz"}; - - f(x, y) = 0; - f(x, y) += r.x + r.y + r.z; - - f.update() - .fuse(r.x, r.y, rxy) - .fuse(r.z, y, yrz) - .rfactor(rxy, z); - - return 0; -} - } // namespace int main(int argc, char **argv) { @@ -1049,7 +1032,6 @@ int main(int argc, char **argv) { {"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test}, {"complex multiply rfactor test", complex_multiply_rfactor_test}, {"argmin rfactor test", argmin_rfactor_test}, - {"rfactor a fused rvar+rvar in the presence of a fused var+rvar", rfactor_with_partial_pure_fusion_test}, }; using Sharder = Halide::Internal::Test::Sharder; diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index d5d26894e032..b946b081cd99 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -94,6 +94,7 @@ tests(GROUPS error require_fail.cpp reuse_var_in_schedule.cpp reused_args.cpp + rfactor_after_var_and_rvar_fusion.cpp rfactor_fused_var_and_rvar.cpp rfactor_inner_dim_non_commutative.cpp round_up_and_blend_race.cpp diff --git a/test/error/rfactor_after_var_and_rvar_fusion.cpp b/test/error/rfactor_after_var_and_rvar_fusion.cpp new file mode 100644 index 000000000000..8b94fbdd1b16 --- /dev/null +++ b/test/error/rfactor_after_var_and_rvar_fusion.cpp @@ -0,0 +1,25 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing y and r$z + f.update() + .fuse(r.x, r.y, rxy) + .fuse(r.z, y, yrz) + .rfactor(rxy, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +} \ No newline at end of file diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp index ed1e94bd1cac..a167ca543c47 100644 --- a/test/error/rfactor_fused_var_and_rvar.cpp +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -12,7 +12,7 @@ int main(int argc, char **argv) { RVar rxy{"rxy"}, yrz{"yrz"}, yr{"yr"}; Var z{"z"}; - // Error: In schedule for f.update(0), can't perform rfactor() on yr because the pure var y is fused into it. + // Error: In schedule for f.update(0), can't perform rfactor() after fusing r$z and y f.update() .fuse(r.x, r.y, rxy) .fuse(y, r.z, yrz) From d480a647a8d9602c6e485b3bd97d8ed1a1ba5294 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sat, 8 Jun 2024 11:43:56 -0400 Subject: [PATCH 5/6] Prove associativity first thing in rfactor(). --- src/Func.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Func.cpp b/src/Func.cpp index 36d9cb526d95..d290f8a1a382 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -788,6 +788,15 @@ Func Stage::rfactor(vector> preserved) { vector &args = definition.args(); vector &values = definition.values(); + // Check whether the operator is associative and determine the operator and + // its identity for each value in the definition if it is a Tuple + const auto &prover_result = prove_associativity(func_name, args, values); + + user_assert(prover_result.associative()) + << "Failed to call rfactor() on " << name() + << " since it can't prove associativity of the operator\n"; + internal_assert(prover_result.size() == values.size()); + // Figure out which pure vars were used in this update definition. std::set pure_vars_used; internal_assert(args.size() == dim_vars.size()); @@ -799,15 +808,6 @@ Func Stage::rfactor(vector> preserved) { } } - // Check whether the operator is associative and determine the operator and - // its identity for each value in the definition if it is a Tuple - const auto &prover_result = prove_associativity(func_name, args, values); - - user_assert(prover_result.associative()) - << "Failed to call rfactor() on " << name() - << " since it can't prove associativity of the operator\n"; - internal_assert(prover_result.size() == values.size()); - vector &splits = definition.schedule().splits(); vector &dims = definition.schedule().dims(); vector &rvars = definition.schedule().rvars(); From 95a753fd49bde4b5c15fdfadfe19db76cd9c67c3 Mon Sep 17 00:00:00 2001 From: Alex Reinking Date: Sat, 8 Jun 2024 11:51:01 -0400 Subject: [PATCH 6/6] Remove unused is_update parameter from apply_split --- src/ApplySplit.cpp | 3 +-- src/ApplySplit.h | 4 +--- src/Func.cpp | 8 ++++---- src/ScheduleFunctions.cpp | 2 +- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index 48d27b1ffc7a..ef2fe12f62a7 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -11,8 +11,7 @@ using std::map; using std::string; using std::vector; -vector apply_split(const Split &split, bool is_update, const string &prefix, - map &dim_extent_alignment) { +vector apply_split(const Split &split, const string &prefix, map &dim_extent_alignment) { vector result; Expr outer = Variable::make(Int(32), prefix + split.outer); diff --git a/src/ApplySplit.h b/src/ApplySplit.h index 5e646b22f08b..812f616af1ba 100644 --- a/src/ApplySplit.h +++ b/src/ApplySplit.h @@ -78,9 +78,7 @@ struct ApplySplitResult { * the definition (in ascending order of application), and let stmts which * defined the values of variables referred by the predicates and substitutions * (ordered from innermost to outermost let). */ -std::vector apply_split( - const Split &split, bool is_update, const std::string &prefix, - std::map &dim_extent_alignment); +std::vector apply_split(const Split &split, const std::string &prefix, std::map &dim_extent_alignment); /** Compute the loop bounds of the new dimensions resulting from applying the * split schedules using the loop bounds of the old dimensions. */ diff --git a/src/Func.cpp b/src/Func.cpp index d290f8a1a382..69534535a84e 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -646,7 +646,7 @@ bool apply_split(const Split &s, vector &rvars, rvars.insert(it + 1, {s.outer, 0, simplify((old_extent - 1 + s.factor) / s.factor)}); - vector splits_result = apply_split(s, true, "", dim_extent_alignment); + vector splits_result = apply_split(s, "", dim_extent_alignment); vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); @@ -681,7 +681,7 @@ bool apply_fuse(const Split &s, vector &rvars, iter_outer->extent = extent; rvars.erase(iter_inner); - vector splits_result = apply_split(s, true, "", dim_extent_alignment); + vector splits_result = apply_split(s, "", dim_extent_alignment); vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); @@ -705,7 +705,7 @@ bool apply_purify(const Split &s, vector &rvars, << ", deleting it from the rvars list\n"; rvars.erase(iter); - vector splits_result = apply_split(s, true, "", dim_extent_alignment); + vector splits_result = apply_split(s, "", dim_extent_alignment); vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); @@ -725,7 +725,7 @@ bool apply_rename(const Split &s, vector &rvars, debug(4) << " Renaming " << iter->var << " into " << s.outer << "\n"; iter->var = s.outer; - vector splits_result = apply_split(s, true, "", dim_extent_alignment); + vector splits_result = apply_split(s, "", dim_extent_alignment); vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index b5d8f35aac28..a6f6f24b4135 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -220,7 +220,7 @@ Stmt build_loop_nest( user_assert(predicated_vars.count(split.old_var) == 0) << "Cannot split a loop variable resulting from a split using PredicateLoads or PredicateStores."; - vector splits_result = apply_split(split, is_update, prefix, dim_extent_alignment); + vector splits_result = apply_split(split, prefix, dim_extent_alignment); // To ensure we substitute all indices used in call or provide, // we need to substitute all lets in, so we correctly guard x in