Skip to content

Commit

Permalink
Fix typing, and model_config_dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Apr 22, 2024
1 parent bb9b60e commit c0dffa3
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 22 deletions.
6 changes: 3 additions & 3 deletions src/sparseml/core/modifier/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class Modifier(BaseModel, ModifierInterface, MultiFrameworkObject):
:param update: The update step for the modifier
"""

index: int = None
group: str = None
start: float = None
index: Optional[int] = None
group: Optional[str] = None
start: Optional[float] = None
end: Optional[float] = None
update: Optional[float] = None

Expand Down
8 changes: 5 additions & 3 deletions src/sparseml/core/recipe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from sparseml.core.framework import Framework
from sparseml.core.recipe.args import RecipeArgs
Expand All @@ -36,6 +36,8 @@ class RecipeBase(BaseModel, ABC):
- create_modifier
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

@abstractmethod
def calculate_start(self) -> int:
raise NotImplementedError()
Expand All @@ -45,7 +47,7 @@ def calculate_end(self) -> int:
raise NotImplementedError()

@abstractmethod
def evaluate(self, args: RecipeArgs = None, shift: int = None):
def evaluate(self, args: Optional[RecipeArgs] = None, shift: Optional[int] = None):
raise NotImplementedError()

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Dict, List, Optional, Union

import yaml
from pydantic import model_validator, Field
from pydantic import Field, model_validator

from sparseml.core.framework import Framework
from sparseml.core.modifier import StageModifiers
Expand Down Expand Up @@ -152,7 +152,7 @@ def create_instance(
)
_LOGGER.debug(f"Input string: {path_or_modifiers}")
obj = _load_json_or_yaml_string(path_or_modifiers)
return Recipe.parse_obj(obj)
return Recipe.model_validate(obj)
else:
_LOGGER.info(f"Loading recipe from file {path_or_modifiers}")

Expand Down Expand Up @@ -534,7 +534,7 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]:
:return: A dictionary representation of the recipe
"""
dict_ = super().dict(*args, **kwargs)
dict_ = super().model_dump(*args, **kwargs)
stages = {}

for stage in dict_["stages"]:
Expand Down
4 changes: 3 additions & 1 deletion src/sparseml/core/recipe/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import model_validator, Field
from pydantic import ConfigDict, Field, model_validator

from sparseml.core.framework import Framework
from sparseml.core.modifier import StageModifiers
Expand Down Expand Up @@ -46,6 +46,8 @@ class RecipeStage(RecipeBase):
:param args_evaluated: the evaluated RecipeArgs for the stage
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

group: Optional[str] = None
run_type: Optional[StageRunType] = None
args: Optional[RecipeArgs] = None
Expand Down
7 changes: 4 additions & 3 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from pydantic import ConfigDict, BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
Expand Down Expand Up @@ -47,8 +47,9 @@ class KeyValueCacheConfig(BaseModel):
additional_transforms: Union[
List[Type[OnnxTransform]], Type[OnnxTransform], None
] = Field(
None, description="A transform class (or list thereof) to use for additional "
"transforms to the model required for finalizing the kv cache injection."
None,
description="A transform class (or list thereof) to use for additional "
"transforms to the model required for finalizing the kv cache injection.",
)
key_num_attention_heads: str = Field(
description="The key to use to get the number of attention heads from the "
Expand Down
6 changes: 4 additions & 2 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
default=export_model,
)
apply_optimizations: Optional[Callable[[Any], None]] = Field(
None, description="A function that takes:"
None,
description="A function that takes:"
" - path to the exported model"
" - names of the optimizations to apply"
" and applies the optimizations to the model",
Expand Down Expand Up @@ -223,6 +224,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
)

deployment_directory_files_optional: Optional[List[str]] = Field(
None, description="A list that describes the "
None,
description="A list that describes the "
"optional expected files of the deployment directory",
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch
from packaging import version
from pydantic import field_validator, BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from torch.nn import Identity


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch
from packaging import version
from pydantic import field_validator, BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from torch.nn import Identity


Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/utils/sparsification_info/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, Generator, Tuple, Union

import torch.nn
from pydantic import ConfigDict, BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from sparseml.pytorch.utils.sparsification_info.helpers import (
get_leaf_operations,
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/sparsification/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,13 @@ def save_sparsification_info(framework: Any, path: Optional[str] = None):
create_parent_dirs(path)

with open(path, "w") as file:
file.write(info.json())
file.write(info.model_dump_json())

_LOGGER.info(
"saved sparsification info for framework %s in file at %s", framework, path
),
else:
print(info.json(indent=4))
print(info.model_dump_json(indent=4))
_LOGGER.info("printed out sparsification info for framework %s", framework)


Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/sparsification/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Any, Dict, List, Optional, Set, Union

import numpy
from pydantic import model_validator, BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from sparseml.utils import clean_path, create_parent_dirs

Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/transformers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ class _TransformersTrainArgs(BaseModel):
push_to_hub_token: str = Field(
default=None, description="The token to use to push to the Model Hub."
)
_n_gpu: int = Field(init=False, repr=False, default=-1)
n_gpu_: int = Field(init=False, repr=False, default=-1)
mp_parameters: Optional[str] = Field(
default=None,
description="Used by the SageMaker launcher to send mp-specific args. "
Expand Down

0 comments on commit c0dffa3

Please sign in to comment.