diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py new file mode 100644 index 0000000000000..e87a1783a83f1 --- /dev/null +++ b/tests/models/test_fp8.py @@ -0,0 +1,90 @@ +# flake8: noqa +"""Tests fp8 models against ground truth generation +Note: these tests will only pass on L4 GPU. +""" +import os + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-FP8", + "meta-llama/Meta-Llama-3-8B-Instruct", +] + +EXPECTED_STRS_MAP = { + "nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + ], + "meta-llama/Meta-Llama-3-8B-Instruct": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], +} + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) + + +@pytest.mark.skipif(fp8_not_supported, + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models( + example_prompts, + model_name, +) -> None: + model = LLM(model=model_name, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + quantization="fp8") + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + + params = SamplingParams(max_tokens=20, temperature=0) + generations = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(generations) + expected_strs = EXPECTED_STRS_MAP[model_name] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c3faa01fc38e6..8e84c8a86ece6 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -248,6 +248,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -256,6 +260,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -325,7 +335,12 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -339,14 +354,13 @@ def weight_loader(self, current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -361,15 +375,14 @@ def weight_loader(self, if output_dim is not None: shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -378,11 +391,17 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -477,7 +496,11 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) if loaded_shard_id is None: # Loaded weight is already packed. @@ -495,14 +518,14 @@ def weight_loader(self, ] packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -524,6 +547,7 @@ def weight_loader(self, shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) @@ -531,8 +555,7 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -545,12 +568,17 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -642,6 +670,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data @@ -650,6 +682,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ba9f3149649c1..b57e1dde81a5f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,23 +1,36 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + class Fp8Config(QuantizationConfig): """Config class for FP8.""" def __init__( self, + is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme @classmethod @@ -30,10 +43,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - # TODO: PyTorch 2.3.0+ is required to run FP8 on - # SM 89 (e.g. Ada) GPUs. Specifically, this PR has to - # be included: https://github.com/pytorch/pytorch/pull/118881 - return 90 + return 89 @classmethod def get_config_filenames(cls) -> List[str]: @@ -41,11 +51,14 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - return cls(activation_scheme) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: if isinstance(layer, LinearBase): return Fp8LinearMethod(self) return None @@ -56,8 +69,12 @@ def get_scaled_act_names(self) -> List[str]: class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. - We now support common FP16/BF16 model checkpoints ONLY. The weight - scaling factor will be initialized after the model weights are loaded. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. @@ -71,6 +88,24 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, { + **extra_weight_attrs, + "fp8_scales_shard_indexer": + self.scales_shard_indexer, + }) + def create_weights( self, layer: torch.nn.Module, @@ -81,46 +116,150 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + del input_size, output_size output_size_per_partition = sum(output_partition_sizes) + + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - dtype=params_dtype), + dtype=weight_dtype), requires_grad=False) layer.register_parameter("weight", weight) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - set_weight_attrs(weight, extra_weight_attrs) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) - w_scale = Parameter( - torch.empty(1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("weight_scaling_factor", w_scale) + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + self._create_scale_param( + scale_name="weight_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + # ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + self._create_scale_param( + scale_name="act_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + def scales_shard_indexer( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, int): + pass + elif isinstance(shard_id, str): + if shard_id not in qkv_idxs: + raise ValueError(f"Unknown shard_id: {shard_id}") + shard_id = qkv_idxs[shard_id] + else: + ValueError(f"Shard id must be int or str but got {type(shard_id)}") + + return param[shard_id], loaded_weight def process_weights_after_loading(self, layer: Module) -> None: - # Although the quant_method is propagated to all layers, - # only linear layers invoke "create_weights". So we check - # whether "weight_scaling_facor" is registered to determine - # whether the layer is a linear layer that requires quantization. - if not hasattr(layer, "weight_scaling_factor"): + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.logical_widths = None + layer.act_scale = None return - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight) - # torch._scaled_mm requires column-major in the second - # input (weight), so we transpose the quantized weight. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scaling_factor.data.copy_(weight_scale) + # If checkpoint is fp8, requantize the separately quantized logical + # weights into a single fp8 weight with a single weight scale. + else: + # WEIGHT_SCALE / WEIGHT + # Loop over logical weights, requantizing with single scale. + max_w_scale = layer.weight_scale.max() + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(layer.weight[start:end, :], + layer.weight_scale[idx]) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max()) + start = end + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # ACT_SCALE + # Dynamic: set to None (required input to ops.scaled_fp8_quant). + # Static: set to max of the act_scales (since they are equal). + if self.quant_config.activation_scheme == "dynamic": + layer.act_scale = None + elif self.quant_config.activation_scheme == "static": + if not all_close_1d(layer.act_scale): + raise ValueError( + "All the act_scales for the logical weights of a layer " + f"must be equal. But got {layer.act_scale}") + layer.act_scale = Parameter(layer.act_scale.max(), + requires_grad=False) + else: + raise ValueError( + f"Unknown scheme {self.quant_config.activation_scheme}") def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qinput, x_scale = ops.scaled_fp8_quant(x) + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.act_scale is None and x_scale computed from x. + # If static, layer.act_scale is scalar and x_scale set to act_scale. + qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + + # Fused GEMM_DQ output, _ = torch._scaled_mm( qinput, layer.weight, out_dtype=x.dtype, scale_a=x_scale, - scale_b=layer.weight_scaling_factor, + scale_b=layer.weight_scale, bias=bias, ) + return output + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def per_tensor_quantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + +def per_tensor_dequantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight