From 9b7aba497aacfd66a6ae762bace124806914e51c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 14:42:45 +0100 Subject: [PATCH 1/7] Feat (core): use runtime parameter for scale --- src/brevitas/core/scaling/runtime.py | 4 ++-- src/brevitas/core/scaling/standalone.py | 6 +++--- src/brevitas/core/stats/stats_wrapper.py | 8 ++++---- src/brevitas/core/stats/view_wrapper.py | 11 +++++++++++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 09f891ed7..bbad0747e 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -60,8 +60,8 @@ def __init__( @brevitas.jit.script_method def forward( - self, ignored: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: - stats = self.parameter_list_stats() + self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + stats = self.parameter_list_stats(x) if threshold is None: threshold = torch.ones(1).type_as(stats) return self.stats_scaling_impl(stats, threshold) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 13ead5afc..47b9af406 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -241,9 +241,9 @@ def __init__( self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: + def forward(self, x: torch.Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: - threshold = torch.ones(1).type_as(ignored) + threshold = torch.ones(1).type_as(x) if self.init_done: threshold = self.stats_scaling_impl.restrict_clamp_threshold( self.restrict_threshold_pre(threshold)) @@ -251,7 +251,7 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor value = value / threshold return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index df3cec952..d7498a906 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -13,6 +13,7 @@ from brevitas.core.utils import inplace_tensor_mul from .view_wrapper import _ViewCatParameterWrapper +from .view_wrapper import _ViewParameter from .view_wrapper import _ViewParameterWrapper DEFAULT_MOMENTUM = 0.1 @@ -96,8 +97,7 @@ def __init__( super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - self.first_tracked_param = _ViewParameterWrapper( - tracked_parameter_list[0], stats_input_view_shape_impl) + self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) if len(tracked_parameter_list) > 1: extra_list = [ _ViewCatParameterWrapper( @@ -109,8 +109,8 @@ def __init__( self.stats = _Stats(stats_impl, stats_output_shape) @brevitas.jit.script_method - def forward(self) -> torch.Tensor: - stats_input = self.first_tracked_param() + def forward(self, x) -> torch.Tensor: + stats_input = self.first_tracked_param(x) if self.extra_tracked_params_list is not None: for extra_tracked_param in self.extra_tracked_params_list: stats_input = extra_tracked_param(stats_input) diff --git a/src/brevitas/core/stats/view_wrapper.py b/src/brevitas/core/stats/view_wrapper.py index acea542d9..c15b27ec5 100644 --- a/src/brevitas/core/stats/view_wrapper.py +++ b/src/brevitas/core/stats/view_wrapper.py @@ -39,6 +39,17 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return output_dict +class _ViewParameter(brevitas.jit.ScriptModule): + + def __init__(self, view_shape_impl: Module) -> None: + super(_ViewParameter, self).__init__() + self.view_shape_impl = view_shape_impl + + @brevitas.jit.script_method + def forward(self, x: Tensor) -> Tensor: + return self.view_shape_impl(x) + + class _ViewCatParameterWrapper(brevitas.jit.ScriptModule): __constants__ = ['cat_dim'] From d20f65b92aa3028cc39683c07d0f39e63761f117 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 16:08:04 +0100 Subject: [PATCH 2/7] Optional input --- src/brevitas/core/scaling/runtime.py | 2 +- src/brevitas/core/scaling/standalone.py | 2 +- src/brevitas/core/stats/stats_wrapper.py | 9 +++++++-- src/brevitas/core/stats/view_wrapper.py | 10 ++++++++-- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index bbad0747e..2317b603c 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -60,7 +60,7 @@ def __init__( @brevitas.jit.script_method def forward( - self, x: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + self, x: Optional[torch.Tensor], threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats(x) if threshold is None: threshold = torch.ones(1).type_as(stats) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 47b9af406..da13e84ff 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -241,7 +241,7 @@ def __init__( self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method - def forward(self, x: torch.Tensor, threshold: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(x) if self.init_done: diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index d7498a906..fa0da5595 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -97,7 +97,12 @@ def __init__( super(_ParameterListStats, self).__init__() self.stats_input_concat_dim = stats_input_concat_dim - self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) + if len(tracked_parameter_list) >= 1: + self.first_tracked_param = _ViewParameterWrapper( + tracked_parameter_list[0], stats_input_view_shape_impl) + else: + self.first_tracked_param = _ViewParameter(stats_input_view_shape_impl) + if len(tracked_parameter_list) > 1: extra_list = [ _ViewCatParameterWrapper( @@ -109,7 +114,7 @@ def __init__( self.stats = _Stats(stats_impl, stats_output_shape) @brevitas.jit.script_method - def forward(self, x) -> torch.Tensor: + def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: stats_input = self.first_tracked_param(x) if self.extra_tracked_params_list is not None: for extra_tracked_param in self.extra_tracked_params_list: diff --git a/src/brevitas/core/stats/view_wrapper.py b/src/brevitas/core/stats/view_wrapper.py index c15b27ec5..98c6ab538 100644 --- a/src/brevitas/core/stats/view_wrapper.py +++ b/src/brevitas/core/stats/view_wrapper.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional + import torch from torch import Tensor from torch.nn import Module @@ -19,8 +21,12 @@ def __init__(self, parameter: Parameter, view_shape_impl: Module) -> None: self.view_shape_impl = view_shape_impl @brevitas.jit.script_method - def forward(self) -> Tensor: - return self.view_shape_impl(self.parameter) + def forward(self, x: Optional[Tensor]) -> Tensor: + if x is not None: + parameter = x + else: + parameter = self.parameter + return self.view_shape_impl(parameter) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, From b8fe5ff01262ed7d5921a5bd3df3ae566feac4cf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 16:13:50 +0100 Subject: [PATCH 3/7] added default None --- src/brevitas/core/stats/stats_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index fa0da5595..be549328e 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -114,7 +114,7 @@ def __init__( self.stats = _Stats(stats_impl, stats_output_shape) @brevitas.jit.script_method - def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: + def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor: stats_input = self.first_tracked_param(x) if self.extra_tracked_params_list is not None: for extra_tracked_param in self.extra_tracked_params_list: From 0b8c98bde763f262ff8d47ab44a8995cab9c4766 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 16:33:42 +0100 Subject: [PATCH 4/7] Pre-forward to fix scales --- src/brevitas_examples/llm/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index f40a367e1..75c08e7c0 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -285,6 +285,8 @@ def main(args): model = offload_model(model) + model(**calibration_loader[0]) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) From 62a82f2f348547c67c408536e98f64c17db3bfa6 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 1 Oct 2024 18:55:40 +0100 Subject: [PATCH 5/7] Fix LSTM --- src/brevitas/core/stats/stats_wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/stats/stats_wrapper.py b/src/brevitas/core/stats/stats_wrapper.py index be549328e..49bf62a82 100644 --- a/src/brevitas/core/stats/stats_wrapper.py +++ b/src/brevitas/core/stats/stats_wrapper.py @@ -115,9 +115,11 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor: - stats_input = self.first_tracked_param(x) if self.extra_tracked_params_list is not None: + stats_input = self.first_tracked_param(None) for extra_tracked_param in self.extra_tracked_params_list: stats_input = extra_tracked_param(stats_input) + else: + stats_input = self.first_tracked_param(x) out = self.stats(stats_input) return out From 228f9830e4d0d6bbb5d59bc608ad2087f23ded23 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 2 Oct 2024 13:11:13 +0100 Subject: [PATCH 6/7] zp fix --- src/brevitas/core/zero_point.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 3f80f1dd4..796940f4f 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -82,7 +82,7 @@ def __init__( @brevitas.jit.script_method def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) return self.scale_shift_zero_point(-stats, scale, bit_width) @@ -266,7 +266,7 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> torch.Tensor: value = self.scale_shift_zero_point(value, scale, bit_width) return value else: - stats = self.parameter_list_stats() + stats = self.parameter_list_stats(x) # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value if self.local_loss_mode: From ae21bcad58fd9a8c584ecbb4f6ad2e1be4919b4b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 5 Nov 2024 13:12:15 +0000 Subject: [PATCH 7/7] precommit --- src/brevitas/core/scaling/runtime.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index 2317b603c..a94f8cd6e 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -60,7 +60,9 @@ def __init__( @brevitas.jit.script_method def forward( - self, x: Optional[torch.Tensor], threshold: Optional[torch.Tensor] = None) -> torch.Tensor: + self, + x: Optional[torch.Tensor], + threshold: Optional[torch.Tensor] = None) -> torch.Tensor: stats = self.parameter_list_stats(x) if threshold is None: threshold = torch.ones(1).type_as(stats)