Skip to content

Commit

Permalink
edited get_transcripts_from_utterances to fix type checker issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivitsa Shankar authored and Vivitsa Shankar committed Jan 4, 2024
1 parent c9a82c8 commit f8147ae
Showing 1 changed file with 50 additions and 35 deletions.
85 changes: 50 additions & 35 deletions dialogy/plugins/text/dob_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
import re
import traceback
from typing import Any, List, Optional, Union, Dict
from typing import Any, Callable, List
from dialogy.types import Utterance

from typing import Any, Callable, Dict, List, Optional, Union

import pandas as pd
from loguru import logger
from tqdm import tqdm
from word2number import w2n

import dialogy.constants as const
from dialogy.base import Guard, Input, Output, Plugin
from dialogy.types import Utterance
from dialogy.utils import normalize
from word2number import w2n
import re


def clean_string(input_string: str) -> str:
# Remove all instances of "."
cleaned_string = input_string.replace(".", "")
Expand All @@ -23,9 +23,10 @@ def clean_string(input_string: str) -> str:
cleaned_string = cleaned_string.replace("for", "4")
return cleaned_string

def _class_6(transcript: str) ->str:

def _class_6(transcript: str) -> str:
"""
input: transcript that looks like- " X Y" where X is a string of numbers with or without space. Y is string- which is a
input: transcript that looks like- " X Y" where X is a string of numbers with or without space. Y is string- which is a
number written in words.
description: let X be as is. Only if X exist, look for Y. then convert Y into numbers- numeric digits
output: X (as is) Y(converted into numbers)
Expand All @@ -47,7 +48,7 @@ def _class_6(transcript: str) ->str:
"""

# Regex pattern to capture numeric part and remaining words separately
pattern = re.compile(r'(\b\d+\s*\d*\b|\b\d+\b)(.*)', re.IGNORECASE)
pattern = re.compile(r"(\b\d+\s*\d*\b|\b\d+\b)(.*)", re.IGNORECASE)

match = re.search(pattern, transcript)

Expand All @@ -56,23 +57,25 @@ def _class_6(transcript: str) ->str:
remaining_words = match.group(2) or ""
try:
words_converted_to_number = str(w2n.word_to_num(remaining_words))
if len(words_converted_to_number)<=2:
# Construct the transformed substring
transformed_transcript = numeric_part + " " + words_converted_to_number
return transformed_transcript.strip()
if len(words_converted_to_number) <= 2:

# Construct the transformed substring
transformed_transcript = numeric_part + " " + words_converted_to_number
return transformed_transcript.strip()
except:
pass
return transcript # Return original transcript if no match
def _class_5(transcript: str) ->str:


def _class_5(transcript: str) -> str:
"""
input: transcript that looks like-"xx:yy "
description: replace ":" with " "; if yy = 00 or 0y: replace yy with " " or if 0y, replace 0 with ""
output: xx yy
"""

# Your regex pattern
pattern = re.compile(r'\b(\d{1,2}):(\d{2})\b')
pattern = re.compile(r"\b(\d{1,2}):(\d{2})\b")

match = pattern.search(transcript)

Expand All @@ -86,7 +89,7 @@ def _class_5(transcript: str) ->str:
minutes = ""
elif minutes.startswith("0"):
minutes = "" + minutes[1]

# Combining the modified hours and minutes
result = hours + " " + minutes
transformed_transcript = transcript[:start_idx] + result + transcript[end_idx:]
Expand All @@ -99,7 +102,7 @@ def _class_5(transcript: str) ->str:

def _transform_invalid_date(transcript: str) -> str:
"""
input: transcripts that are responses for when the user is asked their
input: transcripts that are responses for when the user is asked their
dob for authentication and are not recognised as dates by duckling
output: trasnformed transcript recognised by duckling as date (closest valid date)
description: handling class 5 error
Expand All @@ -109,30 +112,36 @@ def _transform_invalid_date(transcript: str) -> str:
transcript = _class_6(transcript)
return transcript

def get_transcripts_from_utterances(utterances: List[Utterance], func_transcript: Callable[[str], str]) -> List[str]:

def get_transcripts_from_utterances(
utterances: List[Utterance], func_transcript: Callable[[str], str]
) -> List[str]:
"""
input: utterances = [
[{'transcript': '102998', 'confidence': None},
{'transcript': '10 29 98', 'confidence': None},
[{'transcript': '102998', 'confidence': None},
{'transcript': '10 29 98', 'confidence': None},
{'transcript': '1029 niniety eight', 'confidence': None}]
]
description: access each transcript, confidence score pair, get
the result of <any func(transcript)>;
get a dictionary containing all results;
description: access each transcript, confidence score pair, get
the result of <any func(transcript)>;
get a dictionary containing all results;
order this dictionary in decreasing order of confidence score
output:
output:
best_transcript,
"""
result_dict: Dict[str, Any] = {}
transcripts: List[str] = []

for utterance_set in utterances:
for utterance in utterance_set:
transcript = utterance.get('transcript')
confidence = utterance.get('confidence')
confidence = 0 if confidence is None else confidence # Ensure confidence is not None
result = func_transcript(transcript)
if result==None:
transcript = utterance.get("transcript", "")
confidence = utterance.get("confidence", 0)

confidence = (
0 if confidence is None else confidence
) # Ensure confidence is not None
result = func_transcript(str(transcript))
if result == None:
result = ""
confidence = 0
if result in result_dict:
Expand All @@ -141,11 +150,16 @@ def get_transcripts_from_utterances(utterances: List[Utterance], func_transcript
result_dict[result] = confidence

# Sort the result_dict based on confidence in descending order
sorted_result = {k: v for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True) if v is not None}
sorted_result = {
k: v
for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)
if v is not None
}
transcripts = sorted(sorted_result, key=lambda x: sorted_result[x], reverse=True)

return transcripts


def get_dob(utterances: List[Utterance]) -> List[str]:
try:
# print("UTTERS:", utterances)
Expand All @@ -158,15 +172,16 @@ def get_dob(utterances: List[Utterance]) -> List[str]:
else:
# best_transcript = _format_date(utterances)
# print("transcripts = ", transcripts)
transcripts = get_transcripts_from_utterances(utterances=utterances, func_transcript=_transform_invalid_date)
transcripts = get_transcripts_from_utterances(
utterances=utterances, func_transcript=_transform_invalid_date
)
# print("dob output:",transcripts)
return transcripts
except TypeError as type_error:
raise TypeError("`transcript` is expected in the ASR output.") from type_error


class DOBPlugin(Plugin):

def __init__(
self,
input_column: str = const.ALTERNATIVES,
Expand All @@ -175,7 +190,7 @@ def __init__(
dest: Optional[str] = None,
guards: Optional[List[Guard]] = None,
debug: bool = False,
**kwargs: Any
**kwargs: Any,
) -> None:
super().__init__(
dest=dest,
Expand All @@ -184,7 +199,7 @@ def __init__(
output_column=output_column,
use_transform=use_transform,
debug=debug,
**kwargs
**kwargs,
)

async def utility(self, input: Input, _: Output) -> Any:
Expand Down

0 comments on commit f8147ae

Please sign in to comment.