From 77fe8b70972486557f84973d5dec4a44fe9d8921 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Sun, 12 Jan 2025 19:27:00 +0800 Subject: [PATCH] enhance store-load forwarding for single-store allocas --- .../luisa/xir/passes/local_store_forward.h | 8 +- src/backends/fallback/fallback_codegen.cpp | 7 +- src/xir/passes/local_store_forward.cpp | 112 +++++++++++++++++- 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/include/luisa/xir/passes/local_store_forward.h b/include/luisa/xir/passes/local_store_forward.h index aa26e1b53..1936e3dbe 100644 --- a/include/luisa/xir/passes/local_store_forward.h +++ b/include/luisa/xir/passes/local_store_forward.h @@ -12,9 +12,11 @@ class BasicBlock; class Function; class Module; -// This pass is used to forward stores to loads for scalar variables -// within straight-line basic blocks. It is a simple peephole optimization -// that can be used to reduce the number of memory operations. +// This pass is used to forward stores to loads for thread-local variables within +// straight-line basic blocks. It is a simple peephole-style optimization but can +// effectively reduce the number of memory operations produced by the C++ DSL. +// Note: we do not remove the stores after forwarding them to loads, as the DCE +// pass should be able to remove them if they are dead. struct LocalStoreForwardInfo { luisa::unordered_map forwarded_instructions; diff --git a/src/backends/fallback/fallback_codegen.cpp b/src/backends/fallback/fallback_codegen.cpp index da567cad7..ba599c8a7 100644 --- a/src/backends/fallback/fallback_codegen.cpp +++ b/src/backends/fallback/fallback_codegen.cpp @@ -3051,7 +3051,12 @@ class FallbackCodegen { // loop merge b.SetInsertPoint(llvm_loop_merge_block); b.CreateRetVoid(); - + // hoist the loop variable to the top + { + auto &llvm_entry = llvm_wrapper_function->getEntryBlock(); + auto &llvm_first_inst = llvm_entry.front(); + llvm_ptr_i->moveBefore(&llvm_first_inst); + } return llvm_wrapper_function; } diff --git a/src/xir/passes/local_store_forward.cpp b/src/xir/passes/local_store_forward.cpp index eaf275f3b..5c39c86b3 100644 --- a/src/xir/passes/local_store_forward.cpp +++ b/src/xir/passes/local_store_forward.cpp @@ -1,3 +1,5 @@ +#include "luisa/core/logging.h" + #include #include #include @@ -105,12 +107,114 @@ static void run_local_store_forward_on_basic_block(luisa::unordered_set visited; + function->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept { + run_local_store_forward_on_basic_block(visited, block, info); + }); +} + +// find and remove all loads from local variables that only have a single (or no) store +static void forward_single_store_to_loads_on_function(FunctionDefinition *function, LocalStoreForwardInfo &info) noexcept { + luisa::unordered_map single_store; + // search for local variables that only have a single store + { + luisa::unordered_map store_count; + function->traverse_instructions([&](Instruction *inst) noexcept { + switch (inst->derived_instruction_tag()) { + case DerivedInstructionTag::LOAD: [[fallthrough]]; + case DerivedInstructionTag::GEP: break; + default: { + for (auto op_use : inst->operand_uses()) { + if (auto base_alloca = trace_pointer_base_local_alloca_inst(op_use->value())) { + store_count.try_emplace(base_alloca, 0u).first->second++; + } + } + break; + } + } + }); + for (auto [alloca_inst, count] : store_count) { + if (count == 1u) { + for (auto &&use : alloca_inst->use_list()) { + if (auto user = use.user(); + user->derived_value_tag() == DerivedValueTag::INSTRUCTION && + static_cast(user)->derived_instruction_tag() == DerivedInstructionTag::STORE) { + auto store_inst = static_cast(user); + LUISA_DEBUG_ASSERT(store_inst->variable() == alloca_inst, "Store variable must match alloca."); + single_store.emplace(alloca_inst, store_inst); + break; + } + } + } + } + } + // collect the loads that might be eliminated + luisa::vector removable_loads; + function->traverse_instructions([&](Instruction *inst) noexcept { + if (inst->derived_instruction_tag() == DerivedInstructionTag::LOAD) { + auto load = static_cast(inst); + if (auto base_alloca = trace_pointer_base_local_alloca_inst(load->variable()); + base_alloca != nullptr && single_store.contains(base_alloca)) { + removable_loads.emplace_back(load); + } + } + }); + // do the elimination + for (auto load : removable_loads) { + // convert load to extract + luisa::fixed_vector extract_args; + LUISA_DEBUG_ASSERT(load->variable()->derived_value_tag() == DerivedValueTag::INSTRUCTION, + "Load variable must be an instruction."); + auto pointer = static_cast(load->variable()); + for (;;) { + if (auto tag = pointer->derived_instruction_tag(); tag == DerivedInstructionTag::ALLOCA) { + break; + } else if (tag == DerivedInstructionTag::GEP) { + auto gep = static_cast(pointer); + LUISA_DEBUG_ASSERT(gep->base()->derived_value_tag() == DerivedValueTag::INSTRUCTION, + "GEP base must be an instruction."); + auto sub_indices = gep->index_uses(); + // note: we emplace the indices in reverse order to avoid + // expensive insertions at the beginning of the vector + for (auto iter = sub_indices.rbegin(); iter != sub_indices.rend(); ++iter) { + extract_args.emplace_back((*iter)->value()); + } + pointer = static_cast(gep->base()); + } else { + LUISA_ERROR_WITH_LOCATION("Unexpected instruction type."); + } + } + // process the alloca pointer + LUISA_DEBUG_ASSERT(pointer->derived_instruction_tag() == DerivedInstructionTag::ALLOCA, + "Pointer must be an alloca."); + auto store = single_store[static_cast(pointer)]; + LUISA_DEBUG_ASSERT(store != nullptr, "Store must not be null."); + extract_args.emplace_back(store->value()); + auto value = [&]() noexcept -> Value * { + // simple case: scalar load + if (extract_args.size() == 1u) { return extract_args.front(); } + // reverse the indices to the correct order + std::reverse(extract_args.begin(), extract_args.end()); + // create the extract instruction + Builder builder; + builder.set_insertion_point(load); + return builder.call(load->type(), ArithmeticOp::EXTRACT, extract_args); + }(); + load->replace_all_uses_with(value); + load->remove_self(); + // record the elimination + info.forwarded_instructions.emplace(load, store); + } +} + static void run_local_store_forward_on_function(Function *function, LocalStoreForwardInfo &info) noexcept { if (auto definition = function->definition()) { - luisa::unordered_set visited; - definition->traverse_basic_blocks(BasicBlockTraversalOrder::REVERSE_POST_ORDER, [&](BasicBlock *block) noexcept { - run_local_store_forward_on_basic_block(visited, block, info); - }); + // first pass: forward stores to loads within straight-line code + forward_straight_line_stores_to_loads_on_function(definition, info); + // second pass: forward stores to loads from local variables that only have a single (or no) store + forward_single_store_to_loads_on_function(definition, info); } }