Skip to content

Commit

Permalink
Fix inference runner initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
shankarg87 committed Jul 25, 2024
1 parent f015df4 commit 6531560
Showing 1 changed file with 13 additions and 40 deletions.
53 changes: 13 additions & 40 deletions src/aihero/research/finetuning/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,28 @@ def __init__(self, batch_inference_job: BatchInferenceJob):
if self.batch_inference_job.eval:
run_tests_str = self.batch_inference_job.eval.tests or ""
run_metrics_str = self.batch_inference_job.eval.metrics or ""
size = self.batch_inference_job.size or 0
randomize = self.batch_inference_job.randomize or False
size = self.batch_inference_job.eval.size or 0
randomize = self.batch_inference_job.eval.randomize or False
else:
run_tests_str = ""
run_metrics_str = ""
size = 0
randomize = False

print(self.dataset_dict)
self.batch_inference_split = self.dataset_dict["batch_inference"]
if size:
if randomize:
self.batch_inference_split = self.batch_inference_split.shuffle()
self.batch_inference_split = self.batch_inference_split.select(range(size))
self.batch_inference_with_eval = BatchInferenceWithEval(
model=self.model,
tokenizer=self.tokenizer,
task=self.batch_inference_job.task,
run_tests_str=run_tests_str,
run_metrics_str=run_metrics_str,
max_new_tokens=self.batch_inference_job.generator.max_seq_length or MAX_NEW_TOKENS,
)
self.batch_inference_with_eval = BatchInferenceWithEval(
model=self.model,
tokenizer=self.tokenizer,
task=self.batch_inference_job.task,
run_tests_str=run_tests_str,
run_metrics_str=run_metrics_str,
max_new_tokens=self.batch_inference_job.generator.max_seq_length or MAX_NEW_TOKENS,
)

def load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Load the model from HuggingFace Hub or S3."""
Expand Down Expand Up @@ -173,45 +174,17 @@ def fetch_dataset(self) -> DatasetDict:
)
elif self.batch_inference_job.dataset.type == "local":
print("Loading dataset locally: ", os.listdir(self.batch_inference_job.dataset.path))
splits["train"] = Dataset.from_generator(
splits["batch_inference"] = Dataset.from_generator(
dataset_generator,
gen_kwargs={
"dataset": self.batch_inference_job.dataset.path,
"split": "train",
"split": "batch_inference",
"from_disk": True,
"task": self.batch_inference_job.task,
"bos_token": bos_token,
"eos_token": eos_token,
},
)
try:
splits["val"] = Dataset.from_generator(
dataset_generator,
gen_kwargs={
"dataset": self.batch_inference_job.dataset.path,
"split": "val",
"from_disk": True,
"task": self.batch_inference_job.task,
"bos_token": bos_token,
"eos_token": eos_token,
},
)
except: # pylint: disable=bare-except # noqa: E722
print("Unable to create val dataset")
try:
splits["test"] = Dataset.from_generator(
dataset_generator,
gen_kwargs={
"dataset": self.batch_inference_job.dataset.path,
"split": "test",
"from_disk": True,
"task": self.batch_inference_job.task,
"bos_token": bos_token,
"eos_token": eos_token,
},
)
except: # pylint: disable=bare-except # noqa: E722
print("Unable to create test dataset")
else:
raise ValueError(f"Unknown dataset_type: {self.batch_inference_job.dataset.type}")

Expand Down

0 comments on commit 6531560

Please sign in to comment.