Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPTQ UX] Add scheme arg with QuantizationScheme support #2286

Merged
merged 7 commits into from
May 24, 2024

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented May 15, 2024

This PR adds support for a scheme arg in GPTQ, this arg can be set to a single QuantizationScheme object

recipe:

test_stage:
  obcq_modifiers:
    GPTQModifier:
        ignore: ["LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation", "MatMulLeftInput_QK", "MatMulRightInput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV", "MatMulOutput_QK", "MatMulOutput_PV", "lm_head", "Embedding"]
        sequential_update: True
        dampening_frac: 0.001
        block_size: 128
        targets: ["Linear"]
        scheme:
          input_activations: null
          output_activations: null
          weights:
              num_bits: 8
              type: "int"
              symmetric: true
              strategy: "tensor"
              group_size: 128

test script:

from pathlib import Path
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot
import argparse
from datetime import datetime

tinyllama_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tiny_random_llama_stub = "HuggingFaceH4/tiny-random-LlamaForCausalLM"

parser = argparse.ArgumentParser(description="Get Quant Model")
parser.add_argument('--recipe', default="/root/projects/sparseml/local/feature/recipe.yaml", help='Path to the recipe')
parser.add_argument('--model_stub', default=tinyllama_stub, help='Model stub')
parser.add_argument('--dataset', default="open_platypus", help='Dataset name')
parser.add_argument('--max_seq_length', type=int, default=512, help='Maximum sequence length')
parser.add_argument('--output_dir', default=None, help='Output directory')
parser.add_argument('--num_calibration_samples', type=int, default=512, help='Number of calibration samples')
parser.add_argument('--overwrite_output_dir', action='store_true', help='Overwrite output directory')
parser.add_argument('--small', action='store_true', help='Use a small model')
args = parser.parse_args()

def get_save_dir_name(model_stub):
        dir_name = f"{model_stub.split('/')[-1]}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
        return str(Path("output") / dir_name)

recipe = args.recipe
model_stub = tiny_random_llama_stub if args.small else args.model_stub 
dataset = args.dataset
max_seq_length = args.max_seq_length
output_dir = args.output_dir or get_save_dir_name(model_stub)
num_calibration_samples = args.num_calibration_samples
device = "cuda"

oneshot(
        model=model_stub,
        dataset=dataset,
        overwrite_output_dir=True,
        output_dir=output_dir,
        max_seq_length=max_seq_length,
        num_calibration_samples=num_calibration_samples,
        recipe=recipe,
        oneshot_device=device,
)


# try reloading the model

model_new = SparseAutoModelForCausalLM.from_pretrained(output_dir)
print("Model reloaded successfully!")

test command:

python get_quant_model.py --small --recipe ./gptq_ux/recipes/recipe_scheme_quant_scheme.yaml

Output:

Calculating quantization compression ratio: 25it [00:00, 697.59it/s]
2024-05-15 13:26:18 sparseml.pytorch.model_load.helpers INFO     Saving output to /root/projects/sparseml/output/tiny-random-LlamaForCausalLM_2024-05-15-13-26-00
Decompressing model: 0it [00:00, ?it/s]
Model reloaded successfully!

Copy link
Contributor

@bfineran bfineran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rahul-tuli the target UX here doesn't take a full scheme instead the goal is to have targets and scheme separate (with targets added later to the scheme)
i.e.

      GPTQModifier:
          ignore: ["LlamaRotaryEmbedding", "LlamaRMSNorm", "SiLUActivation", "MatMulLeftInput_QK", "MatMulRightInput_QK", "MatMulLeftInput_PV", "MatMulRightInput_PV", "MatMulOutput_QK", "MatMulOutput_PV", "lm_head", "Embedding"]
          sequential_update: True
          dampening_frac: 0.001
          block_size: 128
          targets: ["Linear"]
          scheme:
            input_activations: null
            output_activations: null
            weights:
                num_bits: 8
                type: "int"
                symmetric: true
                strategy: "tensor"
                group_size: 128

Base automatically changed from preserve-mask-structure-test to gptq-ux-config-groups May 17, 2024 16:16
Base automatically changed from gptq-ux-config-groups to quant-modifier-ux May 20, 2024 19:07
Base automatically changed from quant-modifier-ux to main May 22, 2024 18:48
@rahul-tuli rahul-tuli changed the base branch from main to install-compressed-tensors-from-source May 23, 2024 15:10
bfineran
bfineran previously approved these changes May 23, 2024
Satrat
Satrat previously approved these changes May 23, 2024
Base automatically changed from install-compressed-tensors-from-source to main May 24, 2024 13:57
@rahul-tuli rahul-tuli dismissed stale reviews from Satrat and bfineran May 24, 2024 13:57

The base branch was changed.

@rahul-tuli rahul-tuli merged commit 7bb3db3 into main May 24, 2024
15 of 17 checks passed
@rahul-tuli rahul-tuli deleted the add-scheme-support branch May 24, 2024 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants