Skip to content

Commit

Permalink
enhance store-load forwarding for single-store allocas
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 12, 2025
1 parent 933572a commit 77fe8b7
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 8 deletions.
8 changes: 5 additions & 3 deletions include/luisa/xir/passes/local_store_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ class BasicBlock;
class Function;
class Module;

// This pass is used to forward stores to loads for scalar variables
// within straight-line basic blocks. It is a simple peephole optimization
// that can be used to reduce the number of memory operations.
// This pass is used to forward stores to loads for thread-local variables within
// straight-line basic blocks. It is a simple peephole-style optimization but can
// effectively reduce the number of memory operations produced by the C++ DSL.
// Note: we do not remove the stores after forwarding them to loads, as the DCE
// pass should be able to remove them if they are dead.

struct LocalStoreForwardInfo {
luisa::unordered_map<LoadInst *, StoreInst *> forwarded_instructions;
Expand Down
7 changes: 6 additions & 1 deletion src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3051,7 +3051,12 @@ class FallbackCodegen {
// loop merge
b.SetInsertPoint(llvm_loop_merge_block);
b.CreateRetVoid();

// hoist the loop variable to the top
{
auto &llvm_entry = llvm_wrapper_function->getEntryBlock();
auto &llvm_first_inst = llvm_entry.front();
llvm_ptr_i->moveBefore(&llvm_first_inst);
}
return llvm_wrapper_function;
}

Expand Down
112 changes: 108 additions & 4 deletions src/xir/passes/local_store_forward.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "luisa/core/logging.h"

#include <luisa/xir/function.h>
#include <luisa/xir/module.h>
#include <luisa/xir/builder.h>
Expand Down Expand Up @@ -105,12 +107,114 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_set<BasicBlo
}
}

// forward stores to loads within straight-line code
static void forward_straight_line_stores_to_loads_on_function(FunctionDefinition *function, LocalStoreForwardInfo &info) noexcept {
luisa::unordered_set<BasicBlock *> visited;
function->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept {
run_local_store_forward_on_basic_block(visited, block, info);
});
}

// find and remove all loads from local variables that only have a single (or no) store
static void forward_single_store_to_loads_on_function(FunctionDefinition *function, LocalStoreForwardInfo &info) noexcept {
luisa::unordered_map<AllocaInst *, StoreInst *> single_store;
// search for local variables that only have a single store
{
luisa::unordered_map<AllocaInst *, size_t> store_count;
function->traverse_instructions([&](Instruction *inst) noexcept {
switch (inst->derived_instruction_tag()) {
case DerivedInstructionTag::LOAD: [[fallthrough]];
case DerivedInstructionTag::GEP: break;
default: {
for (auto op_use : inst->operand_uses()) {
if (auto base_alloca = trace_pointer_base_local_alloca_inst(op_use->value())) {
store_count.try_emplace(base_alloca, 0u).first->second++;
}
}
break;
}
}
});
for (auto [alloca_inst, count] : store_count) {
if (count == 1u) {
for (auto &&use : alloca_inst->use_list()) {
if (auto user = use.user();
user->derived_value_tag() == DerivedValueTag::INSTRUCTION &&
static_cast<Instruction *>(user)->derived_instruction_tag() == DerivedInstructionTag::STORE) {
auto store_inst = static_cast<StoreInst *>(user);
LUISA_DEBUG_ASSERT(store_inst->variable() == alloca_inst, "Store variable must match alloca.");
single_store.emplace(alloca_inst, store_inst);
break;
}
}
}
}
}
// collect the loads that might be eliminated
luisa::vector<LoadInst *> removable_loads;
function->traverse_instructions([&](Instruction *inst) noexcept {
if (inst->derived_instruction_tag() == DerivedInstructionTag::LOAD) {
auto load = static_cast<LoadInst *>(inst);
if (auto base_alloca = trace_pointer_base_local_alloca_inst(load->variable());
base_alloca != nullptr && single_store.contains(base_alloca)) {
removable_loads.emplace_back(load);
}
}
});
// do the elimination
for (auto load : removable_loads) {
// convert load to extract
luisa::fixed_vector<Value *, 8u> extract_args;
LUISA_DEBUG_ASSERT(load->variable()->derived_value_tag() == DerivedValueTag::INSTRUCTION,
"Load variable must be an instruction.");
auto pointer = static_cast<Instruction *>(load->variable());
for (;;) {
if (auto tag = pointer->derived_instruction_tag(); tag == DerivedInstructionTag::ALLOCA) {
break;
} else if (tag == DerivedInstructionTag::GEP) {
auto gep = static_cast<GEPInst *>(pointer);
LUISA_DEBUG_ASSERT(gep->base()->derived_value_tag() == DerivedValueTag::INSTRUCTION,
"GEP base must be an instruction.");
auto sub_indices = gep->index_uses();
// note: we emplace the indices in reverse order to avoid
// expensive insertions at the beginning of the vector
for (auto iter = sub_indices.rbegin(); iter != sub_indices.rend(); ++iter) {
extract_args.emplace_back((*iter)->value());
}
pointer = static_cast<Instruction *>(gep->base());
} else {
LUISA_ERROR_WITH_LOCATION("Unexpected instruction type.");
}
}
// process the alloca pointer
LUISA_DEBUG_ASSERT(pointer->derived_instruction_tag() == DerivedInstructionTag::ALLOCA,
"Pointer must be an alloca.");
auto store = single_store[static_cast<AllocaInst *>(pointer)];
LUISA_DEBUG_ASSERT(store != nullptr, "Store must not be null.");
extract_args.emplace_back(store->value());
auto value = [&]() noexcept -> Value * {
// simple case: scalar load
if (extract_args.size() == 1u) { return extract_args.front(); }
// reverse the indices to the correct order
std::reverse(extract_args.begin(), extract_args.end());
// create the extract instruction
Builder builder;
builder.set_insertion_point(load);
return builder.call(load->type(), ArithmeticOp::EXTRACT, extract_args);
}();
load->replace_all_uses_with(value);
load->remove_self();
// record the elimination
info.forwarded_instructions.emplace(load, store);
}
}

static void run_local_store_forward_on_function(Function *function, LocalStoreForwardInfo &info) noexcept {
if (auto definition = function->definition()) {
luisa::unordered_set<BasicBlock *> visited;
definition->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept {
run_local_store_forward_on_basic_block(visited, block, info);
});
// first pass: forward stores to loads within straight-line code
forward_straight_line_stores_to_loads_on_function(definition, info);
// second pass: forward stores to loads from local variables that only have a single (or no) store
forward_single_store_to_loads_on_function(definition, info);
}
}

Expand Down

0 comments on commit 77fe8b7

Please sign in to comment.