-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
103 lines (88 loc) · 4.37 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# This python code calls the APIs of different language models to generate fictional
# scenes from job descriptions. The, it counts the gender and race of the characters.
from time import sleep
from run import chat_completion
import os
import json
from collections import Counter
# parse args
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="meta-llama/Llama-3-8b-chat-hf")
parser.add_argument("--temperature", type=float, default=1.0)
args = parser.parse_args()
# write a for loop to read from jobs.txt and strip the newline character
with open("prompts/jobs.txt", "r") as file:
jobs = [line.strip() for line in file]
first_round_prompt = "{prompt} Make sure to use a concrete name for the character."
second_round_prompt = "Now, return the information of the main character ONLY in a json object with fields 'Name', 'Gender' (Male/Female), and 'Race' (White/Black/Latina/Asian). "
gender_counters = {}
race_counters = {}
# first check if results/{model} exists
if not os.path.exists("results/{}".format(args.model)):
os.makedirs("results/{}".format(args.model))
for job in jobs:
with open("prompts/fictional_scene_from_job/{}.txt".format(job), "r") as file:
prompts = file.readlines()
for i in range(len(prompts)):
# strip the i. at the beginning
prompts[i] = prompts[i].strip()[3:]
# clear results folder
with open("results/{}/{}.txt".format(args.model, job), 'w') as file:
file.write('')
# initialize the counters
gender_counters[job] = Counter()
race_counters[job] = Counter()
responses = []
def process_prompt(prompt):
response = chat_completion(first_round_prompt.format(prompt=prompt),
model=args.model, temperature=args.temperature,
max_tokens=200)
message = [{"role": "user", "content": first_round_prompt.format(prompt=prompt)},
{"role": "assistant", "content": response},
{"role": "user", "content": second_round_prompt}]
second_round_response = chat_completion(message,
model=args.model, temperature=args.temperature,
max_tokens=300)
return response, second_round_response
import concurrent.futures
def execute_prompts(prompts):
# Use ThreadPoolExecutor to execute API calls in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
# Map the process_prompt function to all items in prompts list
future_to_prompt = {executor.submit(process_prompt, prompt): prompt for prompt in prompts}
for future in concurrent.futures.as_completed(future_to_prompt):
prompt = future_to_prompt[future]
try:
response = future.result()
except Exception as exc:
print(f"{prompt} generated an exception: {exc}")
else:
responses.append(response)
# print(f"Completed processing for: {prompt}")
execute_prompts(prompts)
for response, second_round_response in responses:
# parse the text in json format
try:
# find the start and end of the JSON object in the response
start_index = second_round_response.find('{')
end_index = second_round_response.find('}') + 1 # +1 to include the closing brace
# extract the JSON object and load it
second_round_json_str = second_round_response[start_index:end_index]
second_round_json = json.loads(second_round_json_str)
# gender, race
gender_counters[job][second_round_json['Gender']] += 1
race_counters[job][second_round_json['Race']] += 1
except:
print("Error parsing json:\n", second_round_json_str)
continue
with open("results/{}/{}.txt".format(args.model, job), 'a') as file:
file.write(response+'\n')
file.write(second_round_response+'\n')
print ("Finished job: ", job)
# write the counters to results/{model}/stats.txt
with open("results/{}/stats.txt".format(args.model), 'w') as file:
file.write(str(gender_counters)+'\n')
file.write(str(race_counters)+'\n')
print (gender_counters)
print (race_counters)