Skip to content

Commit

Permalink
group size speedups + fixes (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored May 9, 2024
1 parent 964276d commit 611f152
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
5 changes: 2 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ def fake_quantize(
for i in range(ceil(columns / group_size)):
# scale.shape should be [nchan, ndim]
# sc.shape should be [nchan, 1] after unsqueeze

sc = scale[:, i].unsqueeze(1)
zp = zero_point[:, i].unsqueeze(1)
sc = scale[:, i].view(-1, 1)
zp = zero_point[:, i].view(-1, 1)

idx = i * group_size
Q = quantize(x[:, idx : (idx + group_size)], sc, zp, args)
Expand Down
6 changes: 3 additions & 3 deletions src/compressed_tensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, quantization_args: QuantizationArgs):
self._scale = None
self._zero_point = None

@torch.no_grad()
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
"""
maps directly to get_qparams
Expand Down Expand Up @@ -91,9 +92,8 @@ def get_qparams(
)
scales.append(scale)
zero_points.append(zero_point)

self._scale = torch.stack(scales, dim=1)
self._zero_point = torch.stack(zero_points, dim=1)
self._scale = torch.stack(scales, dim=1, out=self._scale)
self._zero_point = torch.stack(zero_points, dim=1, out=self._zero_point)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/quantization/observers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def calculate_qparams(
bit_min = -(bit_range + 1) / 2
bit_max = bit_min + bit_range
if quantization_args.symmetric:
zero_points = torch.tensor(0, device=device).to(torch.int8)
max_val_pos = torch.max(-min_vals, max_vals)
scales = max_val_pos / (float(bit_range) / 2)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
else:
scales = (max_vals - min_vals) / float(bit_range)
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
Expand Down
27 changes: 21 additions & 6 deletions src/compressed_tensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
from typing import Optional, Tuple

import torch
from compressed_tensors.quantization.observers.base import Observer
Expand All @@ -36,22 +36,33 @@ def __init__(
):
super().__init__(quantization_args=quantization_args)

self.min_val = float("inf")
self.max_val = -float("inf")
self.min_val = None
self.max_val = None
self.averaging_constant = averaging_constant

def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
averaging_constant
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""

min_val, max_val = torch.aminmax(observed)
if not reduce_dims:
min_val, max_val = torch.aminmax(observed)
else:
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

if self.min_val == float("inf") and self.max_val == float("-inf"):
if self.min_val is None and self.max_val is None:
self.min_val = min_val
self.max_val = max_val
else:
Expand All @@ -63,3 +74,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
)

return calculate_qparams(self.min_val, self.max_val, self.quantization_args)

def get_qparams_along_dim(self, observed, dim: int):
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(observed, reduce_dims=reduce_dims)

0 comments on commit 611f152

Please sign in to comment.