Skip to content

Commit

Permalink
feat(qmodule): avoid random weights initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Aug 24, 2024
1 parent 04c8010 commit a1c310b
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion optimum/quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,15 @@ def from_module(
activations: Optional[qtype] = None,
optimizer: Optional[Optimizer] = None,
):
qmodule = cls.qcreate(module, weights, activations, optimizer)
# Create the quantized module on the meta device to prevent weights intialization
qmodule = cls.qcreate(module, weights, activations, optimizer, device="meta")
if qmodule is None:
return None
# Move the quantized module to the target device, but with empty weights
qmodule = qmodule.to_empty(device=module.weight.device)
# Set scales that were initialized to empty values
qmodule.input_scale = torch.ones_like(qmodule.input_scale)
qmodule.output_scale = torch.ones_like(qmodule.output_scale)
with torch.no_grad():
qmodule.weight.copy_(module.weight)
if module.bias is not None:
Expand Down

0 comments on commit a1c310b

Please sign in to comment.