Skip to content

Commit

Permalink
Feat (activation_calibration): speed-up by skipping quantization (#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 8, 2024
1 parent db6c560 commit 746d97e
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 19 deletions.
20 changes: 12 additions & 8 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ def __init__(
if dtype is None:
dtype = torch.get_default_dtype()
self.eps = torch.finfo(dtype).tiny
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
Expand All @@ -86,10 +85,15 @@ def dequantize(self, y, scale):

@brevitas.jit.script_method
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
scale = self.scaling_impl(x)
if self.observer_only:
y = x
saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values
else:
y, scale = self.quantize(x, scale)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
17 changes: 14 additions & 3 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width


Expand All @@ -176,6 +180,7 @@ def __init__(
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point


Expand Down Expand Up @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
34 changes: 34 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled):
m.local_loss_mode = enabled


def _set_observer_mode(module, enabled, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
previous_observer_mode[m] = m.observer_only
m.observer_only = enabled


def _restore_observer_mode(module, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = previous_observer_mode[m]


class MSE(torch.nn.Module):
# References:
# https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py
Expand All @@ -459,7 +472,12 @@ def __init__(
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.num = mse_iters
self.search_method = mse_search_method
Expand All @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate):
self.internal_candidate = candidate
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_value = self.proxy_forward(x)
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
self.restore_observer_mode()
return loss

def mse_grid_search(self, xl, x):
Expand Down Expand Up @@ -571,7 +591,12 @@ def __init__(
self.hqo_init_op = hqo_init_op_scale
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.hqo_iters = hqo_iters_scale
self.stats_reduce_dim = stats_reduce_dim
Expand All @@ -598,8 +623,10 @@ def parameter_search(self, xl, x):
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
self.restore_observer_mode()
loss = torch.abs(quant_tensor.value - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
Expand Down Expand Up @@ -670,7 +697,12 @@ def __init__(
self.hqo_init_op_zp = hqo_init_op_zp
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.stats_reduce_dim = stats_reduce_dim
self.local_loss_mode: bool = False
Expand All @@ -688,8 +720,10 @@ def parameter_search(self, xl, x):
for i in range(0, self.hqo_iters):
self.internal_candidate = candidate
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
self.restore_observer_mode()
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_zp = self.input_view_shape_impl(quant_tensor.zero_point)
Expand Down
11 changes: 6 additions & 5 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def disable_act_quantization(self, model, is_training):
if isinstance(module, ActQuantProxyFromInjectorBase):
module.train(is_training)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = True
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
Expand All @@ -228,9 +229,9 @@ def enable_act_quantization(self, model, is_training):
elif isinstance(module, ActQuantProxyFromInjectorBase):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
self.disable_act_quant_hooks = []
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = False

def enable_param_quantization(self, model, is_training):
for module in model.modules():
Expand Down
7 changes: 4 additions & 3 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_float_to_quant_float(inp, minifloat_format):
signed=signed,
float_clamp_impl=float_clamp)
expected_out, *_ = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
out_quant, scale = float_quant.quantize(inp, scale)
exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float)
out_quant, *_ = float_quant.float_clamp_impl(
out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias)
Expand Down Expand Up @@ -142,7 +142,8 @@ def test_scaling_impls_called_once(inp, minifloat_format):
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl,
float_clamp_impl=float_clamp)
_ = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
_ = float_quant.quantize(inp, scale)
# scaling implementations should be called exaclty once on the input
float_scaling_impl.assert_called_once_with(
torch.tensor(exponent_bit_width),
Expand Down

0 comments on commit 746d97e

Please sign in to comment.