Skip to content

Commit

Permalink
build: apply a typical project boilerplate of BigScience repo
Browse files Browse the repository at this point in the history
Prepare GitHub actions for more regression tests in the future.

A side note: Poetry is optional for managing venv and dependencies, but
syncing requirements(-dev).txt must be done manually for now.
  • Loading branch information
tianjianjiang committed Aug 17, 2021
1 parent 77b0910 commit 61bd30c
Show file tree
Hide file tree
Showing 15 changed files with 1,852 additions and 56 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/code_quality.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Code quality

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install -r requirements-dev.txt
- name: Check code quality
run: make quality
25 changes: 25 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Test

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
name: Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install .
python -m pip install -r requirements-dev.txt
- name: Test
run: python -m pytest tests
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.PHONY: quality style

check_dirs := .

quality: # Check that source code meets quality standards
black --check $(check_dirs)
isort --check-only $(check_dirs)
flake8 $(check_dirs) --max-line-length 119

style: # Format source code automatically
black $(check_dirs)
isort $(check_dirs)
40 changes: 15 additions & 25 deletions evaluation/eval.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,42 @@
import os
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, List
import os
from typing import List, Optional

import torch
from transformers import (
HfArgumentParser,
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
set_seed,
)
import evaluation.tasks # needed for AutoTask.__subclass__() to work correctly
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

import evaluation.tasks # noqa: F401; needed for AutoTask.__subclass__() to work correctly
from evaluation.tasks.auto_task import AutoTask
from evaluation.utils.log import get_logger


@dataclass
class EvaluationArguments:
"""
Arguments for any adjustable params in this evaluation script
Arguments for any adjustable params in this evaluation script
"""

model_name_or_path: str = field(
metadata={"help": "The model checkpoint that we want to evaluate, could be name or the path."}
)
eval_tasks: List[str] = field(
metadata={"help": "A list of tasks to run the evaluation on, e.g. tydiqa_secondary"}
)
eval_tasks: List[str] = field(metadata={"help": "A list of tasks to run the evaluation on, e.g. tydiqa_secondary"})
config_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained config name or path if not the same as model_name."}
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name."}
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
)
tag: Optional[str] = field(
default=None,
metadata={"help": "Identifier for the evaluation run."}
)
tag: Optional[str] = field(default=None, metadata={"help": "Identifier for the evaluation run."})


def main():
parser = HfArgumentParser((EvaluationArguments, TrainingArguments))
eval_args, train_args = parser.parse_args_into_dataclasses()

if not eval_args.eval_tasks:
raise ValueError('Must provide at least one eval task!')
raise ValueError("Must provide at least one eval task!")

# initialize device
device = torch.device(train_args.device)

Expand All @@ -61,7 +50,8 @@ def main():
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token,
eval_args.model_name_or_path,
pad_token_id=tokenizer.eos_token,
)
model.config.pad_token_id = model.config.eos_token_id
model.resize_token_embeddings(len(tokenizer))
Expand Down
1 change: 1 addition & 0 deletions evaluation/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# source: https://stackoverflow.com/questions/3365740/how-to-import-all-submodules
import pkgutil


__all__ = []
for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
__all__.append(module_name)
Expand Down
6 changes: 3 additions & 3 deletions evaluation/tasks/auto_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
import os
from abc import ABC, abstractmethod

from evaluation.utils.io import save_json

Expand All @@ -17,8 +17,8 @@ def from_task_name(cls, task_name: str, tokenizer, model, device):
for task in all_tasks:
if task.get_display_name() == task_name:
return task(tokenizer=tokenizer, model=model, device=device)
raise ValueError(f'Invalid task: {task_name}')

raise ValueError(f"Invalid task: {task_name}")

@staticmethod
@abstractmethod
Expand Down
20 changes: 11 additions & 9 deletions evaluation/tasks/tydiqa_primary/tydiqa_primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jinja2 import Template
from torch.utils.data import Dataset


TEMPLATE = Template(
"""
{%- set _blank=["passage", "text", "text snippet", "context"]|random -%}
Expand All @@ -16,30 +17,31 @@
{{"\n"}}{{context}}
{%- endif -%}
{{"\n"}}Answer:
"""
""" # noqa W291
)


class TyDiQADataset(Dataset):
def __init__(self, data, tokenizer, target_langs):
super(TyDiQADataset, self).__init__()
self.items = []

for sample_id, sample in enumerate(data):
lang = sample["id"].split("-")[0]
if lang in target_langs:
# Filter out samples in languages that are not used during training
prompt = TEMPLATE.render(
id = sample["id"],
context = sample["context"],
question = sample["question"],
id=sample["id"],
context=sample["context"],
question=sample["question"],
)
prompt = prompt.strip() # Remove trailing white space and newline

# Tokenize and construct this sample
inputs = tokenizer(
prompt,
padding=True,
return_tensors='pt',
return_tensors="pt",
)
self.items.append(
{
Expand All @@ -48,12 +50,12 @@ def __init__(self, data, tokenizer, target_langs):
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"input_len": inputs["attention_mask"].shape[1],
"target_answer": [ans.lower() for ans in sample["answers"]['text']],
"target_answer": [ans.lower() for ans in sample["answers"]["text"]],
}
)

def __len__(self):
return len(self.items)

def __getitem__(self, index):
return self.items[index]
31 changes: 14 additions & 17 deletions evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Module for any additional processing required for the TyDi QA dataset
# HuggingFace dataset link: https://huggingface.co/datasets/tydiqa
from typing import Dict

from datasets import load_dataset
from jinja2 import Template
from torch.utils.data import Dataset
from datasets import load_dataset
from tqdm import tqdm

from evaluation.tasks.auto_task import AutoTask


TEMPLATE = Template(
"""
{%- set _blank=["passage", "text", "text snippet", "context"]|random -%}
Expand All @@ -21,7 +20,7 @@
{{"\n"}}{{context}}
{%- endif -%}
{{"\n"}}Answer:
"""
""" # noqa W291
)


Expand All @@ -30,23 +29,23 @@ def __init__(self, tokenizer, target_langs):
super().__init__()
tydiqa = load_dataset("tydiqa", "secondary_task", split="validation")
self.items = []

for sample in tydiqa:
lang = sample["id"].split("-")[0]
if lang in target_langs:
# Filter out samples in languages that are not used during training
prompt = TEMPLATE.render(
id = sample["id"],
context = sample["context"],
question = sample["question"],
id=sample["id"],
context=sample["context"],
question=sample["question"],
)
prompt = prompt.strip() # Remove trailing white space and newline

# Tokenize and construct this sample
inputs = tokenizer(
prompt,
padding=True,
return_tensors='pt',
return_tensors="pt",
)
self.items.append(
{
Expand All @@ -55,27 +54,27 @@ def __init__(self, tokenizer, target_langs):
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"input_len": inputs["attention_mask"].shape[1],
"target_answer": [ans.lower() for ans in sample["answers"]['text']],
"target_answer": [ans.lower() for ans in sample["answers"]["text"]],
}
)

def __len__(self):
return len(self.items)

def __getitem__(self, index):
return self.items[index]


class TydiqaSecondaryTask(AutoTask):
@staticmethod
def get_display_name() -> str:
return 'tydiqa_secondary'
return "tydiqa_secondary"

def evaluate(self) -> None:
dataset = TyDiQADataset(self.tokenizer, target_langs=["english"])

substring_matches = 0
for sample in tqdm(dataset, desc=f'Evaluating {self.get_display_name()}'):
for sample in tqdm(dataset, desc=f"Evaluating {self.get_display_name()}"):
output = self.model.generate(
input_ids=sample["input_ids"].to(self.device),
attention_mask=sample["attention_mask"].to(self.device),
Expand All @@ -90,6 +89,4 @@ def evaluate(self) -> None:
substring_match = any([target_answer in predicted_answer.lower() for target_answer in target_answers])
substring_matches += substring_match

self.metrics = {
"substring_matches": substring_matches / len(dataset) * 100
}
self.metrics = {"substring_matches": substring_matches / len(dataset) * 100}
4 changes: 2 additions & 2 deletions evaluation/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
def get_logger():
logger = logging.getLogger("evaluation")
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
"%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
return logger
Loading

0 comments on commit 61bd30c

Please sign in to comment.