Skip to content

Commit

Permalink
add label token hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 26, 2024
1 parent e5844cc commit b11dafe
Showing 1 changed file with 68 additions and 59 deletions.
127 changes: 68 additions & 59 deletions wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ class LabelArgs:
custom_punctuation_file: str = None
retain_first_consecutive_punctuation: bool = True
non_whitespace_remove_spaces: bool = True
non_whitespace_retokenize: bool = False
case_corruption_prob_after_newline: float = 0.0
case_corruption_prob_after_punct: float = 0.0
corrupt_entire_chunk_prob: float = 0.0
corrupt_entire_chunk_strategy: str = "mix"
corrupt_entire_chunk_prob_full: float = 0.5
use_all_labels: bool = False
use_all_labels_max_length: int = 3

def __post_init__(self):
if self.custom_punctuation_file:
Expand Down Expand Up @@ -124,6 +125,19 @@ def get_subword_label_dict(label_args, tokenizer):
)
if token_id == tokenizer.unk_token_id:
n_unks += 1
if label_args.use_all_labels:
# check where c is in tokenizer.vocab keys
for i_t, (token, token_idx) in enumerate(tokenizer.vocab.items()):
if (
c in token
and token_idx not in label_dict
and len(token) < label_args.use_all_labels_max_length
and not any(i.isdigit() for i in token)
):
label_dict[token_idx] = 1 + Constants.AUX_OFFSET + i
logger.warning(
f"Special auxiliary character {c} has token ID {token_idx} and label {label_dict[token_idx]}, decoded: {tokenizer.decode([token_idx])}"
)

logger.info(f"found {n_unks} UNK tokens in auxiliary characters")

Expand Down Expand Up @@ -329,53 +343,43 @@ def corrupt_training(
labels.insert(last_index_in_block, 0)
else:
del block_ids[i + 1]
if tokenizer and separator == "":
if label_args.non_whitespace_remove_spaces and i + 1 < len(input_ids):
# tokenizer.decode() retains the space that leaks the information
# so we need to get the position within the tokenized text and then remove the space
# (so there is no more space when fed into the tokenizer call)
if input_ids[i + 1] == tokenizer.convert_tokens_to_ids("▁"):
# remove artificial space
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
if i + 1 < len(input_ids):
next_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
if next_token.startswith("▁"):
# next token starts with _ --> remove the _ from the token and re-tokenize
remove_next = False
remaining_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
if len(remaining_token) > 1:
# ▁Test --> Test
remaining_token = remaining_token[1:]
else:
# ▁ --> remove
remove_next = True
remaining_id = tokenizer.convert_tokens_to_ids(remaining_token)
# replace the token with the remaining token
if remaining_id != tokenizer.unk_token_id:
input_ids[i + 1] = remaining_id
else:
# UNK token, remove it
remove_next = True
if remove_next:
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
elif label_args.non_whitespace_retokenize:
# re-tokenize full span
# this is a bit more expensive, but it is the only way to ensure that the space is removed
# and the token is re-tokenized correctly
# but: only tokenize from current position onwards to keep re-usage labels
input_ids = input_ids[: i + 1] + tokenizer.encode(
tokenizer.decode(input_ids[i + 1 :]), add_special_tokens=False
)
# if new_input_ids != input_ids:
# print("new_input_ids", tokenizer.decode(new_input_ids))
# print("input_ids", tokenizer.decode(input_ids))
# input_ids = new_input_ids
labels = labels[: i + 1] + label(input_ids[i + 1 :], label_dict)

if (
tokenizer
and separator == ""
and label_args.non_whitespace_remove_spaces
and i + 1 < len(input_ids)
):
# tokenizer.decode() retains the space that leaks the information
# so we need to get the position within the tokenized text and then remove the space
# (so there is no more space when fed into the tokenizer call)
if input_ids[i + 1] == tokenizer.convert_tokens_to_ids("▁"):
# remove artificial space
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
if i + 1 < len(input_ids):
next_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
if next_token.startswith("▁"):
# next token starts with _ --> remove the _ from the token and re-tokenize
remove_next = False
remaining_token = tokenizer.convert_ids_to_tokens(input_ids[i + 1])
if len(remaining_token) > 1:
# ▁Test --> Test
remaining_token = remaining_token[1:]
else:
# ▁ --> remove
remove_next = True
remaining_id = tokenizer.convert_tokens_to_ids(remaining_token)
# replace the token with the remaining token
if remaining_id != tokenizer.unk_token_id:
input_ids[i + 1] = remaining_id
else:
# UNK token, remove it
remove_next = True
if remove_next:
del input_ids[i + 1]
del labels[i + 1]
del block_ids[i + 1]
if random.random() < label_args.case_corruption_prob_after_newline and i + 1 < len(input_ids):
input_ids, labels, block_ids = _corrupt_case(tokenizer, input_ids, labels, block_ids, i)

Expand Down Expand Up @@ -535,23 +539,28 @@ def reconstruct_sentences(text, partial_sentences):
newline_whitespace_prob=1.0,
case_corruption_prob_after_punct=1.0,
case_corruption_prob_after_newline=1.0,
use_all_labels=True,
)
label_dict = get_subword_label_dict(label_args, tokenizer)

print(len(label_dict))
# print all tokens with a number in it (from label_dict only)
print(
[tokenizer.decode(input_id) for input_id in input_ids if any(c.isdigit() for c in tokenizer.decode(input_id))]
)
# corrupt
input_ids, block_ids, labels = corrupt_training(
input_ids, block_ids, "en", label_args, label_dict, tokenizer=tokenizer
)
print(input_ids)
print(labels)
print(tokenizer.tokenize(text))
print([(tokenizer.decode([input_id]), label) for input_id, label in zip(input_ids, labels)])
print("newline labels in text:")
print(np.where(np.array(labels) == 1))
print("newline ids in output text:")
print(np.where(np.array(input_ids) == tokenizer.all_special_ids[-1]))
print(tokenizer.decode(input_ids))
print(tokenizer.decode(input_ids))
# print(input_ids)
# print(labels)
# print(tokenizer.tokenize(text))
# print([(tokenizer.decode([input_id]), label) for input_id, label in zip(input_ids, labels)])
# print("newline labels in text:")
# print(np.where(np.array(labels) == 1))
# print("newline ids in output text:")
# print(np.where(np.array(input_ids) == tokenizer.all_special_ids[-1]))
# print(tokenizer.decode(input_ids))
# print(tokenizer.decode(input_ids))

# ords = [ord(c) for c in text]
# block_ords = [0] * len(ords)
Expand Down

0 comments on commit b11dafe

Please sign in to comment.