Skip to content

Commit

Permalink
Implement SPV_INTEL_subgroup_matrix_multiply_accumulate
Browse files Browse the repository at this point in the history
Spec:
  • Loading branch information
YuriPlyakhin committed Oct 31, 2024
1 parent 8dc0349 commit ba900ba
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@ EXT(SPV_INTEL_subgroup_requirements)
EXT(SPV_INTEL_task_sequence)
EXT(SPV_INTEL_maximum_registers)
EXT(SPV_INTEL_bindless_images)
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
18 changes: 18 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4344,5 +4344,23 @@ class SPIRVUntypedPrefetchKHR : public SPIRVInstruction {
std::vector<SPIRVId> CacheTy;
};

class SPIRVSubgroupMatrixMultiplyAccumulateINTELInst
: public SPIRVInstTemplateBase {
public:
std::optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate;
}
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilitySubgroupMatrixMultiplyAccumulateINTEL);
}
};

#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVSubgroupMatrixMultiplyAccumulateINTELInst, \
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(SubgroupMatrixMultiplyAccumulate, true, 7, true, 4)
#undef _SPIRV_OP

} // namespace SPIRV
#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
"SubgroupRequirementsINTEL");
add(internal::CapabilityTaskSequenceINTEL, "TaskSequenceINTEL");
add(internal::CapabilityBindlessImagesINTEL, "BindlessImagesINTEL");
add(internal::CapabilitySubgroupMatrixMultiplyAccumulateINTEL,
"SubgroupMatrixMultiplyAccumulateINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ _SPIRV_OP_INTERNAL(ConvertHandleToSamplerINTEL,
internal::ConvertHandleToSamplerINTEL)
_SPIRV_OP_INTERNAL(ConvertHandleToSampledImageINTEL,
internal::ConvertHandleToSampledImageINTEL)
_SPIRV_OP_INTERNAL(SubgroupMatrixMultiplyAccumulateINTEL,
internal::SubgroupMatrixMultiplyAccumulateINTEL)
25 changes: 24 additions & 1 deletion lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ enum InternalOp {
IOpConvertHandleToImageINTEL = 6529,
IOpConvertHandleToSamplerINTEL = 6530,
IOpConvertHandleToSampledImageINTEL = 6531,
IOpSubgroupMatrixMultiplyAccumulateINTEL = 7777,
IOpPrev = OpMax - 2,
IOpForward
};
Expand Down Expand Up @@ -126,7 +127,8 @@ enum InternalCapability {
ICapabilityJointMatrixPackedInt4ComponentTypeINTEL = 6439,
ICapabilityCacheControlsINTEL = 6441,
ICapabilitySubgroupRequirementsINTEL = 6445,
ICapabilityBindlessImagesINTEL = 6528
ICapabilityBindlessImagesINTEL = 6528,
ICapabilitySubgroupMatrixMultiplyAccumulateINTEL = 8888,
};

enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
Expand Down Expand Up @@ -161,6 +163,24 @@ enum InternalBuiltIn {
IBuiltInGlobalHWThreadIDINTEL = 6136,
};

enum InternalMatrixMultiplyAccumulateOperandsMask {
IMatrixMultiplyAccumulateOperandsMaskNone = 0x00000000,
IMatrixMultiplyAccumulateOperandsMatrixASignedComponentsINTEL = 0x00000001,
IMatrixMultiplyAccumulateOperandsMatrixBSignedComponentsINTEL = 0x00000002,
IMatrixMultiplyAccumulateOperandsMatrixCBFloat16INTEL = 0x00000004,
IMatrixMultiplyAccumulateOperandsMatrixResultBFloat16INTEL = 0x00000008,
IMatrixMultiplyAccumulateOperandsMatrixAPackedInt8INTEL = 0x00000010,
IMatrixMultiplyAccumulateOperandsMatrixBPackedInt8INTEL = 0x00000020,
IMatrixMultiplyAccumulateOperandsMatrixAPackedInt4INTEL = 0x00000040,
IMatrixMultiplyAccumulateOperandsMatrixBPackedInt4INTEL = 0x00000080,
IMatrixMultiplyAccumulateOperandsMatrixATF32INTEL = 0x00000100,
IMatrixMultiplyAccumulateOperandsMatrixBTF32INTEL = 0x00000200,
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat16INTEL = 0x00000400,
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat16INTEL = 0x00000800,
IMatrixMultiplyAccumulateOperandsMatrixAPackedBFloat16INTEL = 0x00001000,
IMatrixMultiplyAccumulateOperandsMatrixBPackedBFloat16INTEL = 0x00002000
};

#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
_SPIRV_OP(Capability, JointMatrixINTEL)
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
Expand Down Expand Up @@ -220,6 +240,9 @@ _SPIRV_OP(Capability, BindlessImagesINTEL)
_SPIRV_OP(Op, ConvertHandleToImageINTEL)
_SPIRV_OP(Op, ConvertHandleToSamplerINTEL)
_SPIRV_OP(Op, ConvertHandleToSampledImageINTEL)

_SPIRV_OP(Capability, SubgroupMatrixMultiplyAccumulateINTEL)
_SPIRV_OP(Op, SubgroupMatrixMultiplyAccumulateINTEL)
#undef _SPIRV_OP

constexpr SourceLanguage SourceLanguagePython =
Expand Down
Loading

0 comments on commit ba900ba

Please sign in to comment.