Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V 1.4] Support OpPtrEqual, OpPtrNotEqual and OpPtrDiff to compare pointers #2482

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,28 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, V);
}

case OpPtrEqual:
case OpPtrNotEqual: {
auto *BC = static_cast<SPIRVBinary *>(BV);
auto Ops = transValue(BC->getOperands(), F, BB);

IRBuilder<> Builder(BB);
Value *Op1 = Builder.CreatePtrToInt(Ops[0], Type::getInt64Ty(*Context));
Value *Op2 = Builder.CreatePtrToInt(Ops[1], Type::getInt64Ty(*Context));
CmpInst::Predicate P =
OC == OpPtrEqual ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
Value *V = Builder.CreateICmp(P, Op1, Op2);
return mapValue(BV, V);
}

case OpPtrDiff: {
auto *BC = static_cast<SPIRVBinary *>(BV);
auto Ops = transValue(BC->getOperands(), F, BB);
IRBuilder<> Builder(BB);
Value *V = Builder.CreatePtrDiff(transType(BC->getType()), Ops[0], Ops[1]);
return mapValue(BV, V);
}

case OpCompositeConstruct: {
auto *CC = static_cast<SPIRVCompositeConstruct *>(BV);
auto Constituents = transValue(CC->getOperands(), F, BB);
Expand Down
14 changes: 14 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,21 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
"Invalid type for bitwise instruction");
assert((Op1Ty->getIntegerBitWidth() == Op2Ty->getIntegerBitWidth()) &&
"Inconsistent BitWidth");
} else if (isBinaryPtrOpCode(OpCode)) {
assert((Op1Ty->isTypePointer() && Op2Ty->isTypePointer()) &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
assert(static_cast<SPIRVTypePointer *>(Op1Ty)->getElementType() ==
static_cast<SPIRVTypePointer *>(Op2Ty)->getElementType() &&
"Invalid types for PtrEqual, PtrNotEqual, or PtrDiff instruction");
} else {
assert(0 && "Invalid op code!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nice occasion to improve diagnostics in this place and add precise info about op code that is not supported. I've once hit this exactly line when SPIRV Backend generated code for missed OpPtrEqual/NotEqual and spent some time on debugging. To avoid this in future we may want to improve diagnostics right in this place and give a user-friendly description of what exactly SPIRV Translator doesn't have support for.
@MrSidims @asudarsa what do you think? ^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Slava

Copy link
Contributor

@asudarsa asudarsa Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the subsequent change from @vmaksimo does provide more information. One issue I see here is our reliance on 'assert'. Should we try to provide an error mechanism which will stop translation in such cases? This can be made as a separate effort.

Thanks

}
}
VersionNumber getRequiredSPIRVVersion() const override {
if (isBinaryPtrOpCode(OpCode))
return VersionNumber::SPIRV_1_4;
return VersionNumber::SPIRV_1_0;
}
};

template <Op OC>
Expand Down Expand Up @@ -719,6 +730,9 @@ _SPIRV_OP(BitwiseAnd)
_SPIRV_OP(BitwiseOr)
_SPIRV_OP(BitwiseXor)
_SPIRV_OP(Dot)
_SPIRV_OP(PtrEqual)
_SPIRV_OP(PtrNotEqual)
_SPIRV_OP(PtrDiff)
#undef _SPIRV_OP

template <Op TheOpCode> class SPIRVInstNoOperand : public SPIRVInstruction {
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ inline bool isBinaryOpCode(Op OpCode) {
OpCode == OpDot || OpCode == OpIAddCarry || OpCode == OpISubBorrow;
}

inline bool isBinaryPtrOpCode(Op OpCode) {
return (unsigned)OpCode >= OpPtrEqual && (unsigned)OpCode <= OpPtrDiff;
}

inline bool isShiftOpCode(Op OpCode) {
return (unsigned)OpCode >= OpShiftRightLogical &&
(unsigned)OpCode <= OpShiftLeftLogical;
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ _SPIRV_OP(GroupNonUniformBitwiseXor, 361)
_SPIRV_OP(GroupNonUniformLogicalAnd, 362)
_SPIRV_OP(GroupNonUniformLogicalOr, 363)
_SPIRV_OP(GroupNonUniformLogicalXor, 364)
_SPIRV_OP(PtrEqual, 401)
_SPIRV_OP(PtrNotEqual, 402)
_SPIRV_OP(PtrDiff, 403)
_SPIRV_OP(CopyLogical, 400)
_SPIRV_OP(GroupNonUniformRotateKHR, 4431)
_SPIRV_OP(SDotKHR, 4450)
Expand Down
52 changes: 52 additions & 0 deletions test/transcoding/ptr_diff.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; Check support of OpPtrDiff instruction that was added in SPIR-V 1.4

; RUN: llvm-as %s -o %t.bc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a small test description (one-liner). Thanks

; RUN: not llvm-spirv --spirv-max-version=1.3 %t.bc 2>&1 | FileCheck --check-prefix=CHECK-ERROR %s

; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
; RUN: spirv-val %t.spv

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

; CHECK-ERROR: RequiresVersion: Cannot fulfill SPIR-V version restriction:
; CHECK-ERROR-NEXT: SPIR-V version was restricted to at most 1.3 (66304) but a construct from the input requires SPIR-V version 1.4 (66560) or above

; SPIR-V 1.4
; CHECK-SPIRV: 66560
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV: TypePointer [[#TypePointer:]] [[#]] [[#TypeFloat]]

; CHECK-SPIRV: Variable [[#TypePointer]] [[#Var:]]
; CHECK-SPIRV: PtrDiff [[#TypeInt]] [[#]] [[#Var]] [[#Var]]

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"

; Function Attrs: nounwind
define spir_kernel void @test(float %a) local_unnamed_addr #0 {
entry:
%0 = alloca float, align 4
store float %a, ptr %0, align 4
; CHECK-LLVM: %[[#Arg1:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: %[[#Arg2:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: %[[#Sub:]] = sub i64 %[[#Arg1]], %[[#Arg2]]
; CHECK-LLVM: sdiv exact i64 %[[#Sub]], ptrtoint (ptr getelementptr (i32, ptr null, i32 1) to i64)
%1 = call spir_func noundef i32 @_Z15__spirv_PtrDiff(ptr %0, ptr %0)
ret void
}

declare spir_func noundef i32 @_Z15__spirv_PtrDiff(ptr, ptr)

attributes #0 = { convergent nounwind writeonly }

!llvm.module.flags = !{!0}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}
60 changes: 60 additions & 0 deletions test/transcoding/ptr_not_equal.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
; Check support of OpPtrEqual and OpPtrNotEqual instructions that were added in SPIR-V 1.4

; RUN: llvm-as %s -o %t.bc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a small test description (one-liner). Thanks

; RUN: not llvm-spirv --spirv-max-version=1.3 %t.bc 2>&1 | FileCheck --check-prefix=CHECK-ERROR %s

; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
; RUN: spirv-val %t.spv

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

; CHECK-ERROR: RequiresVersion: Cannot fulfill SPIR-V version restriction:
; CHECK-ERROR-NEXT: SPIR-V version was restricted to at most 1.3 (66304) but a construct from the input requires SPIR-V version 1.4 (66560) or above

; SPIR-V 1.4
; CHECK-SPIRV: 66560
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV: TypePointer [[#TypePointer:]] [[#]] [[#TypeFloat]]
; CHECK-SPIRV: TypeBool [[#TypeBool:]]

; CHECK-SPIRV: Variable [[#TypePointer]] [[#Var1:]]
; CHECK-SPIRV: Variable [[#TypePointer]] [[#Var2:]]
; CHECK-SPIRV: PtrEqual [[#TypeBool]] [[#]] [[#Var1]] [[#Var2]]
; CHECK-SPIRV: PtrNotEqual [[#TypeBool]] [[#]] [[#Var1]] [[#Var2]]

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64-unknown-unknown"

; Function Attrs: nounwind
define spir_kernel void @test(float %a, float %b) local_unnamed_addr #0 {
entry:
%0 = alloca float, align 4
%1 = alloca float, align 4
store float %a, ptr %0, align 4
store float %b, ptr %1, align 4
; CHECK-LLVM: %[[#Arg1:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: %[[#Arg2:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: icmp eq i64 %[[#Arg1]], %[[#Arg2]]
%2 = call spir_func noundef i1 @_Z16__spirv_PtrEqual(ptr %0, ptr %1)
; CHECK-LLVM: %[[#Arg3:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: %[[#Arg4:]] = ptrtoint ptr %[[#]] to i64
; CHECK-LLVM: icmp ne i64 %[[#Arg3]], %[[#Arg4]]
%3 = call spir_func noundef i1 @_Z19__spirv_PtrNotEqual(ptr %0, ptr %1)
ret void
}

declare spir_func noundef i1 @_Z16__spirv_PtrEqual(ptr, ptr)
declare spir_func noundef i1 @_Z19__spirv_PtrNotEqual(ptr, ptr)

attributes #0 = { convergent nounwind writeonly }

!llvm.module.flags = !{!0}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}
Loading