diff --git a/source/val/validate_small_type_uses.cpp b/source/val/validate_small_type_uses.cpp index 69f61ee4f3..411a7b81c5 100644 --- a/source/val/validate_small_type_uses.cpp +++ b/source/val/validate_small_type_uses.cpp @@ -22,7 +22,7 @@ namespace val { spv_result_t ValidateSmallTypeUses(ValidationState_t& _, const Instruction* inst) { - if (!_.HasCapability(spv::Capability::Shader) || inst->type_id() == 0 || + if (inst->type_id() == 0 || !_.ContainsLimitedUseIntOrFloatType(inst->type_id())) { return SPV_SUCCESS; } diff --git a/test/val/val_small_type_uses_test.cpp b/test/val/val_small_type_uses_test.cpp index b950af5b01..22d7aa1898 100644 --- a/test/val/val_small_type_uses_test.cpp +++ b/test/val/val_small_type_uses_test.cpp @@ -333,6 +333,53 @@ INSTANTIATE_TEST_SUITE_P( "%inst = OpFunctionCall %void %half_func %ld_half", "%inst = OpFunctionCall %void %half_func %float_to_half")); +TEST_F(ValidateSmallTypeUses, F16OpsFail) { + const std::string body = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Float16Buffer +OpCapability Linkage +OpMemoryModel Physical64 OpenCL +%f16 = OpTypeFloat 16 +%func = OpTypeFunction %f16 %f16 %f16 +%add = OpFunction %f16 None %func + %a = OpFunctionParameter %f16 + %b = OpFunctionParameter %f16 +%add_entry = OpLabel + %result = OpFAdd %f16 %a %b + OpReturnValue %result +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Invalid use of 8- or 16-bit result")); +} + +TEST_F(ValidateSmallTypeUses, F16OpsSuccess) { + const std::string body = R"( +OpCapability Addresses +OpCapability Kernel +OpCapability Float16Buffer +OpCapability Float16 +OpCapability Linkage +OpMemoryModel Physical64 OpenCL +%f16 = OpTypeFloat 16 +%func = OpTypeFunction %f16 %f16 %f16 +%add = OpFunction %f16 None %func + %a = OpFunctionParameter %f16 + %b = OpFunctionParameter %f16 +%add_entry = OpLabel + %result = OpFAdd %f16 %a %b + OpReturnValue %result +OpFunctionEnd +)"; + + CompileSuccessfully(body.c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + } // namespace } // namespace val } // namespace spvtools