Skip to content

Commit

Permalink
Disallow calling rfactor() after fusing a pure var
Browse files Browse the repository at this point in the history
Fixes #7854
  • Loading branch information
alexreinking committed Jun 7, 2024
1 parent 0265f61 commit d987c64
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 34 deletions.
19 changes: 4 additions & 15 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,21 +841,6 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
<< ", since it is already used in this Func's schedule elsewhere.\n"
<< dump_argument_list();
}
{
// Check that the RVar has not had a pure var fused into it
for (const Split &s : splits) {
if (s.is_fuse() && var_name_match(s.old_var, rv.name())) {
const auto &iter = std::find_if(pure_vars_used.begin(), pure_vars_used.end(), [&s](const std::string &v) {
return (var_name_match(v, s.outer) || var_name_match(v, s.inner));
});
user_assert(iter == pure_vars_used.end())
<< "In schedule for " << name()
<< ", can't perform rfactor() on " << rv.name()
<< " because the pure var " << (*iter)
<< " is fused into it.\n";
}
}
}
}

// If the operator is associative but non-commutative, rfactor() on inner
Expand Down Expand Up @@ -885,6 +870,10 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
for (const Split &s : splits) {
// If it's already applied, we should remove it from the split list.
if (!apply_split_directive(s, rvars, predicates, args, values)) {
user_assert(!s.is_fuse())
<< "In schedule for " << name()
<< ", can't perform rfactor() after fusing " << s.outer
<< " and " << s.inner << "\n";
temp.push_back(s);
}
}
Expand Down
18 changes: 0 additions & 18 deletions test/correctness/rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,23 +992,6 @@ int self_assignment_rfactor_test() {
return 0;
}

int rfactor_with_partial_pure_fusion_test() {
Func f{"f"};
Var x{"x"}, y{"y"}, z{"z"};
RDom r({{0, 5}, {0, 5}, {0, 5}}, "r");
RVar rxy{"rxy"}, yrz{"yrz"};

f(x, y) = 0;
f(x, y) += r.x + r.y + r.z;

f.update()
.fuse(r.x, r.y, rxy)
.fuse(r.z, y, yrz)
.rfactor(rxy, z);

return 0;
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -1049,7 +1032,6 @@ int main(int argc, char **argv) {
{"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test},
{"complex multiply rfactor test", complex_multiply_rfactor_test},
{"argmin rfactor test", argmin_rfactor_test},
{"rfactor a fused rvar+rvar in the presence of a fused var+rvar", rfactor_with_partial_pure_fusion_test},
};

using Sharder = Halide::Internal::Test::Sharder;
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ tests(GROUPS error
require_fail.cpp
reuse_var_in_schedule.cpp
reused_args.cpp
rfactor_after_var_and_rvar_fusion.cpp
rfactor_fused_var_and_rvar.cpp
rfactor_inner_dim_non_commutative.cpp
round_up_and_blend_race.cpp
Expand Down
25 changes: 25 additions & 0 deletions test/error/rfactor_after_var_and_rvar_fusion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "Halide.h"

using namespace Halide;

int main(int argc, char **argv) {
Func f{"f"};
RDom r({{0, 5}, {0, 5}, {0, 5}}, "r");
Var x{"x"}, y{"y"};
f(x, y) = 0;
f(x, y) += r.x + r.y + r.z;

RVar rxy{"rxy"}, yrz{"yrz"};
Var z{"z"};

// Error: In schedule for f.update(0), can't perform rfactor() after fusing y and r$z
f.update()
.fuse(r.x, r.y, rxy)
.fuse(r.z, y, yrz)
.rfactor(rxy, z);

f.print_loop_nest();

printf("Success!\n");
return 0;
}
2 changes: 1 addition & 1 deletion test/error/rfactor_fused_var_and_rvar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ int main(int argc, char **argv) {
RVar rxy{"rxy"}, yrz{"yrz"}, yr{"yr"};
Var z{"z"};

// Error: In schedule for f.update(0), can't perform rfactor() on yr because the pure var y is fused into it.
// Error: In schedule for f.update(0), can't perform rfactor() after fusing r$z and y
f.update()
.fuse(r.x, r.y, rxy)
.fuse(y, r.z, yrz)
Expand Down

0 comments on commit d987c64

Please sign in to comment.