Skip to content

Commit

Permalink
Merge pull request #63 from tedhtchang/Enable-pylint
Browse files Browse the repository at this point in the history
Enable pylint in the github workflow
  • Loading branch information
anhuong authored Mar 6, 2024
2 parents a93e3bc + a6cfa6a commit 8e0a8f8
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 54 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ jobs:
python -m pip install -r setup_requirements.txt
- name: Check Formatting
run: tox -e fmt
- name: Run pylint
run: tox -e lint

4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ disable=raw-checker-failed,
attribute-defined-outside-init,
abstract-method,
pointless-statement,
wrong-import-order
wrong-import-order,
duplicate-code,
unbalanced-tuple-unpacking

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
26 changes: 18 additions & 8 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_highest_checkpoint(dir_path):
for curr_dir in os.listdir(dir_path):
if curr_dir.startswith("checkpoint"):
if checkpoint_dir:
curr_dir_num = int(checkpoint_dir.split("-")[-1])
curr_dir_num = int(checkpoint_dir.rsplit("-", maxsplit=1)[-1])
new_dir_num = int(curr_dir.split("-")[-1])
if new_dir_num > curr_dir_num:
checkpoint_dir = curr_dir
Expand Down Expand Up @@ -87,13 +87,13 @@ def main():
) = parser.parse_json_file(json_path, allow_extra_keys=True)

contents = ""
with open(json_path, "r") as f:
with open(json_path, "r", encoding="utf-8") as f:
contents = json.load(f)
peft_method_parsed = contents.get("peft_method")
logging.debug(f"Input params parsed: {contents}")
logging.debug("Input params parsed: %s", contents)
elif json_env_var:
job_config_dict = txt_to_obj(json_env_var)
logging.debug(f"Input params parsed: {job_config_dict}")
logging.debug("Input params parsed: %s", job_config_dict)

(
model_args,
Expand All @@ -106,7 +106,8 @@ def main():
peft_method_parsed = job_config_dict.get("peft_method")
else:
raise ValueError(
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)

tune_config = None
Expand All @@ -118,7 +119,12 @@ def main():
tune_config = prompt_tuning_config

logging.debug(
f"Parameters used to launch training: model_args {model_args}, data_args {data_args}, training_args {training_args}, tune_config {tune_config}"
"Parameters used to launch training: \
model_args %s, data_args %s, training_args %s, tune_config %s",
model_args,
data_args,
training_args,
tune_config,
)

original_output_dir = training_args.output_dir
Expand All @@ -138,7 +144,9 @@ def main():
)

logging.info(
f"Merging lora tuned checkpoint {lora_checkpoint_dir} with base model into output path: {export_path}"
"Merging lora tuned checkpoint %s with base model into output path: %s",
lora_checkpoint_dir,
export_path,
)

create_merged_model(
Expand All @@ -151,7 +159,9 @@ def main():
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
logging.info(
f"Copying last checkpoint {pt_checkpoint_dir} into output dir {original_output_dir}"
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
)
shutil.copytree(
os.path.join(training_args.output_dir, pt_checkpoint_dir),
Expand Down
11 changes: 6 additions & 5 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def _apply_config_changes(self, overrides: dict) -> dict:
# If we have no overrides, this context manager is a noop; no need to do anything
if not overrides:
return {}
with open(self.config_path, "r") as config_file:
with open(self.config_path, "r", encoding="utf-8") as config_file:
adapter_config = json.load(config_file)
overridden_values = self._get_old_config_values(adapter_config, overrides)
adapter_config = {**adapter_config, **overrides}
with open(self.config_path, "w") as config_file:
with open(self.config_path, "w", encoding="utf-8") as config_file:
json.dump(adapter_config, config_file, indent=4)
return overridden_values

Expand Down Expand Up @@ -227,7 +227,8 @@ def main():
)
parser.add_argument(
"--base_model_name_or_path",
help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]",
help="Override for base model to be used for non-merged models \
[default: value in model adapter_config.json]",
default=None,
)
parser.add_argument(
Expand Down Expand Up @@ -257,7 +258,7 @@ def main():
if args.text:
texts = [args.text]
else:
with open(args.text_file, "r") as text_file:
with open(args.text_file, "r", encoding="utf-8") as text_file:
texts = [line.strip() for line in text_file.readlines()]

# TODO: we should add batch inference support
Expand All @@ -270,7 +271,7 @@ def main():
]

# Export the results to a file
with open(args.out_file, "w") as out_file:
with open(args.out_file, "w", encoding="utf-8") as out_file:
json.dump(results, out_file, sort_keys=True, indent=4)

print(f"Exported results to: {args.out_file}")
Expand Down
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ allowlist_externals = ./scripts/fmt.sh
[testenv:lint]
description = lint with pylint
deps = pylint>=2.16.2,<=3.1.0
commands = pylint tuning scripts/*.py
-r requirements.txt
commands = pylint tuning scripts/*.py build/*.py
allowlist_externals = pylint
5 changes: 3 additions & 2 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Standard
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
from typing import Optional, Union

# Third Party
import torch
Expand Down Expand Up @@ -64,7 +64,8 @@ class TrainingArguments(transformers.TrainingArguments):
model_max_length: int = field(
default=DEFAULT_CONTEXT_LENGTH,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
"help": "Maximum sequence length. Sequences will be right padded \
(and possibly truncated)."
},
)
packing: bool = field(
Expand Down
6 changes: 4 additions & 2 deletions tuning/config/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ class LoraConfig:
target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"],
metadata={
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D '
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
completely match or "
'end with one of the strings. If the value is ["all-linear"], \
then LORA selects all linear and Conv1D '
"modules except for the output layer."
},
)
Expand Down
10 changes: 1 addition & 9 deletions tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,11 @@
# limitations under the License.

# Standard
from typing import Dict, Sequence
import copy
import json
import logging
from typing import Dict

# Third Party
from torch.utils.data import Dataset
import torch
import transformers

# Local
from tuning.config import configs


def tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
Expand Down
50 changes: 27 additions & 23 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Union
import json
import os
import sys

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand Down Expand Up @@ -94,15 +95,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
return

# append the current log to the jsonl file
with open(log_file, "a") as f:
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")


def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
train_args: configs.TrainingArguments,
peft_config: Optional[
peft_config: Optional[ # pylint: disable=redefined-outer-name
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
):
Expand Down Expand Up @@ -154,9 +155,7 @@ def train(
)

# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, LlamaTokenizer) or isinstance(
tokenizer, LlamaTokenizerFast
):
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
tokenizer.add_special_tokens(
{
"bos_token": "<s>",
Expand All @@ -165,33 +164,36 @@ def train(
"pad_token": "<pad>",
}
)
elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(
tokenizer, GPT2Tokenizer
):
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)

"""TODO: near term - how response template ids are parsed out needs to be cleaned.
The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found.
We will create issue to clean this out after we discuss data formats and collators we will support
"""
# TODO: near term - how response template ids are parsed out needs to be cleaned.
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
# otherwise template is not found. We will create issue to clean this out after we discuss
# data formats and collators we will support.
response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:]
# TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
# as in current main. We need to change name of this parameter we expose to users.
# TODO: This is actually max_seq_length and not model_max_length. we should not override
# model_max_length as in current main. We need to change name of this parameter we expose
# to users.
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
logger.info(f"Model max length {model_max_length}")
logger.info("Model max length %s, model_max_length")
if train_args.model_max_length > tokenizer.model_max_length:
logger.warning(
f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}"
"model_max_length %s exceeds tokenizer.model_max_length \
%s, using tokenizer.model_max_length %s",
train_args.model_max_length,
tokenizer.model_max_length,
tokenizer.model_max_length,
)

# TODO: we need to change this, perhaps follow what open instruct does?
special_tokens_dict = dict()
special_tokens_dict = {}
if tokenizer.pad_token is None:
logger.warning("PAD token set to default, missing in tokenizer")
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
Expand Down Expand Up @@ -219,19 +221,21 @@ def train(
if data_args.validation_data_path:
data_files["validation"] = data_args.validation_data_path

format_dataset = lambda example: {
format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment
f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"]
+ tokenizer.eos_token
}

json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
logger.info(f"Training dataset length is {len(formatted_train_dataset)}")
logger.info("Training dataset length is %s", len(formatted_train_dataset))

formatted_validation_dataset = None
if data_args.validation_data_path:
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")
logger.info(
"Validation dataset length is %s", len(formatted_validation_dataset)
)

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
Expand All @@ -248,13 +252,13 @@ def train(
logger.error(
"Error, response template is None, needs to be set for training"
)
exit(-1)
sys.exit(-1)

if data_args.dataset_text_field is None:
logger.error(
"Error, dataset_text_field is None, needs to be set for training"
)
exit(-1)
sys.exit(-1)

data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids,
Expand Down Expand Up @@ -284,7 +288,7 @@ def train(
trainer.train()


def main(**kwargs):
def main(**kwargs): # pylint: disable=unused-argument
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
Expand Down
5 changes: 2 additions & 3 deletions tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Standard
from typing import Union
import argparse
import json
import os

Expand All @@ -41,7 +40,7 @@ def create_merged_model(
References:
- https://github.com/huggingface/peft/issues/1040
- https://github.com/huggingface/peft/issues/280#issuecomment-1500805831
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter # pylint: disable=line-too-long
Args:
checkpoint_model: Union[str, list[str]]
Expand Down Expand Up @@ -96,7 +95,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
if not os.path.isfile(adapter_config):
raise FileNotFoundError("Unable to locate adapter config to infer base model!")

with open(adapter_config, "r") as cfg:
with open(adapter_config, "r", encoding="utf-8") as cfg:
adapter_dict = json.load(cfg)
if "base_model_name_or_path" not in adapter_dict:
raise KeyError(
Expand Down

0 comments on commit 8e0a8f8

Please sign in to comment.