Skip to content

Commit

Permalink
wip: reference arg usage pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 9, 2025
1 parent 80b3d48 commit 10719a9
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 11 deletions.
2 changes: 2 additions & 0 deletions include/luisa/luisa-compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@
#include <luisa/xir/metadata/location.h>
#include <luisa/xir/metadata/name.h>
#include <luisa/xir/module.h>
#include <luisa/xir/passes/aggregate_field_bitmask.h>
#include <luisa/xir/passes/dce.h>
#include <luisa/xir/passes/dom_tree.h>
#include <luisa/xir/passes/outline.h>
#include <luisa/xir/passes/ref_arg_usage.h>
#include <luisa/xir/passes/sink_alloca.h>
#include <luisa/xir/passes/trace_gep.h>
#include <luisa/xir/pool.h>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <luisa/core/dll_export.h>
#include <luisa/core/stl/memory.h>

namespace luisa::compute {
Expand All @@ -12,7 +13,7 @@ namespace detail {
class AggregateFieldTree;
}// namespace detail

class alignas(16) AggregateFieldBitmask {
class LC_XIR_API alignas(16) AggregateFieldBitmask {

private:
const detail::AggregateFieldTree *_field_tree;
Expand Down Expand Up @@ -43,23 +44,40 @@ class alignas(16) AggregateFieldBitmask {
[[nodiscard]] const Type *type() const noexcept;

public:
class ConstBitSpan {
class LC_XIR_API ConstBitSpan {
protected:
uint64_t *_bits;
uint32_t _offset;
uint32_t _size;
public:
ConstBitSpan(uint64_t *bits, uint32_t offset, uint32_t size) noexcept
: _bits{bits}, _offset{offset}, _size{size} {}
[[nodiscard]] const uint64_t *raw_bits() const noexcept { return _bits; }
[[nodiscard]] size_t offset() const noexcept { return _offset; }
[[nodiscard]] size_t size() const noexcept { return _size; }
[[nodiscard]] bool all() const noexcept;
[[nodiscard]] bool any() const noexcept;
[[nodiscard]] bool none() const noexcept;
};
class BitSpan : public ConstBitSpan {
class LC_XIR_API BitSpan : public ConstBitSpan {
public:
using ConstBitSpan::ConstBitSpan;
void set(bool value = true) && noexcept;
void flip() && noexcept;
[[nodiscard]] uint64_t *raw_bits() noexcept { return _bits; }

void set(bool value = true) noexcept;
void flip() noexcept;

BitSpan &operator|=(const ConstBitSpan &rhs) noexcept;
BitSpan &operator&=(const ConstBitSpan &rhs) noexcept;
BitSpan &operator^=(const ConstBitSpan &rhs) noexcept;
[[nodiscard]] bool operator==(const ConstBitSpan &rhs) const noexcept;
[[nodiscard]] bool operator!=(const ConstBitSpan &rhs) const noexcept;

BitSpan &operator|=(const AggregateFieldBitmask &rhs) noexcept { return *this |= rhs.access(); }
BitSpan &operator&=(const AggregateFieldBitmask &rhs) noexcept { return *this &= rhs.access(); }
BitSpan &operator^=(const AggregateFieldBitmask &rhs) noexcept { return *this ^= rhs.access(); }
[[nodiscard]] bool operator==(const AggregateFieldBitmask &rhs) const noexcept { return *this == rhs.access(); }
[[nodiscard]] bool operator!=(const AggregateFieldBitmask &rhs) const noexcept { return *this != rhs.access(); }
};
[[nodiscard]] BitSpan access(luisa::span<const size_t> access_chain) noexcept;
[[nodiscard]] ConstBitSpan access(luisa::span<const size_t> access_chain) const noexcept;
Expand Down
7 changes: 7 additions & 0 deletions include/luisa/xir/passes/ref_arg_usage.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

namespace luisa::compute::xir {



}
1 change: 1 addition & 0 deletions src/xir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ set(LUISA_COMPUTE_XIR_SOURCES
passes/sink_alloca.cpp
passes/trace_gep.cpp
passes/aggregate_field_bitmask.cpp
passes/ref_arg_usage.cpp
)

add_library(luisa-compute-xir SHARED ${LUISA_COMPUTE_XIR_SOURCES})
Expand Down
57 changes: 53 additions & 4 deletions src/xir/passes/aggregate_field_bitmask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
#include <luisa/core/stl/vector.h>
#include <luisa/core/stl/unordered_map.h>
#include <luisa/ast/type.h>

#include "aggregate_field_bitmask.h"
#include <luisa/xir/passes/aggregate_field_bitmask.h>

namespace luisa::compute::xir {

Expand Down Expand Up @@ -211,7 +210,7 @@ void AggregateFieldBitmask::flip() noexcept {
}
}

void AggregateFieldBitmask::BitSpan::set(bool value) && noexcept {
void AggregateFieldBitmask::BitSpan::set(bool value) noexcept {
auto lower = _offset / 64u;
auto upper = (_offset + _size - 1u) / 64u;
if (lower == upper) {// all selected bits are in the same bucket
Expand Down Expand Up @@ -252,7 +251,7 @@ void AggregateFieldBitmask::BitSpan::set(bool value) && noexcept {
}
}

void AggregateFieldBitmask::BitSpan::flip() && noexcept {
void AggregateFieldBitmask::BitSpan::flip() noexcept {
auto lower = _offset / 64u;
auto upper = (_offset + _size - 1u) / 64u;
if (lower == upper) {// all selected bits are in the same bucket
Expand All @@ -277,6 +276,56 @@ void AggregateFieldBitmask::BitSpan::flip() && noexcept {
}
}

// TODO: Implement the following methods in a SIMD-friendly way
AggregateFieldBitmask::BitSpan &AggregateFieldBitmask::BitSpan::operator|=(const ConstBitSpan &rhs) noexcept {
LUISA_DEBUG_ASSERT(_size == rhs.size(), "Size mismatch.");
for (auto i = 0u; i < _size; i++) {
if (auto rhs_bucket = rhs.raw_bits()[(i + rhs.offset()) / 64u];
(rhs_bucket >> ((i + rhs.offset()) % 64u)) & 1ull) {
_bits[(_offset + i) / 64u] |= 1ull << ((_offset + i) % 64u);
}
}
return *this;
}

AggregateFieldBitmask::BitSpan &AggregateFieldBitmask::BitSpan::operator&=(const ConstBitSpan &rhs) noexcept {
LUISA_DEBUG_ASSERT(_size == rhs.size(), "Size mismatch.");
for (auto i = 0u; i < _size; i++) {
if (auto rhs_bucket = rhs.raw_bits()[(i + rhs.offset()) / 64u];
(rhs_bucket >> ((i + rhs.offset()) % 64u)) & 1ull) {
_bits[(_offset + i) / 64u] &= 1ull << ((_offset + i) % 64u);
}
}
return *this;
}

AggregateFieldBitmask::BitSpan &AggregateFieldBitmask::BitSpan::operator^=(const ConstBitSpan &rhs) noexcept {
LUISA_DEBUG_ASSERT(_size == rhs.size(), "Size mismatch.");
for (auto i = 0u; i < _size; i++) {
if (auto rhs_bucket = rhs.raw_bits()[(i + rhs.offset()) / 64u];
(rhs_bucket >> ((i + rhs.offset()) % 64u)) & 1ull) {
_bits[(_offset + i) / 64u] ^= 1ull << ((_offset + i) % 64u);
}
}
return *this;
}

bool AggregateFieldBitmask::BitSpan::operator==(const ConstBitSpan &rhs) const noexcept {
if (_size != rhs.size()) { return false; }
if (this != &rhs) {
for (auto i = 0u; i < _size; i++) {
auto lhs_bit = (_bits[(_offset + i) / 64u] >> ((_offset + i) % 64u)) & 1ull;
auto rhs_bit = (rhs.raw_bits()[i / 64u] >> (i % 64u)) & 1ull;
if (lhs_bit != rhs_bit) { return false; }
}
}
return true;
}

bool AggregateFieldBitmask::BitSpan::operator!=(const ConstBitSpan &rhs) const noexcept {
return !(*this == rhs);
}

bool AggregateFieldBitmask::ConstBitSpan::all() const noexcept {
auto lower = _offset / 64u;
auto upper = (_offset + _size - 1u) / 64u;
Expand Down
5 changes: 5 additions & 0 deletions src/xir/passes/ref_arg_usage.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <luisa/xir/passes/ref_arg_usage.h>

namespace luisa::compute::xir {

}// namespace luisa::compute::xir
2 changes: 0 additions & 2 deletions src/xir/tests/test_aggregate_field_bitmasks.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include <luisa/luisa-compute.h>
#include "../passes/aggregate_field_bitmask.h"
#include "luisa/dsl/struct.h"

using namespace luisa;
using namespace luisa::compute;
Expand Down

0 comments on commit 10719a9

Please sign in to comment.