-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e1fe5c6
commit 347eb39
Showing
6 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#pragma once | ||
|
||
#include <luisa/core/dll_export.h> | ||
#include <luisa/core/stl/vector.h> | ||
#include <luisa/core/stl/unordered_map.h> | ||
|
||
namespace luisa::compute::xir { | ||
|
||
class Function; | ||
class FunctionDefinition; | ||
class CallInst; | ||
|
||
class LC_XIR_API CallGraph { | ||
|
||
private: | ||
luisa::vector<Function *> _root_functions; | ||
luisa::unordered_map<FunctionDefinition *, luisa::vector<CallInst *>> _call_edges; | ||
|
||
public: | ||
// only for internal use | ||
void _add_function(Function *f) noexcept; | ||
|
||
public: | ||
[[nodiscard]] luisa::span<Function *const> root_functions() const noexcept; | ||
[[nodiscard]] luisa::span<CallInst *const> call_edges(FunctionDefinition *f) const noexcept; | ||
}; | ||
|
||
[[nodiscard]] LC_XIR_API CallGraph compute_call_graph(Module *module) noexcept; | ||
|
||
}// namespace luisa::compute::xir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#pragma once | ||
|
||
#include <luisa/core/dll_export.h> | ||
#include <luisa/core/stl/unordered_map.h> | ||
|
||
namespace luisa::compute::xir { | ||
|
||
class ReferenceArgument; | ||
class ValueArgument; | ||
class Module; | ||
|
||
struct PromoteRefArgInfo { | ||
luisa::unordered_map<ReferenceArgument *, ValueArgument *> promoted_ref_args; | ||
}; | ||
|
||
[[nodiscard]] LC_XIR_API PromoteRefArgInfo promote_ref_arg_pass_run_on_module(Module *module) noexcept; | ||
|
||
}// namespace luisa::compute::xir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include <luisa/core/logging.h> | ||
#include <luisa/xir/module.h> | ||
#include <luisa/xir/function.h> | ||
#include <luisa/xir/instructions/call.h> | ||
#include <luisa/xir/passes/call_graph.h> | ||
|
||
namespace luisa::compute::xir { | ||
|
||
inline void CallGraph::_add_function(Function *f) noexcept { | ||
auto any_caller = false; | ||
for (auto &&use : f->use_list()) { | ||
if (auto user = use.user(); user != nullptr && user->isa<CallInst>()) { | ||
auto call = static_cast<CallInst *>(user); | ||
auto caller = call->parent_function()->definition(); | ||
LUISA_DEBUG_ASSERT(caller != nullptr, "Invalid caller."); | ||
_call_edges[caller].emplace_back(call); | ||
any_caller = true; | ||
} | ||
} | ||
if (!any_caller) { _root_functions.emplace_back(f); } | ||
} | ||
|
||
luisa::span<Function *const> CallGraph::root_functions() const noexcept { | ||
return luisa::span{_root_functions}; | ||
} | ||
|
||
luisa::span<CallInst *const> CallGraph::call_edges(FunctionDefinition *f) const noexcept { | ||
auto iter = _call_edges.find(f); | ||
return iter == _call_edges.cend() ? luisa::span<CallInst *const>{} : luisa::span{iter->second}; | ||
} | ||
|
||
CallGraph compute_call_graph(Module *module) noexcept { | ||
CallGraph graph; | ||
for (auto &&f : module->function_list()) { graph._add_function(&f); } | ||
return graph; | ||
} | ||
|
||
}// namespace luisa::compute::xir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#include <luisa/core/logging.h> | ||
#include <luisa/xir/module.h> | ||
#include <luisa/xir/function.h> | ||
#include <luisa/xir/instructions/call.h> | ||
#include <luisa/xir/passes/call_graph.h> | ||
#include <luisa/xir/passes/promote_ref_arg.h> | ||
|
||
namespace luisa::compute::xir { | ||
|
||
namespace detail { | ||
|
||
struct ArgumentBitmap { | ||
|
||
luisa::unordered_map<CallableFunction *, size_t> callable_bit_offsets; | ||
luisa::bitvector write_bits;// records whether an argument is written to (either by this function or by a callee) | ||
luisa::bitvector smem_bits; // records whether an argument might be a shared memory pointer | ||
|
||
void register_callable(CallableFunction *f) noexcept { | ||
auto offset = write_bits.size(); | ||
if (callable_bit_offsets.try_emplace(f, offset).second) { | ||
write_bits.resize(offset + f->arguments().size(), false); | ||
smem_bits.resize(offset + f->arguments().size(), false); | ||
} | ||
} | ||
|
||
struct Range { | ||
|
||
size_t offset; | ||
luisa::bitvector &write_bits; | ||
luisa::bitvector &smem_bits; | ||
|
||
// returns true if changed | ||
[[nodiscard]] auto _mark(luisa::bitvector &bits, size_t i) const noexcept { | ||
if (!bits[offset + i]) { | ||
bits[offset + i] = true; | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
[[nodiscard]] auto mark_write(size_t i) const noexcept { return _mark(write_bits, i); } | ||
[[nodiscard]] auto mark_smem(size_t i) const noexcept { return _mark(smem_bits, i); } | ||
}; | ||
|
||
[[nodiscard]] auto operator[](CallableFunction *f) noexcept { | ||
auto iter = callable_bit_offsets.find(f); | ||
LUISA_DEBUG_ASSERT(iter != callable_bit_offsets.end(), "Callable function not found."); | ||
return Range{iter->second, write_bits, smem_bits}; | ||
} | ||
}; | ||
|
||
// checks if a function is a promotable callable, i.e., it is a callable function and all of its uses are call instructions | ||
[[nodiscard]] static auto is_promotable_callable(Function *f) noexcept { | ||
if (f->isa<CallableFunction>()) { | ||
for (auto &&use : f->use_list()) { | ||
if (auto user = use.user(); user != nullptr && !user->isa<CallInst>()) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
static void traverse_call_graph_post_order(Function *f, const CallGraph &call_graph, | ||
const ArgumentBitmap &bitmap, | ||
luisa::unordered_set<Function *> &visited, | ||
luisa::vector<CallableFunction *> &post_order) noexcept { | ||
if (visited.emplace(f).second) { | ||
if (auto def = f->definition()) { | ||
auto edges = call_graph.call_edges(def); | ||
for (auto &&call : edges) { | ||
traverse_call_graph_post_order(call->callee(), call_graph, bitmap, visited, post_order); | ||
} | ||
if (def->isa<CallableFunction>() && bitmap.callable_bit_offsets.contains(static_cast<CallableFunction *>(def))) { | ||
post_order.emplace_back(f); | ||
} | ||
} | ||
} | ||
} | ||
|
||
static void promote_ref_args_in_module(Module *m, PromoteRefArgInfo &info) noexcept { | ||
ArgumentBitmap bitmap; | ||
for (auto &&f : m->function_list()) { | ||
if (is_promotable_callable(&f)) { | ||
bitmap.register_callable(static_cast<CallableFunction *>(&f)); | ||
} | ||
} | ||
auto call_graph = compute_call_graph(m); | ||
luisa::vector<CallableFunction *> post_order; | ||
{ | ||
post_order.reserve(bitmap.callable_bit_offsets.size()); | ||
luisa::unordered_set<Function *> visited; | ||
for (auto &&f : call_graph.root_functions()) { | ||
traverse_call_graph_post_order(f, call_graph, bitmap, visited, post_order); | ||
} | ||
} | ||
} | ||
|
||
}// namespace detail | ||
|
||
PromoteRefArgInfo promote_ref_arg_pass_run_on_module(Module *module) noexcept { | ||
PromoteRefArgInfo info; | ||
detail::promote_ref_args_in_module(module, info); | ||
return info; | ||
} | ||
|
||
}// namespace luisa::compute::xir |