From f1004f41c9cc662d34e4922b521a3d8629ed93f8 Mon Sep 17 00:00:00 2001 From: Mike Smith Date: Fri, 10 Jan 2025 20:25:31 +0800 Subject: [PATCH] fix fallback codegen on bool --- src/backends/fallback/fallback_codegen.cpp | 86 +++++++--------------- 1 file changed, 25 insertions(+), 61 deletions(-) diff --git a/src/backends/fallback/fallback_codegen.cpp b/src/backends/fallback/fallback_codegen.cpp index cfd95db97..da567cad7 100644 --- a/src/backends/fallback/fallback_codegen.cpp +++ b/src/backends/fallback/fallback_codegen.cpp @@ -549,6 +549,10 @@ class FallbackCodegen { LUISA_ASSERT(llvm_operand->getType()->isIntOrIntVectorTy() && !llvm_operand->getType()->isIntOrIntVectorTy(1), "Invalid operand type."); + if (operand->type()->is_bool() || operand->type()->is_bool_vector()) {// !b <=> (b == 0) + auto i1_operand = _cmp_eq_zero(b, llvm_operand); + return _zext_i1_to_i8(b, i1_operand); + } return b.CreateNot(llvm_operand); } @@ -660,7 +664,7 @@ class FallbackCodegen { LUISA_ERROR_WITH_LOCATION("Invalid binary mod operand type: {}.", elem_type->description()); } - [[nodiscard]] llvm::Value *_translate_binary_bit_and(CurrentFunction ¤t, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept { + [[nodiscard]] llvm::Value *_translate_binary_bit_op(CurrentFunction ¤t, IRBuilder &b, xir::ArithmeticOp op, const xir::Value *lhs, const xir::Value *rhs) noexcept { // Lookup LLVM values for operands auto llvm_lhs = _lookup_value(current, b, lhs); auto llvm_rhs = _lookup_value(current, b, rhs); @@ -682,68 +686,28 @@ class FallbackCodegen { case Type::Tag::UINT8: [[fallthrough]]; case Type::Tag::UINT16: [[fallthrough]]; case Type::Tag::UINT32: [[fallthrough]]; - case Type::Tag::UINT64: return b.CreateAnd(llvm_lhs, llvm_rhs); + case Type::Tag::UINT64: { + auto is_bool = elem_type->is_bool(); + if (is_bool) { + llvm_lhs = _cmp_ne_zero(b, llvm_lhs); + llvm_rhs = _cmp_ne_zero(b, llvm_rhs); + } + auto result = [&] { + switch (op) { + case xir::ArithmeticOp::BINARY_BIT_AND: return b.CreateAnd(llvm_lhs, llvm_rhs); + case xir::ArithmeticOp::BINARY_BIT_OR: return b.CreateOr(llvm_lhs, llvm_rhs); + case xir::ArithmeticOp::BINARY_BIT_XOR: return b.CreateXor(llvm_lhs, llvm_rhs); + default: break; + } + LUISA_ERROR_WITH_LOCATION("Invalid binary bit operation: {}.", static_cast(op)); + }(); + return is_bool ? _zext_i1_to_i8(b, result) : result; + } default: break; } LUISA_ERROR_WITH_LOCATION("Invalid binary bit and operand type: {}.", elem_type->description()); } - [[nodiscard]] llvm::Value *_translate_binary_bit_or(CurrentFunction ¤t, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept { - // Lookup LLVM values for operands - auto llvm_lhs = _lookup_value(current, b, lhs); - auto llvm_rhs = _lookup_value(current, b, rhs); - auto lhs_type = lhs->type(); - auto rhs_type = rhs->type(); - auto elem_type = lhs->type()->is_vector() ? lhs->type()->element() : lhs->type(); - // Type and null checks - LUISA_ASSERT(lhs_type != nullptr && rhs_type != nullptr, "Operand type is null."); - LUISA_ASSERT(lhs_type == rhs_type, "Type mismatch for bitwise and."); - LUISA_ASSERT(lhs_type->is_scalar() || lhs_type->is_vector(), "Invalid operand type."); - - // Perform bitwise AND operation - switch (elem_type->tag()) { - case Type::Tag::BOOL: [[fallthrough]]; - case Type::Tag::INT8: [[fallthrough]]; - case Type::Tag::INT16: [[fallthrough]]; - case Type::Tag::INT32: [[fallthrough]]; - case Type::Tag::INT64: [[fallthrough]]; - case Type::Tag::UINT8: [[fallthrough]]; - case Type::Tag::UINT16: [[fallthrough]]; - case Type::Tag::UINT32: [[fallthrough]]; - case Type::Tag::UINT64: return b.CreateOr(llvm_lhs, llvm_rhs); - default: break; - } - LUISA_ERROR_WITH_LOCATION("Invalid binary bit or operand type: {}.", elem_type->description()); - } - - [[nodiscard]] llvm::Value *_translate_binary_bit_xor(CurrentFunction ¤t, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept { - // Lookup LLVM values for operands - auto llvm_lhs = _lookup_value(current, b, lhs); - auto llvm_rhs = _lookup_value(current, b, rhs); - auto lhs_type = lhs->type(); - auto rhs_type = rhs->type(); - auto elem_type = lhs->type()->is_vector() ? lhs->type()->element() : lhs->type(); - // Type and null checks - LUISA_ASSERT(lhs_type != nullptr && rhs_type != nullptr, "Operand type is null."); - LUISA_ASSERT(lhs_type == rhs_type, "Type mismatch for bitwise and."); - LUISA_ASSERT(lhs_type->is_scalar() || lhs_type->is_vector(), "Invalid operand type."); - - // Perform bitwise AND operation - switch (elem_type->tag()) { - case Type::Tag::BOOL: [[fallthrough]]; - case Type::Tag::INT8: [[fallthrough]]; - case Type::Tag::INT16: [[fallthrough]]; - case Type::Tag::INT32: [[fallthrough]]; - case Type::Tag::INT64: [[fallthrough]]; - case Type::Tag::UINT8: [[fallthrough]]; - case Type::Tag::UINT16: [[fallthrough]]; - case Type::Tag::UINT32: [[fallthrough]]; - case Type::Tag::UINT64: return b.CreateXor(llvm_lhs, llvm_rhs); - default: break; - } - LUISA_ERROR_WITH_LOCATION("Invalid binary bit xor operand type: {}.", elem_type->description()); - } - [[nodiscard]] llvm::Value *_translate_binary_shift_left(CurrentFunction ¤t, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept { // Lookup LLVM values for operands auto llvm_lhs = _lookup_value(current, b, lhs); @@ -1922,9 +1886,9 @@ class FallbackCodegen { case xir::ArithmeticOp::BINARY_MUL: return _translate_binary_mul(current, b, inst->operand(0u), inst->operand(1u)); case xir::ArithmeticOp::BINARY_DIV: return _translate_binary_div(current, b, inst->operand(0u), inst->operand(1u)); case xir::ArithmeticOp::BINARY_MOD: return _translate_binary_mod(current, b, inst->operand(0u), inst->operand(1u)); - case xir::ArithmeticOp::BINARY_BIT_AND: return _translate_binary_bit_and(current, b, inst->operand(0u), inst->operand(1u)); - case xir::ArithmeticOp::BINARY_BIT_OR: return _translate_binary_bit_or(current, b, inst->operand(0u), inst->operand(1u)); - case xir::ArithmeticOp::BINARY_BIT_XOR: return _translate_binary_bit_xor(current, b, inst->operand(0u), inst->operand(1u)); + case xir::ArithmeticOp::BINARY_BIT_AND: return _translate_binary_bit_op(current, b, inst->op(), inst->operand(0u), inst->operand(1u)); + case xir::ArithmeticOp::BINARY_BIT_OR: return _translate_binary_bit_op(current, b, inst->op(), inst->operand(0u), inst->operand(1u)); + case xir::ArithmeticOp::BINARY_BIT_XOR: return _translate_binary_bit_op(current, b, inst->op(), inst->operand(0u), inst->operand(1u)); case xir::ArithmeticOp::BINARY_SHIFT_LEFT: return _translate_binary_shift_left(current, b, inst->operand(0u), inst->operand(1u)); case xir::ArithmeticOp::BINARY_SHIFT_RIGHT: return _translate_binary_shift_right(current, b, inst->operand(0u), inst->operand(1u)); case xir::ArithmeticOp::BINARY_ROTATE_LEFT: return _translate_binary_rotate_left(current, b, inst->operand(0u), inst->operand(1u));