Skip to content

Commit

Permalink
external data fix (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat authored Feb 23, 2024
1 parent a00ca1e commit bc7218a
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/sparsezoo/utils/onnx/external_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def save_onnx(
model_path: str,
max_external_file_size: int = 16e9,
external_data_file: Optional[str] = None,
do_split_external_data: bool = True,
) -> bool:
"""
Save model to the given path.
Expand All @@ -95,6 +96,8 @@ def save_onnx(
specified in the variable EXTERNAL_ONNX_DATA_NAME
:param max_external_file_size: The maximum file size in bytes of a single split
external data out file. Defaults to 16000000000 (16e9 = 16GB)
:param do_split_external_data: True to split external data file into chunks of max
size max_external_file_size, false otherwise
:return True if the model was saved with external data, False otherwise.
"""
if external_data_file is not None:
Expand All @@ -112,7 +115,8 @@ def save_onnx(
all_tensors_to_one_file=True,
location=external_data_file,
)
split_external_data(model_path, max_file_size=max_external_file_size)
if do_split_external_data:
split_external_data(model_path, max_file_size=max_external_file_size)
return True

if model.ByteSize() > DUMP_EXTERNAL_DATA_THRESHOLD:
Expand All @@ -132,7 +136,8 @@ def save_onnx(
all_tensors_to_one_file=True,
location=external_data_file,
)
split_external_data(model_path, max_file_size=max_external_file_size)
if do_split_external_data:
split_external_data(model_path, max_file_size=max_external_file_size)
return True

onnx.save(model, model_path)
Expand Down Expand Up @@ -247,6 +252,9 @@ def split_external_data(
f"{external_data_file_path} not found. {model_path} must have external "
"data written to a single file in the same directory"
)
if os.path.getsize(external_data_file_path) <= max_file_size:
# return immediately if file is small enough to not split
return

# UPDATE: external data info of graph tensors so they point to the new split out
# files with updated offsets
Expand Down Expand Up @@ -300,14 +308,6 @@ def split_external_data(
# WRITE - ONNX model with updated tensor external data info
onnx.save(model, model_path)

# RENAME - if as a result of splitting we end up with a single file, rename it to
# the original external data file name
if current_external_data_file_number == 1:
os.rename(
os.path.join(base_dir, updated_file_name),
os.path.join(base_dir, external_data_file),
)


def _write_external_data_file_from_base_bytes(
new_file_name, original_byte_ranges, original_file_bytes_reader
Expand Down

0 comments on commit bc7218a

Please sign in to comment.