Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DFSSerializer to remove code duplication. #241

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 174 additions & 72 deletions model_signing/serialization/serialize_by_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

"""Model serializers that operated at file level granularity."""

import abc
import base64
import concurrent.futures
import pathlib
from typing import Callable, Iterable
from typing import Callable, Iterable, cast
from typing_extensions import override

from model_signing.hashing import file
Expand Down Expand Up @@ -65,84 +66,21 @@ def _build_header(
bytes. Each argument is separated by dots and the last byte is also a
dot (so the file digest can be appended unambiguously).
"""
# Note: This will get replaced in subsequent change, right now we're just
# moving existing code around.
encoded_type = entry_type.encode("utf-8")
# Prevent confusion if name has a "." inside by encoding to base64.
encoded_name = base64.b64encode(entry_name.encode("utf-8"))
# Note: empty string at the end, to terminate header with a "."
return b".".join([encoded_type, encoded_name, b""])


class DFSSerializer(serialization.Serializer):
"""Serializer for a model that performs a traversal of the model directory.

This serializer produces a single hash for the entire model. If the model is
a file, the hash is the digest of the file. If the model is a directory, we
perform a depth-first traversal of the directory, hash each individual files
and aggregate the hashes together.
"""

def __init__(
self,
file_hasher: file.SimpleFileHasher,
merge_hasher_factory: Callable[[], hashing.StreamingHashEngine],
):
"""Initializes an instance to serialize a model with this serializer.

Args:
hasher: The hash engine used to hash the individual files.
merge_hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
file digests to compute an aggregate digest.
"""
self._file_hasher = file_hasher
self._merge_hasher_factory = merge_hasher_factory

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
check_file_or_directory(model_path)

if model_path.is_file():
self._file_hasher.set_file(model_path)
return manifest.DigestManifest(self._file_hasher.compute())

return manifest.DigestManifest(self._dfs(model_path))

def _dfs(self, directory: pathlib.Path) -> hashing.Digest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add support
# for excluded files.
children = sorted([x for x in directory.iterdir()])

hasher = self._merge_hasher_factory()
for child in children:
check_file_or_directory(child)

if child.is_file():
header = _build_header(entry_name=child.name, entry_type="file")
hasher.update(header)
self._file_hasher.set_file(child)
digest = self._file_hasher.compute()
hasher.update(digest.digest_value)
else:
header = _build_header(entry_name=child.name, entry_type="dir")
hasher.update(header)
digest = self._dfs(child)
hasher.update(digest.digest_value)

return hasher.compute()


class FilesSerializer(serialization.Serializer):
"""Model serializers that produces an itemized manifest, at file level.
"""Generic file serializer.

Traverses the model directory and creates digests for every file found,
possibly in parallel.

Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
Subclasses can then create a manifest with these digests, either listing
them item by item, or combining everything into a single digest.
"""

def __init__(
Expand All @@ -162,7 +100,7 @@ def __init__(
self._max_workers = max_workers

@override
def serialize(self, model_path: pathlib.Path) -> manifest.FileLevelManifest:
def serialize(self, model_path: pathlib.Path) -> manifest.Manifest:
# TODO: github.com/sigstore/model-transparency/issues/196 - Add checks
# to exclude symlinks if desired.
check_file_or_directory(model_path)
Expand Down Expand Up @@ -210,12 +148,176 @@ def _compute_hash(
digest = self._hasher_factory(path).compute()
return manifest.FileManifestItem(path=relative_path, digest=digest)

@abc.abstractmethod
def _build_manifest(
self, items: Iterable[manifest.FileManifestItem]
) -> manifest.FileLevelManifest:
"""Builds an itemized manifest from a given list of items.
) -> manifest.Manifest:
"""Builds the manifest representing the serialization of the model."""
pass


class ManifestSerializer(FilesSerializer):
"""Model serializer that produces an itemized manifest, at file level.

Since the manifest lists each item individually, this will also enable
support for incremental updates (to be added later).
"""

Every subclass needs to implement this method to determine the format of
the manifest.
@override
def serialize(self, model_path: pathlib.Path) -> manifest.FileLevelManifest:
"""Serializes the model given by the `model_path` argument.

The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.FileLevelManifest` instances.
"""
return cast(manifest.FileLevelManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.FileManifestItem]
) -> manifest.FileLevelManifest:
return manifest.FileLevelManifest(items)


class _FileDigestTree:
"""A tree of files with their digests.

Every leaf in the tree is a file, paired with its digest. Every intermediate
node represents a directory. We need to pair every directory with a digest,
in a bottom-up fashion.
"""

def __init__(
self, path: pathlib.PurePath, digest: hashing.Digest | None = None
):
"""Builds a node in the digest tree.

Don't call this from outside of the class. Instead, use `build_tree`.

Args:
path: Path included in the node.
digest: Optional hash of the path. Files must have a digest,
directories never have one.
"""
self._path = path
self._digest = digest
self._children: list[_FileDigestTree] = []

@classmethod
def build_tree(
cls, items: Iterable[manifest.FileManifestItem]
) -> "_FileDigestTree":
"""Builds a tree out of the sequence of manifest items."""
path_to_node: dict[pathlib.PurePath, _FileDigestTree] = {}

for file_item in items:
file = file_item.path
node = cls(file, file_item.digest)
for parent in file.parents:
if parent in path_to_node:
parent_node = path_to_node[parent]
parent_node._children.append(node)
break # everything else already exists

parent_node = cls(parent) # no digest for directories
parent_node._children.append(node)
path_to_node[parent] = parent_node
node = parent_node

# Handle empty model
if not path_to_node:
return cls(pathlib.PurePosixPath())

return path_to_node[pathlib.PurePosixPath()]

def get_digest(
self,
hasher_factory: Callable[[], hashing.StreamingHashEngine],
) -> hashing.Digest:
"""Returns the digest of this tree of files.

Args:
hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
digests to compute an aggregate digest.
"""
hasher = hasher_factory()

for child in sorted(self._children, key=lambda c: c._path):
name = child._path.name
if child._digest is not None:
header = _build_header(entry_name=name, entry_type="file")
hasher.update(header)
hasher.update(child._digest.digest_value)
else:
header = _build_header(entry_name=name, entry_type="dir")
hasher.update(header)
digest = child.get_digest(hasher_factory)
hasher.update(digest.digest_value)

return hasher.compute()


class DigestSerializer(FilesSerializer):
"""Serializer for a model that performs a traversal of the model directory.

This serializer produces a single hash for the entire model. If the model is
a file, the hash is the digest of the file. If the model is a directory, we
perform a depth-first traversal of the directory, hash each individual files
and aggregate the hashes together.

Currently, this has a different initialization than `FilesSerializer`, but
this will likely change in a subsequent change. Similarly, currently, this
only supports one single worker, but this will change in the future.
"""

def __init__(
self,
file_hasher: file.SimpleFileHasher,
merge_hasher_factory: Callable[[], hashing.StreamingHashEngine],
):
"""Initializes an instance to serialize a model with this serializer.

Args:
hasher: The hash engine used to hash the individual files.
merge_hasher_factory: A callable that returns a
`hashing.StreamingHashEngine` instance used to merge individual
file digests to compute an aggregate digest.
"""

def _factory(path: pathlib.Path) -> file.FileHasher:
file_hasher.set_file(path)
return file_hasher

super().__init__(_factory, max_workers=1)
self._merge_hasher_factory = merge_hasher_factory

@override
def serialize(self, model_path: pathlib.Path) -> manifest.DigestManifest:
"""Serializes the model given by the `model_path` argument.

The only reason for the override is to change the return type, to be
more restrictive. This is to signal that the only manifests that can be
returned are `manifest.DigestManifest` instances.
"""
return cast(manifest.DigestManifest, super().serialize(model_path))

@override
def _build_manifest(
self, items: Iterable[manifest.FileManifestItem]
) -> manifest.DigestManifest:
# Note: we do several computations here to try and match the old
# behavior but these would be simplified in the future. Since we are
# defining the hashing behavior, we can freely change this.

# If the model is just one file, return the hash of the file.
# A model is a file if we have one item only and its path is empty.
items = list(items)
if len(items) == 1 and not items[0].path.name:
return manifest.DigestManifest(items[0].digest)

# Otherwise, build a tree of files and compute the digests.
tree = _FileDigestTree.build_tree(items)
digest = tree.get_digest(self._merge_hasher_factory)
return manifest.DigestManifest(digest)
Loading
Loading