-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #93 from invoke-ai/gradio-image-annotation
Add a Gradio app for dataset annotation
- Loading branch information
Showing
15 changed files
with
601 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 4 additions & 16 deletions
20
src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import json | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
|
||
def load_jsonl(jsonl_path: Path | str) -> list[Any]: | ||
"""Load a JSONL file.""" | ||
data = [] | ||
with open(jsonl_path) as f: | ||
while (line := f.readline().strip()) != "": | ||
data.append(json.loads(line)) | ||
return data | ||
|
||
|
||
def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None: | ||
"""Save a list of objects to a JSONL file.""" | ||
with open(jsonl_path, "w") as f: | ||
for line in data: | ||
f.write(json.dumps(line) + "\n") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,121 +1,27 @@ | ||
import os | ||
import subprocess | ||
import tempfile | ||
import time | ||
from pathlib import Path | ||
|
||
import gradio as gr | ||
import yaml | ||
from fastapi import FastAPI | ||
from fastapi.responses import FileResponse | ||
from fastapi.staticfiles import StaticFiles | ||
|
||
from invoke_training.config.pipeline_config import PipelineConfig | ||
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig | ||
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig | ||
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig | ||
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( | ||
SdxlLoraAndTextualInversionConfig, | ||
) | ||
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig | ||
from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup | ||
from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup | ||
from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import ( | ||
SdxlLoraAndTextualInversionConfigGroup, | ||
) | ||
from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup | ||
from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup | ||
from invoke_training.ui.pipeline_tab import PipelineTab | ||
from invoke_training.ui.utils import get_assets_dir_path, get_config_dir_path | ||
from invoke_training.ui.pages.data_page import DataPage | ||
from invoke_training.ui.pages.training_page import TrainingPage | ||
|
||
|
||
class App: | ||
def __init__(self): | ||
self._config_temp_directory = tempfile.TemporaryDirectory() | ||
self._training_process = None | ||
def build_app(): | ||
training_page = TrainingPage() | ||
data_page = DataPage() | ||
|
||
logo_path = get_assets_dir_path() / "logo.png" | ||
with gr.Blocks(title="invoke-training", analytics_enabled=False) as app: | ||
with gr.Column(): | ||
gr.Image( | ||
value=logo_path, | ||
label="Invoke Training App", | ||
width=200, | ||
interactive=False, | ||
container=False, | ||
) | ||
with gr.Row(): | ||
gr.Markdown( | ||
"*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --" | ||
" Learn more about Invoke at [invoke.com](https://www.invoke.com/)" | ||
) | ||
with gr.Tab(label="SD LoRA"): | ||
PipelineTab( | ||
name="SD LoRA", | ||
default_config_file_path=str(get_config_dir_path() / "sd_lora_pokemon_1x8gb.yaml"), | ||
pipeline_config_cls=SdLoraConfig, | ||
config_group_cls=SdLoraConfigGroup, | ||
run_training_cb=self._run_training, | ||
app=app, | ||
) | ||
with gr.Tab(label="SDXL LoRA"): | ||
PipelineTab( | ||
name="SDXL LoRA", | ||
default_config_file_path=str(get_config_dir_path() / "sdxl_lora_pokemon_1x24gb.yaml"), | ||
pipeline_config_cls=SdxlLoraConfig, | ||
config_group_cls=SdxlLoraConfigGroup, | ||
run_training_cb=self._run_training, | ||
app=app, | ||
) | ||
with gr.Tab(label="SD Textual Inversion"): | ||
PipelineTab( | ||
name="SD Textual Inversion", | ||
default_config_file_path=str(get_config_dir_path() / "sd_textual_inversion_gnome_1x8gb.yaml"), | ||
pipeline_config_cls=SdTextualInversionConfig, | ||
config_group_cls=SdTextualInversionConfigGroup, | ||
run_training_cb=self._run_training, | ||
app=app, | ||
) | ||
with gr.Tab(label="SDXL Textual Inversion"): | ||
PipelineTab( | ||
name="SDXL Textual Inversion", | ||
default_config_file_path=str(get_config_dir_path() / "sdxl_textual_inversion_gnome_1x24gb.yaml"), | ||
pipeline_config_cls=SdxlTextualInversionConfig, | ||
config_group_cls=SdxlTextualInversionConfigGroup, | ||
run_training_cb=self._run_training, | ||
app=app, | ||
) | ||
with gr.Tab(label="SDXL LoRA and Textual Inversion"): | ||
PipelineTab( | ||
name="SDXL LoRA and Textual Inversion", | ||
default_config_file_path=str(get_config_dir_path() / "sdxl_lora_and_ti_gnome_1x24gb.yaml"), | ||
pipeline_config_cls=SdxlLoraAndTextualInversionConfig, | ||
config_group_cls=SdxlLoraAndTextualInversionConfigGroup, | ||
run_training_cb=self._run_training, | ||
app=app, | ||
) | ||
app = FastAPI() | ||
|
||
self._app = app | ||
@app.get("/") | ||
async def root(): | ||
index_path = Path(__file__).parent / "index.html" | ||
return FileResponse(index_path) | ||
|
||
def launch(self): | ||
self._app.launch() | ||
app.mount("/assets", StaticFiles(directory=Path(__file__).parent.parent / "assets"), name="assets") | ||
|
||
def _run_training(self, config: PipelineConfig): | ||
# Check if there is already a training process running. | ||
if self._training_process is not None: | ||
if self._training_process.poll() is None: | ||
print( | ||
"Tried to start a new training process, but another training process is already running. " | ||
"Terminate the existing process first." | ||
) | ||
return | ||
else: | ||
self._training_process = None | ||
|
||
print(f"Starting {config.type} training...") | ||
|
||
# Write the config to a temporary config file where the training subprocess can read it. | ||
timestamp = str(time.time()).replace(".", "_") | ||
config_path = os.path.join(self._config_temp_directory.name, f"{timestamp}.yaml") | ||
with open(config_path, "w") as f: | ||
yaml.safe_dump(config.model_dump(), f, default_flow_style=False, sort_keys=False) | ||
|
||
self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)]) | ||
|
||
print(f"Started {config.type} training.") | ||
app = gr.mount_gradio_app(app, training_page.app(), "/train", app_kwargs={"favicon_path": "/assets/favicon.png"}) | ||
app = gr.mount_gradio_app(app, data_page.app(), "/data", app_kwargs={"favicon_path": "/assets/favicon.png"}) | ||
return app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import gradio as gr | ||
|
||
from invoke_training.ui.utils import get_assets_dir_path | ||
|
||
|
||
class Header: | ||
def __init__(self): | ||
logo_path = get_assets_dir_path() / "logo.png" | ||
gr.Image( | ||
value=logo_path, | ||
label="Invoke Training App", | ||
width=200, | ||
interactive=False, | ||
container=False, | ||
) | ||
gr.Markdown( | ||
"[Home](/)\n\n" | ||
"*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --" | ||
" Learn more about Invoke at [invoke.com](https://www.invoke.com/)" | ||
) |
File renamed without changes.
Oops, something went wrong.