Skip to content

Commit

Permalink
Make punctuation recovery deterministic. And start to implement some …
Browse files Browse the repository at this point in the history
…thing for (insertion) disfluencies
  • Loading branch information
Jeronymous committed Jun 13, 2023
1 parent 732b8c5 commit 8748f8b
Showing 1 changed file with 53 additions and 2 deletions.
55 changes: 53 additions & 2 deletions punctuation/recasepunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def init_random(seed):
# make sure everything is deterministic
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)
set_seed(seed)

def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
Expand Down Expand Up @@ -138,14 +141,19 @@ def load_model(checkpoint_path="/usr/src/app/model-store/model", config=None):
return config


def generate_predictions(config, line):
def generate_predictions(config, line, ignore_disfluencies=False):
if isinstance(line, list):
return [generate_predictions(config, l) for l in line]

model = config.model
set_seed(config.seed)

# also drop punctuation that we may generate
line = ''.join([c for c in line if c not in mapped_punctuation])
if ignore_disfluencies:
line = collapse_whitespace(line)
line = re.sub(r"(\w) *' *(\w)", r"\1'\2", line) # glue apostrophes to words
disfluencies, line = remove_simple_disfluences(line)
output = ''
if config.debug:
print(line)
Expand Down Expand Up @@ -202,6 +210,10 @@ def generate_predictions(config, line):
output += '.'
# Glue apostrophes back to words
output = re.sub(r"(\w) *' *(\w)", r"\1'\2", output)

if ignore_disfluencies:
output = collapse_whitespace(output)
output = reconstitute_text(output, disfluencies)
return output

mapped_punctuation = {
Expand All @@ -227,7 +239,7 @@ def generate_predictions(config, line):
'【': 'COMMA',
'】': 'COMMA',
'└': 'COMMA',
'└ ': 'COMMA',
#'└ ': 'COMMA',
'_': 'O',
'。': 'PERIOD',
'、': 'COMMA', # enumeration comma
Expand All @@ -251,6 +263,10 @@ def generate_predictions(config, line):
'〕': 'COMMA',
}

def collapse_whitespace(text):
return re.sub(r'\s+', ' ', text).strip()


# modification of the wordpiece tokenizer to keep case information even if vocab is lower cased
# forked from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/tokenization_bert.py

Expand Down Expand Up @@ -400,3 +416,38 @@ def init(config):
print('WARNING: reverting to cpu as cuda is not available', file=sys.stderr)
config.device = torch.device(config.device if torch.cuda.is_available() else 'cpu')

def remove_simple_disfluences(text, language=None):
if language is None:
# Get language from environment
language = os.environ.get("LANGUAGE","")
language = language.lower()[:2]
disfluencies = DISFLUENCIES.get(language, [])
all_hits = []
for disfluency in disfluencies:
all_hits += re.finditer(r" *"+disfluency+r" *", text)
all_hits = sorted(all_hits, key=lambda x: x.start())
to_be_inserted = [(hit.start(), hit.group()) for hit in all_hits]
new_text = text
for hit in all_hits[::-1]:
new_text = new_text[:hit.start()] + " " + new_text[hit.end():]
return to_be_inserted, new_text

punctuation_regex = r"["+re.escape("".join(mapped_punctuation.keys()))+r"]"

def reconstitute_text(text, to_be_inserted):
if len(to_be_inserted) == 0:
return text
pos_punc = [s.start() for s in re.finditer(punctuation_regex, text)]
for start, token in to_be_inserted:
start += len([p for p in pos_punc if p < start])
text = text[:start] + token.rstrip(" ") + text[start:]
print(text)
return text


DISFLUENCIES = {
"fr": [
"euh",
"heu",
]
}

0 comments on commit 8748f8b

Please sign in to comment.