Skip to content

Commit

Permalink
spirv-opt: Fix OpCompositeExtract relaxation with struct operands
Browse files Browse the repository at this point in the history
  • Loading branch information
bejado committed Jan 18, 2024
1 parent c96fe8b commit 832d568
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 2 deletions.
24 changes: 22 additions & 2 deletions source/opt/convert_to_half_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ bool ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {

bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
bool modified = false;
// If this is a OpCompositeExtract instruction and has a struct operand, we
// should not relax this instruction. Doing so could cause a mismatch between
// the result type and the struct member type.
bool hasStructOperand = false;
if (inst->opcode() == spv::Op::OpCompositeExtract) {
inst->ForEachInId([&hasStructOperand, this](uint32_t* idp) {
Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
if (IsStruct(op_inst)) hasStructOperand = true;
});
if (hasStructOperand) {
return false;
}
}
// Convert all float32 based operands to float16 equivalent and change
// instruction type to float16 equivalent.
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
Expand Down Expand Up @@ -303,12 +316,19 @@ bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
if (closure_ops_.count(inst->opcode()) == 0) return false;
// Can relax if all float operands are relaxed
bool relax = true;
inst->ForEachInId([&relax, this](uint32_t* idp) {
bool hasStructOperand = false;
inst->ForEachInId([&relax, &hasStructOperand, this](uint32_t* idp) {
Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
if (IsStruct(op_inst)) relax = false;
if (IsStruct(op_inst)) hasStructOperand = true;
if (!IsFloat(op_inst, 32)) return;
if (!IsRelaxed(*idp)) relax = false;
});
// If the instruction has a struct operand, we should not relax it, even if
// all its uses are relaxed. Doing so could cause a mismatch between the
// result type and the struct member type.
if (hasStructOperand) {
return false;
}
if (relax) {
AddRelaxed(inst->result_id());
return true;
Expand Down
69 changes: 69 additions & 0 deletions test/opt/convert_relaxed_to_half_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,75 @@ TEST_F(ConvertToHalfTest, PreserveImageOperandPrecision) {
SinglePassRunAndMatch<ConvertToHalfPass>(test, true);
}

TEST_F(ConvertToHalfTest, DontRelaxDecoratedOpCompositeExtract) {
// This test checks that a OpCompositeExtract with a Struct operand won't be
// relaxed, even if it is explicitly decorated with RelaxedPrecision.
const std::string test =
R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
OpDecorate %9 RelaxedPrecision
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_struct_6 = OpTypeStruct %v4float
%7 = OpUndef %_struct_6
%1 = OpFunction %void None %3
%8 = OpLabel
%9 = OpCompositeExtract %float %7 0 3
OpReturn
OpFunctionEnd
)";

const std::string expected =
R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_struct_6 = OpTypeStruct %v4float
%7 = OpUndef %_struct_6
%1 = OpFunction %void None %3
%8 = OpLabel
%9 = OpCompositeExtract %float %7 0 3
OpReturn
OpFunctionEnd
)";

SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndCheck<ConvertToHalfPass>(test, expected, true);
}

TEST_F(ConvertToHalfTest, DontRelaxOpCompositeExtract) {
// This test checks that a OpCompositeExtract with a Struct operand won't be
// relaxed, even if its result has no uses.
const std::string test =
R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%_struct_6 = OpTypeStruct %v4float
%7 = OpUndef %_struct_6
%1 = OpFunction %void None %3
%8 = OpLabel
%9 = OpCompositeExtract %float %7 0 3
OpReturn
OpFunctionEnd
)";

SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndCheck<ConvertToHalfPass>(test, test, true);
}

} // namespace
} // namespace opt
} // namespace spvtools

0 comments on commit 832d568

Please sign in to comment.