Skip to content

Commit

Permalink
add mixed corruption strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 23, 2024
1 parent 6b9e2b8 commit 37824a4
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import json
import logging
import os
import random
from collections import defaultdict
from dataclasses import dataclass, field
from cached_property import cached_property
from pathlib import Path
from typing import List
import logging
from collections import defaultdict

import numpy as np
import pandas as pd
from transformers import AutoTokenizer
from cached_property import cached_property
from mosestokenizer import MosesTokenizer
from transformers import AutoTokenizer

# same as in CANINE
PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223]
Expand Down Expand Up @@ -86,7 +86,8 @@ class LabelArgs:
case_corruption_prob_after_newline: float = 0.0
case_corruption_prob_after_punct: float = 0.0
corrupt_entire_chunk_prob: float = 0.0
corrupt_entire_chunk_strategy: str = "full"
corrupt_entire_chunk_strategy: str = "mix"
corrupt_entire_chunk_prob_full: float = 0.5

def __post_init__(self):
if self.custom_punctuation_file:
Expand Down Expand Up @@ -239,27 +240,32 @@ def corrupt_training(
input_ids = input_ids.copy()
block_ids = block_ids.copy()
if random.random() < label_args.corrupt_entire_chunk_prob:
# lowercase all text
# choose corruption strategy
if label_args.corrupt_entire_chunk_strategy == "mix":
corrupt_strategy = "full" if random.random() < label_args.corrupt_entire_chunk_prob_full else "asr"
else:
corrupt_strategy = label_args.corrupt_entire_chunk_strategy

input_text = tokenizer.decode(input_ids)
if label_args.corrupt_entire_chunk_strategy == "tokenizer":
if corrupt_strategy == "tokenizer":
if not tokenizer:
raise NotImplementedError()
corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=False)
input_ids = tokenizer.encode(corrupted, add_special_tokens=False)
input_ids = tokenizer.encode(corrupted, add_special_tokens=False, verbose=False)
# remove ALL punct *tokens*
auxiliary_remove_prob = 1.0
elif label_args.corrupt_entire_chunk_strategy == "full":
elif corrupt_strategy == "full":
# remove all punct *characters*
corrupted = corrupt(input_text, do_lowercase=True, do_remove_punct=True)
input_ids = tokenizer.encode(corrupted, add_special_tokens=False)
input_ids = tokenizer.encode(corrupted, add_special_tokens=False, verbose=False)
auxiliary_remove_prob = 1.0 # just for safety/consistency
elif label_args.corrupt_entire_chunk_strategy == "asr":
elif corrupt_strategy == "asr":
if not tokenizer:
raise NotImplementedError()
corrupted_sentences = corrupt_asr(input_text, lang)
corrupted_text = "\n".join(corrupted_sentences)
input_ids = tokenizer.encode(corrupted_text, add_special_tokens=False)
auxiliary_remove_prob = 0.0
input_ids = tokenizer.encode(corrupted_text, add_special_tokens=False, verbose=False)
auxiliary_remove_prob = 0.0 # do not remove additional characters.
block_ids = [0] * len(input_ids)

else:
Expand Down

0 comments on commit 37824a4

Please sign in to comment.