diff --git a/setup.py b/setup.py index 9681fbd45a5..d8129e3a746 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ "pandas>=0.25.0", "packaging>=20.0", "psutil>=5.0.0", - "pydantic>=1.8.2,<2.0.0", + "pydantic>=2.0.0,<2.8.0", "requests>=2.0.0", "scikit-learn>=0.24.2", "scipy<1.9.2,>=1.8; python_version <= '3.9'", diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index bb60b23d90f..c383ec38046 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -40,9 +40,9 @@ class Modifier(BaseModel, ModifierInterface, MultiFrameworkObject): :param update: The update step for the modifier """ - index: int = None - group: str = None - start: float = None + index: Optional[int] = None + group: Optional[str] = None + start: Optional[float] = None end: Optional[float] = None update: Optional[float] = None diff --git a/src/sparseml/core/recipe/base.py b/src/sparseml/core/recipe/base.py index e1377915114..ecd438f127f 100644 --- a/src/sparseml/core/recipe/base.py +++ b/src/sparseml/core/recipe/base.py @@ -13,9 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from sparseml.core.framework import Framework from sparseml.core.recipe.args import RecipeArgs @@ -36,6 +36,8 @@ class RecipeBase(BaseModel, ABC): - create_modifier """ + model_config = ConfigDict(arbitrary_types_allowed=True) + @abstractmethod def calculate_start(self) -> int: raise NotImplementedError() @@ -45,7 +47,7 @@ def calculate_end(self) -> int: raise NotImplementedError() @abstractmethod - def evaluate(self, args: RecipeArgs = None, shift: int = None): + def evaluate(self, args: Optional[RecipeArgs] = None, shift: Optional[int] = None): raise NotImplementedError() @abstractmethod diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index a140c81b53a..514c5fb9572 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -14,7 +14,7 @@ from typing import Any, Dict, Optional -from pydantic import root_validator +from pydantic import model_validator from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework @@ -99,7 +99,8 @@ def create_modifier(self, framework: Framework) -> "Modifier": **self.args_evaluated, ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: modifier = {"group": values.pop("group")} assert len(values) == 1, "multiple key pairs found for modifier" diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 11f0c5f683f..f6ab08af1e6 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union import yaml -from pydantic import Field, root_validator +from pydantic import Field, model_validator from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers @@ -152,7 +152,7 @@ def create_instance( ) _LOGGER.debug(f"Input string: {path_or_modifiers}") obj = _load_json_or_yaml_string(path_or_modifiers) - return Recipe.parse_obj(obj) + return Recipe.model_validate(obj) else: _LOGGER.info(f"Loading recipe from file {path_or_modifiers}") @@ -174,7 +174,7 @@ def create_instance( raise ValueError( f"Could not parse recipe from path {path_or_modifiers}" ) - return Recipe.parse_obj(obj) + return Recipe.model_validate(obj) @staticmethod def simplify_recipe( @@ -391,7 +391,8 @@ def create_modifier(self, framework: Framework) -> List["StageModifiers"]: return modifiers - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]: stages = [] @@ -515,25 +516,9 @@ def combine_metadata(self, metadata: Optional[RecipeMetaData]): def dict(self, *args, **kwargs) -> Dict[str, Any]: """ - >>> recipe_str = ''' - ... test_stage: - ... pruning_modifiers: - ... ConstantPruningModifier: - ... start: 0.0 - ... end: 2.0 - ... targets: ['re:.*weight'] - ... ''' - >>> recipe = Recipe.create_instance(recipe_str) - >>> recipe_dict = recipe.dict() - >>> stage = recipe_dict["stages"]["test"] - >>> pruning_mods = stage[0]['modifiers']['pruning'] - >>> modifier_args = pruning_mods[0]['ConstantPruningModifier'] - >>> modifier_args == {'start': 0.0, 'end': 2.0, 'targets': ['re:.*weight']} - True - :return: A dictionary representation of the recipe """ - dict_ = super().dict(*args, **kwargs) + dict_ = super().model_dump(*args, **kwargs) stages = {} for stage in dict_["stages"]: @@ -577,36 +562,34 @@ def _get_yaml_dict(self) -> Dict[str, Any]: """ Get a dictionary representation of the recipe for yaml serialization The returned dict will only contain information necessary for yaml - serialization (ignores metadata, version, etc), and must not be used - in place of the dict method + serialization and must not be used in place of the dict method :return: A dictionary representation of the recipe for yaml serialization """ - def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]): - # convert a list of modifiers to a dict of modifiers - return { - key: value - for modifier in modifier_group - for key, value in modifier.items() - } + original_recipe_dict = self.dict() + yaml_recipe_dict = {} - def _stage_to_dict(stage: Dict[str, Any]): - # convert a stage to a dict of modifiers - return { - modifier_group_name: _modifier_group_to_dict(modifier_group) - for modifier_group_name, modifier_group in stage["modifiers"].items() - } + # populate recipe level attributes + recipe_level_attributes = ["version", "args", "metadata"] - final_dict = {} - for stage_name, stages in self.dict()["stages"].items(): - if len(stages) == 1: - final_dict[stage_name] = _stage_to_dict(stages[0]) - else: - for idx, stage in enumerate(stages): - final_dict[stage_name + "_" + str(idx)] = _stage_to_dict(stage) + for attribute in recipe_level_attributes: + if attribute_value := original_recipe_dict.get(attribute): + yaml_recipe_dict[attribute] = attribute_value + + # populate stages + stages = original_recipe_dict["stages"] + for stage_name, stage_list in stages.items(): + # stage is always a list of size 1 + stage = stage_list[0] + stage_dict = get_yaml_serializable_stage_dict(modifiers=stage["modifiers"]) - return final_dict + # infer run_type from stage + if run_type := stage.get("run_type"): + stage_dict["run_type"] = run_type + + yaml_recipe_dict[stage_name] = stage_dict + return yaml_recipe_dict @dataclass @@ -704,9 +687,58 @@ def create_recipe_string_from_modifiers( recipe_dict = { f"{modifier_group_name}_stage": { f"{default_group_name}_modifiers": { - modifier.__class__.__name__: modifier.dict() for modifier in modifiers + modifier.__class__.__name__: modifier.model_dump() + for modifier in modifiers } } } recipe_str: str = yaml.dump(recipe_dict) return recipe_str + + +def get_modifiers_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]: + + group_dict = {} + + for modifier in modifiers: + modifier_type = modifier["type"] + modifier_group = modifier["group"] + + if modifier_group not in group_dict: + group_dict[modifier_group] = [] + + modifier_dict = {modifier_type: modifier["args"]} + group_dict[modifier_group].append(modifier_dict) + + return group_dict + + +def get_yaml_serializable_stage_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + This function is used to convert a list of modifiers into a dictionary + where the keys are the group names and the values are the modifiers + which in turn are dictionaries with the modifier type as the key and + the modifier args as the value. + + This is needed to conform to our recipe structure during yaml serialization + where each stage, modifier_groups, and modifiers are represented as + valid yaml dictionaries. + + Note: This function assumes that modifier groups do not contain the same + modifier type more than once in a group. This assumption is also held by + Recipe.create_instance(...) method. + + :param modifiers: A list of dictionaries where each dictionary + holds all information about a modifier + :return: A dictionary where the keys are the group names and the values + are the modifiers which in turn are dictionaries with the modifier + type as the key and the modifier args as the value. + """ + stage_dict = {} + for modifier in modifiers: + group_name = f"{modifier['group']}_modifiers" + modifier_type = modifier["type"] + if group_name not in stage_dict: + stage_dict[group_name] = {} + stage_dict[group_name][modifier_type] = modifier["args"] + return stage_dict diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index e0ed6926f77..31638648f98 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import Field, root_validator +from pydantic import ConfigDict, Field, model_validator from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers @@ -46,6 +46,8 @@ class RecipeStage(RecipeBase): :param args_evaluated: the evaluated RecipeArgs for the stage """ + model_config = ConfigDict(arbitrary_types_allowed=True) + group: Optional[str] = None run_type: Optional[StageRunType] = None args: Optional[RecipeArgs] = None @@ -139,7 +141,8 @@ def create_modifier( return stage_modifiers - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def remap_modifiers(cls, values: Dict[str, Any]) -> Dict[str, Any]: modifiers = RecipeStage.extract_dict_modifiers(values) values["modifiers"] = modifiers diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index ede89cb5109..686adf5c7d5 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sparseml.exporters.transforms import OnnxTransform from sparseml.exporters.transforms.kv_cache.transforms_codegen import ( @@ -47,8 +47,9 @@ class KeyValueCacheConfig(BaseModel): additional_transforms: Union[ List[Type[OnnxTransform]], Type[OnnxTransform], None ] = Field( + None, description="A transform class (or list thereof) to use for additional " - "transforms to the model required for finalizing the kv cache injection." + "transforms to the model required for finalizing the kv cache injection.", ) key_num_attention_heads: str = Field( description="The key to use to get the number of attention heads from the " @@ -59,10 +60,10 @@ class KeyValueCacheConfig(BaseModel): "from the transformer's `config.json` file." ) num_attention_heads: Optional[int] = Field( - description="The number of attention heads." + None, description="The number of attention heads." ) hidden_size_kv_cache: Optional[int] = Field( - description="The hidden size of the key/value cache. " + None, description="The hidden size of the key/value cache. " ) multiply_batch_by_num_att_heads: bool = Field( default=False, @@ -83,9 +84,7 @@ class KeyValueCacheConfig(BaseModel): "the kv cache. If this is not provided, no transpose will " "be applied.", ) - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) OPT_CONFIG = KeyValueCacheConfig( diff --git a/src/sparseml/framework/info.py b/src/sparseml/framework/info.py index fcee41c806e..f4e764d7ae7 100644 --- a/src/sparseml/framework/info.py +++ b/src/sparseml/framework/info.py @@ -231,13 +231,13 @@ def save_framework_info(framework: Any, path: Optional[str] = None): create_parent_dirs(path) with open(path, "w") as file: - file.write(info.json()) + file.write(info.model_dump_json()) _LOGGER.info( "saved framework info for framework %s in file at %s", framework, path ), else: - print(info.json(indent=4)) + print(info.model_dump_json(indent=4)) _LOGGER.info("printed out framework info for framework %s", framework) diff --git a/src/sparseml/integration_helper_functions.py b/src/sparseml/integration_helper_functions.py index cf8cbcdd690..8f332c9c675 100644 --- a/src/sparseml/integration_helper_functions.py +++ b/src/sparseml/integration_helper_functions.py @@ -188,6 +188,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): default=export_model, ) apply_optimizations: Optional[Callable[[Any], None]] = Field( + None, description="A function that takes:" " - path to the exported model" " - names of the optimizations to apply" @@ -223,6 +224,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): ) deployment_directory_files_optional: Optional[List[str]] = Field( + None, description="A list that describes the " "optional expected files of the deployment directory", ) diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py index 049d3862cbc..29e64bf6477 100644 --- a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py @@ -21,7 +21,7 @@ import torch from packaging import version -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from torch.nn import Identity @@ -121,7 +121,8 @@ def get_observer(self) -> "torch.quantization.FakeQuantize": qconfig_kwargs=self.kwargs, ) - @validator("strategy") + @field_validator("strategy") + @classmethod def validate_strategy(cls, value): valid_scopes = ["tensor", "channel"] if value not in valid_scopes: @@ -263,7 +264,7 @@ def __str__(self) -> str: """ :return: YAML friendly string serialization """ - dict_repr = self.dict() + dict_repr = self.model_dump() dict_repr = { key: val if val is not None else "null" for key, val in dict_repr.items() } diff --git a/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py b/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py index b3ef1807227..186748d4b1e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py @@ -21,7 +21,7 @@ import torch from packaging import version -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from torch.nn import Identity @@ -119,7 +119,8 @@ def get_observer(self) -> "torch.quantization.FakeQuantize": qconfig_kwargs=self.kwargs, ) - @validator("strategy") + @field_validator("strategy") + @classmethod def validate_strategy(cls, value): valid_scopes = ["tensor", "channel"] if value not in valid_scopes: diff --git a/src/sparseml/pytorch/utils/sparsification_info/configs.py b/src/sparseml/pytorch/utils/sparsification_info/configs.py index 2cc4a16b262..c615a248ea1 100644 --- a/src/sparseml/pytorch/utils/sparsification_info/configs.py +++ b/src/sparseml/pytorch/utils/sparsification_info/configs.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Generator, Tuple, Union import torch.nn -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sparseml.pytorch.utils.sparsification_info.helpers import ( get_leaf_operations, @@ -326,9 +326,7 @@ class SparsificationQuantization(SparsificationInfo): description="A dictionary that maps the name of a layer" "to the precision of that layer." ) - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @classmethod def from_module( diff --git a/src/sparseml/sparsification/info.py b/src/sparseml/sparsification/info.py index 8aa2af8a29d..4d8c1b8dcc7 100644 --- a/src/sparseml/sparsification/info.py +++ b/src/sparseml/sparsification/info.py @@ -235,13 +235,13 @@ def save_sparsification_info(framework: Any, path: Optional[str] = None): create_parent_dirs(path) with open(path, "w") as file: - file.write(info.json()) + file.write(info.model_dump_json()) _LOGGER.info( "saved sparsification info for framework %s in file at %s", framework, path ), else: - print(info.json(indent=4)) + print(info.model_dump_json(indent=4)) _LOGGER.info("printed out sparsification info for framework %s", framework) diff --git a/src/sparseml/sparsification/model_info.py b/src/sparseml/sparsification/model_info.py index b9f2d2d035a..d82403ff4c1 100644 --- a/src/sparseml/sparsification/model_info.py +++ b/src/sparseml/sparsification/model_info.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Set, Union import numpy -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from sparseml.utils import clean_path, create_parent_dirs @@ -87,7 +87,8 @@ class LayerInfo(BaseModel): description="dictionary of string attribute names to their values", ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_params_if_prunable(_, values): prunable = values.get("prunable") params = values.get("params") diff --git a/tests/integrations/image_classification/args.py b/tests/integrations/image_classification/args.py index 4f662aa0fd9..c046138a797 100644 --- a/tests/integrations/image_classification/args.py +++ b/tests/integrations/image_classification/args.py @@ -149,7 +149,7 @@ def __post_init__(self): class ImageClassificationExportArgs(_ImageClassificationBaseArgs): model_tag: Optional[str] = Field( - Default=None, description="required - tag for model under save_dir" + None, Default=None, description="required - tag for model under save_dir" ) onnx_opset: int = Field( default=13, description="The onnx opset to use for exporting the model" diff --git a/tests/integrations/transformers/args.py b/tests/integrations/transformers/args.py index a7fa5f32aa4..54833ded5e2 100644 --- a/tests/integrations/transformers/args.py +++ b/tests/integrations/transformers/args.py @@ -507,7 +507,7 @@ class _TransformersTrainArgs(BaseModel): push_to_hub_token: str = Field( default=None, description="The token to use to push to the Model Hub." ) - _n_gpu: int = Field(init=False, repr=False, default=-1) + n_gpu_: int = Field(init=False, repr=False, default=-1) mp_parameters: Optional[str] = Field( default=None, description="Used by the SageMaker launcher to send mp-specific args. " diff --git a/tests/sparseml/core/recipe/test_recipe.py b/tests/sparseml/core/recipe/test_recipe.py index 937ed34570a..0061a0502d7 100644 --- a/tests/sparseml/core/recipe/test_recipe.py +++ b/tests/sparseml/core/recipe/test_recipe.py @@ -41,7 +41,8 @@ def test_recipe_create_instance_accepts_valid_recipe_file(recipe_str): @pytest.mark.parametrize("recipe_str", valid_recipe_strings()) def test_serialization(recipe_str): recipe_instance = Recipe.create_instance(recipe_str) - recipe_from_serialized = Recipe.create_instance(recipe_instance.yaml()) + serialized_recipe = recipe_instance.yaml() + recipe_from_serialized = Recipe.create_instance(serialized_recipe) expected_dict = recipe_instance.dict() actual_dict = recipe_from_serialized.dict() diff --git a/tests/sparseml/framework/test_info.py b/tests/sparseml/framework/test_info.py index 335c8868ca7..0411beb0e5c 100644 --- a/tests/sparseml/framework/test_info.py +++ b/tests/sparseml/framework/test_info.py @@ -90,11 +90,11 @@ def test_framework_info_lifecycle(const_args): assert info, "No object returned for info constructor" # test serialization - info_str = info.json() + info_str = info.model_dump_json() assert info_str, "No json returned for info" # test deserialization - info_reconst = FrameworkInfo.parse_raw(info_str) + info_reconst = FrameworkInfo.model_validate_json(info_str) assert info == info_reconst, "Reconstructed does not equal original" @@ -110,7 +110,7 @@ def test_save_load_framework_info(): framework=Framework.unknown, package_versions={"unknown": "0.0.1"} ) save_framework_info(info) - loaded_json = load_framework_info(info.json()) + loaded_json = load_framework_info(info.model_dump_json()) assert info == loaded_json test_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name