Skip to content

Commit

Permalink
remove logic not/and/or in xir
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 10, 2025
1 parent 38c9bf4 commit 854eb7c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 68 deletions.
4 changes: 0 additions & 4 deletions include/luisa/xir/instructions/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ enum class ArithmeticOp {
// unary operators
UNARY_PLUS, // +x
UNARY_MINUS, // -x
UNARY_LOGIC_NOT,// !x
UNARY_BIT_NOT, // ~x

// binary operators
Expand All @@ -19,9 +18,6 @@ enum class ArithmeticOp {
BINARY_DIV,
BINARY_MOD,

BINARY_LOGIC_AND,
BINARY_LOGIC_OR,

BINARY_BIT_AND,
BINARY_BIT_OR,
BINARY_BIT_XOR,
Expand Down
52 changes: 0 additions & 52 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,15 +544,6 @@ class FallbackCodegen {
return b.CreateZExt(value, i8_type);
}

[[nodiscard]] llvm::Value *_translate_unary_logic_not(CurrentFunction &current, IRBuilder &b, const xir::Value *operand) noexcept {
auto llvm_operand = _lookup_value(current, b, operand);
auto operand_type = operand->type();
LUISA_ASSERT(operand_type != nullptr, "Operand type is null.");
LUISA_ASSERT(operand_type->is_scalar() || operand_type->is_vector(), "Invalid operand type.");
auto llvm_cmp = _cmp_eq_zero(b, llvm_operand);
return _zext_i1_to_i8(b, llvm_cmp);
}

[[nodiscard]] llvm::Value *_translate_unary_bit_not(CurrentFunction &current, IRBuilder &b, const xir::Value *operand) noexcept {
auto llvm_operand = _lookup_value(current, b, operand);
LUISA_ASSERT(llvm_operand->getType()->isIntOrIntVectorTy() &&
Expand Down Expand Up @@ -669,46 +660,6 @@ class FallbackCodegen {
LUISA_ERROR_WITH_LOCATION("Invalid binary mod operand type: {}.", elem_type->description());
}

[[nodiscard]] llvm::Value *_translate_binary_logic_and(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
auto llvm_rhs = _lookup_value(current, b, rhs);
auto lhs_type = lhs->type();
auto rhs_type = rhs->type();
// Type and null checks
LUISA_ASSERT(lhs_type != nullptr && rhs_type != nullptr, "Operand type is null.");
LUISA_ASSERT(lhs_type == rhs_type, "Type mismatch for logic and.");
LUISA_ASSERT(lhs_type->is_scalar() || lhs_type->is_vector(), "Invalid operand type.");
LUISA_ASSERT(rhs_type->is_scalar() || rhs_type->is_vector(), "Invalid operand type.");
// Convert operands to boolean values (non-zero becomes true, zero becomes false)
auto llvm_lhs_bool = _cmp_ne_zero(b, llvm_lhs);
auto llvm_rhs_bool = _cmp_ne_zero(b, llvm_rhs);
// Perform logical AND (a && b)
auto llvm_and_result = b.CreateAnd(llvm_lhs_bool, llvm_rhs_bool);
// Convert result to i8 for consistency with your implementation needs
return _zext_i1_to_i8(b, llvm_and_result);
}

[[nodiscard]] llvm::Value *_translate_binary_logic_or(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
auto llvm_rhs = _lookup_value(current, b, rhs);
auto lhs_type = lhs->type();
auto rhs_type = rhs->type();
// Type and null checks
LUISA_ASSERT(lhs_type != nullptr && rhs_type != nullptr, "Operand type is null.");
LUISA_ASSERT(lhs_type == rhs_type, "Type mismatch for logic and.");
LUISA_ASSERT(lhs_type->is_scalar() || lhs_type->is_vector(), "Invalid operand type.");
LUISA_ASSERT(rhs_type->is_scalar() || rhs_type->is_vector(), "Invalid operand type.");
// Convert operands to boolean values (non-zero becomes true, zero becomes false)
auto llvm_lhs_bool = _cmp_ne_zero(b, llvm_lhs);
auto llvm_rhs_bool = _cmp_ne_zero(b, llvm_rhs);
// Perform logical OR (a && b)
auto llvm_or_result = b.CreateOr(llvm_lhs_bool, llvm_rhs_bool);
// Convert result to i8 for consistency with your implementation needs
return _zext_i1_to_i8(b, llvm_or_result);
}

[[nodiscard]] llvm::Value *_translate_binary_bit_and(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
Expand Down Expand Up @@ -1965,15 +1916,12 @@ class FallbackCodegen {
switch (inst->op()) {
case xir::ArithmeticOp::UNARY_PLUS: return _translate_unary_plus(current, b, inst->operand(0u));
case xir::ArithmeticOp::UNARY_MINUS: return _translate_unary_minus(current, b, inst->operand(0u));
case xir::ArithmeticOp::UNARY_LOGIC_NOT: return _translate_unary_logic_not(current, b, inst->operand(0u));
case xir::ArithmeticOp::UNARY_BIT_NOT: return _translate_unary_bit_not(current, b, inst->operand(0u));
case xir::ArithmeticOp::BINARY_ADD: return _translate_binary_add(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_SUB: return _translate_binary_sub(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_MUL: return _translate_binary_mul(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_DIV: return _translate_binary_div(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_MOD: return _translate_binary_mod(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_LOGIC_AND: return _translate_binary_logic_and(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_LOGIC_OR: return _translate_binary_logic_or(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_AND: return _translate_binary_bit_and(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_OR: return _translate_binary_bit_or(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_XOR: return _translate_binary_bit_xor(current, b, inst->operand(0u), inst->operand(1u));
Expand Down
6 changes: 0 additions & 6 deletions src/xir/instructions/op_name_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ luisa::string_view to_string(ArithmeticOp op) noexcept {
switch (op) {
case ArithmeticOp::UNARY_PLUS: return "unary_plus"sv;
case ArithmeticOp::UNARY_MINUS: return "unary_minus"sv;
case ArithmeticOp::UNARY_LOGIC_NOT: return "unary_logic_not"sv;
case ArithmeticOp::UNARY_BIT_NOT: return "unary_bit_not"sv;
case ArithmeticOp::BINARY_ADD: return "binary_add"sv;
case ArithmeticOp::BINARY_SUB: return "binary_sub"sv;
case ArithmeticOp::BINARY_MUL: return "binary_mul"sv;
case ArithmeticOp::BINARY_DIV: return "binary_div"sv;
case ArithmeticOp::BINARY_MOD: return "binary_mod"sv;
case ArithmeticOp::BINARY_LOGIC_AND: return "binary_logic_and"sv;
case ArithmeticOp::BINARY_LOGIC_OR: return "binary_logic_or"sv;
case ArithmeticOp::BINARY_BIT_AND: return "binary_bit_and"sv;
case ArithmeticOp::BINARY_BIT_OR: return "binary_bit_or"sv;
case ArithmeticOp::BINARY_BIT_XOR: return "binary_bit_xor"sv;
Expand Down Expand Up @@ -123,15 +120,12 @@ ArithmeticOp arithmetic_op_from_string(luisa::string_view name) noexcept {
static const luisa::unordered_map<luisa::string_view, ArithmeticOp> m{
{"unary_plus"sv, ArithmeticOp::UNARY_PLUS},
{"unary_minus"sv, ArithmeticOp::UNARY_MINUS},
{"unary_logic_not"sv, ArithmeticOp::UNARY_LOGIC_NOT},
{"unary_bit_not"sv, ArithmeticOp::UNARY_BIT_NOT},
{"binary_add"sv, ArithmeticOp::BINARY_ADD},
{"binary_sub"sv, ArithmeticOp::BINARY_SUB},
{"binary_mul"sv, ArithmeticOp::BINARY_MUL},
{"binary_div"sv, ArithmeticOp::BINARY_DIV},
{"binary_mod"sv, ArithmeticOp::BINARY_MOD},
{"binary_logic_and"sv, ArithmeticOp::BINARY_LOGIC_AND},
{"binary_logic_or"sv, ArithmeticOp::BINARY_LOGIC_OR},
{"binary_bit_and"sv, ArithmeticOp::BINARY_BIT_AND},
{"binary_bit_or"sv, ArithmeticOp::BINARY_BIT_OR},
{"binary_bit_xor"sv, ArithmeticOp::BINARY_BIT_XOR},
Expand Down
24 changes: 19 additions & 5 deletions src/xir/translators/ast2xir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,16 @@ class AST2XIRContext {
switch (unary_op) {
case UnaryOp::PLUS: return ArithmeticOp::UNARY_PLUS;
case UnaryOp::MINUS: return ArithmeticOp::UNARY_MINUS;
case UnaryOp::NOT: return ArithmeticOp::UNARY_LOGIC_NOT;
case UnaryOp::NOT: return ArithmeticOp::UNARY_BIT_NOT;
case UnaryOp::BIT_NOT: return ArithmeticOp::UNARY_BIT_NOT;
}
LUISA_ERROR_WITH_LOCATION("Unexpected unary operation.");
}();
if (expr->op() == UnaryOp::NOT) {
LUISA_DEBUG_ASSERT(expr->type()->is_bool() || expr->type()->is_bool_vector(),
"Invalid type for logical not operation.");
operand = _type_cast_if_necessary(b, expr->type(), operand);
}
return b.call(expr->type(), op, {operand});
}

Expand Down Expand Up @@ -145,8 +150,8 @@ class AST2XIRContext {
case BinaryOp::BIT_XOR: return ArithmeticOp::BINARY_BIT_XOR;
case BinaryOp::SHL: return ArithmeticOp::BINARY_SHIFT_LEFT;
case BinaryOp::SHR: return ArithmeticOp::BINARY_SHIFT_RIGHT;
case BinaryOp::AND: return ArithmeticOp::BINARY_LOGIC_AND;
case BinaryOp::OR: return ArithmeticOp::BINARY_LOGIC_OR;
case BinaryOp::AND: return ArithmeticOp::BINARY_BIT_AND;
case BinaryOp::OR: return ArithmeticOp::BINARY_BIT_OR;
case BinaryOp::LESS: return ArithmeticOp::BINARY_LESS;
case BinaryOp::GREATER: return ArithmeticOp::BINARY_GREATER;
case BinaryOp::LESS_EQUAL: return ArithmeticOp::BINARY_LESS_EQUAL;
Expand All @@ -158,8 +163,17 @@ class AST2XIRContext {
}();
auto lhs = _translate_expression(b, expr->lhs(), true);
auto rhs = _translate_expression(b, expr->rhs(), true);
lhs = _type_cast_if_necessary(b, type_promotion.lhs, lhs);
rhs = _type_cast_if_necessary(b, type_promotion.rhs, rhs);
if (expr->op() == BinaryOp::AND || expr->op() == BinaryOp::OR) {
LUISA_DEBUG_ASSERT(type_promotion.result->is_bool() ||
type_promotion.result->is_bool_vector(),
"Invalid type promotion result type for binary logic operator: {}.",
type_promotion.result->description());
lhs = b.static_cast_if_necessary(type_promotion.result, lhs);
rhs = b.static_cast_if_necessary(type_promotion.result, rhs);
} else {
lhs = _type_cast_if_necessary(b, type_promotion.lhs, lhs);
rhs = _type_cast_if_necessary(b, type_promotion.rhs, rhs);
}
auto result = b.call(expr->type(), op, {lhs, rhs});
return _type_cast_if_necessary(b, type_promotion.result, result);
}
Expand Down
2 changes: 1 addition & 1 deletion src/xir/translators/xir2text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class XIR2TextTranslator final {
}

void _emit_arithmetic_inst(const ArithmeticInst *inst) noexcept {
_main << xir::to_string(inst->op()) << " ";
_main << "arithmetic " << xir::to_string(inst->op()) << " ";
_emit_operands(inst);
}

Expand Down

0 comments on commit 854eb7c

Please sign in to comment.