Skip to content

Commit

Permalink
default Model.onnx_model to target onnx.model.tar.gz (#355)
Browse files Browse the repository at this point in the history
* default Model.onnx_model to target onnx.model.tar.gz

* quality

* fix a unit test

* Fix Typos
super().path call updated to super.path()

* Default force to True, as written in the docstring and expected by Tests
This change fixes failing test
`tests/sparsezoo/objects/test_directory.py ...F....`

* Add force argument to `Directory` class to override unzipping during creation
Add model.onnx.tar.gz to expected files
Update test_unzipping to force unzip, as the parent directory is already present

* fix typo

* Update src/sparsezoo/objects/directories.py

---------

Co-authored-by: Rahul Tuli <[email protected]>
  • Loading branch information
bfineran and rahul-tuli authored Aug 15, 2023
1 parent a516d08 commit 4577acf
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/sparsezoo/analyze/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __sub__(self, other):
my_value = getattr(self, field)
other_value = getattr(other, field)

assert type(my_value) == type(other_value)
assert type(my_value) is type(other_value)
if field == "section_name":
new_fields[field] = my_value
elif isinstance(my_value, str):
Expand Down
10 changes: 9 additions & 1 deletion src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Directory,
File,
NumpyDirectory,
OnnxGz,
SelectDirectory,
is_directory,
)
Expand Down Expand Up @@ -156,7 +157,14 @@ def __init__(self, source: str, download_path: Optional[str] = None):
stub_params=self.stub_params,
)

self.onnx_model: File = self._file_from_files(files, display_name="model.onnx")
self._onnx_gz: Directory = self._directory_from_files(
files, directory_class=OnnxGz, display_name="model.onnx.tar.gz"
)
self.onnx_model: File = (
self._file_from_files(files, display_name="model.onnx")
if self._onnx_gz is None
else self._onnx_gz # if onnx.model.tar.gz present defer to that file
)

self.analysis: File = self._file_from_files(files, display_name="analysis.yaml")
self.benchmarks: File = self._file_from_files(
Expand Down
8 changes: 7 additions & 1 deletion src/sparsezoo/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ThroughputResults,
ValidationResult,
)
from sparsezoo.objects import Directory, File, NumpyDirectory
from sparsezoo.objects import Directory, File, NumpyDirectory, OnnxGz
from sparsezoo.utils import BASE_API_URL, convert_to_bool, save_numpy


Expand Down Expand Up @@ -575,6 +575,12 @@ def _copy_file_contents(
for _file in file:
copy_path = os.path.join(output_dir, os.path.basename(_file.path))
_copy_and_overwrite(_file.path, copy_path, shutil.copyfile)
elif isinstance(file, OnnxGz):
# copy all contents of unzipped onnx.tar.gz file to top level of output
onnx_gz_path = (
os.path.dirname(file.path) if os.path.isfile(file.path) else file.path
)
shutil.copytree(onnx_gz_path, output_dir, dirs_exist_ok=True)
elif isinstance(file, Directory):
copy_path = os.path.join(output_dir, os.path.basename(file.path))
_copy_and_overwrite(file.path, copy_path, shutil.copytree)
Expand Down
27 changes: 26 additions & 1 deletion src/sparsezoo/objects/directories.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


import logging
import os.path
from collections import OrderedDict
from typing import Dict, List, Optional, Union

Expand All @@ -29,7 +30,11 @@
from sparsezoo.utils import DataLoader, Dataset, load_numpy_list


__all__ = ["NumpyDirectory", "SelectDirectory"]
__all__ = [
"NumpyDirectory",
"SelectDirectory",
"OnnxGz",
]

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -269,3 +274,23 @@ def available(self):
@available.setter
def available(self, value):
self._available = value


class OnnxGz(Directory):
"""
Special class to handle onnx.model.tar.gz files.
Desired behavior is that all information about files included in the tarball are
available however, when the `path` property is accessed, it will point only
to the `model.onnx` as this is the expected behavior for loading an onnx model
with or without external data.
"""

@property
def path(self):
_ = super().path # call self.path to download initial file if not already
if self.is_archive:
self.unzip()
if os.path.isdir(self._path) and "model.onnx" in os.listdir(self._path):
# if unzipped into a directory, refer directly to model.onnx
self._path = os.path.join(self._path, "model.onnx")
return self._path
31 changes: 22 additions & 9 deletions src/sparsezoo/objects/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Directory(File):
:param path: path of the Directory
:param url: url of the Directory
:param parent_directory: path of the parent Directory
:param force: boolean flag; True to force unzipping of archive files.
Default is False.
"""

def __init__(
Expand All @@ -50,6 +52,7 @@ def __init__(
path: Optional[str] = None,
url: Optional[str] = None,
parent_directory: Optional[str] = None,
force: bool = False,
):

self.files = (
Expand All @@ -63,7 +66,7 @@ def __init__(
)

if self._unpack():
self.unzip()
self.unzip(force=force)

@classmethod
def from_file(cls, file: File) -> "Directory":
Expand Down Expand Up @@ -207,6 +210,8 @@ def get_file(self, file_name: str) -> Optional[File]:
:return: File if found, otherwise None
"""
for file in self.files:
if file is None:
continue
if file.name == file_name:
return file
if isinstance(file, Directory):
Expand Down Expand Up @@ -254,18 +259,23 @@ def gzip(self, archive_directory: Optional[str] = None):
self._path = tar_file_path
self.is_archive = True

def unzip(self, extract_directory: Optional[str] = None):
def unzip(self, extract_directory: Optional[str] = None, force: bool = False):
"""
Extracts a tar archive Directory.
The extracted files would be saved in the parent directory of
`self`, unless `extract_directory` argument is specified
:param extract_directory: the local path to create
folder Directory at (default = None)
:param force: if True, will always unzip, even if the target directory
already exists. Default False
"""
if self._path is None:
# use path property to download so path exists
self._path = self.path
files = []
if extract_directory is None:
extract_directory = os.path.dirname(self.path)
extract_directory = os.path.dirname(self._path)

if not self.is_archive:
raise ValueError(
Expand All @@ -274,14 +284,17 @@ def unzip(self, extract_directory: Optional[str] = None):
)

name = ".".join(self.name.split(".")[:-2])
tar = tarfile.open(self.path, "r")
path = os.path.join(extract_directory, name)

for member in tar.getmembers():
member.name = os.path.basename(member.name)
tar.extract(member=member, path=path)
files.append(File(name=member.name, path=os.path.join(path, member.name)))
tar.close()
if not os.path.exists(path) or force: # do not re-unzip if not forced
tar = tarfile.open(self._path, "r")
for member in tar.getmembers():
member.name = os.path.basename(member.name)
tar.extract(member=member, path=path)
files.append(
File(name=member.name, path=os.path.join(path, member.name))
)
tar.close()

self.name = name
self.files = files
Expand Down
4 changes: 3 additions & 1 deletion tests/sparsezoo/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"logs",
"onnx",
"model.onnx",
"model.onnx.tar.gz",
"recipe",
"sample_inputs.tar.gz",
"sample_originals.tar.gz",
Expand Down Expand Up @@ -198,6 +199,7 @@ def test_folder_structure(self, setup):
"sample_outputs_deepsparse",
]:
expected_files.update({file_name, file_name + ".tar.gz"})

assert not set(os.listdir(temp_dir.name)).difference(expected_files)

def test_validate(self, setup):
Expand All @@ -223,7 +225,7 @@ def _add_mock_files(directory_path: str, clone_sample_outputs: bool):
os.makedirs(onnx_folder_dir)
for opset in range(1, 3):
shutil.copyfile(
os.path.join(directory_path, "model.onnx"),
os.path.join(directory_path, "deployment", "model.onnx"),
os.path.join(onnx_folder_dir, f"model.{opset}.onnx"),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/objects/test_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_zipping_on_creation(self, setup):
) = setup
directory = Directory(name=name, files=files, path=path)
directory.gzip()
new_directory = Directory(name=directory.name, path=directory.path)
new_directory = Directory(name=directory.name, path=directory.path, force=True)
assert os.path.isdir(new_directory.path)
assert new_directory.path == directory.path.replace(".tar.gz", "")
assert new_directory.files
Expand Down

0 comments on commit 4577acf

Please sign in to comment.