diff --git a/src/sparseml/core/modifier/modifier.py b/src/sparseml/core/modifier/modifier.py index bb60b23d90f..c383ec38046 100644 --- a/src/sparseml/core/modifier/modifier.py +++ b/src/sparseml/core/modifier/modifier.py @@ -40,9 +40,9 @@ class Modifier(BaseModel, ModifierInterface, MultiFrameworkObject): :param update: The update step for the modifier """ - index: int = None - group: str = None - start: float = None + index: Optional[int] = None + group: Optional[str] = None + start: Optional[float] = None end: Optional[float] = None update: Optional[float] = None diff --git a/src/sparseml/core/recipe/base.py b/src/sparseml/core/recipe/base.py index e1377915114..ecd438f127f 100644 --- a/src/sparseml/core/recipe/base.py +++ b/src/sparseml/core/recipe/base.py @@ -13,9 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from sparseml.core.framework import Framework from sparseml.core.recipe.args import RecipeArgs @@ -36,6 +36,8 @@ class RecipeBase(BaseModel, ABC): - create_modifier """ + model_config = ConfigDict(arbitrary_types_allowed=True) + @abstractmethod def calculate_start(self) -> int: raise NotImplementedError() @@ -45,7 +47,7 @@ def calculate_end(self) -> int: raise NotImplementedError() @abstractmethod - def evaluate(self, args: RecipeArgs = None, shift: int = None): + def evaluate(self, args: Optional[RecipeArgs] = None, shift: Optional[int] = None): raise NotImplementedError() @abstractmethod diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 1fc30c50fa1..f2d94c8c5a9 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Union import yaml -from pydantic import model_validator, Field +from pydantic import Field, model_validator from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers @@ -152,7 +152,7 @@ def create_instance( ) _LOGGER.debug(f"Input string: {path_or_modifiers}") obj = _load_json_or_yaml_string(path_or_modifiers) - return Recipe.parse_obj(obj) + return Recipe.model_validate(obj) else: _LOGGER.info(f"Loading recipe from file {path_or_modifiers}") @@ -534,7 +534,7 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: :return: A dictionary representation of the recipe """ - dict_ = super().dict(*args, **kwargs) + dict_ = super().model_dump(*args, **kwargs) stages = {} for stage in dict_["stages"]: diff --git a/src/sparseml/core/recipe/stage.py b/src/sparseml/core/recipe/stage.py index b8e8fe25108..31638648f98 100644 --- a/src/sparseml/core/recipe/stage.py +++ b/src/sparseml/core/recipe/stage.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import model_validator, Field +from pydantic import ConfigDict, Field, model_validator from sparseml.core.framework import Framework from sparseml.core.modifier import StageModifiers @@ -46,6 +46,8 @@ class RecipeStage(RecipeBase): :param args_evaluated: the evaluated RecipeArgs for the stage """ + model_config = ConfigDict(arbitrary_types_allowed=True) + group: Optional[str] = None run_type: Optional[StageRunType] = None args: Optional[RecipeArgs] = None diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index aaea9c661bf..686adf5c7d5 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union -from pydantic import ConfigDict, BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sparseml.exporters.transforms import OnnxTransform from sparseml.exporters.transforms.kv_cache.transforms_codegen import ( @@ -47,8 +47,9 @@ class KeyValueCacheConfig(BaseModel): additional_transforms: Union[ List[Type[OnnxTransform]], Type[OnnxTransform], None ] = Field( - None, description="A transform class (or list thereof) to use for additional " - "transforms to the model required for finalizing the kv cache injection." + None, + description="A transform class (or list thereof) to use for additional " + "transforms to the model required for finalizing the kv cache injection.", ) key_num_attention_heads: str = Field( description="The key to use to get the number of attention heads from the " diff --git a/src/sparseml/integration_helper_functions.py b/src/sparseml/integration_helper_functions.py index 46621ee48a4..8f332c9c675 100644 --- a/src/sparseml/integration_helper_functions.py +++ b/src/sparseml/integration_helper_functions.py @@ -188,7 +188,8 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): default=export_model, ) apply_optimizations: Optional[Callable[[Any], None]] = Field( - None, description="A function that takes:" + None, + description="A function that takes:" " - path to the exported model" " - names of the optimizations to apply" " and applies the optimizations to the model", @@ -223,6 +224,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): ) deployment_directory_files_optional: Optional[List[str]] = Field( - None, description="A list that describes the " + None, + description="A list that describes the " "optional expected files of the deployment directory", ) diff --git a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py index 7f8df1fe614..5660e7acec9 100644 --- a/src/sparseml/modifiers/quantization/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization/utils/quantization_scheme.py @@ -21,7 +21,7 @@ import torch from packaging import version -from pydantic import field_validator, BaseModel, Field +from pydantic import BaseModel, Field, field_validator from torch.nn import Identity diff --git a/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py b/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py index 2bac742c4e2..186748d4b1e 100644 --- a/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py +++ b/src/sparseml/pytorch/sparsification/quantization/quantization_scheme.py @@ -21,7 +21,7 @@ import torch from packaging import version -from pydantic import field_validator, BaseModel, Field +from pydantic import BaseModel, Field, field_validator from torch.nn import Identity diff --git a/src/sparseml/pytorch/utils/sparsification_info/configs.py b/src/sparseml/pytorch/utils/sparsification_info/configs.py index ef8323a00e1..c615a248ea1 100644 --- a/src/sparseml/pytorch/utils/sparsification_info/configs.py +++ b/src/sparseml/pytorch/utils/sparsification_info/configs.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Generator, Tuple, Union import torch.nn -from pydantic import ConfigDict, BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sparseml.pytorch.utils.sparsification_info.helpers import ( get_leaf_operations, diff --git a/src/sparseml/sparsification/info.py b/src/sparseml/sparsification/info.py index 8aa2af8a29d..4d8c1b8dcc7 100644 --- a/src/sparseml/sparsification/info.py +++ b/src/sparseml/sparsification/info.py @@ -235,13 +235,13 @@ def save_sparsification_info(framework: Any, path: Optional[str] = None): create_parent_dirs(path) with open(path, "w") as file: - file.write(info.json()) + file.write(info.model_dump_json()) _LOGGER.info( "saved sparsification info for framework %s in file at %s", framework, path ), else: - print(info.json(indent=4)) + print(info.model_dump_json(indent=4)) _LOGGER.info("printed out sparsification info for framework %s", framework) diff --git a/src/sparseml/sparsification/model_info.py b/src/sparseml/sparsification/model_info.py index c280adccc5f..d82403ff4c1 100644 --- a/src/sparseml/sparsification/model_info.py +++ b/src/sparseml/sparsification/model_info.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Set, Union import numpy -from pydantic import model_validator, BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sparseml.utils import clean_path, create_parent_dirs diff --git a/tests/integrations/transformers/args.py b/tests/integrations/transformers/args.py index a7fa5f32aa4..54833ded5e2 100644 --- a/tests/integrations/transformers/args.py +++ b/tests/integrations/transformers/args.py @@ -507,7 +507,7 @@ class _TransformersTrainArgs(BaseModel): push_to_hub_token: str = Field( default=None, description="The token to use to push to the Model Hub." ) - _n_gpu: int = Field(init=False, repr=False, default=-1) + n_gpu_: int = Field(init=False, repr=False, default=-1) mp_parameters: Optional[str] = Field( default=None, description="Used by the SageMaker launcher to send mp-specific args. "