From 3ca2311797b526181904b83a5697a76ea65b5082 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 24 Oct 2024 13:53:59 -0500 Subject: [PATCH] Choose evaluation example from test set (#1804) --- litgpt/args.py | 5 ++- litgpt/finetune/adapter.py | 3 +- litgpt/finetune/adapter_v2.py | 3 +- litgpt/finetune/full.py | 3 +- litgpt/finetune/lora.py | 4 ++- litgpt/utils.py | 31 +++++++++++++++++++ tests/test_utils.py | 58 +++++++++++++++++++++++++++++++++++ 7 files changed, 102 insertions(+), 5 deletions(-) diff --git a/litgpt/args.py b/litgpt/args.py index f870965358..62c644f423 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -1,7 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import math from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import warnings @@ -85,3 +85,6 @@ class EvalArgs: """Whether to evaluate on the validation set at the beginning of the training""" final_validation: bool = True """Whether to evaluate on the validation set at the end of the training""" + evaluate_example: Union[str, int] = "first" + """How to pick an example instruction to evaluate periodically during training. + Can be "first", "random", or an integer index to pick a specific example.""" diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index 1070ecee02..2f7801b8f1 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -40,6 +40,7 @@ num_parameters, parse_devices, save_hyperparameters, + select_sft_generate_example, ) @@ -381,7 +382,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva # the adapter "kv cache" cannot be initialized under `inference_mode` @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): - instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index 44e05a224f..f05fd0d4d3 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -40,6 +40,7 @@ num_parameters, parse_devices, save_hyperparameters, + select_sft_generate_example, ) @@ -382,7 +383,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva # the adapter "kv cache" cannot be initialized under `inference_mode` @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): - instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index f71e712753..b507aa58e4 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -36,6 +36,7 @@ num_parameters, parse_devices, save_hyperparameters, + select_sft_generate_example, ) @@ -348,7 +349,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): - instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + instruction = select_sft_generate_example(eval, data) fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 6292ea3ef9..af88afb0ec 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -41,6 +41,7 @@ num_parameters, parse_devices, save_hyperparameters, + select_sft_generate_example, ) @@ -413,7 +414,8 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva @torch.no_grad() def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule): - instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + instruction = select_sft_generate_example(eval, data) + fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device) diff --git a/litgpt/utils.py b/litgpt/utils.py index 21e956f47c..10e3831745 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -6,6 +6,7 @@ import math import os import pickle +import random import re import shutil import sys @@ -784,3 +785,33 @@ def create_finetuning_performance_report(training_time, token_counts, device_typ output += "=======================================================\n" return output + + +def select_sft_generate_example(eval, data): + + if eval.evaluate_example == "first": + if len(data.test_dataset.data): + instruction = data.test_dataset.data[0]["instruction"] + else: + instruction = data.train_dataset.data[0]["instruction"] + + elif eval.evaluate_example == "random": + if len(data.test_dataset.data): + random_idx = random.randint(0, len(data.test_dataset.data) - 1) + instruction = data.test_dataset.data[random_idx]["instruction"] + else: + random_idx = random.randint(0, len(data.train_dataset.data) - 1) + instruction = data.train_dataset.data[random_idx]["instruction"] + + elif isinstance(eval.evaluate_example, int): + index = eval.evaluate_example + if len(data.test_dataset.data) > index: + instruction = data.test_dataset.data[index]["instruction"] + elif len(data.train_dataset.data) > index: + instruction = data.train_dataset.data[index]["instruction"] + else: + raise IndexError(f"Index {index} is out of range for both test and training datasets.") + + else: + raise ValueError(f"Unknown evaluation example type: {eval.evaluate_example}") + return instruction diff --git a/tests/test_utils.py b/tests/test_utils.py index 14d4ff4a56..071305c892 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -42,6 +42,7 @@ num_parameters, parse_devices, save_hyperparameters, + select_sft_generate_example, ) @@ -757,3 +758,60 @@ def test_fix_and_load_json(): result_missing_commas = fix_and_load_json(invalid_json_missing_commas) assert result_missing_commas == expected_output_missing_commas + + +def test_select_sft_generate_example(): + eval_mock = mock.MagicMock() + data_mock = mock.MagicMock() + + test_dataset = {"data": [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]} + train_dataset = {"data": [{"instruction": "Train instruction 1"}, {"instruction": "Train instruction 2"}]} + + data_mock.test_dataset.data = test_dataset["data"] + data_mock.train_dataset.data = train_dataset["data"] + + # Test "first" instruction from test dataset + eval_mock.evaluate_example = "first" + instruction = select_sft_generate_example(eval_mock, data_mock) + assert instruction == "Test instruction 1" + + # Test "first" instruction from train dataset when test dataset is empty + data_mock.test_dataset.data = [] + instruction = select_sft_generate_example(eval_mock, data_mock) + assert instruction == "Train instruction 1" + + # Test random selection from test dataset + eval_mock.evaluate_example = "random" + data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}] + with mock.patch('random.randint', return_value=1): + instruction = select_sft_generate_example(eval_mock, data_mock) + assert instruction == "Test instruction 2" + + # Test random selection from train dataset when test dataset is empty + data_mock.test_dataset.data = [] + with mock.patch('random.randint', return_value=1): + instruction = select_sft_generate_example(eval_mock, data_mock) + assert instruction == "Train instruction 2" + + # Test specific index from test dataset + eval_mock.evaluate_example = 1 + data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}] + instruction = select_sft_generate_example(eval_mock, data_mock) + assert instruction == "Test instruction 2" + + # Test specific index from train dataset when test dataset has fewer elements + data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}] + instruction = select_sft_generate_example(eval_mock, data_mock) + assert instruction == "Train instruction 2" + + # Test out-of-range index + eval_mock.evaluate_example = 2 + data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}] + data_mock.train_dataset.data = [{"instruction": "Train instruction 1"}] + with pytest.raises(IndexError): + select_sft_generate_example(eval_mock, data_mock) + + # Test unknown evaluation type + eval_mock.evaluate_example = "unknown" + with pytest.raises(ValueError): + select_sft_generate_example(eval_mock, data_mock)