Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: worker abstractions; fix: various improvements and bug fixes #328

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

tazlin
Copy link
Member

@tazlin tazlin commented Jan 26, 2025

--todo--

tazlin added 30 commits January 26, 2025 09:39
This includes scenarios where, for example, JobIDs and WorkerIDs.
This more accurately reflects the usage of this field type

fix: rename `job_id` vars to `gen_id`
`object` and `any` have distinct meanings. particularly, I do in fact mean "any" here rather than expecting an object compatible with `object`
`id_factory`, `default_testing_image_bytes`, `default_testing_image_PIL`
- Fixes the `SharedKeyCreateRequest` class to use the `PUT` method instead of `POST`.
- A new mixin class, `MessageSpecifiesSharedKeyMixin`
- Accordingly updates the api <-> sdk map json
- Adds a warning when shared keys are empty (probably unintended but still useful for testing)

docs: add missing shared keys mkdocs stub
The `time_constructed` property is the time at which the model is constructed and will be used to infer pop times if not otherwise specified for jobs. It may also serve some purpose with testing and/or debugging.
An env var was not intended to be the long term method for getting this behavior
docs: include `haidra-assests` submodule on rtd.io
The ordering of this function with the continue control statements caused any query fields specified by `GenericQueryFields` types to not be rewritten from their python names, such as in the case of using a protected namespace (e.g., `model_` from pydantic) or when shadowing a builtin (such as `id`).
`model_` is reserved for pydantic usage and there is already a rewriting mechanism in place for this exact reason (i.e., python field name conflicts with builtins)
tazlin and others added 2 commits January 26, 2025 13:37
These common IO exceptions tend to occur during many common IO operations throughout the SDK and for the sake of future maintainability it would be best to define a tuple of them available to the entire namespace
@tazlin tazlin changed the title Even better tests rebase feat: worker abstractions; fix: various improvements and bug fixes Jan 26, 2025
@tazlin
Copy link
Member Author

tazlin commented Jan 26, 2025

@CodiumAI-Agent /review

@tazlin
Copy link
Member Author

tazlin commented Jan 26, 2025

@CodiumAI-Agent /improve

@CodiumAI-Agent
Copy link

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 5 🔵🔵🔵🔵🔵
🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Type Consistency

The replacement of JobID with GenerationID across multiple methods and return types should be carefully reviewed to ensure type consistency and compatibility with existing code and API contracts.

    gen_id: GenerationID,
) -> ImageGenerateCheckResponse:
    """Check if a pending image request has finished generating from the AI-Horde API.

    Not to be confused with `get_generate_status` which returns the images too.

    Args:
        gen_id (GenerationID | str): The ID of the request to check.

    Returns:
        ImageGenerateCheckResponse: The response from the API.
    """
    api_request = ImageGenerateCheckRequest(id=gen_id)

    api_response = self.submit_request(api_request, api_request.get_default_success_response_type())
    if isinstance(api_response, RequestErrorResponse):  # pragma: no cover
        self._handle_api_error(api_response, api_request.get_api_endpoint_url())
        raise AIHordeRequestError(api_response)

    return api_response

def get_generate_status(
    self,
    gen_id: GenerationID,
) -> ImageGenerateStatusResponse:
    """Get the status and any generated images for a pending image request from the AI-Horde API.

    *Do not use this method more often than is necessary.* The AI-Horde API will rate limit you if you do.
    Use `get_generate_check` instead to check the status of a pending image request.

    Args:
        gen_id (GenerationID): The ID of the request to check.

    Returns:
        tuple[ImageGenerateStatusResponse, GenerationID]: The final status response and the corresponding job ID.
    """
    api_request = ImageGenerateStatusRequest(id=gen_id)

    api_response = self.submit_request(api_request, api_request.get_default_success_response_type())
    if isinstance(api_response, RequestErrorResponse):  # pragma: no cover
        self._handle_api_error(api_response, api_request.get_api_endpoint_url())
        raise AIHordeRequestError(api_response)

    return api_response

def delete_pending_image(
    self,
    gen_id: GenerationID,
) -> ImageGenerateStatusResponse:
Test Coverage

The extensive test cases for HordeSingleGeneration permutations and error handling should be validated for completeness and alignment with the new GenerationID changes.

import base64
from collections.abc import Callable, Iterable, Mapping
from io import BytesIO
from typing import Any
from uuid import UUID

import pytest
from loguru import logger
from PIL import Image

from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS
from horde_sdk.ai_horde_worker.generations import (
    AlchemySingleGeneration,
    ImageSingleGeneration,
    TextSingleGeneration,
)
from horde_sdk.ai_horde_worker.generations_base import HordeSingleGeneration


class GenerationPermutation:
    """A permutation of possible generation configurations.

    For example, text generation may not require post-processing or safety checks, while image generation may require
    both. For testing, we can create permutations of these configurations to ensure that the generation process works
    as expected across all possible configurations.

    """

    def __init__(
        self,
        *,
        include_safety_check: bool,
        include_preloading: bool,
        include_post_processing: bool,
    ) -> None:
        """Initialize the permutation.

        Args:
            include_safety_check (bool): Whether to include a safety check in the generation process.
            include_preloading (bool): Whether to include preloading in the generation process.
            include_post_processing (bool): Whether to include post-processing in the generation process.

        """

        self.include_safety_check = include_safety_check
        self.include_preloading = include_preloading
        self.include_post_processing = include_post_processing


@pytest.fixture(scope="session")
def image_permutations() -> list[GenerationPermutation]:
    """Return the supported configurations for a `ImageSingleGeneration` object."""
    return [
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=True,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=True,
            include_post_processing=False,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=False,
            include_post_processing=True,
        ),
    ]


@pytest.fixture(scope="session")
def alchemy_permutations() -> list[GenerationPermutation]:
    """Return the supported configurations for a `AlchemySingleGeneration` object."""
    return [
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=True,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=True,
            include_post_processing=False,
        ),
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=False,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=True,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=False,
            include_post_processing=True,
        ),
    ]


@pytest.fixture(scope="session")
def text_permutations() -> list[GenerationPermutation]:
    """Return the supported configurations for a `TextSingleGeneration` object."""
    return [
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=True,
            include_post_processing=False,
        ),
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=False,
            include_post_processing=False,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=True,
            include_post_processing=False,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=False,
            include_post_processing=False,
        ),
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=True,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=False,
            include_preloading=False,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=True,
            include_post_processing=True,
        ),
        GenerationPermutation(
            include_safety_check=True,
            include_preloading=False,
            include_post_processing=True,
        ),
    ]


class TestHordeSingleGeneration:
    """Test the `HordeSingleGeneration` class."""

    _shared_image: Image.Image

    @pytest.fixture(autouse=True)
    def setup(self, default_testing_image_base64: str) -> None:
        self._shared_image = Image.open(BytesIO(base64.b64decode(default_testing_image_base64)))

    def test_none_generation_init(
        self,
    ) -> None:
        """Test that an exception is raised when a generation is initialized with a `None` ID."""

        with pytest.raises(TypeError):
            ImageSingleGeneration(generation_id=None)  # type: ignore

    @staticmethod
    def shared_check_generation_init(
        generation: HordeSingleGeneration[Any],
        generation_id: GenerationID,
    ) -> None:
        """Confirm that the `HordeSingleGeneration` was initialized correctly."""
        assert generation.generation_id == generation_id

        first_state, _ = generation._progress_history[0]
        assert first_state == GENERATION_PROGRESS.NOT_STARTED

        assert generation._state_error_limits is not None
        assert len(generation.errored_states) == 0
        assert generation.errored_states is not None
        assert len(generation.errored_states) == 0
        assert generation.generation_metadata is not None
        assert len(generation.errored_states) == 0

        assert generation.is_nsfw is None
        assert generation.is_csam is None

        assert generation.generation_result is None

    def test_alchemy_single_generation_init(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test that an `AlchemySingleGeneration` object can be initialized correctly."""

        from horde_sdk.ai_horde_worker.consts import default_alchemy_generate_progress_transitions

        generation = AlchemySingleGeneration(
            generation_id=single_id,
        )

        TestHordeSingleGeneration.shared_check_generation_init(
            generation=generation,
            generation_id=single_id,
        )

        assert generation._generate_progress_transitions == default_alchemy_generate_progress_transitions

    def test_image_single_generation_init(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test that an `ImageSingleGeneration` object can be initialized correctly."""

        from horde_sdk.ai_horde_worker.consts import default_image_generate_progress_transitions

        generation = ImageSingleGeneration(
            generation_id=single_id,
        )

        TestHordeSingleGeneration.shared_check_generation_init(
            generation=generation,
            generation_id=single_id,
        )

        assert generation._generate_progress_transitions == default_image_generate_progress_transitions

    def test_text_single_generation_init(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test that a `TextSingleGeneration` object can be initialized correctly."""

        from horde_sdk.ai_horde_worker.consts import default_text_generate_progress_transitions

        generation = TextSingleGeneration(
            generation_id=single_id,
        )

        TestHordeSingleGeneration.shared_check_generation_init(
            generation=generation,
            generation_id=single_id,
        )

        assert generation._generate_progress_transitions == default_text_generate_progress_transitions

    def test_invalid_step_raises_error(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test that an exception is raised when an invalid step is passed to the generation."""

        generation = ImageSingleGeneration(generation_id=single_id)

        assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

        with pytest.raises(
            ValueError,
            match=f"Invalid state {GENERATION_PROGRESS.PENDING_SAFETY_CHECK} "
            r"\(current state: "
            f"{GENERATION_PROGRESS.NOT_STARTED}"
            r"\)",
        ):
            generation.step(GENERATION_PROGRESS.PENDING_SAFETY_CHECK)

        generation.step(GENERATION_PROGRESS.PRELOADING)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

    def test_wrong_order_of_steps(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test that an exception is raised when the generation steps are called in the wrong order.

        - It should not be possible to transition according to the default transition \
            progressions defined in `horde_sdk/ai_horde_worker/consts.py`.
        - It should not be possible to transition to the same state in which the generation is currently in. \
            This is a safety check to prevent infinite loops or bad implementations.
        """

        generation = ImageSingleGeneration(generation_id=single_id)

        def assert_raises_value_error(func: Callable[..., Any], match: str) -> None:
            with pytest.raises(ValueError, match=match):
                func()

        assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

        assert_raises_value_error(
            generation.on_generation_work_complete,
            f"Invalid transition from {GENERATION_PROGRESS.NOT_STARTED} to {GENERATION_PROGRESS.PENDING_SUBMIT}",
        )

        # Normal progression to preloading
        generation.on_preloading()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

        assert_raises_value_error(
            generation.on_preloading,
            f"is already in state {GENERATION_PROGRESS.PRELOADING}",
        )

        assert_raises_value_error(
            generation.on_generation_work_complete,
            f"Invalid transition from {GENERATION_PROGRESS.PRELOADING} to {GENERATION_PROGRESS.PENDING_SUBMIT}",
        )

        # Normal progression to preloading complete
        generation.on_preloading_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING_COMPLETE

        assert_raises_value_error(
            generation.on_preloading_complete,
            f"is already in state {GENERATION_PROGRESS.PRELOADING_COMPLETE}",
        )

        # Normal progression to generating
        generation.on_generating()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.GENERATING

        assert_raises_value_error(
            generation.on_generating,
            f"is already in state {GENERATION_PROGRESS.GENERATING}",
        )

        assert_raises_value_error(
            generation.on_preloading,
            f"Invalid transition from {GENERATION_PROGRESS.GENERATING} to {GENERATION_PROGRESS.PRELOADING}",
        )

        assert_raises_value_error(
            generation.on_preloading_complete,
            f"Invalid transition from {GENERATION_PROGRESS.GENERATING} to {GENERATION_PROGRESS.PRELOADING_COMPLETE}",
        )

    def test_set_safety_check_result_without_generation_result(self, single_id: GenerationID) -> None:
        """Test that an exception is raised when setting a safety check result without setting a generation result."""
        generation = ImageSingleGeneration(generation_id=single_id)

        with pytest.raises(ValueError, match="Generation result must be set before setting safety check result"):
            generation._set_safety_check_result(is_nsfw=True, is_csam=False)

    def test_reference_run_generation_process_image(self) -> None:
        """Run a reference generation process from start to finish, without testing-specific magic or helpers.

        The purpose of this test is to have a the bare-minimum usage of the `HordeSingleGeneration` class to ensure
        that the most straight forward use-case works as expected and isn't lost in the complexity of the test suite.
        """
        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS
        from horde_sdk.ai_horde_worker.generations import ImageSingleGeneration

        dummy_id = GenerationID(UUID("00000000-0000-0000-0000-000000000000"))
        generation = ImageSingleGeneration(generation_id=dummy_id, requires_post_processing=False)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

        generation.on_preloading()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

        generation.on_preloading_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING_COMPLETE

        generation.on_generating()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.GENERATING

        generation.on_generation_work_complete()

        dummy_image = Image.new("RGB", (100, 100))
        generation.set_work_result(dummy_image)

        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SAFETY_CHECK

        generation.on_safety_checking()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SAFETY_CHECKING

        generation.on_safety_check_complete(is_csam=False, is_nsfw=False)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SUBMIT

        generation.on_submitting()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMITTING

        generation.on_submit_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMIT_COMPLETE

    def test_reference_run_generation_process_text(self) -> None:
        """Run a reference generation process from start to finish, without testing-specific magic or helpers.

        The purpose of this test is to have a the bare-minimum usage of the `HordeSingleGeneration` class to ensure
        that the most straight forward use-case works as expected and isn't lost in the complexity of the test suite.
        """
        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS
        from horde_sdk.ai_horde_worker.generations import TextSingleGeneration

        dummy_id = GenerationID(UUID("00000000-0000-0000-0000-000000000000"))
        generation = TextSingleGeneration(generation_id=dummy_id)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

        generation.on_preloading()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

        generation.on_preloading_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING_COMPLETE

        generation.on_generating()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.GENERATING

        generation.on_generation_work_complete()
        generation.set_work_result("This is a test")
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SUBMIT

        generation.on_submitting()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMITTING

        generation.on_submit_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMIT_COMPLETE

    def test_reference_run_generation_process_alchemy(self) -> None:
        """Run a reference generation process from start to finish, without testing-specific magic or helpers.

        The purpose of this test is to have a the bare-minimum usage of the `HordeSingleGeneration` class to ensure
        that the most straight forward use-case works as expected and isn't lost in the complexity of the test suite.
        """
        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS
        from horde_sdk.ai_horde_worker.generations import AlchemySingleGeneration

        dummy_id = GenerationID(UUID("00000000-0000-0000-0000-000000000000"))
        generation = AlchemySingleGeneration(generation_id=dummy_id)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

        generation.on_preloading()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

        generation.on_preloading_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING_COMPLETE

        generation.on_post_processing()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.POST_PROCESSING

        dummy_image = Image.new("RGB", (100, 100))
        generation.set_work_result(dummy_image)

        generation.on_post_processing_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SUBMIT

        generation.on_submitting()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMITTING

        generation.on_submit_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMIT_COMPLETE

    def run_generation_process(
        self,
        generation: HordeSingleGeneration[Any],
        include_preloading: bool,
        include_generation: bool,
        include_post_processing: bool,
        include_safety_check: bool,
    ) -> None:
        """Run a generation process from start to finish.

        This function will run the generation process from start to finish, including preloading, generation,
        post-processing, safety checks, and submission. It will also check that the generation progresses through the
        correct states.

        If a step is not requested, it will be skipped.
        """

        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS

        if include_preloading:
            assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

            generation.on_preloading()
            assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

            generation.on_preloading_complete()
            assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING_COMPLETE

        if include_generation:
            generation.on_generating()
            assert generation.get_generation_progress() == GENERATION_PROGRESS.GENERATING

            if include_post_processing:
                generation.on_generation_work_complete()
            else:
                generation.on_generation_work_complete()
                generation.set_work_result(self._shared_image)

        if include_post_processing:
            generation.on_post_processing()
            assert generation.get_generation_progress() == GENERATION_PROGRESS.POST_PROCESSING

            generation.on_post_processing_complete()
            generation.set_work_result(self._shared_image)

        assert generation.generation_result == self._shared_image

        if include_safety_check:
            assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SAFETY_CHECK

            generation.on_safety_checking()
            assert generation.get_generation_progress() == GENERATION_PROGRESS.SAFETY_CHECKING

            generation.on_safety_check_complete(is_csam=False, is_nsfw=False)
            assert generation.is_csam is False
            assert generation.is_nsfw is False
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SUBMIT

        generation.on_submitting()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMITTING

        generation.on_submit_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMIT_COMPLETE

    def test_happy_path_image_start_to_finish(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test the happy path for average `ImageSingleGeneration` from start to finish."""

        generation = ImageSingleGeneration(generation_id=single_id, requires_post_processing=True)

        self.process_generation(
            generation,
            include_preloading=True,
            include_safety_check=True,
            include_generation=True,
            include_post_processing=True,
        )

        generation_no_post_processing = ImageSingleGeneration(generation_id=single_id, requires_post_processing=False)
        self.process_generation(
            generation_no_post_processing,
            include_preloading=True,
            include_safety_check=True,
            include_generation=True,
            include_post_processing=False,
        )

    def test_happy_path_image_no_preloading(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test the happy path for average `ImageSingleGeneration` from start to finish without preloading."""

        generation = ImageSingleGeneration(generation_id=single_id, requires_post_processing=True)

        self.run_generation_process(
            generation,
            include_preloading=False,
            include_safety_check=True,
            include_generation=True,
            include_post_processing=True,
        )

        generation_no_post_processing = ImageSingleGeneration(generation_id=single_id, requires_post_processing=False)
        self.run_generation_process(
            generation_no_post_processing,
            include_preloading=False,
            include_safety_check=True,
            include_generation=True,
            include_post_processing=False,
        )

    def test_happy_path_alchemy_start_to_finish(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test the happy path for average `AlchemySingleGeneration` from start to finish."""

        generation = AlchemySingleGeneration(generation_id=single_id)

        self.run_generation_process(
            generation,
            include_preloading=True,
            include_safety_check=False,
            include_generation=False,
            include_post_processing=True,
        )

    def test_happy_path_alchemy_no_preloading(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test the happy path for average `AlchemySingleGeneration` from start to finish without preloading."""

        generation = AlchemySingleGeneration(generation_id=single_id, requires_post_processing=True)

        self.run_generation_process(
            generation,
            include_preloading=False,
            include_safety_check=False,
            include_generation=False,
            include_post_processing=True,
        )

    def test_happy_path_text_start_to_finish(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test the happy path for average `TextSingleGeneration` from start to finish."""

        generation = TextSingleGeneration(generation_id=single_id)

        self.run_generation_process(
            generation,
            include_preloading=True,
            include_safety_check=False,
            include_generation=True,
            include_post_processing=False,
        )

    def test_happy_path_text_no_preloading(
        self,
        single_id: GenerationID,
    ) -> None:
        """Test the happy path for average `TextSingleGeneration` from start to finish without preloading."""

        generation = TextSingleGeneration(generation_id=single_id)

        self.run_generation_process(
            generation,
            include_preloading=False,
            include_safety_check=False,
            include_generation=True,
            include_post_processing=False,
        )

    @staticmethod
    def handle_error(
        generation: HordeSingleGeneration[Any],
        error_message: str,
        error_exception: Exception,
        errors_count: int,
    ) -> None:
        generation.on_error(
            failed_message=error_message,
            failure_exception=error_exception,
        )
        assert generation.get_generation_progress() == GENERATION_PROGRESS.ERROR

    def process_generation(
        self,
        generation: HordeSingleGeneration[Any],
        include_preloading: bool,
        include_generation: bool,
        include_post_processing: bool,
        include_safety_check: bool,
        error_on_preloading: bool = False,
        error_on_generation: bool = False,
        error_on_post_processing: bool = False,
        error_on_safety_check: bool = False,
        error_on_submit: bool = False,
    ) -> None:
        """Process a generation with the given configurations.

        This will step the `HordeSingleGeneration` through the entire generation process, as requested by the
        arguments. If an error is requested, the generation will be marked as errored and the error count will be
        incremented.

        """

        error_flags = {
            "preloading": error_on_preloading and include_preloading,
            "generation": error_on_generation and include_generation,
            "post_processing": error_on_post_processing and include_post_processing,
            "safety_check": error_on_safety_check and include_safety_check,
            "submit": error_on_submit,
        }

        if not generation.does_class_requires_generation() and not include_generation and not include_post_processing:
            logger.trace(
                f"Skipping generation for {generation.__class__.__name__} as it does not require generation "
                "and generation and post-processing are not included",
            )
            return

        target_errors_count = sum(error_flags.values())
        errors_count = 0

        if include_preloading:
            errors_count = self._simulate_preloading(generation, error_on_preloading, errors_count)

        if include_generation:
            errors_count = self._simulate_generation(
                generation,
                error_on_generation=error_on_generation,
                include_post_processing=include_post_processing,
                errors_count=errors_count,
            )

        if include_post_processing:
            errors_count = self._simulate_post_processing(
                generation,
                error_on_post_processing=error_on_post_processing,
                errors_count=errors_count,
            )

        if include_safety_check:
            errors_count = self._simulate_safety_check(
                generation,
                error_on_safety_check=error_on_safety_check,
                errors_count=errors_count,
            )

        errors_count = self._simulate_submission(
            generation,
            error_on_submit=error_on_submit,
            errors_count=errors_count,
        )

        assert generation.generation_failure_count == target_errors_count

    def _simulate_preloading(
        self,
        generation: HordeSingleGeneration[Any],
        error_on_preloading: bool,
        errors_count: int,
    ) -> int:
        """Simulate expected actions for the preloading step for a `HordeSingleGeneration`."""

        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS

        assert generation.get_generation_progress() == GENERATION_PROGRESS.NOT_STARTED

        generation.step(GENERATION_PROGRESS.PRELOADING)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

        if error_on_preloading:
            errors_count += 1
            generation.on_error(
                failed_message="Failed to preload",
                failure_exception=Exception("Failed to preload exception"),
            )
            assert generation.get_generation_progress() == GENERATION_PROGRESS.ERROR

            generation.step(GENERATION_PROGRESS.PRELOADING)
            assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING

            assert generation.generation_failure_count == errors_count
            assert generation.errored_states is not None
            error_state, error_time = generation.errored_states[-1]
            assert error_state == GENERATION_PROGRESS.PRELOADING
            assert error_time != 0

        generation.on_preloading_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.PRELOADING_COMPLETE

        return errors_count

    def _simulate_generation(
        self,
        generation: HordeSingleGeneration[Any],
        error_on_generation: bool,
        include_post_processing: bool,
        errors_count: int,
    ) -> int:
        """Simulate expected actions for the generation step for a `HordeSingleGeneration`."""

        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS

        generation.step(GENERATION_PROGRESS.GENERATING)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.GENERATING

        if error_on_generation:
            errors_count += 1
            generation.on_error(
                failed_message="Failed to generate",
                failure_exception=Exception("Failed to generate exception"),
            )
            assert generation.get_generation_progress() == GENERATION_PROGRESS.ERROR

            generation.step(GENERATION_PROGRESS.GENERATING)
            assert generation.get_generation_progress() == GENERATION_PROGRESS.GENERATING

            assert generation.generation_failure_count == errors_count
            assert generation.errored_states is not None
            error_state, error_time = generation.errored_states[-1]
            assert error_state == GENERATION_PROGRESS.GENERATING
            assert error_time != 0

        if include_post_processing:
            generation.on_generation_work_complete()
        else:
            generation.on_generation_work_complete()
            generation.set_work_result(self._shared_image)

        return errors_count

    def _simulate_post_processing(
        self,
        generation: HordeSingleGeneration[Any],
        error_on_post_processing: bool,
        errors_count: int,
    ) -> int:
        """Simulate expected actions for the post-processing step for a `HordeSingleGeneration`."""

        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS

        generation.step(GENERATION_PROGRESS.POST_PROCESSING)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.POST_PROCESSING

        if error_on_post_processing:
            errors_count += 1
            generation.on_error(
                failed_message="Failed during post-processing",
                failure_exception=Exception("Failed during post-processing exception"),
            )
            assert generation.get_generation_progress() == GENERATION_PROGRESS.ERROR

            generation.step(GENERATION_PROGRESS.POST_PROCESSING)
            assert generation.get_generation_progress() == GENERATION_PROGRESS.POST_PROCESSING

        generation.on_post_processing_complete()
        generation.set_work_result(self._shared_image)

        assert generation.generation_result == self._shared_image

        return errors_count

    def _simulate_safety_check(
        self,
        generation: HordeSingleGeneration[Any],
        error_on_safety_check: bool,
        errors_count: int,
    ) -> int:
        """Simulate expected actions for the safety check step for a `HordeSingleGeneration`."""
        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS

        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SAFETY_CHECK

        generation.step(GENERATION_PROGRESS.SAFETY_CHECKING)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SAFETY_CHECKING

        if error_on_safety_check:
            errors_count += 1
            generation.on_error(
                failed_message="Failed during safety check",
                failure_exception=Exception("Failed during safety check exception"),
            )
            assert generation.get_generation_progress() == GENERATION_PROGRESS.ERROR

            generation.step(GENERATION_PROGRESS.SAFETY_CHECKING)
            assert generation.get_generation_progress() == GENERATION_PROGRESS.SAFETY_CHECKING

        generation.on_safety_check_complete(is_csam=False, is_nsfw=False)
        assert generation.is_csam is False
        assert generation.is_nsfw is False

        return errors_count

    def _simulate_submission(
        self,
        generation: HordeSingleGeneration[Any],
        error_on_submit: bool,
        errors_count: int,
    ) -> int:
        """Simulate expected actions for the submission step for a `HordeSingleGeneration`."""
        from horde_sdk.ai_horde_worker.consts import GENERATION_PROGRESS

        assert generation.get_generation_progress() == GENERATION_PROGRESS.PENDING_SUBMIT

        generation.step(GENERATION_PROGRESS.SUBMITTING)
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMITTING

        if error_on_submit:
            errors_count += 1
            generation.on_error(
                failed_message="Failed during submission",
                failure_exception=Exception("Failed during submission exception"),
            )
            assert generation.get_generation_progress() == GENERATION_PROGRESS.ERROR

            generation.step(GENERATION_PROGRESS.SUBMITTING)
            assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMITTING

        generation.on_submit_complete()
        assert generation.get_generation_progress() == GENERATION_PROGRESS.SUBMIT_COMPLETE

        return errors_count

    def run_generation_test_permutations(
        self,
        generation_class: type[HordeSingleGeneration[Any]],
        generation_id: GenerationID | Callable[[], GenerationID],
        permutations: list[GenerationPermutation],
        process_function: Callable[..., None],
        include_generation: bool,
        requires_generation: bool | None = None,
        **kwargs: Any,  # noqa
    ) -> None:
        """Run permutations of generation configurations.

        See the docstring for `GenerationPermutation` for more information on the possible configurations.

        Args:
            generation_class (type[HordeSingleGeneration[Any]]): The generation class to test.
            generation_id (GenerationID | Callable[[], GenerationID]): The generation ID or generation ID factory to\
                use for the test.
            permutations (list[GenerationPermutation]): The permutations to test.
            process_function (Callable[..., None]): The function to process the generation.
            include_generation (bool): Whether to include generation in the process.
            requires_generation (bool | None): Whether the generation requires generation.
            **kwargs (Any): Additional keyword arguments to pass to the process function.

        """
        for permutation in permutations:
            from horde_sdk.ai_horde_worker.consts import base_generate_progress_transitions

            transition_override: Mapping[GENERATION_PROGRESS, Iterable[GENERATION_PROGRESS]] = (
                generation_class.default_generate_progress_transitions()
            )
            if permutation.include_safety_check:
                transition_override = base_generate_progress_transitions
                logger.trace(f"Using safety check transitions for {generation_class.__name__}")
            else:
                logger.trace(
                    f"Using default transitions for {generation_class.__name__} as defined by the class"
                    "function `default_generate_progress_transitions(...)`",
                )

            if requires_generation and not generation_class.does_class_requires_generation():
                generation = generation_class(
                    generation_id=generation_id() if callable(generation_id) else generation_id,
                    requires_post_processing=permutation.include_post_processing,
                    requires_generation=requires_generation,
                    generate_progress_transitions=transition_override,
                    extra_logging=False,
                )
            else:
                generation = generation_class(
                    generation_id=generation_id() if callable(generation_id) else generation_id,
                    requires_post_processing=permutation.include_post_processing,
                    generate_progress_transitions=transition_override,
                    extra_logging=False,
                )

            if generation_class.does_class_requires_generation():
                logger.trace(
                    f"Overriding requires_generation to {requires_generation} for {generation_class.__name__}",
                )
                include_generation = True

            process_function(
                self,
                generation,
                include_generation=include_generation,
                include_safety_check=permutation.include_safety_check,
                include_preloading=permutation.include_preloading,
                include_post_processing=permutation.include_post_processing,
                **kwargs,
            )

    @pytest.mark.parametrize(
        "generation_class,process_function,include_generation,requires_generation",
        [
            (ImageSingleGeneration, process_generation, True, True),
            (ImageSingleGeneration, process_generation, True, False),
            (ImageSingleGeneration, process_generation, False, True),
            (ImageSingleGeneration, process_generation, False, False),
            (AlchemySingleGeneration, process_generation, True, True),
            (AlchemySingleGeneration, process_generation, True, False),
            (AlchemySingleGeneration, process_generation, False, True),
            (AlchemySingleGeneration, process_generation, False, False),
            (TextSingleGeneration, process_generation, True, True),
            (TextSingleGeneration, process_generation, True, False),
            (TextSingleGeneration, process_generation, False, True),
            (TextSingleGeneration, process_generation, False, False),
        ],
    )
    def test_error_handling(
        self,
        generation_class: type[HordeSingleGeneration[Any]],
        process_function: Callable[..., None],
        include_generation: bool,
        requires_generation: bool,
        id_factory: Callable[[], GenerationID],
        image_permutations: list[GenerationPermutation],
        alchemy_permutations: list[GenerationPermutation],
        text_permutations: list[GenerationPermutation],
    ) -> None:
        """Test error handling for all permutations of generation configurations."""
        error_permutations = [
            (a, b, c, d, e)
            for a in [True, False]
            for b in [True, False]
            for c in [True, False]
            for d in [True, False]
            for e in [True, False]
        ]

        permutations_map = {
            ImageSingleGeneration: image_permutations,
            AlchemySingleGeneration: alchemy_permutations,
            TextSingleGeneration: text_permutations,
        }

        for error_permutation in error_permutations:
            permutations = permutations_map.get(generation_class)
            if permutations is None:
                raise ValueError(f"Permutations not found for {generation_class.__name__}")

            try:
                self.run_generation_test_permutations(
                    generation_class,
                    id_factory,
                    permutations,
                    process_function,
                    include_generation=include_generation,
                    requires_generation=requires_generation,
                    error_on_preloading=error_permutation[0],
                    error_on_generation=error_permutation[1],
                    error_on_post_processing=error_permutation[2],
                    error_on_safety_check=error_permutation[3],
                    error_on_submit=error_permutation[4],
                )
            except Exception as e:
                logger.exception(f"Error running permutations for {generation_class.__name__}")
                logger.exception(f"Error permutation: {error_permutation}")
                logger.exception(f"Generation permutations: {permutations}")
                logger.exception(f"included generation: {include_generation}")
                logger.exception(f"requires generation: {requires_generation}")

                raise e
Validation Logic

The new validate_id method for GenerationID should be reviewed to ensure it handles edge cases and does not introduce unintended behavior, especially with the fallback to generating a UUID.

    return v

@field_validator("id_", mode="before")
def validate_id(cls, v: str | GenerationID) -> GenerationID | str:
    if isinstance(v, str) and v == "":
        logger.warning("Job ID is empty")
        return GenerationID(root=uuid.uuid4())

@CodiumAI-Agent
Copy link

PR Code Suggestions ✨

CategorySuggestion                                                                                                                                    Score
General
Add error handling for file operations

Add error handling for the open calls to ensure the program gracefully handles cases
where the files do not exist or cannot be read.

docs/build_docs.py [84-88]

-with open("docs/ai-horde/api_to_sdk_payload_map.json") as f:
-    api_to_sdk_payload_map = json.load(f)
+try:
+    with open("docs/ai-horde/api_to_sdk_payload_map.json") as f:
+        api_to_sdk_payload_map = json.load(f)
 
-with open("docs/ai-horde/api_to_sdk_response_map.json") as f:
-    api_to_sdk_response_map = json.load(f)
+    with open("docs/ai-horde/api_to_sdk_response_map.json") as f:
+        api_to_sdk_response_map = json.load(f)
+except FileNotFoundError as e:
+    logger.error(f"File not found: {e}")
+    return
+except json.JSONDecodeError as e:
+    logger.error(f"Error decoding JSON: {e}")
+    return
Suggestion importance[1-10]: 9

Why: Adding error handling for file operations ensures the program can gracefully handle missing or unreadable files, which is critical for robustness. The suggestion is accurate and directly improves the code's reliability.

9
Add error handling for file writes

Add error handling for the aiofiles.open call to ensure the program handles file
write errors gracefully.

examples/ai_horde_client/text/async_text_generate.py [91-92]

-async with aiofiles.open(example_path / f"{gen_id}_async_example.txt", "w") as f:
-    await f.write(status_response.model_dump_json(indent=4))
+try:
+    async with aiofiles.open(example_path / f"{gen_id}_async_example.txt", "w") as f:
+        await f.write(status_response.model_dump_json(indent=4))
+except OSError as e:
+    logger.error(f"Error writing to file: {e}")
Suggestion importance[1-10]: 8

Why: Adding error handling for file writes ensures the program can handle issues like permission errors or disk space problems gracefully. This improves the code's robustness and aligns with best practices.

8
Validate kudos for non-negativity

Add a check to ensure kudos in SharedKeySettings is non-negative to prevent invalid
configurations.

horde_sdk/ai_horde_api/apimodels/_sharedkeys.py [21]

 kudos: int
+@field_validator("kudos", mode="before")
+@classmethod
+def validate_kudos(cls, v: int) -> int:
+    if v < 0:
+        raise ValueError("Kudos must be non-negative")
+    return v
Suggestion importance[1-10]: 8

Why: Ensuring that kudos is non-negative prevents invalid configurations and maintains the logical consistency of the application, which is a valuable improvement.

8
Validate state transitions for correctness

Add a safeguard in the base_generate_progress_transitions dictionary to ensure no
invalid state transitions are defined, preventing logical errors in state
management.

horde_sdk/ai_horde_worker/consts.py [83-90]

+for state, transitions in base_generate_progress_transitions.items():
+    if not all(isinstance(transition, GENERATION_PROGRESS) for transition in transitions):
+        raise ValueError(f"Invalid transition states defined for {state}")
 base_generate_progress_transitions: dict[GENERATION_PROGRESS, list[GENERATION_PROGRESS]] = {
     GENERATION_PROGRESS.NOT_STARTED: [
         GENERATION_PROGRESS.PRELOADING,
         GENERATION_PROGRESS.GENERATING,
         GENERATION_PROGRESS.PENDING_POST_PROCESSING,
         GENERATION_PROGRESS.POST_PROCESSING,
     ],
     ...
 }
Suggestion importance[1-10]: 8

Why: Adding a safeguard to validate state transitions ensures logical consistency in the base_generate_progress_transitions dictionary. This prevents potential logical errors in state management and improves maintainability.

8
Improve timeout logging details

Log additional details when a timeout occurs in _handle_progress_response to aid in
debugging and monitoring.

horde_sdk/ai_horde_api/ai_horde_clients.py [586]

 logger.warning(
-    f"Timeout reached, cancelling generations still outstanding: {gen_id}: {check_response}:",
+    f"Timeout reached for GenerationID {gen_id}. Details: {check_response.log_safe_model_dump()}",
 )
Suggestion importance[1-10]: 7

Why: Enhancing the timeout logging with additional details improves debugging and monitoring capabilities, making it easier to identify and resolve issues.

7
Validate filename_base before usage

Ensure that the filename_base variable is defined before being used in the
save_image_and_json function to avoid runtime errors.

examples/ai_horde_client/image/async_simple_client_example.py [78-80]

 for image, gen_id in downloaded_images:
     filename_base = f"{gen_id}_simple_async_example"
-    save_image_and_json(image, generation, example_path, filename_base)
+    if filename_base:
+        save_image_and_json(image, generation, example_path, filename_base)
+    else:
+        logger.error("Filename base is not defined.")
Suggestion importance[1-10]: 3

Why: The suggestion to validate filename_base is unnecessary because it is always defined in the loop. Adding this check introduces redundant code without improving functionality.

3
Security
Validate gen_id before API usage

Ensure that the gen_id parameter is properly validated before being used in API
requests to prevent potential misuse or injection attacks.

horde_sdk/ai_horde_api/ai_horde_clients.py [237]

+if not isinstance(gen_id, GenerationID):
+    raise ValueError("Invalid GenerationID")
 api_request = ImageGenerateCheckRequest(id=gen_id)
Suggestion importance[1-10]: 9

Why: Adding validation for gen_id ensures that only valid GenerationID instances are used in API requests, preventing potential misuse or injection attacks. This is a critical security enhancement.

9
Validate expiry field format

Add stricter validation for the expiry field in SharedKeySettings to ensure it
follows a valid date format and is not in the past.

horde_sdk/ai_horde_api/apimodels/_sharedkeys.py [23]

 expiry: str
+@field_validator("expiry", mode="before")
+@classmethod
+def validate_expiry(cls, v: str) -> str:
+    if not is_valid_date_format(v) or is_past_date(v):
+        raise ValueError("Invalid expiry date")
+    return v
Suggestion importance[1-10]: 8

Why: Adding stricter validation for the expiry field ensures that it follows a valid date format and is not in the past, which is important for maintaining data integrity and preventing misconfigurations.

8
Possible issue
Validate generation_id during initialization

Add a validation check in the init methods of generation classes to ensure
generation_id is not empty or invalid to prevent initialization errors.

horde_sdk/ai_horde_worker/generations.py [37-41]

+if not generation_id:
+    raise ValueError("generation_id must not be empty or None")
 generation_id: GenerationID,
 requires_post_processing: bool = False,
 state_error_limits: (
     Mapping[GENERATION_PROGRESS, int] | None
 ) = HordeWorkerConfigDefaults.DEFAULT_STATE_ERROR_LIMITS,
Suggestion importance[1-10]: 9

Why: Adding a validation check for generation_id ensures that invalid or empty values are caught early during initialization, preventing potential runtime errors. This is a critical improvement for data integrity.

9
Add a null check for comparisons

Add a check in the eq method to ensure that other is not None before performing
comparisons to avoid potential runtime errors.

horde_sdk/ai_horde_api/fields.py [53-57]

-if (
+if other is None or (
     not (isinstance(self.__class__, uuid.UUID) or isinstance(other, uuid.UUID))
     and self.__class__ != other.__class__
 ):
     logger.warning(f"Comparing {self.root.__class__} with {other.__class__}")
Suggestion importance[1-10]: 8

Why: Adding a null check for other in the __eq__ method is a valid improvement to prevent potential runtime errors when None is passed. This enhances the robustness of the comparison logic.

8
Validate input for failing state checks

Ensure that the is_generation_state_failing function handles cases where progress is
not a valid GENERATION_PROGRESS enum value to prevent unexpected behavior.

horde_sdk/ai_horde_worker/consts.py [70-77]

+if not isinstance(progress, GENERATION_PROGRESS):
+    logger.error(f"Invalid progress state: {progress}")
+    return False
 return progress in {
     GENERATION_PROGRESS.ERROR,
     GENERATION_PROGRESS.ABORTED,
     GENERATION_PROGRESS.REPORTED_FAILED,
     GENERATION_PROGRESS.USER_REQUESTED_ABORT,
     GENERATION_PROGRESS.ABANDONED,
 }
Suggestion importance[1-10]: 7

Why: Ensuring that the progress parameter is a valid GENERATION_PROGRESS enum value before performing checks is a good safeguard against unexpected behavior. This improves the function's reliability and error handling.

7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants