-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmistral_prob_given_temperature.py
152 lines (132 loc) · 4.59 KB
/
mistral_prob_given_temperature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import lidCall
import miamiCorpusLID
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
# model_name = '/scratch/gpfs/ca2992/robertuito-base-cased'
model_name = '/scratch/gpfs/ca2992/codeswitch-spaeng-lid-lince'
tokenizer_name = '/scratch/gpfs/ca2992/codeswitch-spaeng-lid-lince'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
lid_model = pipeline('ner', model=model, tokenizer=tokenizer)
out_dir = 'mistral_lid_ratios'
pos_model_name = '/scratch/gpfs/ca2992/codeswitch-spaeng-pos-lince'
pos_model_import = AutoModelForTokenClassification.from_pretrained(pos_model_name)
pos_model = pipeline('ner', model=pos_model_import, tokenizer = tokenizer)
dir = '/scratch/gpfs/ca2992/jpLLM/jpLLM_Data/'
files = ['out_t_0_indiv.tsv',
'out_t_1_indiv.tsv',
'out_t_2_indiv.tsv',
'out_t_3_indiv.tsv',
'out_t_4_indiv.tsv',]
# given a token with the '#' symbol,
# remove the symbol for preprocessing
def cleanPoundSign(word):
tempTok = ""
for i in range(len(word)):
if (word[i] != '#'):
tempTok = tempTok + word[i]
return tempTok
lid_pred = []
# convert token predictions to word predictions
def tokenToWordPred(message, trueWords):
lidResult = lid_model(message)
posResult = pos_model(message)
index = 0
for word in trueWords:
lidToken = lidResult[index].get('word')
# get the lid predicted for this token and append
# to the lid word level predictions
lid = lidResult[index].get('entity')
lid_pred.append([lid])
# if token word mismatch imlidsible to handle
if (word != lidToken and word[0] != lidToken[0]):
print("MISMATCH", word, lidToken)
continue
while (word != lidToken and word[0] == lidToken[0]):
index += 1
lidToken = lidToken + lidResult[index].get('word')
# get rid of # symbols added by tokenizer
lidToken = cleanPoundSign(lidToken)
index += 1
def cleanInstruct(text):
instruction = True
response = False
almost = False
message = ""
for char in text:
if char == ']' and almost == False:
almost = True
elif char == ']' and almost == True:
almost = False
response = True
instruction = False
elif response == True:
message = message + char
assert instruction != response
return message
spanish_count = 0
english_count = 0
switch_verb = 0
switch_noun = 0
switch_conj = 0
switch_count = 0
verb_count = 0
noun_count = 0
conj_count = 0
count = 0
fileNum = 0
out_dir = 'lang_lid_ratio_cond_0_0'
with open(dir + '/byTemp/' + out_dir, "a") as o:
print("success")
for file in files:
if (fileNum != 0):
fileNum += 1
continue
fileNum+=1
with open(dir + file, "r+") as f:
message = ""
for line in f:
if (line[0] == '['):
message = message + cleanInstruct(line)
else:
message += line
lid_results = lid_model(message)
pos_results = pos_model(message)
for i in range(len(lid_results)):
count += 1
lid = lid_results[i].get('entity')
pos = pos_results[i].get('entity')
if (i == 0):
last_lid = lid
last_pos = pos
# detect code-switching switch
if (pos == 'VERB'):
verb_count += 1
if (last_lid != lid):
switch_verb += 1
switch_count +=1
elif (pos == 'NOUN'):
noun_count += 1
if (last_lid != lid):
switch_noun += 1
switch_count+=1
elif (pos == 'CONJ'):
conj_count +=1
if (last_lid != lid):
switch_conj += 1
switch_count+=1
elif (last_lid != lid):
switch_count+=1
if (lid == 'spa'):
spanish_count +=1
if (lid == 'en'):
english_count += 1
last_lid = lid
last_pos = pos
with open(dir + '/byTemp/' + out_dir, "a") as o:
print(file, file = o)
print(spanish_count, "Spanish Count", file = o)
print(english_count, "English Count", file = o)
print(switch_count, switch_noun, switch_conj, switch_verb, file = o)
print(noun_count,conj_count, verb_count, file = o)
fileNum+=1