From 786cd4568d55662030e2b951a7a10d85e0753bd0 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 00:19:03 +0900 Subject: [PATCH 01/13] docs: update example cmd w/ required flag --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 78fc126..5bbccf9 100644 --- a/README.md +++ b/README.md @@ -75,5 +75,5 @@ A [simple benchmark](https://github.com/bigscience-workshop/Megatron-DeepSpeed/i [WMT](https://huggingface.co/datasets/wmt19) and [TyDi QA](https://huggingface.co/datasets/tydiqa) E.g. ```shell -python3 -m evaluation.eval --model_name_or_path=gpt2 --eval_tasks tydiqa_secondary +python3 -m evaluation.eval --model_name_or_path=gpt2 --eval_tasks tydiqa_secondary --output_dir outputs ``` From 92e93004cd522364d0a96830c502f805adb26bcf Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 00:20:33 +0900 Subject: [PATCH 02/13] feat: implement `load_task_args()`, modify load logic load model and tokenizer for each task instead of recycling them --- evaluation/eval.py | 36 ++++++++++++------- evaluation/tasks/auto_task.py | 36 +++++++++++++++---- .../tasks/tydiqa_secondary/english.json | 3 ++ .../tydiqa_secondary/tydiqa_secondary.py | 9 ++++- 4 files changed, 64 insertions(+), 20 deletions(-) create mode 100644 evaluation/tasks/tydiqa_secondary/english.json diff --git a/evaluation/eval.py b/evaluation/eval.py index 053a291..60a0ba9 100644 --- a/evaluation/eval.py +++ b/evaluation/eval.py @@ -38,7 +38,11 @@ class EvaluationArguments: tag: Optional[str] = field( default=None, metadata={"help": "Identifier for the evaluation run."} - ) + ) + is_english_only: Optional[bool] = field( + default=True, + metadata={"help": "Whether to run evaluation in English only."} + ) def main(): @@ -54,18 +58,18 @@ def main(): logger = get_logger() logger.info(f"Beginning evaluation on device {train_args.device}") - # Load model & tokenizer - logger.info("Loading model...") - tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" + # # Load model & tokenizer + # logger.info("Loading model...") + # tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) + # tokenizer.pad_token = tokenizer.eos_token + # tokenizer.padding_side = "left" - model = AutoModelForCausalLM.from_pretrained( - eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token, - ) - model.config.pad_token_id = model.config.eos_token_id - model.resize_token_embeddings(len(tokenizer)) - model.to(device) + # model = AutoModelForCausalLM.from_pretrained( + # eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token, + # ) + # model.config.pad_token_id = model.config.eos_token_id + # model.resize_token_embeddings(len(tokenizer)) + # model.to(device) # Exporting results tag = eval_args.tag or datetime.now().strftime("%y%m%d_%H%M%S") @@ -74,7 +78,13 @@ def main(): for eval_task in eval_args.eval_tasks: logger.info(f"Benchmarking {eval_task}...") - task = AutoTask.from_task_name(eval_task, tokenizer=tokenizer, model=model, device=device) + task = AutoTask.from_task_name( + eval_task, + model_name_or_path=eval_args.model_name_or_path, + tokenizer_name=eval_args.tokenizer_name, + device=device, + is_english_only=eval_args.is_english_only, + ) set_seed(train_args.seed) task.evaluate() task.save_metrics(output_dir, logger) diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index dfa6f3b..ae5fc08 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -1,25 +1,45 @@ from abc import ABC, abstractmethod +from typing import Dict import os -from evaluation.utils.io import save_json +from transformers import AutoTokenizer + +from evaluation.utils.io import save_json, load_json +from evaluation.models import load_model class AutoTask(ABC): - def __init__(self, tokenizer, model, device): - self.tokenizer = tokenizer - self.model = model + def __init__( + self, model_name_or_path, device, is_english_only, tokenizer_name, + ): + self.model = load_model(model_name_or_path).to(device) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) self.device = device self.metrics = {} + self.task_config = self.load_task_args(is_english_only) @classmethod - def from_task_name(cls, task_name: str, tokenizer, model, device): + def from_task_name( + cls, task_name: str, model_name_or_path, device, is_english_only, tokenizer_name="", + ): all_tasks = cls.__subclasses__() for task in all_tasks: if task.get_display_name() == task_name: - return task(tokenizer=tokenizer, model=model, device=device) + return task( + model_name_or_path=model_name_or_path, + device=device, + tokenizer_name=tokenizer_name, + is_english_only=is_english_only, + ) raise ValueError(f'Invalid task: {task_name}') + def load_task_args(self, is_english_only) -> Dict: + task_root = os.path.join("evaluation", "tasks", self.get_display_name()) + if is_english_only: + return load_json(os.path.join(task_root, "english.json")) + return load_json(os.path.join(task_root, "multiligual.json")) + @staticmethod @abstractmethod def get_display_name() -> str: @@ -29,6 +49,10 @@ def get_display_name() -> str: def evaluate(self) -> None: pass + def train(self) -> None: + # TODO: convert to `abstractmethod` once simple_benchmark is ready + raise NotImplementedError + def save_metrics(self, output_dir, logger=None) -> str: output_filename = os.path.join(output_dir, f"{self.get_display_name()}.json") save_json(self.metrics, output_filename) diff --git a/evaluation/tasks/tydiqa_secondary/english.json b/evaluation/tasks/tydiqa_secondary/english.json new file mode 100644 index 0000000..319b5d8 --- /dev/null +++ b/evaluation/tasks/tydiqa_secondary/english.json @@ -0,0 +1,3 @@ +{ + "target_langs": ["english"] +} \ No newline at end of file diff --git a/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py b/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py index 7d97e42..77bdf17 100644 --- a/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py +++ b/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py @@ -70,9 +70,16 @@ class TydiqaSecondaryTask(AutoTask): @staticmethod def get_display_name() -> str: return 'tydiqa_secondary' + + def configure_tokenizer(self): + # configure tokenizer + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" def evaluate(self) -> None: - dataset = TyDiQADataset(self.tokenizer, target_langs=["english"]) + self.configure_tokenizer() + + dataset = TyDiQADataset(self.tokenizer, target_langs=self.task_config["target_langs"]) substring_matches = 0 for sample in tqdm(dataset, desc=f'Evaluating {self.get_display_name()}'): From bb958d5bc0524a53375f1fd88f1985d5e2a9e6f9 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 00:21:41 +0900 Subject: [PATCH 03/13] feat: implement `load_json()` --- evaluation/utils/io.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/evaluation/utils/io.py b/evaluation/utils/io.py index 9c817cf..9af2fcc 100644 --- a/evaluation/utils/io.py +++ b/evaluation/utils/io.py @@ -5,3 +5,7 @@ def save_json(content: Dict, path: str, indent: int = 4, **kwargs) -> None: with open(path, "w") as f: json.dump(content, f, indent=indent, sort_keys=True, **kwargs) + +def load_json(path: str) -> Dict: + with open(path) as f: + return json.load(f) From 367de69edd32cc820f01c9036028fb6285478e04 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 00:22:19 +0900 Subject: [PATCH 04/13] chore: separate out model loading logic will need to develop this function for mt5, megatron, etc --- evaluation/models/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 evaluation/models/__init__.py diff --git a/evaluation/models/__init__.py b/evaluation/models/__init__.py new file mode 100644 index 0000000..365bda3 --- /dev/null +++ b/evaluation/models/__init__.py @@ -0,0 +1,4 @@ +from transformers import AutoModelForCausalLM + +def load_model(model_name_or_path): + return AutoModelForCausalLM.from_pretrained(model_name_or_path) \ No newline at end of file From 9ca11a082806739cc29dd628d8af595824365f89 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 00:22:57 +0900 Subject: [PATCH 05/13] feat: add dummy fine tuning script --- evaluation/train.py | 84 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 evaluation/train.py diff --git a/evaluation/train.py b/evaluation/train.py new file mode 100644 index 0000000..0aef63d --- /dev/null +++ b/evaluation/train.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional, List +import os + +import torch +from transformers import ( + HfArgumentParser, + AutoTokenizer, + AutoModelForCausalLM, + TrainingArguments, + set_seed, +) +import evaluation.tasks # needed for AutoTask.__subclass__() to work correctly +from evaluation.tasks.auto_task import AutoTask +from evaluation.utils.log import get_logger + + +@dataclass +class EvaluationArguments: + """ + Arguments for any adjustable params in this evaluation script + """ + model_name_or_path: str = field( + metadata={"help": "The model checkpoint that we want to evaluate, could be name or the path."} + ) + eval_tasks: List[str] = field( + metadata={"help": "A list of tasks to run the evaluation on, e.g. tydiqa_secondary"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name."} + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."} + ) + tag: Optional[str] = field( + default=None, + metadata={"help": "Identifier for the evaluation run."} + ) + + +def main(): + parser = HfArgumentParser((EvaluationArguments, TrainingArguments)) + eval_args, train_args = parser.parse_args_into_dataclasses() + + if not eval_args.eval_tasks: + raise ValueError('Must provide at least one eval task!') + + # initialize device + device = torch.device(train_args.device) + + logger = get_logger() + logger.info(f"Beginning evaluation on device {train_args.device}") + + # Load model & tokenizer + logger.info("Loading model...") + tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + model = AutoModelForCausalLM.from_pretrained( + eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token, + ) + model.config.pad_token_id = model.config.eos_token_id + model.resize_token_embeddings(len(tokenizer)) + model.to(device) + + # Exporting results + tag = eval_args.tag or datetime.now().strftime("%y%m%d_%H%M%S") + output_dir = os.path.join(train_args.output_dir, tag) + os.makedirs(output_dir, exist_ok=True) + + for eval_task in eval_args.eval_tasks: + logger.info(f"Benchmarking {eval_task}...") + task = AutoTask.from_task_name(eval_task, tokenizer=tokenizer, model=model, device=device) + set_seed(train_args.seed) + task.train() + task.save_metrics(output_dir, logger) + + +if __name__ == "__main__": + main() From 96e28f54e875c31ac467bf797eaf62312ae80ff7 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 00:41:46 +0900 Subject: [PATCH 06/13] feat: allow model & tokenizer init from preloaded objects --- evaluation/eval.py | 26 +++++++------- evaluation/tasks/auto_task.py | 35 +++++++++++++++---- .../tydiqa_secondary/tydiqa_secondary.py | 8 +---- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/evaluation/eval.py b/evaluation/eval.py index 60a0ba9..4e0cc71 100644 --- a/evaluation/eval.py +++ b/evaluation/eval.py @@ -58,18 +58,18 @@ def main(): logger = get_logger() logger.info(f"Beginning evaluation on device {train_args.device}") - # # Load model & tokenizer - # logger.info("Loading model...") - # tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) - # tokenizer.pad_token = tokenizer.eos_token - # tokenizer.padding_side = "left" + # Load model & tokenizer + logger.info("Loading model...") + tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" - # model = AutoModelForCausalLM.from_pretrained( - # eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token, - # ) - # model.config.pad_token_id = model.config.eos_token_id - # model.resize_token_embeddings(len(tokenizer)) - # model.to(device) + model = AutoModelForCausalLM.from_pretrained( + eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token, + ) + model.config.pad_token_id = model.config.eos_token_id + model.resize_token_embeddings(len(tokenizer)) + model.to(device) # Exporting results tag = eval_args.tag or datetime.now().strftime("%y%m%d_%H%M%S") @@ -80,8 +80,8 @@ def main(): logger.info(f"Benchmarking {eval_task}...") task = AutoTask.from_task_name( eval_task, - model_name_or_path=eval_args.model_name_or_path, - tokenizer_name=eval_args.tokenizer_name, + model=model, + tokenizer=tokenizer, device=device, is_english_only=eval_args.is_english_only, ) diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index ae5fc08..b867d5c 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -10,26 +10,49 @@ class AutoTask(ABC): def __init__( - self, model_name_or_path, device, is_english_only, tokenizer_name, + self, + device, + is_english_only: bool, + model=None, + tokenizer=None, + model_name_or_path="", + tokenizer_name="", ): - self.model = load_model(model_name_or_path).to(device) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) + assert model or model_name_or_path, "Expected either `model` or `model_name_or_path`" + assert ( + tokenizer or tokenizer_name or model_name_or_path + ), "Expected either `tokenizer` or `model_name_or_path` or `tokenizer_name`" + if model is None: + model = load_model(model_name_or_path).to(device) + self.model = model + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) + self.tokenizer = tokenizer self.device = device self.metrics = {} self.task_config = self.load_task_args(is_english_only) @classmethod def from_task_name( - cls, task_name: str, model_name_or_path, device, is_english_only, tokenizer_name="", + cls, + task_name: str, + device, + is_english_only: bool, + model=None, + tokenizer=None, + model_name_or_path="", + tokenizer_name="", ): all_tasks = cls.__subclasses__() for task in all_tasks: if task.get_display_name() == task_name: return task( - model_name_or_path=model_name_or_path, device=device, - tokenizer_name=tokenizer_name, is_english_only=is_english_only, + model=model, + tokenizer=tokenizer, + model_name_or_path=model_name_or_path, + tokenizer_name=tokenizer_name, ) raise ValueError(f'Invalid task: {task_name}') diff --git a/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py b/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py index 77bdf17..8911f22 100644 --- a/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py +++ b/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py @@ -28,6 +28,7 @@ class TyDiQADataset(Dataset): def __init__(self, tokenizer, target_langs): super().__init__() + assert tokenizer.pad_token == tokenizer.eos_token tydiqa = load_dataset("tydiqa", "secondary_task", split="validation") self.items = [] @@ -70,15 +71,8 @@ class TydiqaSecondaryTask(AutoTask): @staticmethod def get_display_name() -> str: return 'tydiqa_secondary' - - def configure_tokenizer(self): - # configure tokenizer - self.tokenizer.pad_token = self.tokenizer.eos_token - self.tokenizer.padding_side = "left" def evaluate(self) -> None: - self.configure_tokenizer() - dataset = TyDiQADataset(self.tokenizer, target_langs=self.task_config["target_langs"]) substring_matches = 0 From b6671dfaf800c555507d9ca9bc1c7e54aec18042 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 17:43:07 +0900 Subject: [PATCH 07/13] Update evaluation/tasks/auto_task.py Co-authored-by: Wilson Lee --- evaluation/tasks/auto_task.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index b867d5c..9d32dc5 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -59,9 +59,8 @@ def from_task_name( def load_task_args(self, is_english_only) -> Dict: task_root = os.path.join("evaluation", "tasks", self.get_display_name()) - if is_english_only: - return load_json(os.path.join(task_root, "english.json")) - return load_json(os.path.join(task_root, "multiligual.json")) + config_filename = "english.json" if is_english_only else "multiligual.json" + return load_json(os.path.join(task_root, config_filename)) @staticmethod @abstractmethod From 443eb650fbbdbacdbcf7ad511c63154c7e1de4be Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 17:43:17 +0900 Subject: [PATCH 08/13] Update evaluation/utils/io.py Co-authored-by: Wilson Lee --- evaluation/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/utils/io.py b/evaluation/utils/io.py index 9af2fcc..a390f4b 100644 --- a/evaluation/utils/io.py +++ b/evaluation/utils/io.py @@ -7,5 +7,5 @@ def save_json(content: Dict, path: str, indent: int = 4, **kwargs) -> None: json.dump(content, f, indent=indent, sort_keys=True, **kwargs) def load_json(path: str) -> Dict: - with open(path) as f: + with open(path, "r") as f: return json.load(f) From 191ffb9559c9a4babaa735e5aca4a6ce9119a9cb Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 18:23:49 +0900 Subject: [PATCH 09/13] chore: `is_english_only` -> `english_only` --- evaluation/eval.py | 4 ++-- evaluation/tasks/auto_task.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/evaluation/eval.py b/evaluation/eval.py index 4e0cc71..1c9be0a 100644 --- a/evaluation/eval.py +++ b/evaluation/eval.py @@ -39,7 +39,7 @@ class EvaluationArguments: default=None, metadata={"help": "Identifier for the evaluation run."} ) - is_english_only: Optional[bool] = field( + english_only: Optional[bool] = field( default=True, metadata={"help": "Whether to run evaluation in English only."} ) @@ -83,7 +83,7 @@ def main(): model=model, tokenizer=tokenizer, device=device, - is_english_only=eval_args.is_english_only, + english_only=eval_args.english_only, ) set_seed(train_args.seed) task.evaluate() diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index 9d32dc5..13a420c 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -12,7 +12,7 @@ class AutoTask(ABC): def __init__( self, device, - is_english_only: bool, + english_only: bool, model=None, tokenizer=None, model_name_or_path="", @@ -30,14 +30,14 @@ def __init__( self.tokenizer = tokenizer self.device = device self.metrics = {} - self.task_config = self.load_task_args(is_english_only) + self.task_config = self.load_task_args(english_only) @classmethod def from_task_name( cls, task_name: str, device, - is_english_only: bool, + english_only: bool, model=None, tokenizer=None, model_name_or_path="", @@ -48,7 +48,7 @@ def from_task_name( if task.get_display_name() == task_name: return task( device=device, - is_english_only=is_english_only, + english_only=english_only, model=model, tokenizer=tokenizer, model_name_or_path=model_name_or_path, @@ -57,9 +57,9 @@ def from_task_name( raise ValueError(f'Invalid task: {task_name}') - def load_task_args(self, is_english_only) -> Dict: + def load_task_args(self, english_only) -> Dict: task_root = os.path.join("evaluation", "tasks", self.get_display_name()) - config_filename = "english.json" if is_english_only else "multiligual.json" + config_filename = "english.json" if english_only else "multiligual.json" return load_json(os.path.join(task_root, config_filename)) @staticmethod From 86c11a37a611357e8a19fa5daba4c75e5f906d6a Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 18:26:49 +0900 Subject: [PATCH 10/13] chore: mv `load_model` from `__init__.py` --- evaluation/models/__init__.py | 5 +---- evaluation/models/loader.py | 4 ++++ 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 evaluation/models/loader.py diff --git a/evaluation/models/__init__.py b/evaluation/models/__init__.py index 365bda3..bff8447 100644 --- a/evaluation/models/__init__.py +++ b/evaluation/models/__init__.py @@ -1,4 +1 @@ -from transformers import AutoModelForCausalLM - -def load_model(model_name_or_path): - return AutoModelForCausalLM.from_pretrained(model_name_or_path) \ No newline at end of file +from .loader import load_model \ No newline at end of file diff --git a/evaluation/models/loader.py b/evaluation/models/loader.py new file mode 100644 index 0000000..365bda3 --- /dev/null +++ b/evaluation/models/loader.py @@ -0,0 +1,4 @@ +from transformers import AutoModelForCausalLM + +def load_model(model_name_or_path): + return AutoModelForCausalLM.from_pretrained(model_name_or_path) \ No newline at end of file From c8fc3dc6e60a64c6177872eeb0d20621548c8e04 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 18 Aug 2021 18:33:16 +0900 Subject: [PATCH 11/13] chore: keep init empty, use abs path for import --- evaluation/models/__init__.py | 1 - evaluation/tasks/auto_task.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/evaluation/models/__init__.py b/evaluation/models/__init__.py index bff8447..e69de29 100644 --- a/evaluation/models/__init__.py +++ b/evaluation/models/__init__.py @@ -1 +0,0 @@ -from .loader import load_model \ No newline at end of file diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index 13a420c..a2f3e48 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer from evaluation.utils.io import save_json, load_json -from evaluation.models import load_model +from evaluation.models.loader import load_model class AutoTask(ABC): From 0d7a4674e224fa3ef4c3cbb76f6e8f488313345c Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 19 Aug 2021 00:54:16 +0900 Subject: [PATCH 12/13] refactor: simplify init, establish spec entrypoint --- evaluation/tasks/auto_task.py | 67 ++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index a2f3e48..f2188eb 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -11,51 +11,60 @@ class AutoTask(ABC): def __init__( self, + model, + tokenizer, device, english_only: bool, - model=None, - tokenizer=None, - model_name_or_path="", - tokenizer_name="", ): - assert model or model_name_or_path, "Expected either `model` or `model_name_or_path`" - assert ( - tokenizer or tokenizer_name or model_name_or_path - ), "Expected either `tokenizer` or `model_name_or_path` or `tokenizer_name`" - if model is None: - model = load_model(model_name_or_path).to(device) self.model = model - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) self.tokenizer = tokenizer self.device = device self.metrics = {} self.task_config = self.load_task_args(english_only) + + @classmethod + def _get_task(cls, task_name): + all_tasks = cls.__subclasses__() + for task in all_tasks: + if task.get_display_name() == task_name: + return task + raise ValueError(f'Invalid task: {task_name}') @classmethod def from_task_name( cls, task_name: str, + model, + tokenizer, device, english_only: bool, - model=None, - tokenizer=None, - model_name_or_path="", - tokenizer_name="", ): - all_tasks = cls.__subclasses__() - for task in all_tasks: - if task.get_display_name() == task_name: - return task( - device=device, - english_only=english_only, - model=model, - tokenizer=tokenizer, - model_name_or_path=model_name_or_path, - tokenizer_name=tokenizer_name, - ) - - raise ValueError(f'Invalid task: {task_name}') + task = cls._get_task(task_name) + return task( + model=model, + tokenizer=tokenizer, + device=device, + english_only=english_only, + ) + + @classmethod + def from_spec( + cls, + task_name: str, + model_name_or_path: str, + tokenizer_name: str, + device, + english_only: bool, + ): + task = cls._get_task(task_name) + model = load_model(model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path) + return task( + model=model, + tokenizer=tokenizer, + device=device, + english_only=english_only, + ) def load_task_args(self, english_only) -> Dict: task_root = os.path.join("evaluation", "tasks", self.get_display_name()) From fc320fa456105410c434f7a4d0861fe765673d94 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 19 Aug 2021 04:32:39 +0900 Subject: [PATCH 13/13] docs: add type hints --- evaluation/tasks/auto_task.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py index f2188eb..b2d4d74 100644 --- a/evaluation/tasks/auto_task.py +++ b/evaluation/tasks/auto_task.py @@ -2,7 +2,8 @@ from typing import Dict import os -from transformers import AutoTokenizer +import torch +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast from evaluation.utils.io import save_json, load_json from evaluation.models.loader import load_model @@ -11,9 +12,9 @@ class AutoTask(ABC): def __init__( self, - model, - tokenizer, - device, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerFast, + device: torch.device, english_only: bool, ): self.model = model @@ -34,9 +35,9 @@ def _get_task(cls, task_name): def from_task_name( cls, task_name: str, - model, - tokenizer, - device, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerFast, + device: torch.device, english_only: bool, ): task = cls._get_task(task_name) @@ -53,7 +54,7 @@ def from_spec( task_name: str, model_name_or_path: str, tokenizer_name: str, - device, + device: torch.device, english_only: bool, ): task = cls._get_task(task_name)