Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization #120

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,20 @@ def deserialize_tensor_subclass(t, state_dict, prefix):
return t.__class__.__tensor_unflatten__(inner_tensors_dict, meta_dict, None, None)

deserialized_weight = deserialize_tensor_subclass(self.qweight, state_dict, weight_name + ".")
device = self.weight.device if self.weight.device.type != "meta" else deserialized_weight.device
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
if assign_to_params_buffers:
self.weight = torch.nn.Parameter(deserialized_weight)
else:
if type(self.weight.data) != type(deserialized_weight):
# Reloading frozen weights into unfrozen module: move to the correct device and force assignment
self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))
self.weight = torch.nn.Parameter(deserialized_weight.to(device))
else:
# FIXME: here we should copy frozen weights into frozen module, but this leads to grad error
self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))

self.weight = torch.nn.Parameter(deserialized_weight.to(device))
# this is needed because we can't load it correctly when the bias is on the meta device
if prefix + "bias" in state_dict:
self.bias = torch.nn.Parameter(state_dict.pop(prefix + "bias"))
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs
)
Expand Down
10 changes: 9 additions & 1 deletion quanto/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, Union

import torch
from safetensors.torch import safe_open, save_file
from safetensors.torch import _remove_duplicate_names, safe_open, save_file


def safe_save(state_dict: Dict[str, Union[torch.Tensor, str]], filename: Union[str, os.PathLike]):
Expand All @@ -20,6 +20,14 @@ def safe_save(state_dict: Dict[str, Union[torch.Tensor, str]], filename: Union[s
tensors[name] = value
else:
metadata[name] = value

to_removes = _remove_duplicate_names(tensors)
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
del tensors[to_remove]

metadata["format"] = "pt"

save_file(tensors, filename, metadata)


Expand Down
Loading