Skip to content

Commit

Permalink
minor improvements on dce
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 10, 2025
1 parent 0103272 commit 909edfc
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 9 deletions.
23 changes: 17 additions & 6 deletions include/luisa/xir/passes/local_load_elimination.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
//
// Created by mike on 1/10/25.
//
#pragma once

#ifndef PEEPHOLE_LOAD_ELIMINATION_H
#define PEEPHOLE_LOAD_ELIMINATION_H
#include <luisa/core/dll_export.h>
#include <luisa/core/stl/unordered_map.h>

#endif //PEEPHOLE_LOAD_ELIMINATION_H
namespace luisa::compute::xir {

class LoadInst;
class Function;
class Module;

struct LocalLoadEliminationInfo {
luisa::unordered_map<LoadInst *, LoadInst *> eliminated_instructions;
};

[[nodiscard]] LC_XIR_API LocalLoadEliminationInfo local_load_elimination_pass_run_on_function(Function *function) noexcept;
[[nodiscard]] LC_XIR_API LocalLoadEliminationInfo local_load_elimination_pass_run_on_module(Module *module) noexcept;

}// namespace luisa::compute::xir
60 changes: 60 additions & 0 deletions src/xir/passes/dce.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <luisa/core/logging.h>
#include <luisa/core/stl/optional.h>
#include <luisa/xir/instructions/intrinsic.h>
#include <luisa/xir/passes/dce.h>
#include <luisa/xir/builder.h>
Expand Down Expand Up @@ -199,6 +200,34 @@ void propagate_unreachable_marks_in_function(Function *function, DCEInfo &info)
}
}

[[nodiscard]] static luisa::optional<bool> try_evaluate_static_branch_condition(Value *cond) noexcept {
LUISA_DEBUG_ASSERT(cond != nullptr, "Branch condition must not be null.");
if (cond->derived_value_tag() != DerivedValueTag::CONSTANT) { return luisa::nullopt; }
auto static_cond = static_cast<Constant *>(cond);
LUISA_DEBUG_ASSERT(const_cond->type()->is_bool(), "Branch condition must be a boolean constant.");
return static_cond->as<bool>();
}

[[nodiscard]] static luisa::optional<SwitchInst::case_value_type> try_evaluate_static_switch_condition(Value *cond) noexcept {
LUISA_DEBUG_ASSERT(cond != nullptr, "Switch condition must not be null.");
if (cond->derived_value_tag() != DerivedValueTag::CONSTANT) { return luisa::nullopt; }
return [static_cond = static_cast<Constant *>(cond)]() noexcept -> SwitchInst::case_value_type {
switch (auto t = static_cond->type(); t->tag()) {
case Type::Tag::BOOL: return static_cond->as<bool>();
case Type::Tag::INT8: return static_cond->as<int8_t>();
case Type::Tag::UINT8: return static_cond->as<uint8_t>();
case Type::Tag::INT16: return static_cond->as<int16_t>();
case Type::Tag::UINT16: return static_cond->as<uint16_t>();
case Type::Tag::INT32: return static_cond->as<int32_t>();
case Type::Tag::UINT32: return static_cond->as<uint32_t>();
case Type::Tag::INT64: return static_cond->as<int64_t>();
case Type::Tag::UINT64: return static_cond->as<uint64_t>();
default: break;
}
LUISA_ERROR_WITH_LOCATION("Invalid switch condition type.");
}();
}

void eliminate_unreachable_blocks_in_function(Function *function, DCEInfo &info) noexcept {
if (auto definition = function->definition()) {
luisa::unordered_set<BasicBlock *> reachable;
Expand All @@ -221,6 +250,37 @@ void eliminate_unreachable_blocks_in_function(Function *function, DCEInfo &info)
}
}
});
// also check if the terminator is a constant branch
switch (auto terminator = b->terminator(); terminator->derived_instruction_tag()) {
case DerivedInstructionTag::IF: [[fallthrough]];
case DerivedInstructionTag::CONDITIONAL_BRANCH: {
auto cond_br_inst = static_cast<ConditionalBranchTerminatorInstruction *>(terminator);
if (auto static_cond = try_evaluate_static_branch_condition(cond_br_inst->condition())) {
unreachable.emplace(*static_cond ? cond_br_inst->false_block() : cond_br_inst->true_block());
}
break;
}
case DerivedInstructionTag::SWITCH: {
auto switch_inst = static_cast<SwitchInst *>(terminator);
if (auto static_cond = try_evaluate_static_switch_condition(switch_inst->value())) {
auto any_match = false;
for (auto i = 0u; i < switch_inst->case_count(); i++) {
if (switch_inst->case_value(i) == *static_cond) {
any_match = true;
} else {
LUISA_DEBUG_ASSERT(switch_inst->case_block(i) != nullptr, "Switch case block must not be null.");
unreachable.emplace(switch_inst->case_block(i));
}
}
if (any_match) {
LUISA_DEBUG_ASSERT(switch_inst->default_block() != nullptr, "Switch default block must not be null.");
unreachable.emplace(switch_inst->default_block());
}
}
break;
}
default: break;
}
}
// eliminate all instructions in unreachable blocks
eliminate_instructions_in_unreachable_blocks(unreachable, info);
Expand Down
8 changes: 5 additions & 3 deletions src/xir/passes/local_load_elimination.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//
// Created by mike on 1/10/25.
//
#include <luisa/xir/passes/local_load_elimination.h>

namespace luisa::compute::xir {

}

0 comments on commit 909edfc

Please sign in to comment.