-
Notifications
You must be signed in to change notification settings - Fork 5
/
t0_finetune.py
730 lines (661 loc) · 27.6 KB
/
t0_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
#!/usr/bin/env python
# coding=utf-8
# Copyright BigScience, The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning T0 in PyTorch, optionally few-shot.
This script is adapted from
https://github.com/huggingface/transformers/blob/master/examples/pytorch/multiple-choice/run_swag_no_trainer.py
as well as
https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization_no_trainer.py
"""
import argparse
import logging
import os
import random
from dataclasses import dataclass
from itertools import chain
from typing import Optional, Union
import csv
import math
import datasets
import torch
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedTokenizerBase,
default_data_collator,
DataCollatorForSeq2Seq,
AdamW,
SchedulerType,
get_scheduler,
set_seed,
)
from transformers.file_utils import PaddingStrategy
from promptsource.templates import DatasetTemplates
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tuning T0 in PyTorch, optionally few-shot.")
parser.add_argument(
"-d",
"--dataset_name",
type=str,
default=None,
required=True,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"-s",
"--dataset_config_name",
type=str,
default=None,
help="The configuration name (usually a subset) of the dataset to use (via the datasets library).",
)
parser.add_argument(
"-t",
"--template_name",
type=str,
default=None,
required=True,
help="The template/prompt name in `promptsource`.",
)
parser.add_argument(
"-o",
"--output_dir",
type=str,
default=None,
required=True,
help="Where to store the results CSV and (TODO) optionally the final model."
)
parser.add_argument(
"-m",
"--model_name_or_path",
type=str,
required=True,
help=(
"Path to pretrained model or model identifier from huggingface.co/models. "
"The list of T0 variants can be found on `https://huggingface.co/bigscience/T0_3B`"
),
)
parser.add_argument(
"-pa",
"--parallelize",
action="store_true",
help=(
"If passed, will call `model.parallelize` which splits the model on all GPUs available (model parallelism). "
"Note that this feature is still experimental in HF Transformers."
),
)
parser.add_argument(
"-eb",
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader. Will be multiplied by the number of answer choices.",
)
parser.add_argument(
"-tb",
"--per_device_train_batch_size",
type=int,
default=4,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"-ns",
"--num_shots",
type=int,
default=None,
help="Number of training examples for few-shot learning. Default is None, which uses the entire train set.",
)
parser.add_argument(
"-lr",
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"-ep",
"--num_train_epochs",
type=int,
default=10,
help="Total number of training epochs to perform."
)
parser.add_argument(
"-ms",
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"-ga",
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"-ie",
"--input_eos",
action="store_true",
help=(
"T0 was trained without EOS in its input sequences, which is the default in this script."
"However, T5 was pretrained with EOS in its input sequences. See README for more info."
),
)
parser.add_argument(
"-db",
"--debug",
action="store_true",
help="Activate debug mode and run training only with a subset of data.",
)
parser.add_argument(
"-wb",
"--wandb_proj",
type=str,
default=None,
help="Project name for Weights & Biases. By default, W&B is disabled.",
)
parser.add_argument(
"-sd",
"--seed",
type=int,
default=42,
help="Especially important for few-shot example sampling.",
)
parser.add_argument(
"-cf",
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"-tk",
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"-il",
"--max_length",
type=int,
default=1024,
help=(
"The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
" sequences shorter will be padded if `--pad_to_max_lengh` is passed."
),
)
parser.add_argument(
"-tl",
"--target_max_length",
type=int,
default=256,
help="Target max length. Sequences longer than this will be truncated."
)
parser.add_argument(
"-pml",
"--pad_to_max_length",
action="store_true",
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
)
parser.add_argument(
"-st",
"--use_slow_tokenizer",
action="store_true",
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
)
parser.add_argument(
"-wd",
"--weight_decay",
type=float,
default=0.01,
help="Weight decay for the AdamW optimizer."
)
parser.add_argument(
"-ls",
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"-ws",
"--num_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler."
)
args = parser.parse_args()
return args
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
Note that it's very NOT recommended to use fp16 to do any time of inference with T0 as the predictions will vastly differ from the predictions using fp32.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
num_choices = len(features[0]["input_ids"])
flattened_features = [
[
{
k: v[i]
for k, v in feature.items()
if k != "targets"
}
for i in range(num_choices)
]
for feature in features
]
flattened_features = list(chain(*flattened_features))
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
)
# Pad the labels because it's not padded automatically
max_label_length = max([len(elem["labels"]) for elem in flattened_features])
batch["labels"] = [
l + [self.tokenizer.pad_token_id]*(max_label_length - len(l))
for l in [elem["labels"] for elem in flattened_features]
]
batch["labels_attention_mask"] = [
m + [0]*(max_label_length - len(m))
for m in [elem["labels_attention_mask"] for elem in flattened_features]
]
# Convert to tensors
batch = {
k: torch.tensor(v)
for k, v in batch.items()
}
batch["targets"] = torch.tensor([f.pop("targets") for f in features])
return batch
def main():
args = parse_args()
set_seed(args.seed)
# Initialize the accelerator. We will let the accelerator handle device placement for us.
accelerator = Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state)
# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Handle the output directory creation
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# In distributed evaluation, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
if args.dataset_name == "anli":
raw_train_dataset = load_dataset(args.dataset_name, split=f'train_{args.dataset_config_name}') # dataset_config_name = "r1", "r2", or "r3"
raw_eval_dataset = load_dataset(args.dataset_name, split=f'dev_{args.dataset_config_name}')
else:
raw_train_dataset = load_dataset(args.dataset_name, args.dataset_config_name, split="train")
raw_eval_dataset = load_dataset(args.dataset_name, args.dataset_config_name, split="validation")
else:
raise ValueError('Please specify `args.dataset_name` and `args.dataset_config_name` as appear in `promptsource`.')
#TODO(Victor): enable loading pre-processed dataset from https://huggingface.co/datasets/bigscience/P3
# Trim a number of evaluation examples
if args.debug:
raw_train_dataset = raw_train_dataset.select(range(min(100, len(raw_train_dataset))))
raw_eval_dataset = raw_eval_dataset.select(range(min(100, len(raw_eval_dataset))))
column_names = raw_eval_dataset.column_names
# Load pretrained model and tokenizer
#
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if args.config_name:
config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
raise ValueError(
"Either `args.config_name` or `args.model_name_or_path` should be provided."
)
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if args.model_name_or_path:
model = AutoModelForSeq2SeqLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForSeq2SeqLM.from_config(config)
# Preprocessing the datasets.
# First we tokenize all the texts.
padding = "max_length" if args.pad_to_max_length else False
# Get the prompt to apply and the possible targets.
# TODO(Victor): If pulling from pre-processed data, remove this logic.
if args.dataset_name == 'anli':
prompts = DatasetTemplates('anli', None)
else:
prompts = DatasetTemplates(
f"{args.dataset_name}"
if args.dataset_config_name is None
else f"{args.dataset_name}/{args.dataset_config_name}"
)
template = prompts[args.template_name]
def preprocess_train(examples):
bs = len(examples[column_names[0]])
input_texts = []
target_texts = []
for i in range(bs):
ex = {
k: examples[k][i]
for k in column_names
}
input, target = template.apply(ex)
ex_answer_choices = template.get_answer_choices_list(ex)
assert target in ex_answer_choices
input_texts.append(input)
target_texts.append(target)
model_inputs = tokenizer(
input_texts,
padding=padding,
max_length=args.max_length,
truncation=True,
add_special_tokens=args.input_eos,
)
with tokenizer.as_target_tokenizer():
tokenized_targets = tokenizer(
target_texts,
padding=padding,
max_length=args.target_max_length,
truncation=True,
add_special_tokens=False,
)
model_inputs['labels'] = [
[(t if t != tokenizer.pad_token_id else -100) for t in targets]
for targets in tokenized_targets["input_ids"]
]
return model_inputs
def preprocess_eval(examples):
bs = len(examples[column_names[0]])
input_texts = []
target_texts = []
answer_choices_texts = []
for i in range(bs):
ex = {
k: examples[k][i]
for k in column_names
}
input, target = template.apply(ex)
ex_answer_choices = template.get_answer_choices_list(ex)
assert target in ex_answer_choices
input_texts.append(input)
target_texts.append(target)
answer_choices_texts.append(ex_answer_choices)
tokenized_inputs = tokenizer(
input_texts,
padding=padding,
max_length=args.max_length,
truncation=True,
add_special_tokens=False,
)
tokenized_targets = [
tokenizer(
ans_choi,
padding=True,
max_length=args.target_max_length,
truncation=True,
)
for ans_choi in answer_choices_texts
]
features = {
k: [
[elem for _ in range(len(tokenized_targets[idx]["input_ids"]))]
for idx, elem in enumerate(v)
]
for k, v in tokenized_inputs.items()
}
features["labels"] = [
tokenized_targets[idx]["input_ids"]
for idx in range(bs)
]
features["labels_attention_mask"] = [
tokenized_targets[idx]["attention_mask"]
for idx in range(bs)
]
features["targets"] = [
answer_choices_texts[idx].index(t)
for idx, t in enumerate(target_texts)
]
return features
with accelerator.main_process_first():
eval_dataset = raw_eval_dataset.map(preprocess_eval, batched=True, remove_columns=column_names)
if args.num_shots is not None:
sample_indices = random.sample(range(0, len(raw_train_dataset)), k=args.num_shots)
raw_train_dataset = raw_train_dataset.select(sample_indices)
train_dataset = raw_train_dataset.map(preprocess_train, batched=True, remove_columns=column_names)
# Log a few random examples:
for index in random.sample(range(len(train_dataset)), 3):
logger.debug(f"Sample {index} of the training set: {train_dataset[index]}.")
for index in random.sample(range(len(eval_dataset)), 3):
logger.debug(f"Sample {index} of the evaluation set: {eval_dataset[index]}.")
# DataLoaders creation:
train_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=8 if accelerator.use_fp16 else None
)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=train_collator,
batch_size=args.per_device_train_batch_size
)
if args.pad_to_max_length:
# If padding was already done ot max length, we use the default data collator that will just convert everything
# to tensors.
eval_collator = default_data_collator
else:
# Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
# the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
# of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
eval_collator = DataCollatorForMultipleChoice(
tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=eval_collator, batch_size=args.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
if args.parallelize:
num_gpus = torch.cuda.device_count()
assert num_gpus > 1, "You need at least 2 GPUs to use `model.parallelize()`."
model.parallelize()
optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
optimizer, train_dataloader, eval_dataloader)
else:
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader)
# Metrics
metric = load_metric("accuracy")
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
global_steps = 0
if args.wandb_proj and accelerator.is_main_process:
import wandb
extra_metadata = {
'template_jinja': template.jinja,
'template_answer_choices': template.answer_choices,
'template_reflects_original_task': template.metadata.original_task,
'template_choices_in_prompt': template.metadata.choices_in_prompt,
'template_comment': template.reference,
}
run_config = vars(args)
run_config.update(extra_metadata)
wandb.init(
project=args.wandb_proj,
config=run_config,
# name=f'S{len(train_set)} {args.template_name} R{args.seed}', # uncomment to customize each run's name
# reinit=True, # uncomment if running multiple runs in one script
)
result_table = []
for epoch in range(1, args.num_train_epochs+1):
model.train()
for step, batch in enumerate(train_dataloader):
outputs = model(**batch)
loss = outputs.loss
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
global_steps += 1
loss = loss.item()
if accelerator.is_main_process:
tqdm.write(f"epoch = {epoch}, step = {global_steps}, loss = {loss}")
if args.wandb_proj and accelerator.is_main_process:
wandb.log({"loss": loss}, step=global_steps)
if global_steps >= args.max_train_steps:
break
# Evaluate every epoch
total_batch_size = args.per_device_eval_batch_size * accelerator.num_processes
logger.info("***** Running evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Instantaneous batch size per device = {args.per_device_eval_batch_size}")
logger.info(f" Total eval batch size (w. parallel, distributed) = {total_batch_size}")
# Only show the progress bar once on each machine. # NOTE commented out to avoid nested pbar mess
# progress_bar = tqdm(range(len(eval_dataloader)), disable=not accelerator.is_local_main_process)
model.eval()
for batch in eval_dataloader:
model_inputs = {
k: batch[k]
for k in ["input_ids", "attention_mask", "labels"]
}
with torch.no_grad():
logits = model(**model_inputs).logits
masked_log_probs = batch["labels_attention_mask"].unsqueeze(-1) * torch.log_softmax(logits, dim=-1)
seq_token_log_probs = torch.gather(masked_log_probs, -1, batch["labels"].unsqueeze(-1))
seq_log_prob = seq_token_log_probs.squeeze(dim=-1).sum(dim=-1)
seq_log_prob = seq_log_prob.view(batch["targets"].size(0), -1) #TODO(Victor): this reshapes works based on the assumption that all examples have the same number of choices. the pre-processing doesn't make this assumption.
predictions = seq_log_prob.argmax(dim=-1)
metric.add_batch(
predictions=accelerator.gather(predictions),
references=accelerator.gather(batch["targets"]),
)
# progress_bar.update(1)
eval_metric = metric.compute()
score = eval_metric["accuracy"] # TODO support other metrics; currently hardcoded at load_metric() anyway
accelerator.print(f"Accuracy: {score}")
result_table.append({
"dataset_name": args.dataset_name,
"dataset_config_name": args.dataset_config_name,
"template_name": args.template_name,
"epoch": epoch,
"step": global_steps,
"metric": 'accuracy',
"score": score,
})
if args.wandb_proj and accelerator.is_main_process:
wandb.log({"accuracy": score}, step=global_steps)
# End training loop
if accelerator.is_main_process:
if args.output_dir is not None:
with open(os.path.join(args.output_dir, "results.csv"), "w") as f:
writer = csv.DictWriter(f, fieldnames=result_table[0].keys())
writer.writeheader()
writer.writerows(result_table)
if args.wandb_proj:
wandb.finish()
if __name__ == "__main__":
main()