diff --git a/examples/llama7b_sparse_quantized/README.md b/examples/llama7b_sparse_quantized/README.md index f10bb0984a..1a48c01afc 100644 --- a/examples/llama7b_sparse_quantized/README.md +++ b/examples/llama7b_sparse_quantized/README.md @@ -73,10 +73,12 @@ run the following in the same Python instance as the previous steps. ```python import torch +import os from sparseml.transformers import SparseAutoModelForCausalLM compressed_output_dir = "output_llama7b_2:4_w4a16_channel_compressed" -model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16) +uncompressed_path = os.path.join(output_dir, "stage_quantization") +model = SparseAutoModelForCausalLM.from_pretrained(uncompressed_path, torch_dtype=torch.bfloat16) model.save_pretrained(compressed_output_dir, save_compressed=True) ```