Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 24, 2024
2 parents 2ac93a4 + 3ca2311 commit 32673d0
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 5 deletions.
5 changes: 4 additions & 1 deletion litgpt/args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union
import warnings


Expand Down Expand Up @@ -85,3 +85,6 @@ class EvalArgs:
"""Whether to evaluate on the validation set at the beginning of the training"""
final_validation: bool = True
"""Whether to evaluate on the validation set at the end of the training"""
evaluate_example: Union[str, int] = "first"
"""How to pick an example instruction to evaluate periodically during training.
Can be "first", "random", or an integer index to pick a specific example."""
3 changes: 2 additions & 1 deletion litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)


Expand Down Expand Up @@ -381,7 +382,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
Expand Down
3 changes: 2 additions & 1 deletion litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)


Expand Down Expand Up @@ -382,7 +383,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva
# the adapter "kv cache" cannot be initialized under `inference_mode`
@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
Expand Down
3 changes: 2 additions & 1 deletion litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)


Expand Down Expand Up @@ -348,7 +349,7 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva

@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
instruction = select_sft_generate_example(eval, data)
fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
Expand Down
4 changes: 3 additions & 1 deletion litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)


Expand Down Expand Up @@ -413,7 +414,8 @@ def validate(fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: Eva

@torch.no_grad()
def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
instruction = select_sft_generate_example(eval, data)

fabric.print(instruction)
prompt = data.prompt_style.apply(instruction)
encoded = tokenizer.encode(prompt, device=fabric.device)
Expand Down
31 changes: 31 additions & 0 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
import os
import pickle
import random
import re
import shutil
import sys
Expand Down Expand Up @@ -784,3 +785,33 @@ def create_finetuning_performance_report(training_time, token_counts, device_typ
output += "=======================================================\n"

return output


def select_sft_generate_example(eval, data):

if eval.evaluate_example == "first":
if len(data.test_dataset.data):
instruction = data.test_dataset.data[0]["instruction"]
else:
instruction = data.train_dataset.data[0]["instruction"]

elif eval.evaluate_example == "random":
if len(data.test_dataset.data):
random_idx = random.randint(0, len(data.test_dataset.data) - 1)
instruction = data.test_dataset.data[random_idx]["instruction"]
else:
random_idx = random.randint(0, len(data.train_dataset.data) - 1)
instruction = data.train_dataset.data[random_idx]["instruction"]

elif isinstance(eval.evaluate_example, int):
index = eval.evaluate_example
if len(data.test_dataset.data) > index:
instruction = data.test_dataset.data[index]["instruction"]
elif len(data.train_dataset.data) > index:
instruction = data.train_dataset.data[index]["instruction"]
else:
raise IndexError(f"Index {index} is out of range for both test and training datasets.")

else:
raise ValueError(f"Unknown evaluation example type: {eval.evaluate_example}")
return instruction
58 changes: 58 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
num_parameters,
parse_devices,
save_hyperparameters,
select_sft_generate_example,
)


Expand Down Expand Up @@ -757,3 +758,60 @@ def test_fix_and_load_json():

result_missing_commas = fix_and_load_json(invalid_json_missing_commas)
assert result_missing_commas == expected_output_missing_commas


def test_select_sft_generate_example():
eval_mock = mock.MagicMock()
data_mock = mock.MagicMock()

test_dataset = {"data": [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]}
train_dataset = {"data": [{"instruction": "Train instruction 1"}, {"instruction": "Train instruction 2"}]}

data_mock.test_dataset.data = test_dataset["data"]
data_mock.train_dataset.data = train_dataset["data"]

# Test "first" instruction from test dataset
eval_mock.evaluate_example = "first"
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Test instruction 1"

# Test "first" instruction from train dataset when test dataset is empty
data_mock.test_dataset.data = []
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Train instruction 1"

# Test random selection from test dataset
eval_mock.evaluate_example = "random"
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]
with mock.patch('random.randint', return_value=1):
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Test instruction 2"

# Test random selection from train dataset when test dataset is empty
data_mock.test_dataset.data = []
with mock.patch('random.randint', return_value=1):
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Train instruction 2"

# Test specific index from test dataset
eval_mock.evaluate_example = 1
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}, {"instruction": "Test instruction 2"}]
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Test instruction 2"

# Test specific index from train dataset when test dataset has fewer elements
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}]
instruction = select_sft_generate_example(eval_mock, data_mock)
assert instruction == "Train instruction 2"

# Test out-of-range index
eval_mock.evaluate_example = 2
data_mock.test_dataset.data = [{"instruction": "Test instruction 1"}]
data_mock.train_dataset.data = [{"instruction": "Train instruction 1"}]
with pytest.raises(IndexError):
select_sft_generate_example(eval_mock, data_mock)

# Test unknown evaluation type
eval_mock.evaluate_example = "unknown"
with pytest.raises(ValueError):
select_sft_generate_example(eval_mock, data_mock)

0 comments on commit 32673d0

Please sign in to comment.