Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Input data size != output data size when task batch size < batch size of predecessor #972

Open
zye1996 opened this issue Sep 12, 2024 · 4 comments
Assignees
Labels
bug Something isn't working
Milestone

Comments

@zye1996
Copy link
Contributor

zye1996 commented Sep 12, 2024

Describe the bug
The behavior is a bit random. When the text generation input size < batch size from the previous step and replica > 1. The final output could missing some samples. This does not happen every time but happens frequently. I suspect it has something to do with batch/multi-processing scheduling.

In the following case, default LoadDataFromDicts batch size is 50, and batch_size of Text generation is set lower than that, in this case 17. The total input sample number is 60, however, when saving the data to disk, only 52 samples are saved. When setting Text generation batch size greater than 50, all samples can be successfully saved.

To Reproduce
Code to reproduce

# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from distilabel.llms import MistralLLM, AnthropicLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, StepResources
from distilabel.steps.tasks import TextGeneration

resources = StepResources(replicas=8)

with Pipeline(
    name="Knowledge-Graphs",
    description=(
        "Generate knowledge graphs to answer questions, this type of dataset can be used to "
        "steer a model to answer questions with a knowledge graph."
    ),
) as pipeline:
    sample_questions = [
        "Teach me about quantum mechanics",
        "Who is who in The Simpsons family?",
        "Tell me about the evolution of programming languages",
    ] * 20

    load_dataset = LoadDataFromDicts(
        name="load_instructions",
        data=[
            {
                "system_prompt": "You are a knowledge graph expert generator. Help me understand by describing everything as a detailed knowledge graph.",
                "instruction": f"{question}",
            }
            for question in sample_questions
        ],

    )

    text_generation = TextGeneration(
        name="knowledge_graph_generation",
        llm=AnthropicLLM(
            model="claude-3-5-sonnet-20240620",
            generation_kwargs={"max_tokens": 4096,
                               "temperature": 0.5}

            ),
        input_batch_size=17,
        output_mappings={"model_name": "generation_model"},
        resources=resources
    )
    load_dataset >> text_generation


if __name__ == "__main__":

    from pathlib import Path

    distiset = pipeline.run(
        parameters={
            text_generation.name: {
                "llm": {"generation_kwargs": {"max_tokens": 2048}}
            }
        },
        use_cache=False,
    )
    distiset.save_to_disk(Path("test_out"),
                          save_card=False,
                          save_pipeline_log=False,
                          save_pipeline_config=False)

Expected behaviour
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.
Screenshot 2024-09-11 at 9 14 03 PM

Desktop (please complete the following information):

  • Package version: 1.3.2 and 1.4.0 (develop)
  • Python version: 3.10.13

Additional context
Add any other context about the problem here.

@zye1996
Copy link
Contributor Author

zye1996 commented Sep 12, 2024

looks like some batches are processed twice, more like a multi-processing issue.

Screenshot 2024-09-11 at 9 45 38 PM

@gabrielmbmb gabrielmbmb self-assigned this Sep 12, 2024
@gabrielmbmb gabrielmbmb added the bug Something isn't working label Sep 12, 2024
@gabrielmbmb
Copy link
Member

Thanks for reporting @zye1996. I'll take a look.

@zye1996
Copy link
Contributor Author

zye1996 commented Sep 18, 2024

@gabrielmbmb should this line return False? Otherwise, if the last batch arrives earlier than the previous batches, data are forced to be sent to the next step and some data could be missing if they cannot be created for another batch. Let me know if a PR is needed

@thesven
Copy link

thesven commented Oct 7, 2024

I've also started noticing this on a pipline I've created. Using an input_batch_size of one on some text generation tasks led to the final data set size only containing one row for each processed batch of the previous output - which had been created using a step mixin and could not have an inforced batch size. @gabrielmbmb I have some code I can share that exhibits the issue that I can share as well.

@gabrielmbmb gabrielmbmb added this to the 1.5.0 milestone Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: No status
Development

No branches or pull requests

3 participants