From da0a6209b4e6c4dc6a7df06ebabc51a094b6d398 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 7 Feb 2024 10:33:14 -0500 Subject: [PATCH 01/19] Split the Gradio UI into a training page and a dataset annotation page. --- pyproject.toml | 2 + .../scripts/invoke_train_ui.py | 8 +- src/invoke_training/ui/app.py | 129 +++--------------- src/invoke_training/ui/index.html | 17 +++ src/invoke_training/ui/pages/data_page.py | 27 ++++ src/invoke_training/ui/pages/training_page.py | 121 ++++++++++++++++ 6 files changed, 188 insertions(+), 116 deletions(-) create mode 100644 src/invoke_training/ui/index.html create mode 100644 src/invoke_training/ui/pages/data_page.py create mode 100644 src/invoke_training/ui/pages/training_page.py diff --git a/pyproject.toml b/pyproject.toml index 7fcc5dff..b75a7ce6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "accelerate~=0.25.0", "datasets~=2.14.3", "diffusers~=0.25.0", + "fastapi", "gradio", "numpy", "omegaconf", @@ -32,6 +33,7 @@ dependencies = [ "torchvision", "tqdm", "transformers~=4.36.0", + "uvicorn[standard]", "xformers>=0.0.23", ] diff --git a/src/invoke_training/scripts/invoke_train_ui.py b/src/invoke_training/scripts/invoke_train_ui.py index 646b466a..9cf02432 100644 --- a/src/invoke_training/scripts/invoke_train_ui.py +++ b/src/invoke_training/scripts/invoke_train_ui.py @@ -1,9 +1,11 @@ -from invoke_training.ui.app import App +import uvicorn + +from invoke_training.ui.app import build_app def main(): - app = App() - app.launch() + app = build_app() + uvicorn.run(app) if __name__ == "__main__": diff --git a/src/invoke_training/ui/app.py b/src/invoke_training/ui/app.py index 3a344a3d..1bb3b0e2 100644 --- a/src/invoke_training/ui/app.py +++ b/src/invoke_training/ui/app.py @@ -1,121 +1,24 @@ -import os -import subprocess -import tempfile -import time +from pathlib import Path import gradio as gr -import yaml +from fastapi import FastAPI +from fastapi.responses import FileResponse -from invoke_training.config.pipeline_config import PipelineConfig -from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig -from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig -from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig -from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( - SdxlLoraAndTextualInversionConfig, -) -from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig -from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup -from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup -from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import ( - SdxlLoraAndTextualInversionConfigGroup, -) -from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup -from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup -from invoke_training.ui.pipeline_tab import PipelineTab -from invoke_training.ui.utils import get_assets_dir_path, get_config_dir_path +from invoke_training.ui.pages.data_page import DataPage +from invoke_training.ui.pages.training_page import TrainingPage -class App: - def __init__(self): - self._config_temp_directory = tempfile.TemporaryDirectory() - self._training_process = None +def build_app(): + training_page = TrainingPage() + data_page = DataPage() - logo_path = get_assets_dir_path() / "logo.png" - with gr.Blocks(title="invoke-training", analytics_enabled=False) as app: - with gr.Column(): - gr.Image( - value=logo_path, - label="Invoke Training App", - width=200, - interactive=False, - container=False, - ) - with gr.Row(): - gr.Markdown( - "*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --" - " Learn more about Invoke at [invoke.com](https://www.invoke.com/)" - ) - with gr.Tab(label="SD LoRA"): - PipelineTab( - name="SD LoRA", - default_config_file_path=str(get_config_dir_path() / "sd_lora_pokemon_1x8gb.yaml"), - pipeline_config_cls=SdLoraConfig, - config_group_cls=SdLoraConfigGroup, - run_training_cb=self._run_training, - app=app, - ) - with gr.Tab(label="SDXL LoRA"): - PipelineTab( - name="SDXL LoRA", - default_config_file_path=str(get_config_dir_path() / "sdxl_lora_pokemon_1x24gb.yaml"), - pipeline_config_cls=SdxlLoraConfig, - config_group_cls=SdxlLoraConfigGroup, - run_training_cb=self._run_training, - app=app, - ) - with gr.Tab(label="SD Textual Inversion"): - PipelineTab( - name="SD Textual Inversion", - default_config_file_path=str(get_config_dir_path() / "sd_textual_inversion_gnome_1x8gb.yaml"), - pipeline_config_cls=SdTextualInversionConfig, - config_group_cls=SdTextualInversionConfigGroup, - run_training_cb=self._run_training, - app=app, - ) - with gr.Tab(label="SDXL Textual Inversion"): - PipelineTab( - name="SDXL Textual Inversion", - default_config_file_path=str(get_config_dir_path() / "sdxl_textual_inversion_gnome_1x24gb.yaml"), - pipeline_config_cls=SdxlTextualInversionConfig, - config_group_cls=SdxlTextualInversionConfigGroup, - run_training_cb=self._run_training, - app=app, - ) - with gr.Tab(label="SDXL LoRA and Textual Inversion"): - PipelineTab( - name="SDXL LoRA and Textual Inversion", - default_config_file_path=str(get_config_dir_path() / "sdxl_lora_and_ti_gnome_1x24gb.yaml"), - pipeline_config_cls=SdxlLoraAndTextualInversionConfig, - config_group_cls=SdxlLoraAndTextualInversionConfigGroup, - run_training_cb=self._run_training, - app=app, - ) + app = FastAPI() - self._app = app + @app.get("/") + async def root(): + index_path = Path(__file__).parent / "index.html" + return FileResponse(index_path) - def launch(self): - self._app.launch() - - def _run_training(self, config: PipelineConfig): - # Check if there is already a training process running. - if self._training_process is not None: - if self._training_process.poll() is None: - print( - "Tried to start a new training process, but another training process is already running. " - "Terminate the existing process first." - ) - return - else: - self._training_process = None - - print(f"Starting {config.type} training...") - - # Write the config to a temporary config file where the training subprocess can read it. - timestamp = str(time.time()).replace(".", "_") - config_path = os.path.join(self._config_temp_directory.name, f"{timestamp}.yaml") - with open(config_path, "w") as f: - yaml.safe_dump(config.model_dump(), f, default_flow_style=False, sort_keys=False) - - self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)]) - - print(f"Started {config.type} training.") + app = gr.mount_gradio_app(app, training_page.app(), "/train") + app = gr.mount_gradio_app(app, data_page.app(), "/data") + return app diff --git a/src/invoke_training/ui/index.html b/src/invoke_training/ui/index.html new file mode 100644 index 00000000..861cfa41 --- /dev/null +++ b/src/invoke_training/ui/index.html @@ -0,0 +1,17 @@ + + + + + + invoke-training + + +

invoke-training

+

+ Training +

+

+ Dataset Annotation +

+ + diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py new file mode 100644 index 00000000..8fb531c0 --- /dev/null +++ b/src/invoke_training/ui/pages/data_page.py @@ -0,0 +1,27 @@ +import gradio as gr + +from invoke_training.ui.utils import get_assets_dir_path + + +class DataPage: + def __init__(self): + logo_path = get_assets_dir_path() / "logo.png" + with gr.Blocks(title="invoke-training", analytics_enabled=False) as app: + with gr.Column(): + gr.Image( + value=logo_path, + label="Invoke Training App", + width=200, + interactive=False, + container=False, + ) + with gr.Row(): + gr.Markdown( + "*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --" + " Learn more about Invoke at [invoke.com](https://www.invoke.com/)" + ) + + self._app = app + + def app(self): + return self._app diff --git a/src/invoke_training/ui/pages/training_page.py b/src/invoke_training/ui/pages/training_page.py new file mode 100644 index 00000000..b84cf986 --- /dev/null +++ b/src/invoke_training/ui/pages/training_page.py @@ -0,0 +1,121 @@ +import os +import subprocess +import tempfile +import time + +import gradio as gr +import yaml + +from invoke_training.config.pipeline_config import PipelineConfig +from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig +from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig +from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig +from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( + SdxlLoraAndTextualInversionConfig, +) +from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig +from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup +from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup +from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import ( + SdxlLoraAndTextualInversionConfigGroup, +) +from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup +from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup +from invoke_training.ui.pipeline_tab import PipelineTab +from invoke_training.ui.utils import get_assets_dir_path, get_config_dir_path + + +class TrainingPage: + def __init__(self): + self._config_temp_directory = tempfile.TemporaryDirectory() + self._training_process = None + + logo_path = get_assets_dir_path() / "logo.png" + with gr.Blocks(title="invoke-training", analytics_enabled=False) as app: + with gr.Column(): + gr.Image( + value=logo_path, + label="Invoke Training App", + width=200, + interactive=False, + container=False, + ) + with gr.Row(): + gr.Markdown( + "*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --" + " Learn more about Invoke at [invoke.com](https://www.invoke.com/)" + ) + with gr.Tab(label="SD LoRA"): + PipelineTab( + name="SD LoRA", + default_config_file_path=str(get_config_dir_path() / "sd_lora_pokemon_1x8gb.yaml"), + pipeline_config_cls=SdLoraConfig, + config_group_cls=SdLoraConfigGroup, + run_training_cb=self._run_training, + app=app, + ) + with gr.Tab(label="SDXL LoRA"): + PipelineTab( + name="SDXL LoRA", + default_config_file_path=str(get_config_dir_path() / "sdxl_lora_pokemon_1x24gb.yaml"), + pipeline_config_cls=SdxlLoraConfig, + config_group_cls=SdxlLoraConfigGroup, + run_training_cb=self._run_training, + app=app, + ) + with gr.Tab(label="SD Textual Inversion"): + PipelineTab( + name="SD Textual Inversion", + default_config_file_path=str(get_config_dir_path() / "sd_textual_inversion_gnome_1x8gb.yaml"), + pipeline_config_cls=SdTextualInversionConfig, + config_group_cls=SdTextualInversionConfigGroup, + run_training_cb=self._run_training, + app=app, + ) + with gr.Tab(label="SDXL Textual Inversion"): + PipelineTab( + name="SDXL Textual Inversion", + default_config_file_path=str(get_config_dir_path() / "sdxl_textual_inversion_gnome_1x24gb.yaml"), + pipeline_config_cls=SdxlTextualInversionConfig, + config_group_cls=SdxlTextualInversionConfigGroup, + run_training_cb=self._run_training, + app=app, + ) + with gr.Tab(label="SDXL LoRA and Textual Inversion"): + PipelineTab( + name="SDXL LoRA and Textual Inversion", + default_config_file_path=str(get_config_dir_path() / "sdxl_lora_and_ti_gnome_1x24gb.yaml"), + pipeline_config_cls=SdxlLoraAndTextualInversionConfig, + config_group_cls=SdxlLoraAndTextualInversionConfigGroup, + run_training_cb=self._run_training, + app=app, + ) + + self._app = app + + def app(self): + return self._app + + def _run_training(self, config: PipelineConfig): + # Check if there is already a training process running. + if self._training_process is not None: + if self._training_process.poll() is None: + print( + "Tried to start a new training process, but another training process is already running. " + "Terminate the existing process first." + ) + return + else: + self._training_process = None + + print(f"Starting {config.type} training...") + + # Write the config to a temporary config file where the training subprocess can read it. + timestamp = str(time.time()).replace(".", "_") + config_path = os.path.join(self._config_temp_directory.name, f"{timestamp}.yaml") + with open(config_path, "w") as f: + yaml.safe_dump(config.model_dump(), f, default_flow_style=False, sort_keys=False) + + self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)]) + + print(f"Started {config.type} training.") From 0d3a3872d0447d55414f25a532f1ab21b9e32dd5 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 10:17:18 -0500 Subject: [PATCH 02/19] Add utils for saving and loading .jsonl files. --- .../datasets/image_caption_jsonl_dataset.py | 13 ++++++------ .../datasets/image_pair_preference_dataset.py | 20 ++++--------------- src/invoke_training/_shared/utils/jsonl.py | 19 ++++++++++++++++++ .../datasets/test_hf_image_caption_dataset.py | 6 ++---- .../_shared/data/image_dir_fixture.py | 9 ++------- .../_shared/utils/test_jsonl.py | 13 ++++++++++++ 6 files changed, 46 insertions(+), 34 deletions(-) create mode 100644 src/invoke_training/_shared/utils/jsonl.py create mode 100644 tests/invoke_training/_shared/utils/test_jsonl.py diff --git a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py index 10a2d262..4dd85e14 100644 --- a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py +++ b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py @@ -1,4 +1,3 @@ -import json import typing from pathlib import Path @@ -6,6 +5,7 @@ from PIL import Image from invoke_training._shared.data.utils.resolution import Resolution +from invoke_training._shared.utils.jsonl import load_jsonl class ImageCaptionJsonlDataset(torch.utils.data.Dataset): @@ -22,12 +22,11 @@ def __init__( self._jsonl_path = Path(jsonl_path) self._data: list[dict[str, typing.Any]] = [] - with open(jsonl_path) as f: - while (line := f.readline()) != "": - line_json = json.loads(line) - assert image_column in line_json - assert caption_column in line_json - self._data.append(line_json) + data = load_jsonl(self._jsonl_path) + for d in data: + assert image_column in d + assert caption_column in d + self._data = data self._image_column = image_column self._caption_column = caption_column diff --git a/src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py b/src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py index 87c100df..33ab8367 100644 --- a/src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py +++ b/src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py @@ -1,4 +1,3 @@ -import json import os import typing from pathlib import Path @@ -6,22 +5,15 @@ import torch.utils.data from PIL import Image +from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl + class ImagePairPreferenceDataset(torch.utils.data.Dataset): def __init__(self, dataset_dir: str): super().__init__() self._dataset_dir = dataset_dir - self._metadata = self.load_metadata(self._dataset_dir) - - @classmethod - def load_metadata(cls, dataset_dir: Path | str) -> list[dict[str, typing.Any]]: - """Load the dataset metadata from metadata.jsonl.""" - metadata: list[dict[str, typing.Any]] = [] - with open(Path(dataset_dir) / "metadata.jsonl") as f: - while (line := f.readline()) != "": - metadata.append(json.loads(line)) - return metadata + self._metadata = load_jsonl(Path(dataset_dir) / "metadata.jsonl") @classmethod def save_metadata( @@ -29,11 +21,7 @@ def save_metadata( ) -> Path: """Load the dataset metadata from metadata.jsonl.""" metadata_path = Path(dataset_dir) / metadata_file - with open(metadata_path, "w") as f: - for m in metadata: - json.dump(m, f) - f.write("\n") - + save_jsonl(metadata, metadata_path) return metadata_path def __len__(self) -> int: diff --git a/src/invoke_training/_shared/utils/jsonl.py b/src/invoke_training/_shared/utils/jsonl.py new file mode 100644 index 00000000..85236866 --- /dev/null +++ b/src/invoke_training/_shared/utils/jsonl.py @@ -0,0 +1,19 @@ +import json +from pathlib import Path +from typing import Any + + +def load_jsonl(jsonl_path: Path | str) -> list[Any]: + """Load a JSONL file.""" + data = [] + with open(jsonl_path) as f: + while (line := f.readline()) != "": + data.append(json.loads(line)) + return data + + +def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None: + """Save a list of objects to a JSONL file.""" + with open(jsonl_path, "w") as f: + for line in data: + f.write(json.dumps(line) + "\n") diff --git a/tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py b/tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py index 602bda5d..76aa4983 100644 --- a/tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py +++ b/tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py @@ -1,4 +1,3 @@ -import json from pathlib import Path import numpy as np @@ -10,6 +9,7 @@ HFImageCaptionDataset, ) from invoke_training._shared.data.utils.resolution import Resolution +from invoke_training._shared.utils.jsonl import save_jsonl ################################################ # Tests for HFImageCaptionDataset.from_dir(...) @@ -39,9 +39,7 @@ def create_hf_imagefolder_dataset(tmp_dir: Path, num_images: int): # Write the metadata.jsonl to disk. metadata_path = tmp_dir / "metadata.jsonl" - with open(metadata_path, "w") as f: - for metadata_line in metadata: - f.write(json.dumps(metadata_line) + "\n") + save_jsonl(metadata, metadata_path) @pytest.fixture(scope="session") diff --git a/tests/invoke_training/_shared/data/image_dir_fixture.py b/tests/invoke_training/_shared/data/image_dir_fixture.py index 1442172f..92ea6dcb 100644 --- a/tests/invoke_training/_shared/data/image_dir_fixture.py +++ b/tests/invoke_training/_shared/data/image_dir_fixture.py @@ -1,10 +1,9 @@ -import json - import numpy as np import PIL.Image import pytest from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset +from invoke_training._shared.utils.jsonl import save_jsonl @pytest.fixture(scope="session") @@ -73,11 +72,7 @@ def image_caption_jsonl(tmp_path_factory: pytest.TempPathFactory): data.append({"image": str(rgb_rel_path), "text": f"caption {i}"}) data_jsonl_path = tmp_dir / "data.jsonl" - with open(data_jsonl_path, "w") as f: - for d in data: - json.dump(d, f) - f.write("\n") - + save_jsonl(data, data_jsonl_path) return data_jsonl_path diff --git a/tests/invoke_training/_shared/utils/test_jsonl.py b/tests/invoke_training/_shared/utils/test_jsonl.py new file mode 100644 index 00000000..14337f4f --- /dev/null +++ b/tests/invoke_training/_shared/utils/test_jsonl.py @@ -0,0 +1,13 @@ +from pathlib import Path + +from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl + + +def test_jsonl_roundtrip(tmp_path: Path): + in_objs = [{"a": 1, "b": 2}, {"a": 1, "b": 2}] + jsonl_path = tmp_path / "test.jsonl" + + save_jsonl(in_objs, jsonl_path) + out_objs = load_jsonl(jsonl_path) + + assert in_objs == out_objs From cf57a7595d2ae4feb942b23171f106bd556d7cfa Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 11:10:38 -0500 Subject: [PATCH 03/19] Add UI section for loading and creating a dataset for editing. --- .../datasets/image_caption_jsonl_dataset.py | 38 +++++++------ src/invoke_training/ui/pages/data_page.py | 53 ++++++++++++++++++- 2 files changed, 70 insertions(+), 21 deletions(-) diff --git a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py index 4dd85e14..aad0a87c 100644 --- a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py +++ b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py @@ -3,36 +3,34 @@ import torch.utils.data from PIL import Image +from pydantic import BaseModel from invoke_training._shared.data.utils.resolution import Resolution from invoke_training._shared.utils.jsonl import load_jsonl +class _ImageCaptionExample(BaseModel): + image_path: str + caption: str + + class ImageCaptionJsonlDataset(torch.utils.data.Dataset): """A dataset that loads images and captions from a directory of image files and .txt files.""" - def __init__( - self, - jsonl_path: str, - image_column: str = "image", - caption_column: str = "text", - ): - """Initialize an ImageCaptionDirDataset""" + def __init__(self, jsonl_path: Path, image_column: str = "image", caption_column: str = "text"): super().__init__() - - self._jsonl_path = Path(jsonl_path) - self._data: list[dict[str, typing.Any]] = [] - data = load_jsonl(self._jsonl_path) - for d in data: - assert image_column in d - assert caption_column in d - self._data = data - + self._jsonl_path = jsonl_path self._image_column = image_column self._caption_column = caption_column + data = load_jsonl(jsonl_path) + examples: list[_ImageCaptionExample] = [] + for d in data: + examples.append(_ImageCaptionExample(image_path=d[image_column], caption=d[caption_column])) + self._examples = examples + def _get_image_path(self, idx: int) -> str: - image_path = self._data[idx][self._image_column] + image_path = self._examples[idx].image_path # image_path could be either absolute, or relative to the jsonl file. if not image_path.startswith("/"): @@ -52,15 +50,15 @@ def get_image_dimensions(self) -> list[Resolution]: calculate this dynamically every time. """ image_dims: list[Resolution] = [] - for i in range(len(self._data)): + for i in range(len(self._examples)): image = Image.open(self._get_image_path(i)) image_dims.append(Resolution(image.height, image.width)) return image_dims def __len__(self) -> int: - return len(self._data) + return len(self._examples) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: image = self._load_image(self._get_image_path(idx)) - return {"id": str(idx), "image": image, "caption": self._data[idx][self._caption_column]} + return {"id": str(idx), "image": image, "caption": self._examples[idx].caption} diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 8fb531c0..843558bd 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -1,5 +1,9 @@ +from pathlib import Path + import gradio as gr +from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset +from invoke_training._shared.utils.jsonl import save_jsonl from invoke_training.ui.utils import get_assets_dir_path @@ -21,7 +25,54 @@ def __init__(self): " Learn more about Invoke at [invoke.com](https://www.invoke.com/)" ) - self._app = app + gr.Markdown("# Data Annotation") + + gr.Markdown("To get started, either create a new dataset or load an existing one.") + gr.Markdown( + "Note: This UI creates datasets in `IMAGE_CAPTION_JSONL_DATASET` format. For more information about " + "this format see [the docs](https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/)" + ) + + gr.Markdown("## Setup") + with gr.Group(): + # TODO: Expose image_column and caption_column as inputs? + self._load_path_textbox = gr.Textbox( + label=".jsonl Path", + info="Enter the path to the .jsonl file to load or create.", + placeholder="/path/to/dataset.jsonl", + ) + self._load_dataset_button = gr.Button("Load or Create Dataset") + + gr.Markdown("## Editing ") + self._current_jsonl_textbox = gr.Textbox( + label="Currently editing", interactive=False, placeholder="No dataset loaded" + ) + self._current_len_number = gr.Number(label="Dataset length", interactive=False) + + self._load_dataset_button.click( + self._on_load_dataset_button_click, + inputs=set([self._load_path_textbox]), + outputs=[self._current_jsonl_textbox, self._current_len_number], + ) + + self._app = app + + def _on_load_dataset_button_click(self, data: dict): + jsonl_path = Path(data[self._load_path_textbox]) + jsonl_path = jsonl_path.resolve() + if jsonl_path.exists(): + print(f"Loading dataset from '{jsonl_path}'.") + else: + print(f"Creating new dataset at '{jsonl_path}'.") + assert jsonl_path.suffix == ".jsonl" + jsonl_path.parent.mkdir(parents=True, exist_ok=True) + # Create an empty jsonl file. + save_jsonl([], jsonl_path) + + # Initialize the dataset to validate the jsonl file, and to get the length. + dataset = ImageCaptionJsonlDataset(jsonl_path) + + return {self._current_jsonl_textbox: jsonl_path, self._current_len_number: len(dataset)} def app(self): return self._app From 6488e386630cab626950dee13c0f47445cbe708a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 12:04:03 -0500 Subject: [PATCH 04/19] Add image display to the dataset editing UI. --- .../datasets/image_caption_jsonl_dataset.py | 7 ++- src/invoke_training/ui/pages/data_page.py | 43 +++++++++++++++---- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py index aad0a87c..1fc37fb6 100644 --- a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py +++ b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py @@ -8,6 +8,9 @@ from invoke_training._shared.data.utils.resolution import Resolution from invoke_training._shared.utils.jsonl import load_jsonl +IMAGE_COLUMN_DEFAULT = "image" +CAPTION_COLUMN_DEFAULT = "text" + class _ImageCaptionExample(BaseModel): image_path: str @@ -17,7 +20,9 @@ class _ImageCaptionExample(BaseModel): class ImageCaptionJsonlDataset(torch.utils.data.Dataset): """A dataset that loads images and captions from a directory of image files and .txt files.""" - def __init__(self, jsonl_path: Path, image_column: str = "image", caption_column: str = "text"): + def __init__( + self, jsonl_path: Path, image_column: str = IMAGE_COLUMN_DEFAULT, caption_column: str = CAPTION_COLUMN_DEFAULT + ): super().__init__() self._jsonl_path = jsonl_path self._image_column = image_column diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 843558bd..69a8ddc1 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -2,7 +2,11 @@ import gradio as gr -from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset +from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ( + CAPTION_COLUMN_DEFAULT, + IMAGE_COLUMN_DEFAULT, + ImageCaptionJsonlDataset, +) from invoke_training._shared.utils.jsonl import save_jsonl from invoke_training.ui.utils import get_assets_dir_path @@ -41,22 +45,30 @@ def __init__(self): info="Enter the path to the .jsonl file to load or create.", placeholder="/path/to/dataset.jsonl", ) + self._image_column_textbox = gr.Textbox( + label="Image Column (Optional)", placeholder=IMAGE_COLUMN_DEFAULT + ) + self._caption_column_textbox = gr.Textbox( + label="Caption Column (Optional)", placeholder=CAPTION_COLUMN_DEFAULT + ) self._load_dataset_button = gr.Button("Load or Create Dataset") - gr.Markdown("## Editing ") + gr.Markdown("## Edit") self._current_jsonl_textbox = gr.Textbox( label="Currently editing", interactive=False, placeholder="No dataset loaded" ) self._current_len_number = gr.Number(label="Dataset length", interactive=False) + self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) + self._cur_caption = gr.Textbox(label="Caption", interactive=True) + self._app = app + self._load_dataset_button.click( self._on_load_dataset_button_click, - inputs=set([self._load_path_textbox]), - outputs=[self._current_jsonl_textbox, self._current_len_number], + inputs=set([self._load_path_textbox, self._image_column_textbox, self._caption_column_textbox]), + outputs=[self._current_jsonl_textbox, self._current_len_number, self._cur_image, self._cur_caption], ) - self._app = app - def _on_load_dataset_button_click(self, data: dict): jsonl_path = Path(data[self._load_path_textbox]) jsonl_path = jsonl_path.resolve() @@ -70,9 +82,24 @@ def _on_load_dataset_button_click(self, data: dict): save_jsonl([], jsonl_path) # Initialize the dataset to validate the jsonl file, and to get the length. - dataset = ImageCaptionJsonlDataset(jsonl_path) + dataset = ImageCaptionJsonlDataset( + jsonl_path=jsonl_path, + image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, + caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, + ) + image = None + caption = None + if len(dataset) > 0: + example = dataset[0] + image = example["image"] + caption = example["caption"] - return {self._current_jsonl_textbox: jsonl_path, self._current_len_number: len(dataset)} + return { + self._current_jsonl_textbox: jsonl_path, + self._current_len_number: len(dataset), + self._cur_image: image, + self._cur_caption: caption, + } def app(self): return self._app From 4ab3e1bb8b76c297876ef1753b7949aed1a1fdd3 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 13:03:16 -0500 Subject: [PATCH 05/19] Add ability to update image captions from the dataset editing UI. --- .../datasets/image_caption_jsonl_dataset.py | 24 ++-- src/invoke_training/ui/pages/data_page.py | 126 ++++++++++++++---- 2 files changed, 114 insertions(+), 36 deletions(-) diff --git a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py index 1fc37fb6..cacba471 100644 --- a/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py +++ b/src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py @@ -6,13 +6,13 @@ from pydantic import BaseModel from invoke_training._shared.data.utils.resolution import Resolution -from invoke_training._shared.utils.jsonl import load_jsonl +from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl IMAGE_COLUMN_DEFAULT = "image" CAPTION_COLUMN_DEFAULT = "text" -class _ImageCaptionExample(BaseModel): +class ImageCaptionExample(BaseModel): image_path: str caption: str @@ -29,13 +29,19 @@ def __init__( self._caption_column = caption_column data = load_jsonl(jsonl_path) - examples: list[_ImageCaptionExample] = [] + examples: list[ImageCaptionExample] = [] for d in data: - examples.append(_ImageCaptionExample(image_path=d[image_column], caption=d[caption_column])) - self._examples = examples + examples.append(ImageCaptionExample(image_path=d[image_column], caption=d[caption_column])) + self.examples = examples + + def save_jsonl(self): + data = [] + for example in self.examples: + data.append({self._image_column: example.image_path, self._caption_column: example.caption}) + save_jsonl(data, self._jsonl_path) def _get_image_path(self, idx: int) -> str: - image_path = self._examples[idx].image_path + image_path = self.examples[idx].image_path # image_path could be either absolute, or relative to the jsonl file. if not image_path.startswith("/"): @@ -55,15 +61,15 @@ def get_image_dimensions(self) -> list[Resolution]: calculate this dynamically every time. """ image_dims: list[Resolution] = [] - for i in range(len(self._examples)): + for i in range(len(self.examples)): image = Image.open(self._get_image_path(i)) image_dims.append(Resolution(image.height, image.width)) return image_dims def __len__(self) -> int: - return len(self._examples) + return len(self.examples) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: image = self._load_image(self._get_image_path(idx)) - return {"id": str(idx), "image": image, "caption": self._examples[idx].caption} + return {"id": str(idx), "image": image, "caption": self.examples[idx].caption} diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 69a8ddc1..3ee65781 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -40,37 +40,97 @@ def __init__(self): gr.Markdown("## Setup") with gr.Group(): # TODO: Expose image_column and caption_column as inputs? - self._load_path_textbox = gr.Textbox( + self._jsonl_path_textbox = gr.Textbox( label=".jsonl Path", info="Enter the path to the .jsonl file to load or create.", placeholder="/path/to/dataset.jsonl", ) - self._image_column_textbox = gr.Textbox( - label="Image Column (Optional)", placeholder=IMAGE_COLUMN_DEFAULT - ) - self._caption_column_textbox = gr.Textbox( - label="Caption Column (Optional)", placeholder=CAPTION_COLUMN_DEFAULT - ) + with gr.Row(): + self._image_column_textbox = gr.Textbox( + label="Image Column (Optional)", placeholder=IMAGE_COLUMN_DEFAULT + ) + self._caption_column_textbox = gr.Textbox( + label="Caption Column (Optional)", placeholder=CAPTION_COLUMN_DEFAULT + ) self._load_dataset_button = gr.Button("Load or Create Dataset") - gr.Markdown("## Edit") - self._current_jsonl_textbox = gr.Textbox( - label="Currently editing", interactive=False, placeholder="No dataset loaded" - ) - self._current_len_number = gr.Number(label="Dataset length", interactive=False) + gr.Markdown("## Edit Captions") + self._cur_len_number = gr.Number(label="Dataset length", interactive=False) + self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=False) self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) self._cur_caption = gr.Textbox(label="Caption", interactive=True) + with gr.Row(): + self._save_and_prev_button = gr.Button("Save and Go-To Previous") + self._save_and_next_button = gr.Button("Save and Go-To Next") + self._app = app self._load_dataset_button.click( self._on_load_dataset_button_click, - inputs=set([self._load_path_textbox, self._image_column_textbox, self._caption_column_textbox]), - outputs=[self._current_jsonl_textbox, self._current_len_number, self._cur_image, self._cur_caption], + inputs=set([self._jsonl_path_textbox, self._image_column_textbox, self._caption_column_textbox]), + outputs=[ + self._cur_len_number, + self._cur_example_index, + self._cur_image, + self._cur_caption, + ], + ) + self._save_and_prev_button.click( + self._on_save_and_prev_button_click, + inputs=set( + [ + self._jsonl_path_textbox, + self._image_column_textbox, + self._caption_column_textbox, + self._cur_example_index, + self._cur_caption, + ] + ), + outputs=[ + self._cur_len_number, + self._cur_example_index, + self._cur_image, + self._cur_caption, + ], ) + self._save_and_next_button.click( + self._on_save_and_next_button_click, + inputs=set( + [ + self._jsonl_path_textbox, + self._image_column_textbox, + self._caption_column_textbox, + self._cur_example_index, + self._cur_caption, + ] + ), + outputs=[ + self._cur_len_number, + self._cur_example_index, + self._cur_image, + self._cur_caption, + ], + ) + + def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): + idx = idx + image = None + caption = None + if 0 <= idx and idx < len(dataset): + example = dataset[idx] + image = example["image"] + caption = example["caption"] + + return { + self._cur_len_number: len(dataset), + self._cur_example_index: idx, + self._cur_image: image, + self._cur_caption: caption, + } def _on_load_dataset_button_click(self, data: dict): - jsonl_path = Path(data[self._load_path_textbox]) + jsonl_path = Path(data[self._jsonl_path_textbox]) jsonl_path = jsonl_path.resolve() if jsonl_path.exists(): print(f"Loading dataset from '{jsonl_path}'.") @@ -87,19 +147,31 @@ def _on_load_dataset_button_click(self, data: dict): image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, ) - image = None - caption = None - if len(dataset) > 0: - example = dataset[0] - image = example["image"] - caption = example["caption"] - return { - self._current_jsonl_textbox: jsonl_path, - self._current_len_number: len(dataset), - self._cur_image: image, - self._cur_caption: caption, - } + return self._update_state(dataset, 0) + + def _on_save_and_go_button_click(self, data: dict, idx_change: int): + jsonl_path = Path(data[self._jsonl_path_textbox]) + dataset = ImageCaptionJsonlDataset( + jsonl_path=jsonl_path, + image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, + caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, + ) + + # Update the current caption and re-save the jsonl file. + idx: int = data[self._cur_example_index] + print(f"Updating caption for example {idx} of '{jsonl_path}'.") + caption = data[self._cur_caption] + dataset.examples[idx].caption = caption + dataset.save_jsonl() + + return self._update_state(dataset, idx + idx_change) + + def _on_save_and_next_button_click(self, data: dict): + return self._on_save_and_go_button_click(data, 1) + + def _on_save_and_prev_button_click(self, data: dict): + return self._on_save_and_go_button_click(data, -1) def app(self): return self._app From 303496dc75dae9e6b10faf4d9f4a314179b2c906 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 17:19:34 -0500 Subject: [PATCH 06/19] Add ability to add images to a dataset through the UI. --- src/invoke_training/ui/pages/data_page.py | 77 +++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 3ee65781..58288be5 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -5,11 +5,14 @@ from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ( CAPTION_COLUMN_DEFAULT, IMAGE_COLUMN_DEFAULT, + ImageCaptionExample, ImageCaptionJsonlDataset, ) from invoke_training._shared.utils.jsonl import save_jsonl from invoke_training.ui.utils import get_assets_dir_path +IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png"] + class DataPage: def __init__(self): @@ -54,6 +57,16 @@ def __init__(self): ) self._load_dataset_button = gr.Button("Load or Create Dataset") + gr.Markdown("## Add Images") + with gr.Group(): + self._image_source_textbox = gr.Textbox( + label="Image Source", + info="Enter the path to a single image or a directory containing images. If a directory path is " + "passed, it will be searched recursively for image files.", + placeholder="/path/to/image_dir", + ) + self._add_images_button = gr.Button("Add Images") + gr.Markdown("## Edit Captions") self._cur_len_number = gr.Number(label="Dataset length", interactive=False) @@ -113,6 +126,24 @@ def __init__(self): ], ) + self._add_images_button.click( + self._on_add_images_button_click, + inputs=set( + [ + self._jsonl_path_textbox, + self._image_column_textbox, + self._caption_column_textbox, + self._image_source_textbox, + ] + ), + outputs=[ + self._cur_len_number, + self._cur_example_index, + self._cur_image, + self._cur_caption, + ], + ) + def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): idx = idx image = None @@ -173,5 +204,51 @@ def _on_save_and_next_button_click(self, data: dict): def _on_save_and_prev_button_click(self, data: dict): return self._on_save_and_go_button_click(data, -1) + def _on_add_images_button_click(self, data: dict): + """Add images to the dataset.""" + image_source_path = Path(data[self._image_source_textbox]) + + if not image_source_path.exists(): + raise ValueError(f"'{image_source_path}' does not exist.") + + jsonl_path = Path(data[self._jsonl_path_textbox]) + dataset = ImageCaptionJsonlDataset( + jsonl_path=jsonl_path, + image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, + caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, + ) + + # Determine the list of image paths to add to the dataset. + image_paths = [] + if image_source_path.is_file(): + if image_source_path.suffix.lower() not in IMAGE_EXTENSIONS: + raise ValueError( + f"'{image_source_path}' is not a valid image file. Expected one of {IMAGE_EXTENSIONS}." + ) + + image_paths.append(image_source_path) + else: + # Recursively search for image files in the image_source_path directory. + for file_path in image_source_path.glob("**/*"): + if file_path.is_file() and file_path.suffix.lower() in IMAGE_EXTENSIONS: + image_paths.append(file_path) + + # Avoid adding duplicate images. + cur_image_paths = set([Path(example.image_path) for example in dataset.examples]) + image_paths = set(image_paths) + new_image_paths = image_paths - cur_image_paths + if len(new_image_paths) < len(image_paths): + print(f"Skipping {len(image_paths) - len(new_image_paths)} images that are already in the dataset.") + + # Add the new images to the dataset. + print(f"Adding {len(new_image_paths)} images to '{jsonl_path}'.") + for image_path in new_image_paths: + dataset.examples.append(ImageCaptionExample(image_path=str(image_path), caption="")) + + # Save the updated dataset. + dataset.save_jsonl() + + return self._update_state(dataset, 0) + def app(self): return self._app From 8bb6a2141c792efbad81ccda90f33e71d1adc23b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 17:49:56 -0500 Subject: [PATCH 07/19] Show full dataset .jsonl in the dataset annotation UI. --- src/invoke_training/_shared/utils/jsonl.py | 2 +- src/invoke_training/ui/pages/data_page.py | 44 +++++++++------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/src/invoke_training/_shared/utils/jsonl.py b/src/invoke_training/_shared/utils/jsonl.py index 85236866..ea0c2685 100644 --- a/src/invoke_training/_shared/utils/jsonl.py +++ b/src/invoke_training/_shared/utils/jsonl.py @@ -7,7 +7,7 @@ def load_jsonl(jsonl_path: Path | str) -> list[Any]: """Load a JSONL file.""" data = [] with open(jsonl_path) as f: - while (line := f.readline()) != "": + while (line := f.readline().strip()) != "": data.append(json.loads(line)) return data diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 58288be5..59a18338 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -77,17 +77,22 @@ def __init__(self): self._save_and_prev_button = gr.Button("Save and Go-To Previous") self._save_and_next_button = gr.Button("Save and Go-To Next") + gr.Markdown("## Raw JSONL") + self._data_jsonl = gr.Code(label="Dataset .jsonl", language="json", interactive=False) + self._app = app + standard_outputs = [ + self._cur_len_number, + self._cur_example_index, + self._cur_image, + self._cur_caption, + self._data_jsonl, + ] self._load_dataset_button.click( self._on_load_dataset_button_click, inputs=set([self._jsonl_path_textbox, self._image_column_textbox, self._caption_column_textbox]), - outputs=[ - self._cur_len_number, - self._cur_example_index, - self._cur_image, - self._cur_caption, - ], + outputs=standard_outputs, ) self._save_and_prev_button.click( self._on_save_and_prev_button_click, @@ -100,12 +105,7 @@ def __init__(self): self._cur_caption, ] ), - outputs=[ - self._cur_len_number, - self._cur_example_index, - self._cur_image, - self._cur_caption, - ], + outputs=standard_outputs, ) self._save_and_next_button.click( self._on_save_and_next_button_click, @@ -118,12 +118,7 @@ def __init__(self): self._cur_caption, ] ), - outputs=[ - self._cur_len_number, - self._cur_example_index, - self._cur_image, - self._cur_caption, - ], + outputs=standard_outputs, ) self._add_images_button.click( @@ -136,12 +131,7 @@ def __init__(self): self._image_source_textbox, ] ), - outputs=[ - self._cur_len_number, - self._cur_example_index, - self._cur_image, - self._cur_caption, - ], + outputs=standard_outputs, ) def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): @@ -153,11 +143,13 @@ def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): image = example["image"] caption = example["caption"] + jsonl_str = "\n".join([example.model_dump_json() for example in dataset.examples]) return { self._cur_len_number: len(dataset), self._cur_example_index: idx, self._cur_image: image, self._cur_caption: caption, + self._data_jsonl: jsonl_str, } def _on_load_dataset_button_click(self, data: dict): @@ -226,12 +218,12 @@ def _on_add_images_button_click(self, data: dict): f"'{image_source_path}' is not a valid image file. Expected one of {IMAGE_EXTENSIONS}." ) - image_paths.append(image_source_path) + image_paths.append(image_source_path.resolve()) else: # Recursively search for image files in the image_source_path directory. for file_path in image_source_path.glob("**/*"): if file_path.is_file() and file_path.suffix.lower() in IMAGE_EXTENSIONS: - image_paths.append(file_path) + image_paths.append(file_path.resolve()) # Avoid adding duplicate images. cur_image_paths = set([Path(example.image_path) for example in dataset.examples]) From aa60d0ea2c8132883574d7547d007577570311e1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 17:56:36 -0500 Subject: [PATCH 08/19] Display a warning when the index is beyond the dataset limits. --- src/invoke_training/ui/pages/data_page.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 59a18338..40048b76 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -71,6 +71,7 @@ def __init__(self): self._cur_len_number = gr.Number(label="Dataset length", interactive=False) self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=False) + self._beyond_dataset_limits_warning = gr.Markdown("**Current index is beyond dataset limits.**") self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) self._cur_caption = gr.Textbox(label="Caption", interactive=True) with gr.Row(): @@ -87,6 +88,7 @@ def __init__(self): self._cur_example_index, self._cur_image, self._cur_caption, + self._beyond_dataset_limits_warning, self._data_jsonl, ] self._load_dataset_button.click( @@ -138,7 +140,9 @@ def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): idx = idx image = None caption = None + beyond_limits = True if 0 <= idx and idx < len(dataset): + beyond_limits = False example = dataset[idx] image = example["image"] caption = example["caption"] @@ -149,6 +153,7 @@ def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): self._cur_example_index: idx, self._cur_image: image, self._cur_caption: caption, + self._beyond_dataset_limits_warning: gr.Markdown(visible=beyond_limits), self._data_jsonl: jsonl_str, } From ff83eda36dbc0a0e359596c4fd511a5ed4b9943a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 8 Feb 2024 18:18:18 -0500 Subject: [PATCH 09/19] Add ability to jump directly to any index in the dataset when editing captions. --- src/invoke_training/ui/pages/data_page.py | 30 ++++++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 40048b76..979d8d96 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -68,9 +68,9 @@ def __init__(self): self._add_images_button = gr.Button("Add Images") gr.Markdown("## Edit Captions") - self._cur_len_number = gr.Number(label="Dataset length", interactive=False) - - self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=False) + with gr.Row(): + self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=True) + self._cur_len_number = gr.Number(label="Dataset length", interactive=False) self._beyond_dataset_limits_warning = gr.Markdown("**Current index is beyond dataset limits.**") self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) self._cur_caption = gr.Textbox(label="Caption", interactive=True) @@ -122,7 +122,6 @@ def __init__(self): ), outputs=standard_outputs, ) - self._add_images_button.click( self._on_add_images_button_click, inputs=set( @@ -135,6 +134,19 @@ def __init__(self): ), outputs=standard_outputs, ) + self._cur_example_index.input( + self._on_cur_example_index_change, + inputs=set( + [ + self._jsonl_path_textbox, + self._image_column_textbox, + self._caption_column_textbox, + self._cur_example_index, + self._cur_caption, + ] + ), + outputs=standard_outputs, + ) def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): idx = idx @@ -201,6 +213,16 @@ def _on_save_and_next_button_click(self, data: dict): def _on_save_and_prev_button_click(self, data: dict): return self._on_save_and_go_button_click(data, -1) + def _on_cur_example_index_change(self, data: dict): + jsonl_path = Path(data[self._jsonl_path_textbox]) + dataset = ImageCaptionJsonlDataset( + jsonl_path=jsonl_path, + image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, + caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, + ) + + return self._update_state(dataset, data[self._cur_example_index]) + def _on_add_images_button_click(self, data: dict): """Add images to the dataset.""" image_source_path = Path(data[self._image_source_textbox]) From 222aebf47187e9d8f19c8f57843fcb1557109172 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 9 Feb 2024 13:47:19 -0500 Subject: [PATCH 10/19] Toggle between showing dataset selection and dataset editing parts of the UI. --- src/invoke_training/ui/pages/data_page.py | 193 +++++++++++----------- 1 file changed, 94 insertions(+), 99 deletions(-) diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index 979d8d96..bcadc29d 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -16,6 +16,10 @@ class DataPage: def __init__(self): + # The dataset that is currently being edited. + self._jsonl_path: str | None = None + self._dataset: ImageCaptionJsonlDataset | None = None + logo_path = get_assets_dir_path() / "logo.png" with gr.Blocks(title="invoke-training", analytics_enabled=False) as app: with gr.Column(): @@ -33,15 +37,15 @@ def __init__(self): ) gr.Markdown("# Data Annotation") - - gr.Markdown("To get started, either create a new dataset or load an existing one.") gr.Markdown( "Note: This UI creates datasets in `IMAGE_CAPTION_JSONL_DATASET` format. For more information about " "this format see [the docs](https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/)" ) - gr.Markdown("## Setup") - with gr.Group(): + # HACK: I use a column as a wrapper to control visbility of this group of UI elements. gr.Group sounds like + # a more natural choice for this purpose, but it applies some styling that makes the group look weird. + with gr.Column() as select_dataset_group: + gr.Markdown("## Setup") # TODO: Expose image_column and caption_column as inputs? self._jsonl_path_textbox = gr.Textbox( label=".jsonl Path", @@ -57,33 +61,45 @@ def __init__(self): ) self._load_dataset_button = gr.Button("Load or Create Dataset") - gr.Markdown("## Add Images") - with gr.Group(): - self._image_source_textbox = gr.Textbox( - label="Image Source", - info="Enter the path to a single image or a directory containing images. If a directory path is " - "passed, it will be searched recursively for image files.", - placeholder="/path/to/image_dir", - ) - self._add_images_button = gr.Button("Add Images") - - gr.Markdown("## Edit Captions") - with gr.Row(): - self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=True) - self._cur_len_number = gr.Number(label="Dataset length", interactive=False) - self._beyond_dataset_limits_warning = gr.Markdown("**Current index is beyond dataset limits.**") - self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) - self._cur_caption = gr.Textbox(label="Caption", interactive=True) - with gr.Row(): - self._save_and_prev_button = gr.Button("Save and Go-To Previous") - self._save_and_next_button = gr.Button("Save and Go-To Next") - - gr.Markdown("## Raw JSONL") - self._data_jsonl = gr.Code(label="Dataset .jsonl", language="json", interactive=False) + self._select_dataset_group = select_dataset_group + + # HACK: I use a column as a wrapper to control visbility of this group of UI elements. gr.Group sounds like + # a more natural choice for this purpose, but it applies some styling that makes the group look weird. + with gr.Column(visible=False) as edit_dataset_group: + with gr.Row(): + self._current_jsonl_path = gr.Textbox(label="Currently editing:", interactive=False) + self._change_dataset_button = gr.Button("Change") + gr.Markdown("## Add Images") + with gr.Group(): + self._image_source_textbox = gr.Textbox( + label="Image Source", + info="Enter the path to a single image or a directory containing images. If a directory path " + "is passed, it will be searched recursively for image files.", + placeholder="/path/to/image_dir", + ) + self._add_images_button = gr.Button("Add Images") + + gr.Markdown("## Edit Captions") + with gr.Row(): + self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=True) + self._cur_len_number = gr.Number(label="Dataset length", interactive=False) + self._beyond_dataset_limits_warning = gr.Markdown("**Current index is beyond dataset limits.**") + self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) + self._cur_caption = gr.Textbox(label="Caption", interactive=True) + with gr.Row(): + self._save_and_prev_button = gr.Button("Save and Go-To Previous") + self._save_and_next_button = gr.Button("Save and Go-To Next") + + gr.Markdown("## Raw JSONL") + self._data_jsonl = gr.Code(label="Dataset .jsonl", language="json", interactive=False) + self._edit_dataset_group = edit_dataset_group self._app = app standard_outputs = [ + self._select_dataset_group, + self._edit_dataset_group, + self._current_jsonl_path, self._cur_len_number, self._cur_example_index, self._cur_image, @@ -91,77 +107,70 @@ def __init__(self): self._beyond_dataset_limits_warning, self._data_jsonl, ] + self._load_dataset_button.click( self._on_load_dataset_button_click, inputs=set([self._jsonl_path_textbox, self._image_column_textbox, self._caption_column_textbox]), outputs=standard_outputs, ) + + self._change_dataset_button.click( + self._on_change_dataset_button_click, inputs=None, outputs=standard_outputs + ) self._save_and_prev_button.click( self._on_save_and_prev_button_click, - inputs=set( - [ - self._jsonl_path_textbox, - self._image_column_textbox, - self._caption_column_textbox, - self._cur_example_index, - self._cur_caption, - ] - ), + inputs=set([self._cur_example_index, self._cur_caption]), outputs=standard_outputs, ) + self._save_and_next_button.click( self._on_save_and_next_button_click, - inputs=set( - [ - self._jsonl_path_textbox, - self._image_column_textbox, - self._caption_column_textbox, - self._cur_example_index, - self._cur_caption, - ] - ), + inputs=set([self._cur_example_index, self._cur_caption]), outputs=standard_outputs, ) + self._add_images_button.click( self._on_add_images_button_click, - inputs=set( - [ - self._jsonl_path_textbox, - self._image_column_textbox, - self._caption_column_textbox, - self._image_source_textbox, - ] - ), + inputs=set([self._image_source_textbox]), outputs=standard_outputs, ) + self._cur_example_index.input( self._on_cur_example_index_change, - inputs=set( - [ - self._jsonl_path_textbox, - self._image_column_textbox, - self._caption_column_textbox, - self._cur_example_index, - self._cur_caption, - ] - ), + inputs=set([self._cur_example_index]), outputs=standard_outputs, ) - def _update_state(self, dataset: ImageCaptionJsonlDataset, idx: int): + def _update_state(self, idx: int): + if self._dataset is None or self._jsonl_path is None: + return { + self._select_dataset_group: gr.Group(visible=True), + self._edit_dataset_group: gr.Column(visible=False), + self._current_jsonl_path: None, + self._cur_len_number: 0, + self._cur_example_index: 0, + self._cur_image: None, + self._cur_caption: None, + self._beyond_dataset_limits_warning: gr.Markdown(visible=False), + self._data_jsonl: "", + } + idx = idx image = None caption = None beyond_limits = True - if 0 <= idx and idx < len(dataset): + if 0 <= idx and idx < len(self._dataset): beyond_limits = False - example = dataset[idx] + example = self._dataset[idx] image = example["image"] caption = example["caption"] - jsonl_str = "\n".join([example.model_dump_json() for example in dataset.examples]) + jsonl_str = "\n".join([example.model_dump_json() for example in self._dataset.examples]) return { - self._cur_len_number: len(dataset), + self._select_dataset_group: gr.Group(visible=self._dataset is None), + self._edit_dataset_group: gr.Column(visible=self._dataset is not None), + self._current_jsonl_path: str(self._jsonl_path), + self._cur_len_number: len(self._dataset), self._cur_example_index: idx, self._cur_image: image, self._cur_caption: caption, @@ -188,24 +197,24 @@ def _on_load_dataset_button_click(self, data: dict): caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, ) - return self._update_state(dataset, 0) + self._jsonl_path = jsonl_path + self._dataset = dataset + return self._update_state(0) - def _on_save_and_go_button_click(self, data: dict, idx_change: int): - jsonl_path = Path(data[self._jsonl_path_textbox]) - dataset = ImageCaptionJsonlDataset( - jsonl_path=jsonl_path, - image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, - caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, - ) + def _on_change_dataset_button_click(self): + self._jsonl_path = None + self._dataset = None + return self._update_state(0) + def _on_save_and_go_button_click(self, data: dict, idx_change: int): # Update the current caption and re-save the jsonl file. idx: int = data[self._cur_example_index] - print(f"Updating caption for example {idx} of '{jsonl_path}'.") + print(f"Updating caption for example {idx} of '{self._jsonl_path}'.") caption = data[self._cur_caption] - dataset.examples[idx].caption = caption - dataset.save_jsonl() + self._dataset.examples[idx].caption = caption + self._dataset.save_jsonl() - return self._update_state(dataset, idx + idx_change) + return self._update_state(idx + idx_change) def _on_save_and_next_button_click(self, data: dict): return self._on_save_and_go_button_click(data, 1) @@ -214,14 +223,7 @@ def _on_save_and_prev_button_click(self, data: dict): return self._on_save_and_go_button_click(data, -1) def _on_cur_example_index_change(self, data: dict): - jsonl_path = Path(data[self._jsonl_path_textbox]) - dataset = ImageCaptionJsonlDataset( - jsonl_path=jsonl_path, - image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, - caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, - ) - - return self._update_state(dataset, data[self._cur_example_index]) + return self._update_state(data[self._cur_example_index]) def _on_add_images_button_click(self, data: dict): """Add images to the dataset.""" @@ -230,13 +232,6 @@ def _on_add_images_button_click(self, data: dict): if not image_source_path.exists(): raise ValueError(f"'{image_source_path}' does not exist.") - jsonl_path = Path(data[self._jsonl_path_textbox]) - dataset = ImageCaptionJsonlDataset( - jsonl_path=jsonl_path, - image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, - caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, - ) - # Determine the list of image paths to add to the dataset. image_paths = [] if image_source_path.is_file(): @@ -253,21 +248,21 @@ def _on_add_images_button_click(self, data: dict): image_paths.append(file_path.resolve()) # Avoid adding duplicate images. - cur_image_paths = set([Path(example.image_path) for example in dataset.examples]) + cur_image_paths = set([Path(example.image_path) for example in self._dataset.examples]) image_paths = set(image_paths) new_image_paths = image_paths - cur_image_paths if len(new_image_paths) < len(image_paths): print(f"Skipping {len(image_paths) - len(new_image_paths)} images that are already in the dataset.") # Add the new images to the dataset. - print(f"Adding {len(new_image_paths)} images to '{jsonl_path}'.") + print(f"Adding {len(new_image_paths)} images to '{self._jsonl_path}'.") for image_path in new_image_paths: - dataset.examples.append(ImageCaptionExample(image_path=str(image_path), caption="")) + self._dataset.examples.append(ImageCaptionExample(image_path=str(image_path), caption="")) # Save the updated dataset. - dataset.save_jsonl() + self._dataset.save_jsonl() - return self._update_state(dataset, 0) + return self._update_state(0) def app(self): return self._app From de81ee0b01d153a85f1f403fad0395133f37e1e1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 9 Feb 2024 13:54:20 -0500 Subject: [PATCH 11/19] Fix crash when editing a dataset caption outside of the data bounds in the UI. --- src/invoke_training/ui/pages/data_page.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/invoke_training/ui/pages/data_page.py b/src/invoke_training/ui/pages/data_page.py index bcadc29d..f9c79626 100644 --- a/src/invoke_training/ui/pages/data_page.py +++ b/src/invoke_training/ui/pages/data_page.py @@ -209,6 +209,10 @@ def _on_change_dataset_button_click(self): def _on_save_and_go_button_click(self, data: dict, idx_change: int): # Update the current caption and re-save the jsonl file. idx: int = data[self._cur_example_index] + if idx < 0 or idx >= len(self._dataset): + # idx is out of bounds, so don't update the caption, but still change the index. + return self._update_state(idx + idx_change) + print(f"Updating caption for example {idx} of '{self._jsonl_path}'.") caption = data[self._cur_caption] self._dataset.examples[idx].caption = caption From 8a711f8b30289b4cefe8b05f905147bcbf83c58a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 9 Feb 2024 14:59:07 -0500 Subject: [PATCH 12/19] Tidy the main landing page for the gradio UI. --- src/invoke_training/ui/app.py | 3 ++ src/invoke_training/ui/index.html | 69 +++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/invoke_training/ui/app.py b/src/invoke_training/ui/app.py index 1bb3b0e2..418f3ce6 100644 --- a/src/invoke_training/ui/app.py +++ b/src/invoke_training/ui/app.py @@ -3,6 +3,7 @@ import gradio as gr from fastapi import FastAPI from fastapi.responses import FileResponse +from fastapi.staticfiles import StaticFiles from invoke_training.ui.pages.data_page import DataPage from invoke_training.ui.pages.training_page import TrainingPage @@ -19,6 +20,8 @@ async def root(): index_path = Path(__file__).parent / "index.html" return FileResponse(index_path) + app.mount("/assets", StaticFiles(directory=Path(__file__).parent.parent / "assets"), name="assets") + app = gr.mount_gradio_app(app, training_page.app(), "/train") app = gr.mount_gradio_app(app, data_page.app(), "/data") return app diff --git a/src/invoke_training/ui/index.html b/src/invoke_training/ui/index.html index 861cfa41..ea807328 100644 --- a/src/invoke_training/ui/index.html +++ b/src/invoke_training/ui/index.html @@ -4,14 +4,67 @@ invoke-training + + + -

invoke-training

-

- Training -

-

- Dataset Annotation -

+
+ Invoke logo. +

invoke-training

+

Invoke Training - Documentation

+

Learn more about Invoke at invoke.com

+
+ + + - + \ No newline at end of file From 8a4c35964799d7dbc18bf6f722fefa0e4c8112f0 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 9 Feb 2024 15:35:47 -0500 Subject: [PATCH 13/19] Create a shared Header gradio component. And add a favicon to the GUI. --- src/invoke_training/assets/favicon.png | Bin 0 -> 584 bytes src/invoke_training/ui/app.py | 4 +-- .../ui/gradio_blocks/header.py | 20 ++++++++++++++ .../ui/{ => gradio_blocks}/pipeline_tab.py | 0 src/invoke_training/ui/index.html | 1 + src/invoke_training/ui/pages/data_page.py | 24 +++++----------- src/invoke_training/ui/pages/training_page.py | 26 ++++++------------ 7 files changed, 39 insertions(+), 36 deletions(-) create mode 100644 src/invoke_training/assets/favicon.png create mode 100644 src/invoke_training/ui/gradio_blocks/header.py rename src/invoke_training/ui/{ => gradio_blocks}/pipeline_tab.py (100%) diff --git a/src/invoke_training/assets/favicon.png b/src/invoke_training/assets/favicon.png new file mode 100644 index 0000000000000000000000000000000000000000..0cb244f8b19b9a22fe622ba579a6d1088aa9e9d7 GIT binary patch literal 584 zcmV-O0=NB%P)y?{jCfKjZF7HY@+*;SyA$aAl_ zDe3A!yVV=nOKJoX?BnIPZV1b8oafnG1nSMi$o}3|2X3w{$Nz|w6e^I`oznx~5{+Cm zeqnnjI`C{wS0VjmgCBf1qUy;Hl#MU$;N_j>7z#i5u7W}_#zTNg936f`SOT?TG`_k1 z&e!#3k~!p>%R&P60hGT~j`OvV??T56K28ERqCzpw98tL%m={3_H4U*Eu=J4I$0cFq zj3ZPl#$+t)$N1XFe?kyK&w|pyJP!q`Q9bq_3SkKPSw2iQ7lle~VPFrNQhv9~*P<|B zh@24CkEpM=_0dj;9TCHUD)^lTt_3Xee3;tVB78WMqew~54tUM%gkVCt;p)JoPe#Mc z?ujVuNysbG6+j0}kca|=Jb)x9g$BEvYV2~VO1hld86`!`7B2A04Aa46s$Zc^wEYLr W5~|sB7iAg%0000 invoke-training +