From 17d69b093334f6b02e556cf6cf0b98c1e0635f6b Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Mon, 22 Jul 2024 11:05:22 +0530 Subject: [PATCH 1/7] fix: Added correct link for the trainer-controller readme, to main readme.md (#254) Signed-off-by: Padmanabha V Seshadri --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d0b3fd3b..579727e2 100644 --- a/README.md +++ b/README.md @@ -547,7 +547,7 @@ Trainer controller is a framework for controlling the trainer loop using user-de This framework helps users define rules to capture scenarios like criteria for stopping an ongoing training (E.g validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing etc). -For details about how you can use set a custom stopping criteria and perform custom operations, see [examples/trainer_controller/README.md](examples/trainer_controller/README.md) +For details about how you can use set a custom stopping criteria and perform custom operations, see [examples/trainercontroller_configs/Readme.md](examples/trainercontroller_configs/Readme.md) ## More Examples From 006f2b93de411af3d5c5ebada505a38b7380d7c9 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Mon, 22 Jul 2024 10:33:54 -0600 Subject: [PATCH 2/7] trainer controller doc updates (#244) * trainer controller doc updates Signed-off-by: Alex-Brooks * Add missing link Signed-off-by: Alex-Brooks --------- Signed-off-by: Alex-Brooks --- examples/trainercontroller_configs/Readme.md | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/trainercontroller_configs/Readme.md b/examples/trainercontroller_configs/Readme.md index 236acb87..174a1ce6 100644 --- a/examples/trainercontroller_configs/Readme.md +++ b/examples/trainercontroller_configs/Readme.md @@ -4,11 +4,11 @@ Trainer controller is a framework for controlling the trainer loop using user-de ### Motivation -This frameworks helps user define rules to capture scenarios like criteria for stopping an ongoing training (E.g validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing etc). +This frameworks helps user define rules to capture scenarios like criteria for stopping an ongoing training (e.g., validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing, etc). ### Usage *Note: Evaluation loss and validation loss are the same.* -1. The trainer controller feature can be used and its behavior is controlled by a configuration file (we will illustrate the configuration file below) supplied by the user at the start of the training. Here is a sample of how the user can initiate a trainer controller for a training job, by specifying path to an existing configuration `loss.yaml` in the `./examples/trainercontroller_configs` directory using the flag `--trainer_controller_config_file`: +1. The trainer controller feature can be controlled by a configuration file supplied by the user at the start of the training. Here is a sample of how the user can initiate a trainer controller for a training job, by specifying path to an existing configuration `loss.yaml` in the `./examples/trainercontroller_configs` directory using the flag `--trainer_controller_config_file`: ```shell python ./tuning/sft_trainer.py \ ... @@ -32,15 +32,15 @@ This frameworks helps user define rules to capture scenarios like criteria for s operations: - hfcontrols.should_training_stop ``` - Here is a brief primer on the above configuration. More details could be found [here](./architecture_records/001-trainer-controller-framework.md). + Here is a brief primer on the above configuration. More details could be found [here](../../architecture_records/001-trainer-controller-framework.md). Note that in the following descriptions, we use `metric` and `metric handler` interchangeably to describe a class which exposes numeric information about the training state / relevant computations for use in a `rule` for early termination. - *Description:* The above configuration stops the training when a **evaluation loss** decreases below 2.25 after two epochs. - - *Metrics:* The configuration uses two metrics listed under `controller-metrics` section. One is named `evalmetric`, which uses an in-built metric class called `EvalMetrics` to expose evaluation loss and the other (`trainer_state`) uses `TrainingState` to expose the current epoch. These are referred to in the `rule` as shown above. There are other metrics also which could be used in place of `evalmetric` and . Here is a list of supported metric classes: - - `Loss`: Exposes the **training loss** after every `on_log` event. See more on trainer events [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback). - - `TrainerState`: This metric exposes the **trainer state** (more on trainer state can be found [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerState)). [Here](tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml) is an example metric which uses both the `TrainerState` and `Loss` metric. - - `EvalMetrics`: This metric exposes all the evaluation metrics used in the training job (E.g evaluation/validation loss). [Here](tests/data/trainercontroller/exposed_metrics.yaml) is an example metric which uses both the `EvalMetrics`. + - *Metrics:* The configuration uses two metrics listed under `controller_metrics` section. One is named `evalmetric`, which uses the built-in metric handler class `EvalMetrics` to expose evaluation loss, and the other, `trainer_state`, uses the built-in metric handler class `TrainingState` to expose the current epoch. These are referred to in the `rule` as shown above. There are other metrics that could also be used in place of `evalmetric` and `trainer_state`. At the time of writing, the supported metric handler classes are as follows: + - `Loss`: This metric handler exposes the **training loss** after every `on_log` event. See more on trainer events [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback). + - `TrainerState`: This metric exposes the **trainer state** (more on trainer state can be found [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerState)). [Here](../../tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml) is an example metric which uses both the `TrainerState` and `Loss` metric. + - `EvalMetrics`: This metric exposes all the evaluation metrics used in the training job (E.g evaluation/validation loss). [Here](../../tests/data/trainercontroller/exposed_metrics.yaml) is an example config which uses the `EvalMetric`'s `eval_loss`. - `HistoryBasedMetric`: This metric exposes a moving **window** of evaluation metrics and training loss. It is useful to create rules on a history of values (i.e. evaluation metrics and training loss). Following are some examples which illustrate how this metric could be used: - - [epoch-level-eval-loss-patience.yaml](tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml): This configuration performs a threshold test for evaluation loss with a **patience threshold** of 2. I.e suppose the evaluation loss lower threshold is 2, and patience threshold is 3, then the trainer controller will not take an action (E.g. stop training) when the rule becomes true (i.e. evaluation loss is lower than 2) for for three consecutive times. - - [non-decreasing-training-loss.yaml](tests/data/trainercontroller/non-decreasing-training-loss.yaml): This configuration compares the first and last values of a window of training loss samples and determines if the training loss has increased or not. If there is an increase, the training is stopped. + - [epoch-level-eval-loss-patience.yaml](../../tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml): This configuration performs a threshold test for evaluation loss with a **patience threshold** of 2. I.e., suppose the evaluation loss lower threshold is 2, and patience threshold is 3, then the trainer controller will not take an action, e.g., stop the training, when the rule becomes true. i.e., evaluation loss is lower than 2, three consecutive times. + - [non-decreasing-training-loss.yaml](../../tests/data/trainercontroller/non-decreasing-training-loss.yaml): This configuration compares the first and last values of a window of training loss samples and determines if the training loss has increased or not. If there is an increase, the training is stopped. Let us assume use the below example to understand the usage: ```yaml @@ -80,10 +80,10 @@ This frameworks helps user define rules to capture scenarios like criteria for s ``` 1. To access the first value in window of evaluation metric `eval_loss`, here is the illustration `history_window["metrics"]["eval_loss"][0]`. In the above YAML, the last element is accessed as follows: `history_window["metrics"]["eval_loss"][-1]`. 1. Similarly, the `history_window["metrics"]["global_step"][0]` is global_step at the time of generation of this evaluation metric and `history_window["metrics"]["epoch"][0]` is the corresponding epoch. - 1. Similar approach is followed to access training loss (i.e. `history_window["training_loss"]["loss"][0]` givest the first training loss). + 1. A similar approach is followed to access training loss (i.e., `history_window["training_loss"]["loss"][0]` gives the first training loss). - - *Trigger:* There is also a trigger event to decide when the `rule` needs to be evaluated. This event has to be one of the trainer events listed [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback). - - *Rule:* The `rule` is a python statement which could use the metric name (e.g. `loss` in the above case) to define conditions which, when satisfied (it is a boolean condition and should evaluate to True to be satisfied) will trigger the operation(s) listed in `operations`. + - *Trigger:* There is also a trigger event to decide *when* the `rule` needs to be evaluated. This event has to be one of the trainer events listed [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerCallback). The choice of even to trigger on allows for more control, e.g., controlling the times at which we should consider early training termination. + - *Rule:* The `rule` is a python statement which could use the metric name, e.g., `loss` in the above case, to define boolean conditions which, when satisfied, will trigger the operation(s) listed in `operations`. - *Operation:* The `operations` section lists the operations that could be performed when the `rule` is satisfied (i.e. condition becomes True). Currently, we support only one type of operation class `HFControls` (In this particular example, the class and corresponding operation name `hfcontrols` are not specified explicitly as they are considered default and can be omitted). The `HFControls` class supports all operations listed below. More on these operations can be found [here](https://huggingface.co/docs/transformers/v4.41.3/en/main_classes/callback#transformers.TrainerControl). - `hfcontrols.should_training_stop`: Stops the training. - `hfcontrols.should_epoch_stop`: Interrupts the current epoch. From 297544e9cf9cb08549d6ba566568969433f7a8cd Mon Sep 17 00:00:00 2001 From: Hari Date: Thu, 25 Jul 2024 12:00:32 +0530 Subject: [PATCH 3/7] docs: fix the instructions for running with LORA (#265) Signed-off-by: Harikrishnan Balagopal --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 579727e2..ac65f1c2 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ python tuning/sft_trainer.py \ --r 8 \ --lora_dropout 0.05 \ --lora_alpha 16 \ ---target_modules ["c_attn", "c_proj"] +--target_modules c_attn c_proj ``` Equally you can pass in a JSON configuration for running tuning. See [build doc](./build/README.md) for more details. The above can also be passed in as JSON: From 6d15cf9595265f50e17b96ae994adb1aeca3bd53 Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Thu, 25 Jul 2024 15:45:08 -0600 Subject: [PATCH 4/7] refactor code to preprocess datasets (#259) * refactor code to preprocess datasets Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * fix formatting Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 * code cleanup Co-authored-by: Alex-Brooks Signed-off-by: Sukriti-Sharma4 --------- Signed-off-by: Sukriti-Sharma4 Co-authored-by: Alex-Brooks --- tests/data/__init__.py | 1 + tests/test_sft_trainer.py | 2 +- tests/utils/test_data_utils.py | 23 ++++--- tests/utils/test_preprocessing_utils.py | 63 ++++++++++++++++++- tuning/sft_trainer.py | 61 ++++-------------- tuning/utils/data_utils.py | 15 +++-- tuning/utils/preprocessing_utils.py | 84 +++++++++++++++++++++---- 7 files changed, 170 insertions(+), 79 deletions(-) diff --git a/tests/data/__init__.py b/tests/data/__init__.py index 1d086821..cf88ece9 100644 --- a/tests/data/__init__.py +++ b/tests/data/__init__.py @@ -26,3 +26,4 @@ TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(DATA_DIR, "twitter_complaints_json.json") EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") +MODEL_NAME = "Maykeye/TinyLLama-v0" diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 57ff216c..7c96ccce 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -33,6 +33,7 @@ from tests.data import ( EMPTY_DATA, MALFORMATTED_DATA, + MODEL_NAME, TWITTER_COMPLAINTS_DATA, TWITTER_COMPLAINTS_JSON_FORMAT, ) @@ -41,7 +42,6 @@ from tuning import sft_trainer from tuning.config import configs, peft_config -MODEL_NAME = "Maykeye/TinyLLama-v0" MODEL_ARGS = configs.ModelArguments( model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" ) diff --git a/tests/utils/test_data_utils.py b/tests/utils/test_data_utils.py index 471f2859..f027d9fd 100644 --- a/tests/utils/test_data_utils.py +++ b/tests/utils/test_data_utils.py @@ -34,12 +34,13 @@ def test_apply_custom_formatting_template(): "### Input: @HMRCcustomers No this is my first job" + " \n\n ### Response: no complaint" ) - formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( - json_dataset, template + formatted_dataset_field = "formatted_data_field" + formatted_dataset = data_utils.apply_custom_formatting_template( + json_dataset, template, formatted_dataset_field ) # a new dataset_text_field is created in Dataset - assert dataset_text_field in formatted_dataset["train"][0] - assert formatted_dataset["train"][0][dataset_text_field] == expected_response + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response def test_apply_custom_formatting_template_adds_eos_token(): @@ -50,17 +51,21 @@ def test_apply_custom_formatting_template_adds_eos_token(): "### Input: @HMRCcustomers No this is my first job" + " \n\n ### Response: no complaintEOS" ) - formatted_dataset, dataset_text_field = data_utils.apply_custom_formatting_template( - json_dataset, template, "EOS" + formatted_dataset_field = "formatted_data_field" + formatted_dataset = data_utils.apply_custom_formatting_template( + json_dataset, template, formatted_dataset_field, "EOS" ) # a new dataset_text_field is created in Dataset - assert dataset_text_field in formatted_dataset["train"][0] - assert formatted_dataset["train"][0][dataset_text_field] == expected_response + assert formatted_dataset_field in formatted_dataset["train"][0] + assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): """Tests that the formatting function will throw error if wrong keys are passed to template""" json_dataset = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA) template = "### Input: {{not found}} \n\n ### Response: {{text_label}}" + formatted_dataset_field = "formatted_data_field" with pytest.raises(KeyError): - data_utils.apply_custom_formatting_template(json_dataset, template, "EOS") + data_utils.apply_custom_formatting_template( + json_dataset, template, formatted_dataset_field, "EOS" + ) diff --git a/tests/utils/test_preprocessing_utils.py b/tests/utils/test_preprocessing_utils.py index 7a807da9..e24cf710 100644 --- a/tests/utils/test_preprocessing_utils.py +++ b/tests/utils/test_preprocessing_utils.py @@ -8,6 +8,7 @@ # First Party from tests.data import ( MALFORMATTED_DATA, + MODEL_NAME, TWITTER_COMPLAINTS_DATA, TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, ) @@ -16,7 +17,9 @@ from tuning.config import configs from tuning.utils.preprocessing_utils import ( combine_sequence, + format_dataset, get_data_trainer_kwargs, + get_formatted_dataset_with_single_sequence, get_preprocessed_dataset, load_hf_dataset_from_jsonl_file, validate_data_args, @@ -84,7 +87,7 @@ def test_load_hf_dataset_from_jsonl_file_duplicate_keys(): # Tests for custom masking / preprocessing logic @pytest.mark.parametrize("max_sequence_length", [1, 10, 100, 1000]) def test_get_preprocessed_dataset(max_sequence_length): - tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) preprocessed_data = get_preprocessed_dataset( data_path=TWITTER_COMPLAINTS_DATA, tokenizer=tokenizer, @@ -128,7 +131,7 @@ def test_get_trainer_kwargs_with_response_template_and_text_field( packing=packing, response_template="\n### Label:", max_sequence_length=100, - tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"), + tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), dataset_text_field="output", ) assert len(trainer_kwargs) == 3 @@ -161,7 +164,7 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): packing=False, response_template=None, max_sequence_length=100, - tokenizer=AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0"), + tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME), dataset_text_field=None, ) assert len(trainer_kwargs) == 4 @@ -207,3 +210,57 @@ def test_get_trainer_kwargs_with_custom_masking(use_validation_data): def test_validate_args(data_args, packing): with pytest.raises(ValueError): validate_data_args(data_args, packing) + + +@pytest.mark.parametrize( + "data_path, dataset_text_field, data_formatter_template", + [ + (TWITTER_COMPLAINTS_DATA, "output", None), + ( + TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + "formatted_field", + "### Text:{{input}} \n\n### Label: {{output}}", + ), + ], +) +def test_get_formatted_dataset_with_single_sequence( + data_path, dataset_text_field, data_formatter_template +): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + formatted_dataset = get_formatted_dataset_with_single_sequence( + data_path, dataset_text_field, tokenizer, data_formatter_template + ) + assert isinstance(formatted_dataset, Dataset) + assert dataset_text_field in formatted_dataset.column_names + + +@pytest.mark.parametrize( + "data_args", + [ + # single sequence and response template + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA, + validation_data_path=TWITTER_COMPLAINTS_DATA, + dataset_text_field="output", + response_template="\n### Label:", + ) + ), + # data formatter template with input/output JSON + ( + configs.DataArguments( + training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT, + dataset_text_field="formatted_field", + data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}", + ) + ), + ], +) +def test_format_dataset(data_args): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + train_set, eval_set, dataset_text_field = format_dataset(data_args, tokenizer) + assert isinstance(train_set, Dataset) + assert isinstance(eval_set, Dataset) + assert dataset_text_field in train_set.column_names + assert dataset_text_field in eval_set.column_names diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 6e7f2eb6..0e360ad4 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -35,7 +35,6 @@ ) from transformers.utils import is_accelerate_available, logging from trl import SFTConfig, SFTTrainer -import datasets import fire import transformers @@ -56,13 +55,16 @@ from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config, get_json_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.utils.data_utils import apply_custom_formatting_template from tuning.utils.error_logging import ( INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE, write_termination_log, ) -from tuning.utils.preprocessing_utils import get_data_collator, validate_data_args +from tuning.utils.preprocessing_utils import ( + format_dataset, + get_data_collator, + validate_data_args, +) def train( @@ -261,52 +263,13 @@ def train( # Validate if data args are set properly validate_data_args(data_args, packing) - data_collator = get_data_collator(packing, data_args.response_template, tokenizer) - - # load the data by parsing JSON - ### TODO: all the jSON file formatting will be moved to a separate function - data_files = {"train": data_args.training_data_path} - if data_args.validation_data_path: - data_files["validation"] = data_args.validation_data_path - format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment - f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"] - + tokenizer.eos_token - } - - json_dataset = datasets.load_dataset("json", data_files=data_files) - if data_args.data_formatter_template: - ( - formatted_train_dataset, - data_args.dataset_text_field, - ) = apply_custom_formatting_template( - json_dataset["train"], - data_args.data_formatter_template, - tokenizer.eos_token, - ) - else: - formatted_train_dataset = json_dataset["train"].map(format_dataset) - logger.info("Training dataset length is %s", len(formatted_train_dataset)) - - formatted_validation_dataset = None - if data_args.validation_data_path: - if data_args.data_formatter_template: - ( - formatted_validation_dataset, - data_args.dataset_text_field, - ) = apply_custom_formatting_template( - json_dataset["validation"], - data_args.data_formatter_template, - tokenizer.eos_token, - ) - else: - formatted_validation_dataset = json_dataset["validation"].map( - format_dataset - ) - logger.info( - "Validation dataset length is %s", len(formatted_validation_dataset) - ) - ### JSON file formatting ends here + ( + formatted_train_dataset, + formatted_validation_dataset, + dataset_text_field, + ) = format_dataset(data_args, tokenizer) + data_collator = get_data_collator(packing, data_args.response_template, tokenizer) if framework is not None and framework.requires_agumentation: model, (peft_config,) = framework.augmentation( @@ -337,7 +300,7 @@ def train( eval_dataset=formatted_validation_dataset, packing=packing, data_collator=data_collator, - dataset_text_field=data_args.dataset_text_field, + dataset_text_field=dataset_text_field, args=training_args, max_seq_length=max_seq_length, callbacks=trainer_callbacks, diff --git a/tuning/utils/data_utils.py b/tuning/utils/data_utils.py index 3e67cc56..db5ff0f0 100644 --- a/tuning/utils/data_utils.py +++ b/tuning/utils/data_utils.py @@ -2,21 +2,28 @@ import re -def apply_custom_formatting_template(dataset, template, eos_token=""): +def apply_custom_formatting_template( + dataset, template, formatted_dataset_field, eos_token="" +): """Function to format datasets with Alpaca style / other templates. Args: dataset: the HF Dataset element loaded from a JSON or DatasetDict object. template: Template to format data with. Features of Dataset should be referred to by {{key}} + formatted_dataset_field: Dataset_text_field eos_token: string EOS token to be appended while formatting data to a single sequence. Defaults to empty Returns: - Formatted HF Dataset, dataset_field name that contains formatted data. + Formatted HF Dataset """ - formatted_dataset_field = "formatted_data_field" template += eos_token + if not formatted_dataset_field: + raise ValueError( + "Unable to apply custom formatting because the formatted_dataset_field was not provided" + ) + def formatter(element): def replace_text(match_obj): captured_groups = match_obj.groups() @@ -37,4 +44,4 @@ def replace_text(match_obj): ) } - return dataset.map(formatter), formatted_dataset_field + return dataset.map(formatter) diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py index 545e1635..88db911a 100644 --- a/tuning/utils/preprocessing_utils.py +++ b/tuning/utils/preprocessing_utils.py @@ -18,11 +18,15 @@ # Third Party from datasets import Dataset from transformers import AutoTokenizer, DataCollatorForSeq2Seq +from transformers.utils import logging from trl import DataCollatorForCompletionOnlyLM import datasets # Local from tuning.config import configs +from tuning.utils.data_utils import apply_custom_formatting_template + +logger = logging.get_logger("sft_trainer_preprocessing") def validate_data_args(data_args: configs.DataArguments, packing: bool): @@ -110,6 +114,42 @@ def get_data_collator( # tokenizer=tokenizer, padding=True, max_length=max_sequence_length # ) # 2. add anything needed for preprocessed input + raise ValueError( + "Could not pick a data collator. Please refer to supported data formats" + ) + + +def format_dataset(data_args: configs.DataArguments, tokenizer: AutoTokenizer): + """ + Args: + data_args: tuning.config.configs.DataArguments + tokenizer: AutoTokenizer + Returns: + Tuple(Dataset, Dataset, str) + tuple containing train_dataset, eval_dataset and dataset_text_field + """ + eval_dataset = None + dataset_text_field = data_args.dataset_text_field + if data_args.data_formatter_template or dataset_text_field: + if dataset_text_field is None: + dataset_text_field = "new_formatted_field" + train_dataset = get_formatted_dataset_with_single_sequence( + data_args.training_data_path, + dataset_text_field, + tokenizer, + data_args.data_formatter_template, + ) + logger.info("Training dataset length is %s", len(train_dataset)) + if data_args.validation_data_path: + (eval_dataset) = get_formatted_dataset_with_single_sequence( + data_args.validation_data_path, + dataset_text_field, + tokenizer, + data_args.data_formatter_template, + ) + logger.info("Validation dataset length is %s", len(eval_dataset)) + # TODO: add a else here for preprocessing + return train_dataset, eval_dataset, dataset_text_field ################################################################################### @@ -222,13 +262,11 @@ def get_data_trainer_kwargs( output_field_name="output", ) else: - # Collator is a DataCollatorForCompletionOnlyLM or None; - # Load it as JSON and apply our normal preprocessing logic - train_dataset = get_formatted_dataset( + train_dataset = get_formatted_dataset_with_single_sequence( training_data_path, dataset_text_field, tokenizer ) if validation_data_path: - eval_dataset = get_formatted_dataset( + eval_dataset = get_formatted_dataset_with_single_sequence( validation_data_path, dataset_text_field, tokenizer ) @@ -238,8 +276,11 @@ def get_data_trainer_kwargs( return data_kwargs -def get_formatted_dataset( - data_path: str, dataset_text_field: str, tokenizer: AutoTokenizer +def get_formatted_dataset_with_single_sequence( + data_path: str, + dataset_text_field: str, + tokenizer: AutoTokenizer, + data_formatter_template: Optional[str] = None, ) -> Dataset: """Applies formatting to the loaded dataset instance; does NOT pretokenize data. @@ -247,21 +288,38 @@ def get_formatted_dataset( data_path: str Path to the file to be loaded. dataset_text_field: str - Dataset text field fto be used for formatting by TRL. + Dataset text field to be used for formatting. + If data_formatter_template specified, \ + this will be the new field creating single sequence. tokenizer: AutoTokenizer Loaded tokenizer object to be used by the collator. + data_formatter_template: str + Template to apply to create single sequence and store it in dataset_text_field Returns: Dataset HF Dataset with formatted [str] data. """ - format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment - f"{dataset_text_field}": example[f"{dataset_text_field}"] + tokenizer.eos_token - } + json_dataset = datasets.load_dataset("json", data_files=data_path) - return json_dataset.map(format_dataset)[ - "train" - ] # HACK - for now, we just do both datasets separately; train is the default split + format_dataset_EOS = ( + lambda example: { # pylint: disable=unnecessary-lambda-assignment + f"{dataset_text_field}": example[f"{dataset_text_field}"] + + tokenizer.eos_token + } + ) + if data_formatter_template: + formatted_train_dataset = apply_custom_formatting_template( + json_dataset["train"], + data_formatter_template, + dataset_text_field, + tokenizer.eos_token, + ) + else: + formatted_train_dataset = json_dataset.map(format_dataset_EOS)[ + "train" + ] # HACK - for now, we just do both datasets separately; train is the default split + return formatted_train_dataset def get_preprocessed_dataset( From 7dfd4e71a0ded17ab65654925e18bf9a1d76b0fc Mon Sep 17 00:00:00 2001 From: Joe Olson <118190512+olson-ibm@users.noreply.github.com> Date: Fri, 26 Jul 2024 11:24:52 -0500 Subject: [PATCH 5/7] Replace shutil.copytree() to fix permission error (#251) * Closes #1089 Signed-off-by: Joe Olson * added unit tests, fixed other issues raised by review. Signed-off-by: Joe Olson * Closes 1089 Signed-off-by: Joe Olson * Closes #1089 Signed-off-by: Joe Olson * Closes #1089 Signed-off-by: Joe Olson * Closes #1089 Signed-off-by: Joe Olson * Closes #1089 Signed-off-by: Joe Olson * Closes #1089 Signed-off-by: Joe Olson * Closes #1089 Signed-off-by: Joe Olson --------- Signed-off-by: Joe Olson Co-authored-by: Anh Uong --- build/accelerate_launch.py | 7 +- build/utils.py | 17 +++++ tests/build/test_utils.py | 132 +++++++++++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 4 deletions(-) diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 9816dde0..9af5ad80 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -35,6 +35,7 @@ process_accelerate_launch_args, serialize_args, get_highest_checkpoint, + copy_checkpoint, ) from tuning.utils.config_utils import get_json_config from tuning.config.tracker_configs import FileLoggingTrackerConfig @@ -124,10 +125,8 @@ def main(): pt_checkpoint_dir, original_output_dir, ) - shutil.copytree( - os.path.join(tempdir, pt_checkpoint_dir), - original_output_dir, - dirs_exist_ok=True, + copy_checkpoint( + os.path.join(tempdir, pt_checkpoint_dir), original_output_dir ) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) diff --git a/build/utils.py b/build/utils.py index 96c432d9..3d2f89d1 100644 --- a/build/utils.py +++ b/build/utils.py @@ -21,6 +21,23 @@ # Third Party import torch from accelerate.commands.launch import launch_command_parser +import shutil + + +def copy_checkpoint(source, destination): + if not os.path.exists(destination): + os.makedirs(destination) + shutil.copystat(source, destination) + # Have a list of directory objects, now iterate over them. + for item in os.listdir(source): + source_file = os.path.join(source, item) + destination_file = os.path.join(destination, item) + if os.path.isdir(source_file): + # recursive call for subdirectories + copy_checkpoint(source_file, destination_file) + else: + # straight copy. + shutil.copy2(source_file, destination_file) def get_highest_checkpoint(dir_path): diff --git a/tests/build/test_utils.py b/tests/build/test_utils.py index dddd861c..fde0ffb2 100644 --- a/tests/build/test_utils.py +++ b/tests/build/test_utils.py @@ -17,12 +17,15 @@ import json import os from unittest.mock import patch +import tempfile # Third Party import pytest +import filecmp # Local from build.utils import process_accelerate_launch_args +from build.utils import copy_checkpoint HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( os.path.dirname(__file__), "dummy_job_config.json" @@ -108,3 +111,132 @@ def test_process_accelerate_launch_custom_config_file(patch_path_exists): temp_job_config = {"accelerate_launch_args": {"config_file": dummy_config_path}} args = process_accelerate_launch_args(temp_job_config) assert args.config_file == dummy_config_path + + +class CopyCheckpointTestConfig: + def __init__(self, temp_root): + + # Create the following file tree for testing: + # test_root + # test_copytree_source + # tf1.txt + # tf2.txt + # tf3.txt + # subdir1 + # tf4.txt + # tf5.txt + # tf6.txt + + self.test_root = temp_root + self.source_dir = os.path.join(self.test_root, "test_copytree_source") + self.source_sub_dir = os.path.join(self.source_dir, "subdir1") + + os.mkdir(self.source_dir) + for file_number in range(2): + with open( + os.path.join(self.source_dir, f"tf{file_number+1}.txt"), + "a", + encoding="utf-8", + ) as f: + f.close() + + os.mkdir(self.source_sub_dir) + for file_number in range(2): + with open( + os.path.join(self.source_sub_dir, f"tf{file_number+4}.txt"), + "a", + encoding="utf-8", + ) as f: + f.close() + + def are_dir_trees_equal(self, dir1, dir2): + + dirs_cmp = filecmp.dircmp(dir1, dir2) + if ( + len(dirs_cmp.left_only) > 0 + or len(dirs_cmp.right_only) > 0 + or len(dirs_cmp.funny_files) > 0 + ): + return False + (_, mismatch, errors) = filecmp.cmpfiles( + dir1, dir2, dirs_cmp.common_files, shallow=False + ) + if len(mismatch) > 0 or len(errors) > 0: + return False + for common_dir in dirs_cmp.common_dirs: + new_dir1 = os.path.join(dir1, common_dir) + new_dir2 = os.path.join(dir2, common_dir) + if not self.are_dir_trees_equal(new_dir1, new_dir2): + return False + return True + + +def test_copy_checkpoint_dest_dir_does_not_exist(): + + # Init source directory + with tempfile.TemporaryDirectory() as test_root: + config = CopyCheckpointTestConfig(test_root) + + target_dir_does_not_exist = os.path.join( + config.test_root, "test_copytree_target" + ) + + # Execute the copy + copy_checkpoint(config.source_dir, target_dir_does_not_exist) + assert config.are_dir_trees_equal(config.source_dir, target_dir_does_not_exist) + + +def test_copy_checkpoint_dest_dir_does_exist(): + + # Init source directory + with tempfile.TemporaryDirectory() as test_root: + config = CopyCheckpointTestConfig(test_root) + + # Init target directory + target_dir_does_exist = os.path.join(config.test_root, "test_copytree_target2") + os.mkdir(target_dir_does_exist) + # Add a file to the target. This file will be overwritten during the copy. + with open( + os.path.join(target_dir_does_exist, "tf1.txt"), + "a", + encoding="utf-8", + ) as f: + f.close() + # Add a file to the target. This file does not exist in source. + with open( + os.path.join(target_dir_does_exist, "tf9.txt"), + "a", + encoding="utf-8", + ) as f: + f.close() + # Execute the copy + copy_checkpoint(config.source_dir, target_dir_does_exist) + assert os.path.exists(os.path.join(target_dir_does_exist, "tf9.txt")) + # Remove it so we can validate the dir trees are equal. + os.remove(os.path.join(target_dir_does_exist, "tf9.txt")) + assert config.are_dir_trees_equal(config.source_dir, target_dir_does_exist) + + +def test_copy_checkpoint_dest_dir_not_writeable(): + + # Init source directory + with tempfile.TemporaryDirectory() as test_root: + config = CopyCheckpointTestConfig(test_root) + + # Init target directory + target_dir_not_writeable = os.path.join( + config.test_root, "test_copytree_notwriteable" + ) + + os.makedirs(target_dir_not_writeable, mode=0o446) + + # Execute the copy. Should FAIL + with pytest.raises(PermissionError) as e: + copy_checkpoint(config.source_dir, target_dir_not_writeable) + assert "Permission denied:" in str(e.value) + + +def test_copy_checkpoint_source_dir_does_not_exist(): + with pytest.raises(FileNotFoundError) as e: + copy_checkpoint("/doesnotexist", "/tmp") + assert "No such file or directory" in str(e.value) From 55ca61219264a26ce68dbbcdcc463ae6af3ec3fa Mon Sep 17 00:00:00 2001 From: Hari Date: Mon, 29 Jul 2024 16:56:38 +0530 Subject: [PATCH 6/7] fix: logic for getting tracker config (#267) Signed-off-by: Harikrishnan Balagopal --- tuning/trackers/tracker_factory.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 4196705a..98771c14 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -149,10 +149,9 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): C = meta["config"] T = meta["tracker"] - if tracker_configs is not None: - _conf = _get_tracker_config_by_name(name, tracker_configs) - if _conf is not None: - config = C(**_conf) - else: - config = C() + _conf = _get_tracker_config_by_name(name, tracker_configs) + if _conf is not None: + config = C(**_conf) + else: + config = C() return T(config) From 537215f6c5be6a124c422a07eae790c13be46c7e Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Mon, 29 Jul 2024 14:20:02 -0600 Subject: [PATCH 7/7] fix: remove lm_head for granite with llama arch models (#258) * initial code for deleting lm_head Signed-off-by: Anh-Uong * fix logic for copying checkpoint Signed-off-by: Anh-Uong * fix check that embed_tokens and lm_head weights are the same Signed-off-by: Anh-Uong * fix warning assertion Signed-off-by: Anh-Uong * fix lm_head check, remove test Signed-off-by: Anh-Uong * small fixes from code review Signed-off-by: Anh-Uong * fmt Signed-off-by: Anh-Uong --------- Signed-off-by: Anh-Uong Co-authored-by: Anh-Uong --- build/accelerate_launch.py | 103 +++++++++++++++++++++++++++++++++---- 1 file changed, 94 insertions(+), 9 deletions(-) diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 9af5ad80..ee8718b5 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -26,9 +26,13 @@ import tempfile import shutil from pathlib import Path +import json # Third Party from accelerate.commands.launch import launch_command +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel +from torch import bfloat16 # Local from build.utils import ( @@ -44,10 +48,18 @@ USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) +from tuning.data import tokenizer_data_utils ERROR_LOG = "/dev/termination-log" +def get_base_model_from_adapter_config(adapter_config): + """Given path to adapter_config.json file, returns the base model name""" + with open(adapter_config, "r", encoding="utf-8") as config_file: + adapter_config = json.load(config_file) + return adapter_config.get("base_model_name_or_path") + + def main(): LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() logging.basicConfig(level=LOGLEVEL) @@ -118,16 +130,89 @@ def main(): sys.exit(INTERNAL_ERROR_EXIT_CODE) try: - # copy last checkpoint into mounted output dir - pt_checkpoint_dir = get_highest_checkpoint(tempdir) - logging.info( - "Copying last checkpoint %s into output dir %s", - pt_checkpoint_dir, - original_output_dir, - ) - copy_checkpoint( - os.path.join(tempdir, pt_checkpoint_dir), original_output_dir + last_checkpoint_dir = get_highest_checkpoint(tempdir) + last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir) + + use_flash_attn = job_config.get("use_flash_attn", True) + adapter_config_path = os.path.join( + last_checkpoint_path, "adapter_config.json" ) + tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path) + + if os.path.exists(adapter_config_path): + base_model_path = get_base_model_from_adapter_config( + adapter_config_path + ) + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + + # since the peft library (PEFTModelForCausalLM) does not handle cases + # where the model's layers are modified, in our case the embedding layer + # is modified, so we resize the backbone model's embedding layer with our own + # utility before passing it along to load the PEFT model. + tokenizer_data_utils.tokenizer_and_embedding_resize( + {}, tokenizer=tokenizer, model=base_model + ) + model = PeftModel.from_pretrained( + base_model, + last_checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + last_checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + + model_arch = model.config.model_type + # check that it is a granite model with llama architecture with tied weights + # ie. lm_head is duplicate of embeddings + + # a fine tuned model will have params_dict.get("model.embed_tokens.weight") + # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") + # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") + copy_checkpoint_bool = True + if model_arch == "llama" and hasattr(model, "lm_head"): + if ( + # lora tuned model has an addt model layer + ( + hasattr(model.model, "model") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() + ) + # prompt tuned model or fine tuned model + or ( + hasattr(model.model, "embed_tokens") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.embed_tokens.weight.untyped_storage().data_ptr() + ) + ): + + copy_checkpoint_bool = False + logging.info("Removing lm_head from checkpoint") + del model.lm_head.weight + + if hasattr(model, "lm_head.weight"): + logging.warning("Failed to delete lm_head.weight from model") + + logging.info("Saving checkpoint to %s", original_output_dir) + model.save_pretrained(original_output_dir) + # save tokenizer with model + tokenizer.save_pretrained(original_output_dir) + + # copy last checkpoint into mounted output dir + if copy_checkpoint_bool: + logging.info( + "Copying last checkpoint %s into output dir %s", + last_checkpoint_dir, + original_output_dir, + ) + copy_checkpoint(last_checkpoint_path, original_output_dir) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log(