Skip to content

Commit

Permalink
Merge pull request #93 from invoke-ai/gradio-image-annotation
Browse files Browse the repository at this point in the history
Add a Gradio app for dataset annotation
  • Loading branch information
RyanJDick authored Feb 12, 2024
2 parents ff828eb + df50169 commit 4e7826b
Show file tree
Hide file tree
Showing 15 changed files with 601 additions and 161 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"accelerate~=0.25.0",
"datasets~=2.14.3",
"diffusers~=0.25.0",
"fastapi",
"gradio",
"numpy",
"omegaconf",
Expand All @@ -32,6 +33,7 @@ dependencies = [
"torchvision",
"tqdm",
"transformers~=4.36.0",
"uvicorn[standard]",
"xformers>=0.0.23",
]

Expand All @@ -58,8 +60,9 @@ dependencies = [
"Discord" = "https://discord.gg/ZmtBAhwWhy"

[tool.setuptools.package-data]
"invoke_training.sample_configs" = ["**/*.yaml"]
"invoke_training.assets" = ["*.png"]
"invoke_training.sample_configs" = ["**/*.yaml"]
"invoke_training.ui" = ["*.html"]

[tool.ruff]
src = ["src"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
import json
import typing
from pathlib import Path

import torch.utils.data
from PIL import Image
from pydantic import BaseModel

from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl

IMAGE_COLUMN_DEFAULT = "image"
CAPTION_COLUMN_DEFAULT = "text"


class ImageCaptionExample(BaseModel):
image_path: str
caption: str


class ImageCaptionJsonlDataset(torch.utils.data.Dataset):
"""A dataset that loads images and captions from a directory of image files and .txt files."""

def __init__(
self,
jsonl_path: str,
image_column: str = "image",
caption_column: str = "text",
jsonl_path: Path | str,
image_column: str = IMAGE_COLUMN_DEFAULT,
caption_column: str = CAPTION_COLUMN_DEFAULT,
):
"""Initialize an ImageCaptionDirDataset"""
super().__init__()

self._jsonl_path = Path(jsonl_path)
self._data: list[dict[str, typing.Any]] = []
with open(jsonl_path) as f:
while (line := f.readline()) != "":
line_json = json.loads(line)
assert image_column in line_json
assert caption_column in line_json
self._data.append(line_json)

self._image_column = image_column
self._caption_column = caption_column

data = load_jsonl(jsonl_path)
examples: list[ImageCaptionExample] = []
for d in data:
# Clear error messages here are helpful in the Gradio UI.
if image_column not in d:
raise ValueError(f"Column '{image_column}' not found in jsonl file '{jsonl_path}'.")
if caption_column not in d:
raise ValueError(f"Column '{caption_column}' not found in jsonl file '{jsonl_path}'.")
examples.append(ImageCaptionExample(image_path=d[image_column], caption=d[caption_column]))
self.examples = examples

def save_jsonl(self):
data = []
for example in self.examples:
data.append({self._image_column: example.image_path, self._caption_column: example.caption})
save_jsonl(data, self._jsonl_path)

def _get_image_path(self, idx: int) -> str:
image_path = self._data[idx][self._image_column]
image_path = self.examples[idx].image_path

# image_path could be either absolute, or relative to the jsonl file.
if not image_path.startswith("/"):
Expand All @@ -53,15 +69,15 @@ def get_image_dimensions(self) -> list[Resolution]:
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._data)):
for i in range(len(self.examples)):
image = Image.open(self._get_image_path(i))
image_dims.append(Resolution(image.height, image.width))

return image_dims

def __len__(self) -> int:
return len(self._data)
return len(self.examples)

def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
image = self._load_image(self._get_image_path(idx))
return {"id": str(idx), "image": image, "caption": self._data[idx][self._caption_column]}
return {"id": str(idx), "image": image, "caption": self.examples[idx].caption}
Original file line number Diff line number Diff line change
@@ -1,39 +1,27 @@
import json
import os
import typing
from pathlib import Path

import torch.utils.data
from PIL import Image

from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl


class ImagePairPreferenceDataset(torch.utils.data.Dataset):
def __init__(self, dataset_dir: str):
super().__init__()
self._dataset_dir = dataset_dir

self._metadata = self.load_metadata(self._dataset_dir)

@classmethod
def load_metadata(cls, dataset_dir: Path | str) -> list[dict[str, typing.Any]]:
"""Load the dataset metadata from metadata.jsonl."""
metadata: list[dict[str, typing.Any]] = []
with open(Path(dataset_dir) / "metadata.jsonl") as f:
while (line := f.readline()) != "":
metadata.append(json.loads(line))
return metadata
self._metadata = load_jsonl(Path(dataset_dir) / "metadata.jsonl")

@classmethod
def save_metadata(
cls, metadata: list[dict[str, typing.Any]], dataset_dir: str | Path, metadata_file: str = "metadata.jsonl"
) -> Path:
"""Load the dataset metadata from metadata.jsonl."""
metadata_path = Path(dataset_dir) / metadata_file
with open(metadata_path, "w") as f:
for m in metadata:
json.dump(m, f)
f.write("\n")

save_jsonl(metadata, metadata_path)
return metadata_path

def __len__(self) -> int:
Expand Down
19 changes: 19 additions & 0 deletions src/invoke_training/_shared/utils/jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json
from pathlib import Path
from typing import Any


def load_jsonl(jsonl_path: Path | str) -> list[Any]:
"""Load a JSONL file."""
data = []
with open(jsonl_path) as f:
while (line := f.readline().strip()) != "":
data.append(json.loads(line))
return data


def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None:
"""Save a list of objects to a JSONL file."""
with open(jsonl_path, "w") as f:
for line in data:
f.write(json.dumps(line) + "\n")
Binary file added src/invoke_training/assets/favicon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions src/invoke_training/scripts/invoke_train_ui.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from invoke_training.ui.app import App
import uvicorn

from invoke_training.ui.app import build_app


def main():
app = App()
app.launch()
app = build_app()
uvicorn.run(app)


if __name__ == "__main__":
Expand Down
130 changes: 18 additions & 112 deletions src/invoke_training/ui/app.py
Original file line number Diff line number Diff line change
@@ -1,121 +1,27 @@
import os
import subprocess
import tempfile
import time
from pathlib import Path

import gradio as gr
import yaml
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles

from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (
SdxlLoraAndTextualInversionConfig,
)
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig
from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup
from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup
from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import (
SdxlLoraAndTextualInversionConfigGroup,
)
from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup
from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup
from invoke_training.ui.pipeline_tab import PipelineTab
from invoke_training.ui.utils import get_assets_dir_path, get_config_dir_path
from invoke_training.ui.pages.data_page import DataPage
from invoke_training.ui.pages.training_page import TrainingPage


class App:
def __init__(self):
self._config_temp_directory = tempfile.TemporaryDirectory()
self._training_process = None
def build_app():
training_page = TrainingPage()
data_page = DataPage()

logo_path = get_assets_dir_path() / "logo.png"
with gr.Blocks(title="invoke-training", analytics_enabled=False) as app:
with gr.Column():
gr.Image(
value=logo_path,
label="Invoke Training App",
width=200,
interactive=False,
container=False,
)
with gr.Row():
gr.Markdown(
"*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --"
" Learn more about Invoke at [invoke.com](https://www.invoke.com/)"
)
with gr.Tab(label="SD LoRA"):
PipelineTab(
name="SD LoRA",
default_config_file_path=str(get_config_dir_path() / "sd_lora_pokemon_1x8gb.yaml"),
pipeline_config_cls=SdLoraConfig,
config_group_cls=SdLoraConfigGroup,
run_training_cb=self._run_training,
app=app,
)
with gr.Tab(label="SDXL LoRA"):
PipelineTab(
name="SDXL LoRA",
default_config_file_path=str(get_config_dir_path() / "sdxl_lora_pokemon_1x24gb.yaml"),
pipeline_config_cls=SdxlLoraConfig,
config_group_cls=SdxlLoraConfigGroup,
run_training_cb=self._run_training,
app=app,
)
with gr.Tab(label="SD Textual Inversion"):
PipelineTab(
name="SD Textual Inversion",
default_config_file_path=str(get_config_dir_path() / "sd_textual_inversion_gnome_1x8gb.yaml"),
pipeline_config_cls=SdTextualInversionConfig,
config_group_cls=SdTextualInversionConfigGroup,
run_training_cb=self._run_training,
app=app,
)
with gr.Tab(label="SDXL Textual Inversion"):
PipelineTab(
name="SDXL Textual Inversion",
default_config_file_path=str(get_config_dir_path() / "sdxl_textual_inversion_gnome_1x24gb.yaml"),
pipeline_config_cls=SdxlTextualInversionConfig,
config_group_cls=SdxlTextualInversionConfigGroup,
run_training_cb=self._run_training,
app=app,
)
with gr.Tab(label="SDXL LoRA and Textual Inversion"):
PipelineTab(
name="SDXL LoRA and Textual Inversion",
default_config_file_path=str(get_config_dir_path() / "sdxl_lora_and_ti_gnome_1x24gb.yaml"),
pipeline_config_cls=SdxlLoraAndTextualInversionConfig,
config_group_cls=SdxlLoraAndTextualInversionConfigGroup,
run_training_cb=self._run_training,
app=app,
)
app = FastAPI()

self._app = app
@app.get("/")
async def root():
index_path = Path(__file__).parent / "index.html"
return FileResponse(index_path)

def launch(self):
self._app.launch()
app.mount("/assets", StaticFiles(directory=Path(__file__).parent.parent / "assets"), name="assets")

def _run_training(self, config: PipelineConfig):
# Check if there is already a training process running.
if self._training_process is not None:
if self._training_process.poll() is None:
print(
"Tried to start a new training process, but another training process is already running. "
"Terminate the existing process first."
)
return
else:
self._training_process = None

print(f"Starting {config.type} training...")

# Write the config to a temporary config file where the training subprocess can read it.
timestamp = str(time.time()).replace(".", "_")
config_path = os.path.join(self._config_temp_directory.name, f"{timestamp}.yaml")
with open(config_path, "w") as f:
yaml.safe_dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)

self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)])

print(f"Started {config.type} training.")
app = gr.mount_gradio_app(app, training_page.app(), "/train", app_kwargs={"favicon_path": "/assets/favicon.png"})
app = gr.mount_gradio_app(app, data_page.app(), "/data", app_kwargs={"favicon_path": "/assets/favicon.png"})
return app
20 changes: 20 additions & 0 deletions src/invoke_training/ui/gradio_blocks/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import gradio as gr

from invoke_training.ui.utils import get_assets_dir_path


class Header:
def __init__(self):
logo_path = get_assets_dir_path() / "logo.png"
gr.Image(
value=logo_path,
label="Invoke Training App",
width=200,
interactive=False,
container=False,
)
gr.Markdown(
"[Home](/)\n\n"
"*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --"
" Learn more about Invoke at [invoke.com](https://www.invoke.com/)"
)
File renamed without changes.
Loading

0 comments on commit 4e7826b

Please sign in to comment.