forked from kjhayes/wildhacks-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cite_transcript_model.py
30 lines (27 loc) · 1.03 KB
/
cite_transcript_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import re
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
model_name = "deepset/roberta-base-squad2"
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def cite_transcript(context,question):
# find answer to question in context
# context: string (transcript)
# question: string (question)
# returns: string (answer)
starts=[0]+[m.start()+1 for m in re.finditer('\.', context)]+[m.start()+1 for m in re.finditer('[\r\n]+', context)]
QA_input = {
'question': question,
'context': context}
out=nlp(QA_input)
score=out["score"]
start=out["start"]
end=out["end"]
ans=out["answer"]
print(question)
print(out)
if(score<=0.1):
return
ss=max([0]+[x for x in starts if x<=start])
ee=min([len(context)]+[x for x in starts if x>=end])
return (context[ss:start],ans,context[end:ee])