Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to do Sparsity + Finetuning #1028

Open
Thunderbeee opened this issue Jan 2, 2025 · 4 comments
Open

How to do Sparsity + Finetuning #1028

Thunderbeee opened this issue Jan 2, 2025 · 4 comments
Assignees
Labels
help wanted Extra attention is needed

Comments

@Thunderbeee
Copy link

hi the team, thanks so much for your works!

We found the performance was not good after simply pruning, so we want to do finetuing after pruning (as this example https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_2of4_sparse_w4a16 but without quantization).

Could you please give us an example of achieve Sparsity + Finetuning, and load the final model with vLLM or Huggingface? Thanks!

@Thunderbeee Thunderbeee added the bug Something isn't working label Jan 2, 2025
@dsikka
Copy link
Collaborator

dsikka commented Jan 4, 2025

Hi @Thunderbeee,

The recipe in the example you listed can be updated to only apply sparsity + finetune:

sparsity_stage:
  run_type: oneshot
  sparsity_modifiers:
    SparseGPTModifier:
      sparsity: 0.5
      mask_structure: "2:4"
      sequential_update: false
finetuning_stage:
  run_type: train
  finetuning_modifiers:
    ConstantPruningModifier:
      targets: [
        're:.*q_proj.weight',
        're:.*k_proj.weight', 
        're:.*v_proj.weight',
        're:.*o_proj.weight',
        're:.*gate_proj.weight',
        're:.*up_proj.weight',
        're:.*down_proj.weight',
      ]
      start: 0

You can then run the same apply(...) call as in the example, while adding one additional argument, save_compressed=False, e.g.:

apply(
    model=model,
    dataset=dataset,
    recipe=recipe,
    bf16=bf16,
    output_dir=output_dir,
    splits=splits,
    max_seq_length=max_seq_length,
    num_calibration_samples=num_calibration_samples,
    num_train_epochs=num_train_epochs,
    logging_steps=logging_steps,
    save_steps=save_steps,
    gradient_checkpointing=gradient_checkpointing,
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    save_compressed=False
)

The model produced should then generate coherent results.
The following model was generated using the above code and can then run using transformers:

MODEL_ID = 'llm-compressor/examples/quantization_2of4_sparse_w4a16/output_llama7b_2of4_w4a16_channel/stage_finetuning'

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

input_ids = tokenizer("The Toronto Raptors are", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=30)
print(tokenizer.decode(output[0]))

Output:

<s> The Toronto Raptors are a professional basketball team based in Toronto, Canada.</s>

@dsikka dsikka self-assigned this Jan 4, 2025
@Thunderbeee
Copy link
Author

Thunderbeee commented Jan 5, 2025

Hi @dsikka Thank you for your example; it has been incredibly helpful! I have a few conceptual questions that I’d like to ask to better understand the process of LLM-compressor:

  1. We noticed that if we don’t add save_compressed=False, a checkpoint_321 appears in the final fine-tune stage. This checkpoint can also be loaded using AutoModelForCausalLM. Could you explain what this checkpoint_321 is doing?

  2. Could you elaborate on the operation of save_compressed=False? How is it related to stage_sparsity, stage_finetuning, and stage_quantization? Additionally, what is its relationship with extra parameters like the bitmask?

  3. The code you provided works well for loading models using AutoModelForCausalLM, but why doesn’t it work with vllm? (We attempted it but encountered an error.)

  4. Is it possible to use separate datasets for the fine-tuning stage and the sparsity calibration stage?

  5. For example, if we want a specific dataset (such as MMLU) to serve as the fine-tuning dataset, how should we configure it?

Thank you very much for your explanation—it’s been incredibly helpful!

@dsikka
Copy link
Collaborator

dsikka commented Jan 5, 2025

@Thunderbeee

  1. This is just a checkpoint of the model after a certain number of training steps. It is not the final finetuned model
  2. save_compressed=False will apply the dense sparsity compressor, when saving the final model in the compressed-tensors format. This compressor is compatible with vLLM. You can see details on the compressor here: https://github.com/neuralmagic/compressed-tensors/blob/fe4a4427f236ae24a7bf9d6acad81af056069be9/src/compressed_tensors/compressors/sparse_compressors/dense.py#L23
  3. Can you share the error you got in vLLM? I was able to load it in vLLM as well.
  4. To use two different datasets, I would suggest first running oneshot() to apply 2:4 sparsity and then train() with the specified dataset

@Thunderbeee
Copy link
Author

Thanks so much for your response! Let us explore it first, and let you know if we have any further questions. Thank you again!

@kylesayrs kylesayrs added help wanted Extra attention is needed and removed bug Something isn't working labels Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants