-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
111 lines (88 loc) · 4.33 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import transformers
import json
import re
from tqdm import tqdm
with open('comparisonqa_benchmark/comparisonqa_test.json', 'r', encoding='utf-8') as f:
data = json.load(f)
import argparse
parser = argparse.ArgumentParser()
# model_name = 'google/gemma-2-9b'
# model_name = 'tiiuae/falcon-11B'
# model_name = 'mistralai/Mistral-7B-v0.3'
# model_name = 'mistralai/Mistral-7B-Instruct-v0.3'
# model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
# model_name = 'meta-llama/Llama-3.1-8B-Instruct'
# model_name = 'meta-llama/Llama-3.2-3B-Instruct'
parser.add_argument('--model_name', type=str, default='meta-llama/Meta-Llama-3-8B-Instruct')
parser.add_argument('--mode', type=str, default='zero') # or "few"
args = parser.parse_args()
# Check device status
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('CUDA available:', torch.cuda.is_available())
print(torch.cuda.get_device_name())
print('Device number:', torch.cuda.device_count())
print(torch.cuda.get_device_properties(device))
if torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = str(1)
torch.cuda.set_device(1)
name = args.model_name.split('/')[-1].replace('.', '').replace('-', '_')
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(args.model_name)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def generate_text(prompt):
inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)
length = len(inputs[0])
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=length+20,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
zero_prompt = """{} A. {}. B. {}. C. {}. D. {}
The correct answer is: """
few_prompt = """Answer the following multiple choice question. Select only one correct answer from the choices.
Follow these examples:
Which country does Cimarron Firearms belong to? A. Australia. B. Mexico. C. Canada. D. America.
Answer: **D.**
Which specific target does Belimumab inhibit or act against? A. B-cell activating factor (BAFF). B. Human Rhesus factor. C. Tumor Necrosis Factor (TNF). D. Programmed death-1 (PD-1).
Answer: **A.**
In which geographical regions can Leuconotopicus typically be found? A. Central and South America. B. North and South America. C. Only North America. D. Europe and Asia.
Answer: **B.**
What is the primary color range associated with Ochre? A. Blue to green. B. Yellow-brown to bright red. C. Yellow to deep orange or brown. D. Purple to black.
Answer: **C.**
{} A. {}. B. {}. C. {}. D. {}.
"""
new_data = []
print_count = 0
for line in tqdm(data):
id = line["question_id"]
for q in ["high_question", "low_question"]:
question = line[q]["question"]
if args.mode == "zero":
input = zero_prompt.format(question, line[q]["options"]["A"], line[q]["options"]["B"], line[q]["options"]["C"], line[q]["options"]["D"])
else:
input = few_prompt.format(question, line[q]["options"]["A"], line[q]["options"]["B"], line[q]["options"]["C"], line[q]["options"]["D"])
generated = generate_text(input)
input_length = len(input)
output_start_index = generated.find(input) + input_length
answer = generated[output_start_index:].strip()
line[q]["model_output"] = answer
new_data.append(line)
print_count += 1
if print_count % 1000 == 1:
with open(f'./experiments/longtailqa_test_output_{name}_{args.mode}.json', 'w', encoding='utf-8') as file:
json.dump(new_data, file, ensure_ascii=False, indent=4)
with open(f'./experiments/longtailqa_test_output_{name}_{args.mode}_copy.json', 'w', encoding='utf-8') as file:
json.dump(new_data, file, ensure_ascii=False, indent=4)
with open(f'./experiments/longtailqa_test_output_{name}_{args.mode}.json', 'w', encoding='utf-8') as file:
json.dump(new_data, file, ensure_ascii=False, indent=4)
with open(f'./experiments/longtailqa_test_output_{name}_{args.mode}_copy.json', 'w', encoding='utf-8') as file:
json.dump(new_data, file, ensure_ascii=False, indent=4)