Skip to content

Commit

Permalink
eval + train ADP on Igor's models
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 13, 2024
1 parent a334575 commit 02dba86
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 32 deletions.
36 changes: 36 additions & 0 deletions configs/peft/adapter_igor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"model_name_or_path": "xlmr-multilingual-sentence-segmentation-09-04-3L-256BS-UD-OPUS-TED",
"output_dir": "xlmr-3l-v3-igor-mixture_adapter_rf16_ep30_v2",
"block_size": 256,
"eval_stride": 128,
"do_train": true,
"do_eval": true,
"per_device_train_batch_size": 64,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 8,
"preprocessing_num_workers": 1,
"learning_rate": 3e-4,
"fp16": false,
"num_train_epochs": 30,
"logging_steps": 50,
"report_to": "wandb",
"save_steps": 100000000,
"remove_unused_columns": false,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": false,
"warmup_ratio": 0.1,
"non_punctuation_sample_ratio": null,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"custom_punctuation_file": "punctuation_xlmr_unk.txt",
"log_level": "warning",
"adapter_config": "seq_bn[reduction_factor=16]",
"weight_decay": 0.0,
"auxiliary_remove_prob": 0.0,
"train_adapter": true
}
27 changes: 16 additions & 11 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
logger = logging.getLogger()
logger.setLevel(logging.INFO)


@dataclass
class Args:
model_path: str
Expand Down Expand Up @@ -148,12 +149,16 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
with_head=True,
load_as="text",
)
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")):
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")) and not os.path.exists(
os.path.join(args.model_path, "model.safetensors")
):
model_path = os.path.join(args.model_path, dataset_name, "en")
if not os.path.exists(model_path):
model_path = args.model_path
print(model_path)
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
model = PyTorchWrapper(
AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)
)
except Exception as e:
print(f"Error loading adapter for {dataset_name} in {lang_code}: {e}")
continue
Expand Down Expand Up @@ -227,9 +232,7 @@ def main(args):
save_model_path = args.model_path
if args.adapter_path:
save_model_path = args.adapter_path
save_str = (
f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}"
)
save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}"
if args.do_lowercase:
save_str += "_lc"
if args.do_remove_punct:
Expand All @@ -243,7 +246,9 @@ def main(args):

print("Loading model...")
# if model_path does not contain a model, take first subfolder
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")):
if not os.path.exists(os.path.join(args.model_path, "pytorch_model.bin")) and not os.path.exists(
os.path.join(args.model_path, "model.safetensors")
):
try:
model_path = os.path.join(args.model_path, os.listdir(args.model_path)[0], "en")
except:
Expand All @@ -261,15 +266,15 @@ def main(args):
model.model.config.model_type = model_type
if "meta-clf" in args.adapter_path:
clf = model.model.classifier
model.model.classifier = torch.nn.Sequential(
clf,
torch.nn.Linear(clf.out_features, 1)
)
model.model.classifier = torch.nn.Sequential(clf, torch.nn.Linear(clf.out_features, 1))

# first, logits for everything.
f, total_test_time = load_or_compute_logits(args, model, eval_data, valid_data, save_str)

save_str += f"_u{args.threshold}{args.save_suffix}"

if "multilingual" in model_path:
Constants.NEWLINE_INDEX += 1

# now, compute the intrinsic scores.
results = {}
Expand Down
54 changes: 33 additions & 21 deletions wtpsplit/train/train_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Args:
model_name_or_path: str
base_model: str = "xlm-roberta-base"
shuffle: bool = True
text_path: str = "data/eval.pth"
text_path: str = "data/all_data.pth"
include_languages: List[str] = None
preprocessing_num_workers: int = 1
block_size: int = 512
Expand All @@ -67,6 +67,7 @@ class Args:
do_process: bool = False
meta_clf: bool = False
wandb_project: str = "sentence"
eval_every: int = 5
# corruption
do_lowercase: bool = False
do_remove_punct: bool = False
Expand All @@ -92,6 +93,9 @@ def main():
if (label_args.use_auxiliary or args.do_auxiliary_training or args.meta_clf)
else 0
)
if "multilingual" in args.model_name_or_path:
# Igor's models were not trained with aux. objective.
num_labels = 2
config = SubwordXLMConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels,
Expand All @@ -106,7 +110,7 @@ def prepare_dataset(
split="train",
do_lowercase=False,
do_remove_punct=False,
subsample: Union[None, int, float] = None
subsample: Union[None, int, float] = None,
):
# maybe we use more than 1 lang later at once.
with training_args.main_process_first():
Expand Down Expand Up @@ -164,8 +168,8 @@ def prepare_dataset(
subsample = min(subsample, len(dataset))
dataset = dataset.select(range(subsample))
elif isinstance(subsample, float):
dataset = dataset.select(range(int(subsample * len(dataset))))
logger.warning(f"Subsampled {len(dataset)} examples from {old_length}.")
dataset = dataset.select(range(int(subsample * len(dataset))))
logger.warning(f"Subsampled {len(dataset)} examples from {old_length}.")

# very likely not relevant / used only for the compound part
if args.ignore_non_hyphen:
Expand Down Expand Up @@ -371,9 +375,9 @@ def maybe_pad(text):
with training_args.main_process_first():
dataset = dataset.map(
lambda x: {
"input_ids": [
tokenizer.convert_tokens_to_ids(tokenizer.bos_token)
] + x["input_ids"] + [tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
"input_ids": [tokenizer.convert_tokens_to_ids(tokenizer.bos_token)]
+ x["input_ids"]
+ [tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
},
batched=False,
)
Expand Down Expand Up @@ -516,7 +520,7 @@ def compute_metrics(trainer):
label_dict = (
get_subword_label_dict(label_args, tokenizer) if args.use_subwords else get_label_dict(label_args)
)

if adapter_args.train_adapter:
# init new adapter
model.backbone.add_adapter(
Expand All @@ -529,7 +533,7 @@ def compute_metrics(trainer):
training_args.adapter_warmup_steps = args.adapter_warmup_steps
training_args.adapter_lr_multiplier = args.adapter_lr_multiplier
kwargs = {}

with training_args.main_process_first():
logger.warning(model.backbone.adapter_summary())

Expand All @@ -552,17 +556,23 @@ def compute_metrics(trainer):
torch.nn.Linear(clf.out_features, 1),
)
model.backbone.config.num_labels = 1

if args.one_sample_per_line:
# eval only 10x during the entire training
training_args.evaluation_strategy = "steps"
training_args.eval_steps = max(len(train_dataset) // training_args.per_device_train_batch_size, 5)
# log twice as often
training_args.logging_steps = training_args.eval_steps // 2

trainer_cls = AdapterTrainer if adapter_args.train_adapter else Trainer
# if args.one_sample_per_line:
# eval only 5x during the entire training
training_args.evaluation_strategy = "steps"
training_args.eval_steps = (
len(train_dataset)
// training_args.per_device_train_batch_size
// training_args.gradient_accumulation_steps
// args.eval_every
* training_args.num_train_epochs
)
# log more often than this
training_args.logging_steps = training_args.eval_steps // 4

trainer_cls = AdapterTrainer if adapter_args.train_adapter else Trainer
# add logging_prefix and skip_eval_loss as args to trainer_cls if trainer_cls is AdapterTrainer only

trainer = trainer_cls(
model,
training_args,
Expand Down Expand Up @@ -605,14 +615,16 @@ def compute_metrics(trainer):
else:
eval_function = "intrinsic"
if args.do_lowercase and args.do_remove_punct:
suffix = "--do_lowercase --do_remove_punct"
suffix = "--do_lowercase --do_remove_punct"
elif "multilingual" in trainings_args.model_name_or_path:

Check failure on line 619 in wtpsplit/train/train_adapter.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F821)

wtpsplit/train/train_adapter.py:619:32: F821 Undefined name `trainings_args`

Check failure on line 619 in wtpsplit/train/train_adapter.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F821)

wtpsplit/train/train_adapter.py:619:32: F821 Undefined name `trainings_args`
suffix = "--threshold 0.5"
else:
suffix = ""
suffix = ""
if "adapter" in training_args.output_dir:
model_info = f"--model_path {args.model_name_or_path} --adapter_path {training_args.output_dir}"
else:
model_info = f"--model_path {training_args.output_dir}"

if "verses" in args.text_path or "lines" in args.text_path:
cmd = f"python3 wtpsplit/evaluation/{eval_function}.py {model_info} --threshold 0.1 --custom_language_list data/mldb_langs.csv --eval_data_path {args.text_path} {suffix}"
else:
Expand Down

0 comments on commit 02dba86

Please sign in to comment.