Skip to content

Commit

Permalink
eval: skip parts of ds, exclude every kth, xlmr compat
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 12, 2024
1 parent a27bd89 commit 4f27691
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 40 deletions.
41 changes: 41 additions & 0 deletions configs/peft/lora_lyrics_xlmr.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlm-roberta-base_lora-v2_ep30_mldbW-verses_bs512",
"block_size": 512,
"do_train": true,
"do_eval": true,
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"evaluation_strategy": "epoch",
"dataloader_num_workers": 1,
"preprocessing_num_workers": 1,
"learning_rate": 3e-4,
"fp16": false,
"num_train_epochs": 30,
"logging_steps": 50,
"report_to": "wandb",
"wandb_project": "lyrics-peft",
"save_steps": 100000000,
"remove_unused_columns": false,
"one_sample_per_line": true,
"do_sentence_training": true,
"do_auxiliary_training": false,
"warmup_ratio": 0.1,
"non_punctuation_sample_ratio": null,
"prediction_loss_only": true,
"use_auxiliary": false,
"ddp_timeout": 3600,
"use_subwords": true,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning",
"adapter_config": "lora[r=16,alpha=32,intermediate_lora=True]",
"weight_decay": 0.0,
"auxiliary_remove_prob": 0.0,
"text_path": "data/all_data_11_05-lyrics.pth",
"skip_eval_loss": false,
"shuffle": false,
"train_adapter": true,
"subsample": null
}
27 changes: 25 additions & 2 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,38 @@ def get_labels(lang_code, sentences, after_space=True):
return labels


def evaluate_sentences(lang_code, sentences, predicted_sentences, return_indices: bool = False):
def evaluate_sentences(
lang_code, sentences, predicted_sentences, return_indices: bool = False, exclude_every_k: int = 0
):
separator = Constants.SEPARATORS[lang_code]

text = separator.join(sentences)

assert len(text) == len("".join(predicted_sentences))

labels = get_labels(lang_code, sentences)

predicted_end_indices = np.cumsum(np.array([len(s) for s in predicted_sentences]))
predictions = np.zeros_like(labels)
predictions[predicted_end_indices] = 1

assert len(labels) == len(predictions)

if exclude_every_k > 0:
true_end_indices = np.where(labels == 1)[0]
# every k-th from those where labels are 1
indices_to_remove = true_end_indices[exclude_every_k-1::exclude_every_k]

# mask for indices to keep
mask = np.ones_like(labels, dtype=bool)
mask[indices_to_remove] = False
mask[-1] = False # last is always excluded

# remove indices
labels = labels[mask]
predictions = predictions[mask]

assert len(labels) == len(predictions)

return f1_score(labels, predictions), {
"recall": recall_score(labels, predictions),
Expand Down Expand Up @@ -102,6 +122,7 @@ def evaluate_mixture(
test_x,
true_sentences,
return_indices,
exclude_every_k,
clf,
features,
threshold_transformed,
Expand Down Expand Up @@ -130,6 +151,7 @@ def evaluate_mixture(
true_sentences,
reconstruct_sentences(text, indices_to_sentences(text, predicted_indices_newline)),
return_indices,
exclude_every_k,
)

indices_newline_info = {
Expand All @@ -152,6 +174,7 @@ def evaluate_mixture(
true_sentences,
reconstruct_sentences(text, indices_to_sentences(text, predicted_indices_transformed)),
return_indices,
exclude_every_k,
)
indices_transformed_info = {
"true_indices": info_transformed.pop("true_indices"),
Expand Down
75 changes: 49 additions & 26 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Args:
clf_from_scratch: bool = False
return_indices: bool = False
skip_punct: bool = True
exclude_every_k: int = 0


def process_logits(text, model, lang_code, args):
Expand Down Expand Up @@ -138,7 +139,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st

total_test_time = 0 # Initialize total test processing time

with h5py.File(logits_path, "a") as f, torch.no_grad():
with h5py.File(logits_path, "w") as f, torch.no_grad(): # FIXME
for lang_code in tqdm(use_langs, desc="Languages"):
if args.include_langs is not None and lang_code not in args.include_langs:
continue
Expand Down Expand Up @@ -167,12 +168,32 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code):
if args.skip_corrupted and "corrupted" in dataset_name:
continue
if "Alternative" not in dataset_name:
continue
elif "nllb" in dataset_name:
continue
if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr":
print("SKIP: ", lang_code, dataset_name)
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang_code, dataset_name)
continue
if lang_code == "en" and dataset_name == "legal-all-laws":
# not available.
print("SKIP: ", lang_code, dataset_name)
continue
try:
if args.adapter_path:
if args.clf_from_scratch:
model.model.classifier = torch.nn.Linear(model.model.classifier.in_features, 1)
elif model.model.classifier.out_features == 2:
# we train XLM-R using our wrapper, needs to be adapted for adapters to be loaded
model.model.classifier = torch.nn.Linear(
model.model.classifier.in_features,
1, # FIXME: hardcoded?
)
model.model.__class__.__name__ = 'SubwordXLMForTokenClassification'

if (
any(code in lang_code for code in ["ceb", "jv", "mn", "yo"])
and "ted2020" not in dataset_name
Expand All @@ -189,15 +210,6 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
with_head=True,
load_as="text",
)
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")) and not os.path.exists(
os.path.join(args.model_path, "model.safetensors")
):
model_path = os.path.join(args.model_path, dataset_name, "en")
if not os.path.exists(model_path):
model_path = args.model_path
model = PyTorchWrapper(
AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
Expand Down Expand Up @@ -339,6 +351,8 @@ def main(args):
f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str)

save_str += f"_u{args.threshold}"
if args.exclude_every_k > 0:
save_str += f"_k{args.exclude_every_k}"

# now, compute the intrinsic scores.
results = {}
Expand All @@ -360,9 +374,16 @@ def main(args):

for dataset_name, dataset in dsets["sentence"].items():
sentences = dataset["data"][: args.max_n_test_sentences]
if len(sentences) == 0:
continue
if lang_code not in f or dataset_name not in f[lang_code]:
continue

if "lyrics" in dataset_name or "short" in dataset_name:
exclude_every_k = 0
else:
exclude_every_k = args.exclude_every_k

if "train_logits" in f[lang_code][dataset_name] and not args.skip_adaptation:
feature_indices = None
clf = train_mixture(
Expand All @@ -386,6 +407,7 @@ def main(args):
f[lang_code][dataset_name]["test_logits"][:][start:end],
list(short_seq),
args.return_indices,
exclude_every_k,
*clf,
)
score_t.append(single_score_t)
Expand All @@ -395,13 +417,11 @@ def main(args):
info["info_transformed"]["correct_pairwise"] if info["info_transformed"] else None
)
# indices: accumulate from start
t_indices.extend(
[idx + start for idx in cur_t_indices["pred_indices"]]
if cur_t_indices and cur_t_indices["pred_indices"]
else []
t_indices.append(
cur_t_indices["pred_indices"] if cur_t_indices and cur_t_indices["pred_indices"] else []
)
punct_indices.extend(
[idx + start for idx in cur_punct_indices["pred_indices"]]
punct_indices.append(
cur_punct_indices["pred_indices"]
if cur_punct_indices and cur_punct_indices["pred_indices"]
else []
)
Expand All @@ -412,6 +432,7 @@ def main(args):
f[lang_code][dataset_name]["test_logits"][:],
sentences,
args.return_indices,
exclude_every_k,
*clf,
)

Expand All @@ -428,30 +449,32 @@ def main(args):
acc_u = []
score_u = []
u_indices, true_indices = [], []
length = 0
length = []
for i, short_seq in enumerate(sentences):
start, end = f[lang_code][dataset_name]["test_logit_lengths"][i]
single_score_u, _, info, cur_u_indices, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:][start:end],
list(short_seq),
args.return_indices,
exclude_every_k,
*clf,
)
score_u.append(single_score_u)
acc_u.append(info["info_newline"]["correct_pairwise"])
# indices: accumulate from start
u_indices.extend(
[idx + start for idx in cur_u_indices["pred_indices"]] if cur_u_indices["pred_indices"] else []
)
true_indices.extend(
[idx + start for idx in cur_u_indices["true_indices"]] if cur_u_indices["true_indices"] else []
)
length += cur_u_indices["length"] - 1
u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else [])
true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else [])
length.append(cur_u_indices["length"])

else:
score_u, _, _, u_indices, _ = evaluate_mixture(
lang_code, f[lang_code][dataset_name]["test_logits"][:], sentences, args.return_indices, *clf
lang_code,
f[lang_code][dataset_name]["test_logits"][:],
sentences,
args.return_indices,
exclude_every_k,
*clf,
)

if isinstance(sentences[0], list):
Expand Down Expand Up @@ -551,7 +574,7 @@ def main(args):
"w",
),
default=int,
# indent=4,
indent=4,
)
print(Constants.CACHE_DIR / "intrinsic" / f"{save_str}_IDX.json")
print("Indices saved to file.")
Expand Down
32 changes: 21 additions & 11 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def extract(
"""
if "xlm" in model.config.model_type:
use_subwords = True
tokenizer = AutoTokenizer.from_pretrained(model.config.base_model)
tokenizer = AutoTokenizer.from_pretrained(
model.config.base_model if hasattr(model.config, "base_model") else model.config._name_or_path
)
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
tokens = tokenizer(batch_of_texts, return_offsets_mapping=True, verbose=False)
# remove CLS and SEP tokens, they are added later anyhow
Expand Down Expand Up @@ -172,10 +174,7 @@ def extract(
# containers for the final logits
all_logits = [
np.zeros(
(
length,
model.config.num_labels
),
(length, model.config.num_labels),
dtype=np.float16,
)
for length in text_lengths
Expand Down Expand Up @@ -220,12 +219,23 @@ def extract(

kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {}

logits = model(
input_ids=batch_input_ids if use_subwords else None,
hashed_ids=None if use_subwords else batch_input_hashes,
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]
if use_subwords and model.config.model_type == "xlm-roberta":
# TODO: generalize
import torch
with torch.no_grad():
logits = model.model(
input_ids=torch.from_numpy(batch_input_ids).to(model.model.device),
attention_mask=torch.from_numpy(batch_attention_mask).to(model.model.device),
**kwargs,
)["logits"].cpu().numpy()
else:
logits = model(
input_ids=batch_input_ids if use_subwords else None,
hashed_ids=None if use_subwords else batch_input_hashes,
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]

if use_subwords:
logits = logits[:, 1:-1, :] # remove CLS and SEP tokens

Expand Down
4 changes: 3 additions & 1 deletion wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,13 +392,15 @@ def maybe_pad(text):
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang, dataset_name)
continue
if "media" in dataset_name:
continue
if lang == "en" and dataset_name == "legal-all-laws":
# not available.
print("SKIP: ", lang, dataset_name)
continue
print("RUNNING:", dataset_name, lang)
# skip langs starting with a, b, ..., k
# if lang.startswith(tuple("abcd")):
# if not lang.startswith(tuple("k")) and not "en-de" in lang:
# print(f"Skipping {lang} {dataset_name}")
# continue
# do model stuff here; otherwise, head params would be overwritten every time
Expand Down

0 comments on commit 4f27691

Please sign in to comment.