diff --git a/src/sparsezoo/analyze/utils/models.py b/src/sparsezoo/analyze/utils/models.py index d8f95cad..9ff1979f 100644 --- a/src/sparsezoo/analyze/utils/models.py +++ b/src/sparsezoo/analyze/utils/models.py @@ -255,7 +255,7 @@ def __sub__(self, other): my_value = getattr(self, field) other_value = getattr(other, field) - assert type(my_value) == type(other_value) + assert type(my_value) is type(other_value) if field == "section_name": new_fields[field] = my_value elif isinstance(my_value, str): diff --git a/src/sparsezoo/model/model.py b/src/sparsezoo/model/model.py index 1ec64764..004f20ed 100644 --- a/src/sparsezoo/model/model.py +++ b/src/sparsezoo/model/model.py @@ -34,6 +34,7 @@ Directory, File, NumpyDirectory, + OnnxGz, SelectDirectory, is_directory, ) @@ -156,7 +157,14 @@ def __init__(self, source: str, download_path: Optional[str] = None): stub_params=self.stub_params, ) - self.onnx_model: File = self._file_from_files(files, display_name="model.onnx") + self._onnx_gz: Directory = self._directory_from_files( + files, directory_class=OnnxGz, display_name="model.onnx.tar.gz" + ) + self.onnx_model: File = ( + self._file_from_files(files, display_name="model.onnx") + if self._onnx_gz is None + else self._onnx_gz # if onnx.model.tar.gz present defer to that file + ) self.analysis: File = self._file_from_files(files, display_name="analysis.yaml") self.benchmarks: File = self._file_from_files( diff --git a/src/sparsezoo/model/utils.py b/src/sparsezoo/model/utils.py index 0b306499..9eddcf1e 100644 --- a/src/sparsezoo/model/utils.py +++ b/src/sparsezoo/model/utils.py @@ -29,7 +29,7 @@ ThroughputResults, ValidationResult, ) -from sparsezoo.objects import Directory, File, NumpyDirectory +from sparsezoo.objects import Directory, File, NumpyDirectory, OnnxGz from sparsezoo.utils import BASE_API_URL, convert_to_bool, save_numpy @@ -575,6 +575,12 @@ def _copy_file_contents( for _file in file: copy_path = os.path.join(output_dir, os.path.basename(_file.path)) _copy_and_overwrite(_file.path, copy_path, shutil.copyfile) + elif isinstance(file, OnnxGz): + # copy all contents of unzipped onnx.tar.gz file to top level of output + onnx_gz_path = ( + os.path.dirname(file.path) if os.path.isfile(file.path) else file.path + ) + shutil.copytree(onnx_gz_path, output_dir, dirs_exist_ok=True) elif isinstance(file, Directory): copy_path = os.path.join(output_dir, os.path.basename(file.path)) _copy_and_overwrite(file.path, copy_path, shutil.copytree) diff --git a/src/sparsezoo/objects/directories.py b/src/sparsezoo/objects/directories.py index 52fc38d9..7d14646b 100644 --- a/src/sparsezoo/objects/directories.py +++ b/src/sparsezoo/objects/directories.py @@ -18,6 +18,7 @@ import logging +import os.path from collections import OrderedDict from typing import Dict, List, Optional, Union @@ -29,7 +30,11 @@ from sparsezoo.utils import DataLoader, Dataset, load_numpy_list -__all__ = ["NumpyDirectory", "SelectDirectory"] +__all__ = [ + "NumpyDirectory", + "SelectDirectory", + "OnnxGz", +] _LOGGER = logging.getLogger(__name__) @@ -269,3 +274,23 @@ def available(self): @available.setter def available(self, value): self._available = value + + +class OnnxGz(Directory): + """ + Special class to handle onnx.model.tar.gz files. + Desired behavior is that all information about files included in the tarball are + available however, when the `path` property is accessed, it will point only + to the `model.onnx` as this is the expected behavior for loading an onnx model + with or without external data. + """ + + @property + def path(self): + _ = super().path # call self.path to download initial file if not already + if self.is_archive: + self.unzip() + if os.path.isdir(self._path) and "model.onnx" in os.listdir(self._path): + # if unzipped into a directory, refer directly to model.onnx + self._path = os.path.join(self._path, "model.onnx") + return self._path diff --git a/src/sparsezoo/objects/directory.py b/src/sparsezoo/objects/directory.py index 5557aed8..85509016 100644 --- a/src/sparsezoo/objects/directory.py +++ b/src/sparsezoo/objects/directory.py @@ -41,6 +41,8 @@ class Directory(File): :param path: path of the Directory :param url: url of the Directory :param parent_directory: path of the parent Directory + :param force: boolean flag; True to force unzipping of archive files. + Default is False. """ def __init__( @@ -50,6 +52,7 @@ def __init__( path: Optional[str] = None, url: Optional[str] = None, parent_directory: Optional[str] = None, + force: bool = False, ): self.files = ( @@ -63,7 +66,7 @@ def __init__( ) if self._unpack(): - self.unzip() + self.unzip(force=force) @classmethod def from_file(cls, file: File) -> "Directory": @@ -207,6 +210,8 @@ def get_file(self, file_name: str) -> Optional[File]: :return: File if found, otherwise None """ for file in self.files: + if file is None: + continue if file.name == file_name: return file if isinstance(file, Directory): @@ -254,7 +259,7 @@ def gzip(self, archive_directory: Optional[str] = None): self._path = tar_file_path self.is_archive = True - def unzip(self, extract_directory: Optional[str] = None): + def unzip(self, extract_directory: Optional[str] = None, force: bool = False): """ Extracts a tar archive Directory. The extracted files would be saved in the parent directory of @@ -262,10 +267,15 @@ def unzip(self, extract_directory: Optional[str] = None): :param extract_directory: the local path to create folder Directory at (default = None) + :param force: if True, will always unzip, even if the target directory + already exists. Default False """ + if self._path is None: + # use path property to download so path exists + self._path = self.path files = [] if extract_directory is None: - extract_directory = os.path.dirname(self.path) + extract_directory = os.path.dirname(self._path) if not self.is_archive: raise ValueError( @@ -274,14 +284,17 @@ def unzip(self, extract_directory: Optional[str] = None): ) name = ".".join(self.name.split(".")[:-2]) - tar = tarfile.open(self.path, "r") path = os.path.join(extract_directory, name) - for member in tar.getmembers(): - member.name = os.path.basename(member.name) - tar.extract(member=member, path=path) - files.append(File(name=member.name, path=os.path.join(path, member.name))) - tar.close() + if not os.path.exists(path) or force: # do not re-unzip if not forced + tar = tarfile.open(self._path, "r") + for member in tar.getmembers(): + member.name = os.path.basename(member.name) + tar.extract(member=member, path=path) + files.append( + File(name=member.name, path=os.path.join(path, member.name)) + ) + tar.close() self.name = name self.files = files diff --git a/tests/sparsezoo/model/test_model.py b/tests/sparsezoo/model/test_model.py index c1acd69c..cec1c134 100644 --- a/tests/sparsezoo/model/test_model.py +++ b/tests/sparsezoo/model/test_model.py @@ -30,6 +30,7 @@ "logs", "onnx", "model.onnx", + "model.onnx.tar.gz", "recipe", "sample_inputs.tar.gz", "sample_originals.tar.gz", @@ -198,6 +199,7 @@ def test_folder_structure(self, setup): "sample_outputs_deepsparse", ]: expected_files.update({file_name, file_name + ".tar.gz"}) + assert not set(os.listdir(temp_dir.name)).difference(expected_files) def test_validate(self, setup): @@ -223,7 +225,7 @@ def _add_mock_files(directory_path: str, clone_sample_outputs: bool): os.makedirs(onnx_folder_dir) for opset in range(1, 3): shutil.copyfile( - os.path.join(directory_path, "model.onnx"), + os.path.join(directory_path, "deployment", "model.onnx"), os.path.join(onnx_folder_dir, f"model.{opset}.onnx"), ) diff --git a/tests/sparsezoo/objects/test_directory.py b/tests/sparsezoo/objects/test_directory.py index a276bae7..6f052a02 100644 --- a/tests/sparsezoo/objects/test_directory.py +++ b/tests/sparsezoo/objects/test_directory.py @@ -112,7 +112,7 @@ def test_zipping_on_creation(self, setup): ) = setup directory = Directory(name=name, files=files, path=path) directory.gzip() - new_directory = Directory(name=directory.name, path=directory.path) + new_directory = Directory(name=directory.name, path=directory.path, force=True) assert os.path.isdir(new_directory.path) assert new_directory.path == directory.path.replace(".tar.gz", "") assert new_directory.files