-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to do Sparsity + Finetuning #1028
Comments
Hi @Thunderbeee, The recipe in the example you listed can be updated to only apply sparsity + finetune: sparsity_stage:
run_type: oneshot
sparsity_modifiers:
SparseGPTModifier:
sparsity: 0.5
mask_structure: "2:4"
sequential_update: false
finetuning_stage:
run_type: train
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0 You can then run the same apply(...) call as in the example, while adding one additional argument, apply(
model=model,
dataset=dataset,
recipe=recipe,
bf16=bf16,
output_dir=output_dir,
splits=splits,
max_seq_length=max_seq_length,
num_calibration_samples=num_calibration_samples,
num_train_epochs=num_train_epochs,
logging_steps=logging_steps,
save_steps=save_steps,
gradient_checkpointing=gradient_checkpointing,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
save_compressed=False
) The model produced should then generate coherent results. MODEL_ID = 'llm-compressor/examples/quantization_2of4_sparse_w4a16/output_llama7b_2of4_w4a16_channel/stage_finetuning'
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
input_ids = tokenizer("The Toronto Raptors are", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=30)
print(tokenizer.decode(output[0])) Output: <s> The Toronto Raptors are a professional basketball team based in Toronto, Canada.</s> |
Hi @dsikka Thank you for your example; it has been incredibly helpful! I have a few conceptual questions that I’d like to ask to better understand the process of LLM-compressor:
Thank you very much for your explanation—it’s been incredibly helpful! |
|
Thanks so much for your response! Let us explore it first, and let you know if we have any further questions. Thank you again! |
hi the team, thanks so much for your works!
We found the performance was not good after simply pruning, so we want to do finetuing after pruning (as this example https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_2of4_sparse_w4a16 but without quantization).
Could you please give us an example of achieve Sparsity + Finetuning, and load the final model with vLLM or Huggingface? Thanks!
The text was updated successfully, but these errors were encountered: