Skip to content

Commit

Permalink
Merge pull request #163 from VoltaML/experimental
Browse files Browse the repository at this point in the history
Version 0.4.0
  • Loading branch information
Stax124 authored Oct 27, 2023
2 parents 4babe57 + ded547d commit 2205279
Show file tree
Hide file tree
Showing 238 changed files with 16,001 additions and 10,214 deletions.
8 changes: 5 additions & 3 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ pyproject.toml
# User generated files
/converted
/data
!/data/themes/dark.json
!/data/themes/dark_flat.json
!/data/themes/light.json
!/data/themes/light_flat.json
/engine
/onnx
/traced_unet
Expand All @@ -22,9 +26,6 @@ yarn.lock
# Frontend
frontend/dist/

# Static files
/static

# Python
/venv

Expand All @@ -42,6 +43,7 @@ test.docker-compose.yml

# Other
**/**.pyc
poetry.lock
.pytest_cache
.coverage
/.ruff_cache
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ docs/.vitepress/dist
external
/tmp
/data
/data/settings.json
/AITemplate

# Ignore for black
Expand Down
3 changes: 0 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
{
"python.linting.pylintEnabled": true,
"python.linting.enabled": true,
"python.testing.pytestArgs": ["."],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.typeCheckingMode": "basic",
"python.formatting.provider": "black",
"python.languageServer": "Pylance",
"rust-analyzer.linkedProjects": ["./manager/Cargo.toml"]
}
111 changes: 55 additions & 56 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,21 @@
from pathlib import Path

from api_analytics.fastapi import Analytics
from fastapi import Depends, FastAPI, Request
from fastapi import Depends, FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi_simple_cachecontrol.middleware import CacheControlMiddleware
from fastapi_simple_cachecontrol.types import CacheControl
from huggingface_hub.hf_api import LocalTokenNotFoundError
from starlette import status
from starlette.responses import JSONResponse

from api import websocket_manager
from api.routes import (
general,
generate,
hardware,
models,
outputs,
settings,
static,
test,
ws,
)
from api.routes import static, ws
from api.websockets.data import Data
from api.websockets.notification import Notification
from core import shared
from core.types import InferenceBackend

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,50 +89,53 @@ async def hf_token_error(_request, _exc):


@app.exception_handler(404)
async def custom_http_exception_handler(_request, _exc):
async def custom_http_exception_handler(request: Request, _exc):
"Redirect back to the main page (frontend will handle it)"

if request.url.path.startswith("/api"):
return JSONResponse(
content={
"status_code": 10404,
"message": "Not Found",
"data": None,
},
status_code=status.HTTP_404_NOT_FOUND,
)

return FileResponse("frontend/dist/index.html")


@app.on_event("startup")
async def startup_event():
"Prepare the event loop for other asynchronous tasks"

# Inject the logger
from rich.logging import RichHandler

# Disable duplicate logger
logging.getLogger("uvicorn").handlers = []

for logger_ in ("uvicorn.access", "uvicorn.error", "fastapi"):
l = logging.getLogger(logger_)
handler = RichHandler(
rich_tracebacks=True, show_time=False, omit_repeated_times=False
)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(name)s » %(message)s", datefmt="%H:%M:%S"
)
)
l.handlers = [handler]

if logger.level > logging.DEBUG:
from transformers import logging as transformers_logging

transformers_logging.set_verbosity_error()

shared.asyncio_loop = asyncio.get_event_loop()
websocket_manager.loop = shared.asyncio_loop

sync_task = asyncio.create_task(websocket_manager.sync_loop())
logger.info("Started WebSocketManager sync loop")
perf_task = asyncio.create_task(websocket_manager.perf_loop())

shared.asyncio_tasks.append(sync_task)
shared.asyncio_tasks.append(perf_task)

from core.config import config

if config.api.autoloaded_models:
from core.shared_dependent import cached_model_list, gpu

all_models = cached_model_list.all()

for model in config.api.autoloaded_models:
if model in [i.path for i in all_models]:
backend: InferenceBackend = [i.backend for i in all_models if i.path == model][0] # type: ignore
await gpu.load_model(model, backend)
else:
logger.warning(f"Autoloaded model {model} not found, skipping")

logger.info("Started WebSocketManager performance monitoring loop")
logger.info("UI Available at: http://localhost:5003/")
logger.info(f"UI Available at: http://localhost:{shared.api_port}/")


@app.on_event("shutdown")
Expand All @@ -165,42 +157,49 @@ async def shutdown_event():
# Mount routers
## HTTP
app.include_router(static.router)
app.include_router(test.router, prefix="/api/test")
app.include_router(generate.router, prefix="/api/generate")
app.include_router(hardware.router, prefix="/api/hardware")
app.include_router(models.router, prefix="/api/models")
app.include_router(outputs.router, prefix="/api/output")
app.include_router(general.router, prefix="/api/general")
app.include_router(settings.router, prefix="/api/settings")

# Walk the routes folder and mount all routers
for file in Path("api/routes").iterdir():
if file.is_file():
if (
file.name != "__init__.py"
and file.suffix == ".py"
and file.stem not in ["static", "ws"]
):
logger.debug(f"Mounting: {file} as /api/{file.stem}")
module = __import__(f"api.routes.{file.stem}", fromlist=["router"])
app.include_router(module.router, prefix=f"/api/{file.stem}")

## WebSockets
app.include_router(ws.router, prefix="/api/websockets")

# Mount outputs folder
output_folder = Path("data/outputs")
output_folder.mkdir(exist_ok=True)
app.mount("/data/outputs", StaticFiles(directory="data/outputs"), name="outputs")

# Mount static files (css, js, images, etc.)
static_app = FastAPI()
static_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
static_app.add_middleware(
CacheControlMiddleware, cache_control=CacheControl("no-cache")
)
static_app.mount("/", StaticFiles(directory="frontend/dist/assets"), name="assets")

app.mount("/assets", static_app)
app.mount("/static", StaticFiles(directory="static"), name="extra_static_files")
app.mount("/themes", StaticFiles(directory="data/themes"), name="themes")

origins = ["*"]

# Allow CORS for specified origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
static_app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
Expand Down
29 changes: 29 additions & 0 deletions api/routes/autofill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
from pathlib import Path
from typing import List

from fastapi import APIRouter

router = APIRouter(tags=["autofill"])
logger = logging.getLogger(__name__)


@router.get("/")
def get_autofill_list() -> List[str]:
"Gathers and returns all words from the prompt autofill files"

autofill_folder = Path("data/autofill")

words = []

logger.debug(f"Looking for autofill files in {autofill_folder}")
logger.debug(f"Found {list(autofill_folder.iterdir())} files")

for file in autofill_folder.iterdir():
if file.is_file():
if file.suffix == ".txt":
logger.debug(f"Found autofill file: {file}")
with open(file, "r", encoding="utf-8") as f:
words.extend(f.read().splitlines())

return list(set(words))
14 changes: 14 additions & 0 deletions api/routes/general.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from pathlib import Path

from fastapi import APIRouter

Expand Down Expand Up @@ -83,3 +84,16 @@ async def queue_clear():
queue.clear()

return {"message": "Queue cleared"}


@router.get("/themes")
async def themes():
"Get all available themes"

path = Path("data/themes")
files = []
for file in path.glob("*.json"):
files.append(file.stem)

files.sort()
return files
26 changes: 8 additions & 18 deletions api/routes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ async def txt2img_job(job: Txt2ImgQueueEntry):
time: float
images, time = await gpu.generate(job)
except ModelNotLoadedError:
raise HTTPException( # pylint: disable=raise-missing-from
status_code=400, detail="Model is not loaded"
)
raise HTTPException(status_code=400, detail="Model is not loaded")

return images_to_response(images, time)

Expand All @@ -55,9 +53,7 @@ async def img2img_job(job: Img2ImgQueueEntry):
time: float
images, time = await gpu.generate(job)
except ModelNotLoadedError:
raise HTTPException( # pylint: disable=raise-missing-from
status_code=400, detail="Model is not loaded"
)
raise HTTPException(status_code=400, detail="Model is not loaded")

return images_to_response(images, time)

Expand All @@ -79,16 +75,14 @@ async def inpaint_job(job: InpaintQueueEntry):
time: float
images, time = await gpu.generate(job)
except ModelNotLoadedError:
raise HTTPException( # pylint: disable=raise-missing-from
status_code=400, detail="Model is not loaded"
)
raise HTTPException(status_code=400, detail="Model is not loaded")

return images_to_response(images, time)


@router.post("/controlnet")
async def controlnet_job(job: ControlNetQueueEntry):
"Generate variations of the image"
"Generate images based on a reference image"

image_bytes = job.data.image
assert isinstance(image_bytes, bytes)
Expand All @@ -99,9 +93,7 @@ async def controlnet_job(job: ControlNetQueueEntry):
time: float
images, time = await gpu.generate(job)
except ModelNotLoadedError:
raise HTTPException( # pylint: disable=raise-missing-from
status_code=400, detail="Model is not loaded"
)
raise HTTPException(status_code=400, detail="Model is not loaded")

return images_to_response(images, time)

Expand All @@ -119,9 +111,7 @@ async def realesrgan_upscale_job(job: UpscaleQueueEntry):
time: float
image, time = await gpu.upscale(job)
except ModelNotLoadedError:
raise HTTPException( # pylint: disable=raise-missing-from
status_code=400, detail="Model is not loaded"
)
raise HTTPException(status_code=400, detail="Model is not loaded")

return {
"time": time,
Expand All @@ -133,7 +123,7 @@ async def realesrgan_upscale_job(job: UpscaleQueueEntry):

@router.post("/generate-aitemplate")
async def generate_aitemplate(request: AITemplateBuildRequest):
"Generate a AITemplate model from a local model"
"Generate an AITemplate model from a local model"

await gpu.build_aitemplate_engine(request)

Expand All @@ -142,7 +132,7 @@ async def generate_aitemplate(request: AITemplateBuildRequest):

@router.post("/generate-dynamic-aitemplate")
async def generate_dynamic_aitemplate(request: AITemplateDynamicBuildRequest):
"Generate a AITemplate engine from a local model"
"Generate an AITemplate engine from a local model"

await gpu.build_dynamic_aitemplate_engine(request)

Expand Down
4 changes: 1 addition & 3 deletions api/routes/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ async def gpu_memory(gpu_id: int):
gpu_data = GPUStatCollection.new_query().gpus[gpu_id]
return (gpu_data.memory_total, gpu_data.memory_free, "MB")
except IndexError:
raise HTTPException( # pylint: disable=raise-missing-from
status_code=400, detail="GPU not found"
)
raise HTTPException(status_code=400, detail="GPU not found")


@router.get("/capabilities")
Expand Down
Loading

0 comments on commit 2205279

Please sign in to comment.