-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval.py
158 lines (139 loc) · 4.85 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from dataset import BBNLI, BBQ, Arithmetic, RealToxicityPrompts, NewInfo, ARC
from config import (
BBNLITestInputsConfig,
BBNLIEditConfig,
BBQTestInputsConfig,
BBQEditConfig,
ArithmeticEditConfig,
ArithmeticTestInputsConfig,
RealToxEditConfig,
RealToxTestInputsConfig,
)
from eval_utils import eval_api, eval_hf
from util import (
bm25_retriever_generator,
get_edits_with_scope,
gpt3_retriever,
dpr_retriever_generator,
SCOPECLASSIFIERPATH,
ALLEDITSPATH,
)
import ipdb
import json
import os
import argparse
class_dict = {
"BBNLI": BBNLI,
"BBQ": BBQ,
"Arithmetic": Arithmetic,
"NewInfo": NewInfo,
"RealTox": RealToxicityPrompts,
"ARC": ARC,
}
edit_config_dict = {
"BBNLI": BBNLIEditConfig,
"BBQ": BBQEditConfig,
"Arithmetic": ArithmeticEditConfig,
"RealTox": RealToxEditConfig,
}
ti_config_dict = {
"BBNLI": BBNLITestInputsConfig,
"BBQ": BBQTestInputsConfig,
"Arithmetic": ArithmeticTestInputsConfig,
"RealTox": RealToxTestInputsConfig,
}
def create_retriever(args):
retriever_mechanism = args.retriever_mechanism
if retriever_mechanism == "bm25":
return bm25_retriever_generator(args.num_retrievals)
elif retriever_mechanism == "gpt3":
return gpt3_retriever
elif retriever_mechanism == "dpr":
return dpr_retriever_generator(args.faiss_index_path)
elif retriever_mechanism == "scope":
def fscp(ti, edit):
return "Information: " + edit.strip() + " Question: " + ti.strip()
def scope_retriever(edits_all, test_inputs):
return get_edits_with_scope(
SCOPECLASSIFIERPATH,
edits_all,
test_inputs,
fscp,
args.device,
args.scope_cache,
args.num_retrievals,
)
return scope_retriever
elif retriever_mechanism is None:
return None
else:
raise ValueError(f"Invalid: {retriever_mechanism} for retriever.")
def get_all_edits(args):
if args.with_edit:
with open(
ALLEDITSPATH,
"r",
) as f:
all_edits = [line.strip() for line in f.readlines()]
return all_edits
else:
return None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="google/flan-t5-small")
parser.add_argument("--dataset_name", type=str, default="BBQ")
parser.add_argument(
"--filename_queries",
type=str,
help="json file to read test queries from with 'test_inputs' and 'edits' keys",
default="bbq_questions_answers_ti8_out",
)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--with_edit", action="store_true")
parser.add_argument("--no_edit", dest="with_edit", action="store_false")
parser.add_argument("--retriever_mechanism", type=str, default=None)
parser.add_argument("--max_new_tokens", type=int, default=20)
parser.add_argument("--from_flax", action="store_true", help="If using Flax model")
parser.add_argument("--llama", action="store_true", help="If using Flax model")
parser.add_argument("--peft", action="store_true", help="If finetuning with PEFT")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--scope_cache", type=str, default=None)
parser.add_argument("--generations_cache", type=str, default=None)
parser.add_argument("--api", type=str, default=None)
parser.add_argument("--num_retrievals", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--chat_prompt_dict_path", type=str, default=None)
parser.add_argument("--faiss_index_path", type=str, default=None)
args = parser.parse_args()
if "gpt" in args.model_name:
eval_func = eval_api
args.api = "openai"
elif "flan" in args.model_name or "llama" in args.model_name:
eval_func = eval_hf
elif "bard" in args.model_name:
eval_func = eval_api
args.api = "bard"
else:
raise ValueError(f"{args.model_name} not recognized")
# Load dataset.
cc = class_dict[args.dataset_name]()
cc.load_test_inputs(args.filename_queries, flattened=True)
outputs = eval_func(
cc,
args,
model_name=args.model_name,
with_edit=args.with_edit,
retriever=create_retriever(args),
edits_all=get_all_edits(args),
max_new_tokens=args.max_new_tokens,
device=args.device,
)
# Save.
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(os.path.join(args.output_dir, "outputs.json"), "w") as f:
print("Saving to ", os.path.join(args.output_dir, "outputs.json"))
outputs["args"] = vars(args)
json.dump(outputs, f)
if __name__ == "__main__":
main()