From b11dafe00889173fe04e994daa1a60ab1985a1b9 Mon Sep 17 00:00:00 2001 From: markus583 Date: Fri, 26 Apr 2024 12:33:19 +0000 Subject: [PATCH] add label token hierarchy --- wtpsplit/utils.py | 127 +++++++++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 59 deletions(-) diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index adf0f00c..75023601 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -83,12 +83,13 @@ class LabelArgs: custom_punctuation_file: str = None retain_first_consecutive_punctuation: bool = True non_whitespace_remove_spaces: bool = True - non_whitespace_retokenize: bool = False case_corruption_prob_after_newline: float = 0.0 case_corruption_prob_after_punct: float = 0.0 corrupt_entire_chunk_prob: float = 0.0 corrupt_entire_chunk_strategy: str = "mix" corrupt_entire_chunk_prob_full: float = 0.5 + use_all_labels: bool = False + use_all_labels_max_length: int = 3 def __post_init__(self): if self.custom_punctuation_file: @@ -124,6 +125,19 @@ def get_subword_label_dict(label_args, tokenizer): ) if token_id == tokenizer.unk_token_id: n_unks += 1 + if label_args.use_all_labels: + # check where c is in tokenizer.vocab keys + for i_t, (token, token_idx) in enumerate(tokenizer.vocab.items()): + if ( + c in token + and token_idx not in label_dict + and len(token) < label_args.use_all_labels_max_length + and not any(i.isdigit() for i in token) + ): + label_dict[token_idx] = 1 + Constants.AUX_OFFSET + i + logger.warning( + f"Special auxiliary character {c} has token ID {token_idx} and label {label_dict[token_idx]}, decoded: {tokenizer.decode([token_idx])}" + ) logger.info(f"found {n_unks} UNK tokens in auxiliary characters") @@ -329,53 +343,43 @@ def corrupt_training( labels.insert(last_index_in_block, 0) else: del block_ids[i + 1] - if tokenizer and separator == "": - if label_args.non_whitespace_remove_spaces and i + 1 < len(input_ids): - # tokenizer.decode() retains the space that leaks the information - # so we need to get the position within the tokenized text and then remove the space - # (so there is no more space when fed into the tokenizer call) - if input_ids[i + 1] == tokenizer.convert_tokens_to_ids("▁"): - # remove artificial space - del input_ids[i + 1] - del labels[i + 1] - del block_ids[i + 1] - if i + 1 < len(input_ids): - next_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1]) - if next_token.startswith("▁"): - # next token starts with _ --> remove the _ from the token and re-tokenize - remove_next = False - remaining_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1]) - if len(remaining_token) > 1: - # ▁Test --> Test - remaining_token = remaining_token[1:] - else: - # ▁ --> remove - remove_next = True - remaining_id = tokenizer.convert_tokens_to_ids(remaining_token) - # replace the token with the remaining token - if remaining_id != tokenizer.unk_token_id: - input_ids[i + 1] = remaining_id - else: - # UNK token, remove it - remove_next = True - if remove_next: - del input_ids[i + 1] - del labels[i + 1] - del block_ids[i + 1] - elif label_args.non_whitespace_retokenize: - # re-tokenize full span - # this is a bit more expensive, but it is the only way to ensure that the space is removed - # and the token is re-tokenized correctly - # but: only tokenize from current position onwards to keep re-usage labels - input_ids = input_ids[: i + 1] + tokenizer.encode( - tokenizer.decode(input_ids[i + 1 :]), add_special_tokens=False - ) - # if new_input_ids != input_ids: - # print("new_input_ids", tokenizer.decode(new_input_ids)) - # print("input_ids", tokenizer.decode(input_ids)) - # input_ids = new_input_ids - labels = labels[: i + 1] + label(input_ids[i + 1 :], label_dict) - + if ( + tokenizer + and separator == "" + and label_args.non_whitespace_remove_spaces + and i + 1 < len(input_ids) + ): + # tokenizer.decode() retains the space that leaks the information + # so we need to get the position within the tokenized text and then remove the space + # (so there is no more space when fed into the tokenizer call) + if input_ids[i + 1] == tokenizer.convert_tokens_to_ids("▁"): + # remove artificial space + del input_ids[i + 1] + del labels[i + 1] + del block_ids[i + 1] + if i + 1 < len(input_ids): + next_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1]) + if next_token.startswith("▁"): + # next token starts with _ --> remove the _ from the token and re-tokenize + remove_next = False + remaining_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1]) + if len(remaining_token) > 1: + # ▁Test --> Test + remaining_token = remaining_token[1:] + else: + # ▁ --> remove + remove_next = True + remaining_id = tokenizer.convert_tokens_to_ids(remaining_token) + # replace the token with the remaining token + if remaining_id != tokenizer.unk_token_id: + input_ids[i + 1] = remaining_id + else: + # UNK token, remove it + remove_next = True + if remove_next: + del input_ids[i + 1] + del labels[i + 1] + del block_ids[i + 1] if random.random() < label_args.case_corruption_prob_after_newline and i + 1 < len(input_ids): input_ids, labels, block_ids = _corrupt_case(tokenizer, input_ids, labels, block_ids, i) @@ -535,23 +539,28 @@ def reconstruct_sentences(text, partial_sentences): newline_whitespace_prob=1.0, case_corruption_prob_after_punct=1.0, case_corruption_prob_after_newline=1.0, + use_all_labels=True, ) label_dict = get_subword_label_dict(label_args, tokenizer) - + print(len(label_dict)) + # print all tokens with a number in it (from label_dict only) + print( + [tokenizer.decode(input_id) for input_id in input_ids if any(c.isdigit() for c in tokenizer.decode(input_id))] + ) # corrupt input_ids, block_ids, labels = corrupt_training( input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer ) - print(input_ids) - print(labels) - print(tokenizer.tokenize(text)) - print([(tokenizer.decode([input_id]), label) for input_id, label in zip(input_ids, labels)]) - print("newline labels in text:") - print(np.where(np.array(labels) == 1)) - print("newline ids in output text:") - print(np.where(np.array(input_ids) == tokenizer.all_special_ids[-1])) - print(tokenizer.decode(input_ids)) - print(tokenizer.decode(input_ids)) + # print(input_ids) + # print(labels) + # print(tokenizer.tokenize(text)) + # print([(tokenizer.decode([input_id]), label) for input_id, label in zip(input_ids, labels)]) + # print("newline labels in text:") + # print(np.where(np.array(labels) == 1)) + # print("newline ids in output text:") + # print(np.where(np.array(input_ids) == tokenizer.all_special_ids[-1])) + # print(tokenizer.decode(input_ids)) + # print(tokenizer.decode(input_ids)) # ords = [ord(c) for c in text] # block_ords = [0] * len(ords)