Skip to content

Commit

Permalink
Fix cuda hardcode for inference woq (#5565)
Browse files Browse the repository at this point in the history
This is a simple fix for inference woq part, changing from `'cuda'` to
`get_accelerator().device_name()`.

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
Liangliang-Ma and loadams authored Jun 5, 2024
1 parent f4cb866 commit af4356b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions deepspeed/inference/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu'

quantizer_cuda_module = None
quantizer_module = None


def get_quantizer_cuda_module():
global quantizer_cuda_module
if quantizer_cuda_module is None:
quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
return quantizer_cuda_module
def get_quantizer_module():
global quantizer_module
if quantizer_module is None:
quantizer_module = deepspeed.ops.op_builder.QuantizerBuilder().load()
return quantizer_module


def tensor_clamp(tensor: Tensor, min, max) -> Tensor:
Expand Down Expand Up @@ -107,19 +107,19 @@ def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) ->
if self.config['group_size'] % 8 == 0 and \
(self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \
self.config['group_dim'] == len(tensor.shape) - 1 and \
self.dtype == torch.float16 and device == 'cuda':
self.dtype == torch.float16 and device == get_accelerator().device_name():

last_dimension_size = self.config['group_size']
if self.config['num_bits'] == 4:
last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental(
quantized_tensor = get_quantizer_module().dequantize_int4_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)
shape[-1] = shape[-1] * 2
elif self.config['num_bits'] == 8:
# last_dimension_size = last_dimension_size // 2
quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental(
quantized_tensor = get_quantizer_module().dequantize_int8_to_half_experimental(
tensor.reshape(-1, last_dimension_size), quant_scale, quant_min,
tensor.numel() // last_dimension_size, self.config['group_size'])
shape = list(tensor.shape)
Expand Down

0 comments on commit af4356b

Please sign in to comment.