Skip to content

Commit

Permalink
fix bug if num_samples<100
Browse files Browse the repository at this point in the history
  • Loading branch information
rparundekar committed May 26, 2024
1 parent a1add9a commit 30f3ce8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/aihero/research/finetuning/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(
# Sample a few rows from the test split to generate a table of predictions
# for visual inspection a.k.a. spot checking
# Randomly select indices for the samples
selected_indices = random.sample(range(test_split.num_rows), num_samples)
if num_samples >= test_split.num_rows:
selected_indices = list(range(0, test_split.num_rows))
else:
selected_indices = random.sample(range(test_split.num_rows), num_samples)
# Retrieve the selected samples from the dataset
test_split_list = list(test_split)
self.sample_split = []
Expand Down
2 changes: 1 addition & 1 deletion src/aihero/research/finetuning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def train(self) -> None:
trainer,
task,
test_split,
num_samples=100,
num_samples=test_split.num_rows if test_split.num_rows < 100 else 100,
max_new_tokens=self.training_job.trainer.max_seq_length,
run_tests_str=run_tests_str,
run_metrics_str=run_metrics_str,
Expand Down

0 comments on commit 30f3ce8

Please sign in to comment.