Skip to content

Commit

Permalink
Refactor sharded serialization to remove code duplication
Browse files Browse the repository at this point in the history
Similar to sigstore#241, there is a duplication in the directory traversal between serializing to a digest and serializing to a manifest. This time, both supported parallelism, so there is really no need for the duplication.

We make an abstract `ShardedFilesSerializer` class to contain the logic for the directory traversal and then create the better named `DigestSerializer` and `ManifestSerializer` for the two serializing classes.

This time, instead of trying extremely hard to match the old behavior for digest serialization, we just update the goldens. This means that this depends on sigstore#244.

We still had to update some other tests: since the hashes are computed only for files, we no longer differentiate between a model with an empty directory and a model where that empty directory is completely removed. This is a corner case and it is ok to do this.

In fact, ignoring empty directories is part of the optimization hinted at in sigstore#197.

Signed-off-by: Mihai Maruseac <[email protected]>
  • Loading branch information
mihaimaruseac committed Jul 22, 2024
1 parent 9774883 commit f9df44c
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 222 deletions.
266 changes: 96 additions & 170 deletions model_signing/serialization/serialize_by_file_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

"""Model serializers that operated at file shard level of granularity."""

import abc
import base64
import concurrent.futures
import pathlib
from typing import Callable, Iterable, TypeAlias
from typing import Callable, Iterable, TypeAlias, cast
from typing_extensions import override

from model_signing.hashing import file
Expand All @@ -27,21 +28,16 @@
from model_signing.serialization import serialize_by_file


_ShardSignTask: TypeAlias = tuple[pathlib.PurePath, str, int, int]


def _build_header(
*,
entry_name: str,
entry_type: str,
name: str,
start: int,
end: int,
) -> bytes:
"""Builds a header to encode a path with given name and type.
"""Builds a header to encode a path with given name and shard range.
Args:
entry_name: The name of the entry to build the header for.
entry_type: The type of the entry (file or directory).
start: Offset for the start of the path shard.
end: Offset for the end of the path shard.
Expand All @@ -50,14 +46,11 @@ def _build_header(
bytes. Each argument is separated by dots and the last byte is also a
dot (so the file digest can be appended unambiguously).
"""
# Note: This will get replaced in subsequent change, right now we're just
# moving existing code around.
encoded_type = entry_type.encode("utf-8")
# Prevent confusion if name has a "." inside by encoding to base64.
encoded_name = base64.b64encode(entry_name.encode("utf-8"))
encoded_name = base64.b64encode(name.encode("utf-8"))
encoded_range = f"{start}-{end}".encode("utf-8")
# Note: empty string at the end, to terminate header with a "."
return b".".join([encoded_type, encoded_name, encoded_range, b""])
return b".".join([encoded_name, encoded_range, b""])


def _endpoints(step: int, end: int) -> Iterable[int]:
Expand All @@ -83,164 +76,15 @@ def _endpoints(step: int, end: int) -> Iterable[int]:
yield end


class ShardedDFSSerializer(serialization.Serializer):
"""DFSSerializer that uses a sharded hash engine to exploit parallelism."""

def __init__(
self,
file_hasher_factory: Callable[
[pathlib.Path, int, int], file.ShardedFileHasher
],
merge_hasher: hashing.StreamingHashEngine,
max_workers: int | None = None,
):
"""Initializes an instance to serialize a model with this serializer.
Args:
hasher_factory: A callable to build the hash engine used to hash
every shard of the files in the model. Because each shard is
processed in parallel, every thread needs to call the factory to
start hashing. The arguments are the file, and the endpoints of
the shard.
merge_hasher: A `hashing.StreamingHashEngine` instance used to merge
individual file digests to compute an aggregate digest.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurent.futures` library.
"""
self._file_hasher_factory = file_hasher_factory
self._merge_hasher = merge_hasher
self._max_workers = max_workers

# Precompute some private values only once by using a mock file hasher.
# None of the arguments used to build the hasher are used.
hasher = file_hasher_factory(pathlib.Path(), 0, 1)
self._shard_size = hasher.shard_size

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
# Note: This function currently uses `pathlib.Path.glob` so the DFS
# expansion relies on the `glob` implementation performing a DFS. We
# will be truthful again when switching to `pathlib.Path.walk`, after
# Python 3.12 is the minimum version we support.

# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
serialize_by_file.check_file_or_directory(model_path)

if model_path.is_file():
entries = [model_path]
else:
# TODO: github.com/sigstore/model-transparency/issues/200 - When
# Python3.12 is the minimum supported version, this can be replaced
# with `pathlib.Path.walk` for a clearer interface, and some speed
# improvement.
entries = sorted(model_path.glob("**/*"))

tasks = self._convert_paths_to_tasks(entries, model_path)

digest_len = self._merge_hasher.digest_size
digests_buffer = bytearray(len(tasks) * digest_len)

with concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_workers
) as tpe:
futures_dict = {
tpe.submit(self._perform_hash_task, model_path, task): i
for i, task in enumerate(tasks)
}
for future in concurrent.futures.as_completed(futures_dict):
i = futures_dict[future]
task_digest = future.result()

task_path, task_type, task_start, task_end = tasks[i]
header = _build_header(
entry_name=task_path.name,
entry_type=task_type,
start=task_start,
end=task_end,
)
self._merge_hasher.reset(header)
self._merge_hasher.update(task_digest)
digest = self._merge_hasher.compute().digest_value

start = i * digest_len
end = start + digest_len
digests_buffer[start:end] = digest

self._merge_hasher.reset(digests_buffer)
return manifest.DigestManifest(self._merge_hasher.compute())

def _convert_paths_to_tasks(
self, paths: Iterable[pathlib.Path], root_path: pathlib.Path
) -> list[_ShardSignTask]:
"""Returns the tasks that would hash shards of files in parallel.
Every file in `paths` is replaced by a set of tasks. Each task computes
the digest over a shard of the file. Directories result in a single
task, just to compute a digest over a header.
To differentiate between (empty) files and directories with the same
name, every task needs to also include a header. The header needs to
include relative path to the model root, as we want to obtain the same
digest if the model is moved.
We don't construct an enum for the type of the entry, because these will
never escape this class.
Note that the path component of the tasks is a `pathlib.PurePath`, so
operations on it cannot touch the filesystem.
"""
# TODO: github.com/sigstore/model-transparency/issues/196 - Add support
# for excluded files.

tasks = []
for path in paths:
serialize_by_file.check_file_or_directory(path)
relative_path = path.relative_to(root_path)

if path.is_file():
path_size = path.stat().st_size
start = 0
for end in _endpoints(self._shard_size, path_size):
tasks.append((relative_path, "file", start, end))
start = end
else:
tasks.append((relative_path, "dir", 0, 0))

return tasks

def _perform_hash_task(
self, model_path: pathlib.Path, task: _ShardSignTask
) -> bytes:
"""Produces the hash of the file shard included in `task`."""
task_path, task_type, task_start, task_end = task

# TODO: github.com/sigstore/model-transparency/issues/197 - Directories
# don't need to use the file hasher. Rather than starting a process
# just for them, we should filter these ahead of time, and only use
# threading for file shards. For now, just return an empty result.
if task_type == "dir":
return b""

# TODO: github.com/sigstore/model-transparency/issues/197 - Similarly,
# empty files should be hashed outside of a parallel task, to not waste
# resources.
if task_start == task_end:
return b""

full_path = model_path.joinpath(task_path)
hasher = self._file_hasher_factory(full_path, task_start, task_end)
return hasher.compute().digest_value


class ShardedFilesSerializer(serialization.Serializer):
"""Model serializers that produces an itemized manifest, at shard level.
"""Generic file shard serializer.
Traverses the model directory and creates digests for every file found,
sharding the file in equal shards and computing the digests in parallel.
Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
Subclasses can then create a manifest with these digests, either listing
them item by item, combining them into file digests, or combining all of
them into a single digest.
"""

def __init__(
Expand Down Expand Up @@ -270,9 +114,7 @@ def __init__(
self._shard_size = hasher.shard_size

@override
def serialize(
self, model_path: pathlib.Path
) -> manifest.ShardLevelManifest:
def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
serialize_by_file.check_file_or_directory(model_path)
Expand Down Expand Up @@ -337,12 +179,96 @@ def _compute_hash(
path=relative_path, digest=digest, start=start, end=end
)

@abc.abstractmethod
def _build_manifest(
self, items: Iterable[manifest.ShardedFileManifestItem]
) -> manifest.ShardLevelManifest:
) -> manifest.Manifest:
"""Builds an itemized manifest from a given list of items.
Every subclass needs to implement this method to determine the format of
the manifest.
"""
pass


class ManifestSerializer(ShardedFilesSerializer):
"""Model serializers that produces an itemized manifest, at shard level.
Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
"""

@override
def serialize(
self, model_path: pathlib.Path
) -> manifest.ShardLevelManifest:
"""Serializes the model given by the `model_path` argument.
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
"""
return cast(manifest.ShardLevelManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.ShardedFileManifestItem]
) -> manifest.ShardLevelManifest:
return manifest.ShardLevelManifest(items)


class DigestSerializer(ShardedFilesSerializer):
"""Serializer for a model that performs a traversal of the model directory.
This serializer produces a single hash for the entire model.
"""

def __init__(
self,
file_hasher_factory: Callable[
[pathlib.Path, int, int], file.ShardedFileHasher
],
merge_hasher: hashing.StreamingHashEngine,
max_workers: int | None = None,
):
"""Initializes an instance to serialize a model with this serializer.
Args:
hasher_factory: A callable to build the hash engine used to hash
every shard of the files in the model. Because each shard is
processed in parallel, every thread needs to call the factory to
start hashing. The arguments are the file, and the endpoints of
the shard.
merge_hasher: A `hashing.StreamingHashEngine` instance used to merge
individual file shard digests to compute an aggregate digest.
max_workers: Maximum number of workers to use in parallel. Default
is to defer to the `concurent.futures` library.
"""
super().__init__(file_hasher_factory, max_workers)
self._merge_hasher = merge_hasher

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
"""Serializes the model given by the `model_path` argument.
The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
"""
return cast(manifest.DigestManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.ShardedFileManifestItem]
) -> manifest.DigestManifest:
self._merge_hasher.reset()

for item in sorted(items, key=lambda i: (i.path, i.start, i.end)):
header = _build_header(
name=item.path.name, start=item.start, end=item.end
)
self._merge_hasher.update(header)
self._merge_hasher.update(item.digest.digest_value)

digest = self._merge_hasher.compute()
return manifest.DigestManifest(digest)
Loading

0 comments on commit f9df44c

Please sign in to comment.