Skip to content

Commit

Permalink
wip: promote_ref_arg pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Feb 12, 2025
1 parent e1fe5c6 commit 347eb39
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 0 deletions.
2 changes: 2 additions & 0 deletions include/luisa/luisa-compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@
#include <luisa/xir/module.h>
#include <luisa/xir/passes/aggregate_field_bitmask.h>
#include <luisa/xir/passes/autodiff.h>
#include <luisa/xir/passes/call_graph.h>
#include <luisa/xir/passes/dce.h>
#include <luisa/xir/passes/dom_tree.h>
#include <luisa/xir/passes/early_return_elimination.h>
Expand All @@ -248,6 +249,7 @@
#include <luisa/xir/passes/mem2reg.h>
#include <luisa/xir/passes/outline.h>
#include <luisa/xir/passes/pointer_usage.h>
#include <luisa/xir/passes/promote_ref_arg.h>
#include <luisa/xir/passes/reg2mem.h>
#include <luisa/xir/passes/sroa.h>
#include <luisa/xir/passes/trace_gep.h>
Expand Down
30 changes: 30 additions & 0 deletions include/luisa/xir/passes/call_graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <luisa/core/dll_export.h>
#include <luisa/core/stl/vector.h>
#include <luisa/core/stl/unordered_map.h>

namespace luisa::compute::xir {

class Function;
class FunctionDefinition;
class CallInst;

class LC_XIR_API CallGraph {

private:
luisa::vector<Function *> _root_functions;
luisa::unordered_map<FunctionDefinition *, luisa::vector<CallInst *>> _call_edges;

public:
// only for internal use
void _add_function(Function *f) noexcept;

public:
[[nodiscard]] luisa::span<Function *const> root_functions() const noexcept;
[[nodiscard]] luisa::span<CallInst *const> call_edges(FunctionDefinition *f) const noexcept;
};

[[nodiscard]] LC_XIR_API CallGraph compute_call_graph(Module *module) noexcept;

}// namespace luisa::compute::xir
18 changes: 18 additions & 0 deletions include/luisa/xir/passes/promote_ref_arg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <luisa/core/dll_export.h>
#include <luisa/core/stl/unordered_map.h>

namespace luisa::compute::xir {

class ReferenceArgument;
class ValueArgument;
class Module;

struct PromoteRefArgInfo {
luisa::unordered_map<ReferenceArgument *, ValueArgument *> promoted_ref_args;
};

[[nodiscard]] LC_XIR_API PromoteRefArgInfo promote_ref_arg_pass_run_on_module(Module *module) noexcept;

}// namespace luisa::compute::xir
2 changes: 2 additions & 0 deletions src/xir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
passes/helpers.cpp
passes/dce.cpp
passes/dom_tree.cpp
passes/call_graph.cpp
passes/outline.cpp
passes/sroa.cpp
passes/trace_gep.cpp
Expand All @@ -73,6 +74,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
passes/transpose_gep.cpp
passes/unused_callable_removal.cpp
passes/reg2mem.cpp
passes/promote_ref_arg.cpp
)

add_library(luisa-compute-xir SHARED ${LUISA_COMPUTE_XIR_SOURCES})
Expand Down
38 changes: 38 additions & 0 deletions src/xir/passes/call_graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <luisa/core/logging.h>
#include <luisa/xir/module.h>
#include <luisa/xir/function.h>
#include <luisa/xir/instructions/call.h>
#include <luisa/xir/passes/call_graph.h>

namespace luisa::compute::xir {

inline void CallGraph::_add_function(Function *f) noexcept {
auto any_caller = false;
for (auto &&use : f->use_list()) {
if (auto user = use.user(); user != nullptr && user->isa<CallInst>()) {
auto call = static_cast<CallInst *>(user);
auto caller = call->parent_function()->definition();
LUISA_DEBUG_ASSERT(caller != nullptr, "Invalid caller.");
_call_edges[caller].emplace_back(call);
any_caller = true;
}
}
if (!any_caller) { _root_functions.emplace_back(f); }
}

luisa::span<Function *const> CallGraph::root_functions() const noexcept {
return luisa::span{_root_functions};
}

luisa::span<CallInst *const> CallGraph::call_edges(FunctionDefinition *f) const noexcept {
auto iter = _call_edges.find(f);
return iter == _call_edges.cend() ? luisa::span<CallInst *const>{} : luisa::span{iter->second};
}

CallGraph compute_call_graph(Module *module) noexcept {
CallGraph graph;
for (auto &&f : module->function_list()) { graph._add_function(&f); }
return graph;
}

}// namespace luisa::compute::xir
108 changes: 108 additions & 0 deletions src/xir/passes/promote_ref_arg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include <luisa/core/logging.h>
#include <luisa/xir/module.h>
#include <luisa/xir/function.h>
#include <luisa/xir/instructions/call.h>
#include <luisa/xir/passes/call_graph.h>
#include <luisa/xir/passes/promote_ref_arg.h>

namespace luisa::compute::xir {

namespace detail {

struct ArgumentBitmap {

luisa::unordered_map<CallableFunction *, size_t> callable_bit_offsets;
luisa::bitvector write_bits;// records whether an argument is written to (either by this function or by a callee)
luisa::bitvector smem_bits; // records whether an argument might be a shared memory pointer

void register_callable(CallableFunction *f) noexcept {
auto offset = write_bits.size();
if (callable_bit_offsets.try_emplace(f, offset).second) {
write_bits.resize(offset + f->arguments().size(), false);
smem_bits.resize(offset + f->arguments().size(), false);
}
}

struct Range {

size_t offset;
luisa::bitvector &write_bits;
luisa::bitvector &smem_bits;

// returns true if changed
[[nodiscard]] auto _mark(luisa::bitvector &bits, size_t i) const noexcept {
if (!bits[offset + i]) {
bits[offset + i] = true;
return true;
}
return false;
}

[[nodiscard]] auto mark_write(size_t i) const noexcept { return _mark(write_bits, i); }
[[nodiscard]] auto mark_smem(size_t i) const noexcept { return _mark(smem_bits, i); }
};

[[nodiscard]] auto operator[](CallableFunction *f) noexcept {
auto iter = callable_bit_offsets.find(f);
LUISA_DEBUG_ASSERT(iter != callable_bit_offsets.end(), "Callable function not found.");
return Range{iter->second, write_bits, smem_bits};
}
};

// checks if a function is a promotable callable, i.e., it is a callable function and all of its uses are call instructions
[[nodiscard]] static auto is_promotable_callable(Function *f) noexcept {
if (f->isa<CallableFunction>()) {
for (auto &&use : f->use_list()) {
if (auto user = use.user(); user != nullptr && !user->isa<CallInst>()) {
return false;
}
}
return true;
}
return false;
}

static void traverse_call_graph_post_order(Function *f, const CallGraph &call_graph,
const ArgumentBitmap &bitmap,
luisa::unordered_set<Function *> &visited,
luisa::vector<CallableFunction *> &post_order) noexcept {
if (visited.emplace(f).second) {
if (auto def = f->definition()) {
auto edges = call_graph.call_edges(def);
for (auto &&call : edges) {
traverse_call_graph_post_order(call->callee(), call_graph, bitmap, visited, post_order);
}
if (def->isa<CallableFunction>() && bitmap.callable_bit_offsets.contains(static_cast<CallableFunction *>(def))) {
post_order.emplace_back(f);
}
}
}
}

static void promote_ref_args_in_module(Module *m, PromoteRefArgInfo &info) noexcept {
ArgumentBitmap bitmap;
for (auto &&f : m->function_list()) {
if (is_promotable_callable(&f)) {
bitmap.register_callable(static_cast<CallableFunction *>(&f));
}
}
auto call_graph = compute_call_graph(m);
luisa::vector<CallableFunction *> post_order;
{
post_order.reserve(bitmap.callable_bit_offsets.size());
luisa::unordered_set<Function *> visited;
for (auto &&f : call_graph.root_functions()) {
traverse_call_graph_post_order(f, call_graph, bitmap, visited, post_order);
}
}
}

}// namespace detail

PromoteRefArgInfo promote_ref_arg_pass_run_on_module(Module *module) noexcept {
PromoteRefArgInfo info;
detail::promote_ref_args_in_module(module, info);
return info;
}

}// namespace luisa::compute::xir

0 comments on commit 347eb39

Please sign in to comment.