-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate-after-lora.py
56 lines (48 loc) · 1.44 KB
/
generate-after-lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from transformers import AutoTokenizer, GenerationConfig
from llmtune.llms.autollm import AutoLLMForCausalLM
from llmtune.utils import to_half_precision
from llmtune.engine.lora.peft import quant_peft
# model config
model_name = ''
# model_name = './llama-7b-quantized' # can generate local dir via quantize.py
tokenizer_name = 'huggyllama/llama-7b'
DEV = 'cuda'
# load model
llm = AutoLLMForCausalLM.from_pretrained(model_name).to(DEV)
llm.eval()
llm = to_half_precision(llm)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# load lora from existing checkpoint
adapter_path = './llama-7b-quantized-lora' # can generate this via finetune.py
model = quant_peft.PeftModel.from_pretrained(
llm, adapter_path,
device_map='auto'
)
print(adapter_path, 'loaded')
# encode prompt
prompt = 'Write a detailed step-by-step recipe for a blueberry lasagna dish'
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEV)
# generation config
min_length=10
max_length=200
top_p=.95
top_k=25
temperature=1.0
# generate text
with torch.no_grad():
generated_ids = model.generate(
inputs=input_ids,
generation_config=GenerationConfig(
do_sample=True,
min_length=min_length,
max_length=max_length,
top_p=top_p,
top_k=top_k,
temperature=temperature,
)
)
# decode and print
output = tokenizer.decode([el.item() for el in generated_ids[0]])
print(output)