diff --git a/include/luisa/xir/passes/local_load_elimination.h b/include/luisa/xir/passes/local_load_elimination.h index f1af622ef..92a2887e8 100644 --- a/include/luisa/xir/passes/local_load_elimination.h +++ b/include/luisa/xir/passes/local_load_elimination.h @@ -1,8 +1,19 @@ -// -// Created by mike on 1/10/25. -// +#pragma once -#ifndef PEEPHOLE_LOAD_ELIMINATION_H -#define PEEPHOLE_LOAD_ELIMINATION_H +#include +#include -#endif //PEEPHOLE_LOAD_ELIMINATION_H +namespace luisa::compute::xir { + +class LoadInst; +class Function; +class Module; + +struct LocalLoadEliminationInfo { + luisa::unordered_map eliminated_instructions; +}; + +[[nodiscard]] LC_XIR_API LocalLoadEliminationInfo local_load_elimination_pass_run_on_function(Function *function) noexcept; +[[nodiscard]] LC_XIR_API LocalLoadEliminationInfo local_load_elimination_pass_run_on_module(Module *module) noexcept; + +}// namespace luisa::compute::xir diff --git a/src/xir/passes/dce.cpp b/src/xir/passes/dce.cpp index b4e6c1cbc..844d09d16 100644 --- a/src/xir/passes/dce.cpp +++ b/src/xir/passes/dce.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -199,6 +200,34 @@ void propagate_unreachable_marks_in_function(Function *function, DCEInfo &info) } } +[[nodiscard]] static luisa::optional try_evaluate_static_branch_condition(Value *cond) noexcept { + LUISA_DEBUG_ASSERT(cond != nullptr, "Branch condition must not be null."); + if (cond->derived_value_tag() != DerivedValueTag::CONSTANT) { return luisa::nullopt; } + auto static_cond = static_cast(cond); + LUISA_DEBUG_ASSERT(const_cond->type()->is_bool(), "Branch condition must be a boolean constant."); + return static_cond->as(); +} + +[[nodiscard]] static luisa::optional try_evaluate_static_switch_condition(Value *cond) noexcept { + LUISA_DEBUG_ASSERT(cond != nullptr, "Switch condition must not be null."); + if (cond->derived_value_tag() != DerivedValueTag::CONSTANT) { return luisa::nullopt; } + return [static_cond = static_cast(cond)]() noexcept -> SwitchInst::case_value_type { + switch (auto t = static_cond->type(); t->tag()) { + case Type::Tag::BOOL: return static_cond->as(); + case Type::Tag::INT8: return static_cond->as(); + case Type::Tag::UINT8: return static_cond->as(); + case Type::Tag::INT16: return static_cond->as(); + case Type::Tag::UINT16: return static_cond->as(); + case Type::Tag::INT32: return static_cond->as(); + case Type::Tag::UINT32: return static_cond->as(); + case Type::Tag::INT64: return static_cond->as(); + case Type::Tag::UINT64: return static_cond->as(); + default: break; + } + LUISA_ERROR_WITH_LOCATION("Invalid switch condition type."); + }(); +} + void eliminate_unreachable_blocks_in_function(Function *function, DCEInfo &info) noexcept { if (auto definition = function->definition()) { luisa::unordered_set reachable; @@ -221,6 +250,37 @@ void eliminate_unreachable_blocks_in_function(Function *function, DCEInfo &info) } } }); + // also check if the terminator is a constant branch + switch (auto terminator = b->terminator(); terminator->derived_instruction_tag()) { + case DerivedInstructionTag::IF: [[fallthrough]]; + case DerivedInstructionTag::CONDITIONAL_BRANCH: { + auto cond_br_inst = static_cast(terminator); + if (auto static_cond = try_evaluate_static_branch_condition(cond_br_inst->condition())) { + unreachable.emplace(*static_cond ? cond_br_inst->false_block() : cond_br_inst->true_block()); + } + break; + } + case DerivedInstructionTag::SWITCH: { + auto switch_inst = static_cast(terminator); + if (auto static_cond = try_evaluate_static_switch_condition(switch_inst->value())) { + auto any_match = false; + for (auto i = 0u; i < switch_inst->case_count(); i++) { + if (switch_inst->case_value(i) == *static_cond) { + any_match = true; + } else { + LUISA_DEBUG_ASSERT(switch_inst->case_block(i) != nullptr, "Switch case block must not be null."); + unreachable.emplace(switch_inst->case_block(i)); + } + } + if (any_match) { + LUISA_DEBUG_ASSERT(switch_inst->default_block() != nullptr, "Switch default block must not be null."); + unreachable.emplace(switch_inst->default_block()); + } + } + break; + } + default: break; + } } // eliminate all instructions in unreachable blocks eliminate_instructions_in_unreachable_blocks(unreachable, info); diff --git a/src/xir/passes/local_load_elimination.cpp b/src/xir/passes/local_load_elimination.cpp index 2bd1b4a72..5b2cbb993 100644 --- a/src/xir/passes/local_load_elimination.cpp +++ b/src/xir/passes/local_load_elimination.cpp @@ -1,3 +1,5 @@ -// -// Created by mike on 1/10/25. -// +#include + +namespace luisa::compute::xir { + +}