Skip to content

Commit

Permalink
Fix batching code
Browse files Browse the repository at this point in the history
  * Process large number of rows in batches
  * Hardcoding some parameters
  • Loading branch information
shankarg87 committed Jul 28, 2024
1 parent 6531560 commit 69a467d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
5 changes: 5 additions & 0 deletions src/aihero/research/finetuning/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
run_tests_str=run_tests_str,
run_metrics_str=run_metrics_str,
max_new_tokens=max_new_tokens,
batch_size=trainer.args.per_device_eval_batch_size,
)

# Sample a few rows from the test split to generate a table of predictions
Expand All @@ -56,7 +57,9 @@ def initialize(self: "LLMSampleCB") -> None:
"""Generate initial predictions for the sample split and log them to WANDB."""
self._wandb.init()

self.batch_inference.model.eval()
_, (records_table, metrics) = self.batch_inference.run_initial_predictions(self.sample_split)
self.batch_inference.model.train()

# Log the table of sample predictions to W&B
self._wandb.log({"sample_predictions": records_table})
Expand All @@ -69,8 +72,10 @@ def on_evaluate(self, args: Any, state: Any, control: Any, **kwargs: dict[str, A
"""Log the sample predictions and metrics to WANDB on eval callback."""
super().on_evaluate(args, state, control, **kwargs)

self.batch_inference.model.eval()
# Generate the table of sample predictions
_, (records_table, metrics) = self.batch_inference.infer(self.sample_split)
self.batch_inference.model.train()

# Log the table of sample predictions to W&B
self._wandb.log({"sample_predictions": records_table})
Expand Down
26 changes: 17 additions & 9 deletions src/aihero/research/finetuning/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, batch_inference_job: BatchInferenceJob):
size = 0
randomize = False

print(self.dataset_dict)
self.model.eval()
self.batch_inference_split = self.dataset_dict["batch_inference"]
if size:
if randomize:
Expand All @@ -58,6 +58,7 @@ def __init__(self, batch_inference_job: BatchInferenceJob):
run_tests_str=run_tests_str,
run_metrics_str=run_metrics_str,
max_new_tokens=self.batch_inference_job.generator.max_seq_length or MAX_NEW_TOKENS,
batch_size=8, # Needs to be added to eval arguments
)

def load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(
"""Initialize the batch inference class."""
self.gen_config = GenerationConfig.from_pretrained(model.name_or_path, max_new_tokens=max_new_tokens)
self.model = model
self.batch_size = batch_size
self.tokenizer = tokenizer
self.task = task
self.run_tests_str = run_tests_str
Expand All @@ -265,15 +267,21 @@ def __init__(

def generate(self, prompts: List[str]) -> Any:
"""Generate a completion from a prompt."""
tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"].cuda()
tokens = self.tokenizer(prompts, return_tensors="pt", padding=True)
outputs = []
with torch.inference_mode():
output = self.model.generate(
inputs=tokenized_prompts,
generation_config=self.gen_config,
pad_token_id=self.tokenizer.eos_token_id,
)
decoded_outputs = self.tokenizer.batch_decode(output, skip_special_tokens=True)
return [output[len(tokens) :] for tokens, output in zip(tokenized_prompts, decoded_outputs)]
for i in tqdm(range(0, len(prompts), self.batch_size), leave=False):
output = self.model.generate(
inputs=tokens["input_ids"][i : i + self.batch_size].cuda(),
attention_mask=tokens["attention_mask"][i : i + self.batch_size].cuda(),
generation_config=self.gen_config,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.2, # TODO: Add to generation config?
num_return_sequences=1,
)
outputs.append(output)
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return [output[len(tokens) :] for tokens, output in zip(tokens["input_ids"].cuda(), decoded_outputs)]

def run_initial_predictions(self, rows: Dataset) -> Tuple[list[dict[str, Any]], Tuple[Table, dict[str, Any]]]:
"""Generate initial predictions for the sample split."""
Expand Down

0 comments on commit 69a467d

Please sign in to comment.