From cca53b0bfdf1e09e27571a6b253ccec387c3cecb Mon Sep 17 00:00:00 2001 From: Liangliang-Ma Date: Wed, 5 Jun 2024 23:08:29 +0800 Subject: [PATCH] Fix cuda hardcode for inference woq (#5565) This is a simple fix for inference woq part, changing from `'cuda'` to `get_accelerator().device_name()`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/inference/quantization/utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deepspeed/inference/quantization/utils.py b/deepspeed/inference/quantization/utils.py index 712abc384a44..a5e8f28bdec9 100644 --- a/deepspeed/inference/quantization/utils.py +++ b/deepspeed/inference/quantization/utils.py @@ -14,14 +14,14 @@ device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu' -quantizer_cuda_module = None +quantizer_module = None -def get_quantizer_cuda_module(): - global quantizer_cuda_module - if quantizer_cuda_module is None: - quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load() - return quantizer_cuda_module +def get_quantizer_module(): + global quantizer_module + if quantizer_module is None: + quantizer_module = deepspeed.ops.op_builder.QuantizerBuilder().load() + return quantizer_module def tensor_clamp(tensor: Tensor, min, max) -> Tensor: @@ -107,19 +107,19 @@ def dequantize(self, tensor: Tensor, quant_scale: Tensor, quant_min: Tensor) -> if self.config['group_size'] % 8 == 0 and \ (self.config['num_bits'] == 4 or self.config['num_bits'] == 8) and \ self.config['group_dim'] == len(tensor.shape) - 1 and \ - self.dtype == torch.float16 and device == 'cuda': + self.dtype == torch.float16 and device == get_accelerator().device_name(): last_dimension_size = self.config['group_size'] if self.config['num_bits'] == 4: last_dimension_size = last_dimension_size // 2 - quantized_tensor = get_quantizer_cuda_module().dequantize_int4_to_half_experimental( + quantized_tensor = get_quantizer_module().dequantize_int4_to_half_experimental( tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, tensor.numel() // last_dimension_size, self.config['group_size']) shape = list(tensor.shape) shape[-1] = shape[-1] * 2 elif self.config['num_bits'] == 8: # last_dimension_size = last_dimension_size // 2 - quantized_tensor = get_quantizer_cuda_module().dequantize_int8_to_half_experimental( + quantized_tensor = get_quantizer_module().dequantize_int8_to_half_experimental( tensor.reshape(-1, last_dimension_size), quant_scale, quant_min, tensor.numel() // last_dimension_size, self.config['group_size']) shape = list(tensor.shape)