From 69b94b7267783ca59e83fd9328f491a0fa01f0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 10 Jan 2025 13:06:38 +0100 Subject: [PATCH 1/2] Fix handling empty list statistics --- src/distilabel/steps/tasks/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index dba92588cd..ae19a1038f 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -521,15 +521,15 @@ def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": gen_length = len(output["generations"]) for stat_key, stat_values in output["statistics"].items(): - current_length = len(stat_values) + current_length = len(stat_values) # type: ignore - if current_length < gen_length: + if current_length > 0 and current_length < gen_length: # Calculate how many times to repeat the tokens repeats = gen_length // current_length remainder = gen_length % current_length # Create new list with repeated values - new_values = stat_values * repeats + stat_values[:remainder] + new_values = stat_values * repeats + stat_values[:remainder] # type: ignore output["statistics"][stat_key] = new_values return output @@ -552,7 +552,11 @@ def iterate_generations_with_stats( ] for i, generation in enumerate(outputs["generations"]): # Create a new dictionary with the statistics for this index - stats = {key: values[i] for key, values in outputs["statistics"].items()} # type: ignore + stats = { + key: values[i] # type: ignore + for key, values in outputs["statistics"].items() + if values + } # Extra keys returned by the `LLM` extra = {key: outputs[key][i] for key in extra_keys} yield generation, stats, extra From 5a36802df2a871f209e8134c210486ad5d76bb56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Fri, 10 Jan 2025 13:09:21 +0100 Subject: [PATCH 2/2] Do not include empty input/output tokens if `None` --- src/distilabel/models/llms/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/distilabel/models/llms/utils.py b/src/distilabel/models/llms/utils.py index afb09cab4d..ef97e53e1f 100644 --- a/src/distilabel/models/llms/utils.py +++ b/src/distilabel/models/llms/utils.py @@ -57,11 +57,15 @@ def prepare_output( """ output: "GenerateOutput" = { "generations": generations, - "statistics": { - "input_tokens": input_tokens or [], - "output_tokens": output_tokens or [], - }, + "statistics": {}, } + + if input_tokens: + output["statistics"]["input_tokens"] = input_tokens + + if output_tokens: + output["statistics"]["output_tokens"] = output_tokens + if logprobs: output["logprobs"] = logprobs return output