Skip to content

Commit

Permalink
Fix (groupwise): correct log, groupdim, and scale computation (#1071)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 31, 2024
1 parent ae3ec68 commit 7bae8ad
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 68 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list[self.group_dim] = (
tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size
block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list)
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x
Expand Down
17 changes: 4 additions & 13 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

class _RestrictClampValue(brevitas.jit.ScriptModule):

def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]):
def __init__(
self,
scaling_min_val: Optional[float] = None,
restrict_value_impl: Optional[Module] = None):
super(_RestrictClampValue, self).__init__()
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
Expand Down Expand Up @@ -90,9 +93,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return x
Expand All @@ -116,9 +116,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.power_of_two(x)
Expand All @@ -143,9 +140,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand All @@ -171,9 +165,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand Down
56 changes: 44 additions & 12 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ def __init__(
tracked_parameter_list: List[torch.nn.Parameter],
scaling_shape: Tuple[int, ...],
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(StatsFromParameterScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.parameter_list_stats = _ParameterListStats(
scaling_stats_impl,
scaling_shape,
Expand All @@ -44,6 +50,7 @@ def __init__(
tracked_parameter_list)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
scaling_shape,
scaling_min_val,
affine_rescaling,
Expand All @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule):
def __init__(
self,
restrict_scaling_impl: Module,
restrict_threshold_impl: Module,
scaling_shape: Tuple[int, ...],
scaling_min_val: Optional[float],
affine_rescaling: bool,
Expand All @@ -81,19 +89,22 @@ def __init__(
else:
self.affine_rescaling = Identity()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_scaling_impl = restrict_scaling_impl
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_scaling_pre(threshold)
threshold = self.restrict_threshold_pre(threshold)
threshold = self.restrict_clamp_threshold(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
stats = stats / threshold
return stats


Expand All @@ -107,12 +118,17 @@ def __init__(
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
scaling_stats_momentum: float = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(RuntimeStatsScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.runtime_stats = _RuntimeStats(
scaling_stats_impl,
scaling_shape,
Expand All @@ -122,6 +138,7 @@ def __init__(
device)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
scaling_shape,
scaling_min_val,
affine_rescaling,
Expand Down Expand Up @@ -173,20 +190,32 @@ def _load_from_state_dict(
class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue()) -> None:
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.group_size = group_size
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module(
)
self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module(
)

@brevitas.jit.script_method
def forward(
Expand All @@ -196,7 +225,10 @@ def forward(
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
out = self.scaling_stats_impl(stats_input_reshaped)
# Apply log scaling
out = self.restrict_scaling_pre(out)
# Scaling min val
out = self.restrict_clamp_scaling(out)
out = self.restrict_clamp_scaling(out) / threshold
return out
Loading

0 comments on commit 7bae8ad

Please sign in to comment.