From 53bac08762c97375ac16ba481a76ac0f9101d6cd Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 29 Oct 2024 10:08:12 +0000 Subject: [PATCH] Fallback to ort if hadamard shapes not available --- src/brevitas/graph/equalize.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index bc38a3089..63be40559 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -46,7 +46,13 @@ else: RMSNorm = object -__all__ = ['GraphActivationEqualization', 'LayerwiseActivationEqualization', 'EqualizeGraph'] +__all__ = [ + 'GraphActivationEqualization', + 'LayerwiseActivationEqualization', + 'EqualizeGraph', + 'LayerwiseActivationRotation', + 'MergeLnAffine', + 'GraphRotationEqualization'] EPSILON = 1e-9 @@ -1285,6 +1291,12 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= rot_mat, K = get_hadK(hidden_dim) except AssertionError as e: print(f"Incomptible shapes {hidden_dim}") + if not insert_rotation_module: + print("Falling back to orthogonal matrices") + rot_mat = random_orthogonal_matrix(hidden_dim) + K = None + rot_func = _apply_ort_device + print("Skipping layers") continue rot_func = _apply_had_device