Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spirv-val: Some Float16 fixes #6009

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/val/validate_capability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ bool IsSupportOptionalOpenCL_1_2(uint32_t capability) {
switch (spv::Capability(capability)) {
case spv::Capability::ImageBasic:
case spv::Capability::Float64:
case spv::Capability::Float16:
return true;
default:
break;
Expand Down
2 changes: 1 addition & 1 deletion source/val/validate_small_type_uses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace val {

spv_result_t ValidateSmallTypeUses(ValidationState_t& _,
const Instruction* inst) {
if (!_.HasCapability(spv::Capability::Shader) || inst->type_id() == 0 ||
if (inst->type_id() == 0 ||
!_.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
return SPV_SUCCESS;
}
Expand Down
13 changes: 13 additions & 0 deletions test/val/val_data_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ std::string header_with_float64 = R"(
OpCapability Float64
OpMemoryModel Logical GLSL450
)";
std::string header_with_kernel_float16 = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability Linkage
OpCapability Float16
OpMemoryModel Physical32 OpenCL
)";

std::string invalid_comp_error = "Illegal number of components";
std::string missing_cap_error = "requires the Vector16 capability";
Expand Down Expand Up @@ -340,6 +347,12 @@ TEST_F(ValidateData, float16_buffer_good) {
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}

TEST_F(ValidateData, float16_kernel_good) {
std::string str = header_with_kernel_float16 + "%2 = OpTypeFloat 16";
CompileSuccessfully(str.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_OPENCL_1_2));
}

TEST_F(ValidateData, float16_bad) {
std::string str = header + "%2 = OpTypeFloat 16";
CompileSuccessfully(str.c_str());
Expand Down
47 changes: 47 additions & 0 deletions test/val/val_small_type_uses_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,53 @@ INSTANTIATE_TEST_SUITE_P(
"%inst = OpFunctionCall %void %half_func %ld_half",
"%inst = OpFunctionCall %void %half_func %float_to_half"));

TEST_F(ValidateSmallTypeUses, F16OpsFail) {
const std::string body = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability Float16Buffer
OpCapability Linkage
OpMemoryModel Physical64 OpenCL
%f16 = OpTypeFloat 16
%func = OpTypeFunction %f16 %f16 %f16
%add = OpFunction %f16 None %func
%a = OpFunctionParameter %f16
%b = OpFunctionParameter %f16
%add_entry = OpLabel
%result = OpFAdd %f16 %a %b
OpReturnValue %result
OpFunctionEnd
)";

CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Invalid use of 8- or 16-bit result"));
}

TEST_F(ValidateSmallTypeUses, F16OpsSuccess) {
const std::string body = R"(
OpCapability Addresses
OpCapability Kernel
OpCapability Float16Buffer
OpCapability Float16
OpCapability Linkage
OpMemoryModel Physical64 OpenCL
%f16 = OpTypeFloat 16
%func = OpTypeFunction %f16 %f16 %f16
%add = OpFunction %f16 None %func
%a = OpFunctionParameter %f16
%b = OpFunctionParameter %f16
%add_entry = OpLabel
%result = OpFAdd %f16 %a %b
OpReturnValue %result
OpFunctionEnd
)";

CompileSuccessfully(body.c_str());
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}

} // namespace
} // namespace val
} // namespace spvtools