diff --git a/nbs/animation.ipynb b/nbs/animation.ipynb index 16f58781..8e855ece 100644 --- a/nbs/animation.ipynb +++ b/nbs/animation.ipynb @@ -10,6 +10,20 @@ "# Animation SDK example" ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "cellView": "form", + "id": "eeA1mYLdxr2j" + }, + "outputs": [], + "source": [ + "#@title Install the Stability SDK\n", + "%%capture captured --no-stderr\n", + "%pip install stability-sdk[anim]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -32,31 +46,23 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "cellView": "form", "id": "zj56t6tc3prF" }, "outputs": [], "source": [ - "%%capture\n", "#@title Connect to the Stability API\n", - "\n", - "# install Stability Animation SDK for Python\n", - "%pip install stability-sdk[anim]\n", - "\n", "import datetime\n", + "import getpass\n", "import json\n", "import os\n", "import panel as pn\n", "import param\n", - "import shutil\n", - "import sys\n", "\n", "from base64 import b64encode\n", "from IPython import display\n", - "from pathlib import Path\n", - "from PIL import Image\n", "from tqdm import tqdm\n", "from types import SimpleNamespace\n", "\n", @@ -64,13 +70,16 @@ "from stability_sdk.animation import AnimationArgs, Animator\n", "from stability_sdk.utils import create_video_from_frames\n", "\n", - "\n", - "# Enter your API key from dreamstudio.ai\n", + "# @markdown To get your API key visit https://dreamstudio.ai/account\n", "STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n", - "STABILITY_KEY = \"\" #@param {type:\"string\"}\n", + "STABILITY_KEY = getpass.getpass('Enter your API Key')\n", "\n", "# Connect to Stability API\n", - "api_context = Context(STABILITY_HOST, STABILITY_KEY)" + "context = Context(STABILITY_HOST, STABILITY_KEY)\n", + "\n", + "# Test the connection\n", + "context.get_user_info()\n", + "print(\"Connection successful!\")" ] }, { @@ -84,12 +93,10 @@ "source": [ "# @title Settings\n", "\n", - "# @markdown Run this cell to reveal the settings UI. After entering values, move on to the next step.\n", + "# @markdown Run this cell to reveal the settings UI grouped across several tabs. After entering values, move on to the next step.\n", "\n", "# @markdown To reset values to default, simply re-run this cell.\n", "\n", - "# @markdown NB: Settings are grouped across several tabs.\n", - "\n", "show_documentation = True # @param {type:'boolean'}\n", "\n", "# #@markdown ####**Resume:**\n", @@ -158,6 +165,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { "id": "_SudvbZG3prI" @@ -168,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": { "id": "FT9slDSw3prJ" }, @@ -228,7 +236,7 @@ "print(f\"Saving animation frames to {out_dir}...\")\n", "\n", "animator = Animator(\n", - " api_context=api_context,\n", + " api_context=context,\n", " animation_prompts=animation_prompts,\n", " args=args,\n", " out_dir=out_dir, \n", @@ -271,13 +279,12 @@ ], "metadata": { "colab": { - "collapsed_sections": [], "provenance": [] }, "kernelspec": { - "display_name": "client", + "display_name": "venv", "language": "python", - "name": "client" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -291,13 +298,8 @@ "pygments_lexer": "ipython3", "version": "3.9.5" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c" - } - } + "orig_nbformat": 4 }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 0 } diff --git a/nbs/animation_gradio.ipynb b/nbs/animation_gradio.ipynb index c38a0db6..935471b8 100644 --- a/nbs/animation_gradio.ipynb +++ b/nbs/animation_gradio.ipynb @@ -15,21 +15,30 @@ "execution_count": null, "metadata": { "cellView": "form", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LUMF8i8BTwYH", - "outputId": "f61c635e-bc57-48ab-cc3d-5166286b158f" + "id": "enjwV3WW1yxL" + }, + "outputs": [], + "source": [ + "#@title Install the Stability SDK\n", + "%%capture captured --no-stderr\n", + "%pip install stability-sdk[anim_ui]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "LUMF8i8BTwYH" }, "outputs": [], "source": [ "#@title Mount Google Drive\n", - "import os\n", "try:\n", " from google.colab import drive\n", " drive.mount('/content/gdrive')\n", " outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n", - " os.makedirs(outputs_path, exist_ok=True)\n", + " !mkdir -p $outputs_path\n", "except:\n", " outputs_path = \".\"\n", "print(f\"Animations will be saved to {outputs_path}\")" @@ -44,18 +53,21 @@ }, "outputs": [], "source": [ - "#@title Install Animation SDK and connect to the Stability API\n", - "%pip install stability-sdk[anim_ui]\n", - "\n", + "#@title Connect to the Stability API\n", + "import getpass\n", "from stability_sdk.api import Context\n", "from stability_sdk.animation_ui import create_ui\n", "\n", - "# Enter your API key from dreamstudio.ai\n", + "# @markdown To get your API key visit https://dreamstudio.ai/account\n", "STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n", - "STABILITY_KEY = \"\" #@param {type:\"string\"}\n", + "STABILITY_KEY = getpass.getpass('Enter your API Key')\n", "\n", "# Connect to Stability API\n", - "api_context = Context(STABILITY_HOST, STABILITY_KEY)" + "context = Context(STABILITY_HOST, STABILITY_KEY)\n", + "\n", + "# Test the connection\n", + "context.get_user_info()\n", + "print(\"Connection successful!\")" ] }, { @@ -70,8 +82,7 @@ "#@title Animation UI\n", "show_ui_in_notebook = True #@param {type:\"boolean\"}\n", "\n", - "ui = create_ui(api_context, outputs_path)\n", - "\n", + "ui = create_ui(context, outputs_path)\n", "ui.queue(concurrency_count=2, max_size=2)\n", "ui.launch(show_api=False, debug=True, inline=show_ui_in_notebook, height=768, share=True, show_error=True)" ] @@ -106,5 +117,5 @@ } }, "nbformat": 4, - "nbformat_minor": 4 -} + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/src/stability_sdk/animation.py b/src/stability_sdk/animation.py index 5f183749..239b3358 100644 --- a/src/stability_sdk/animation.py +++ b/src/stability_sdk/animation.py @@ -10,7 +10,6 @@ import param import random import shutil -import subprocess from collections import OrderedDict, deque from dataclasses import dataclass, fields @@ -805,10 +804,18 @@ def setup_animation(self, resume): # select image generation model self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model + # validate dimensions + if args.width % 64 != 0 or args.height % 64 != 0: + args.width, args.height = map(lambda x: x - x % 64, (args.width, args.height)) + logger.warning(f"Adjusted dimensions to {args.width}x{args.height} to be multiples of 64.") + # validate border settings if args.border == 'wrap' and args.animation_mode != '2D': args.border = 'reflect' logger.warning(f"Border 'wrap' is only supported in 2D mode, switching to '{args.border}'.") + if args.border == 'prefill' and args.animation_mode in ('2D', '3D warp') and not args.inpaint_border: + args.border = 'reflect' + logger.warning(f"Border 'prefill' is only supported when 'inpaint_border' is enabled, switching to '{args.border}'.") # validate clip guidance setting against selected model and sampler if args.clip_guidance.lower() != 'none': @@ -817,11 +824,8 @@ def setup_animation(self, resume): logger.warning(f"CLIP guidance is not supported by {unsupported}, disabling guidance.") args.clip_guidance = 'None' - def curve_to_series(curve: str) -> List[float]: - return curve_from_cn_string(curve) - # expand key frame strings to per frame series - frame_args_dict = {f.name: curve_to_series(getattr(args, f.name)) for f in fields(FrameArgs)} + frame_args_dict = {f.name: curve_from_cn_string(getattr(args, f.name)) for f in fields(FrameArgs)} self.frame_args = FrameArgs(**frame_args_dict) # prepare sorted list of key frames @@ -953,7 +957,6 @@ def transform_video(self, frame_idx) -> Optional[Image.Image]: mask = masks[0] self.prior_frames.extend(transformed_prior_frames) self.video_prev_frame = video_next_frame - self.color_match_image = video_next_frame return mask return None diff --git a/src/stability_sdk/animation_ui.py b/src/stability_sdk/animation_ui.py index 88a2010a..d307eb4f 100644 --- a/src/stability_sdk/animation_ui.py +++ b/src/stability_sdk/animation_ui.py @@ -288,7 +288,7 @@ def post_process_tab(): video_out = gr.Video(label="video", visible=False) process_button = gr.Button("Process") stop_button = gr.Button("Stop", visible=False) - error_log = gr.Textbox(label="Error", lines=3, visible=False) + status = gr.Textbox(lines=3, visible=False) def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: int, upscale: bool, use_video_instead: bool, video_to_postprocess: str): @@ -300,15 +300,14 @@ def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: raise gr.Error("Videofile does not exist") yield { - header: gr.update(), - image_out: gr.update(visible=True, label=""), + image_out: gr.update(visible=True, label="", value=None), video_out: gr.update(visible=False), process_button: gr.update(visible=False), stop_button: gr.update(visible=True), - error_log: gr.update(visible=False), + status: gr.update(visible=False), } - error = None + error, output_video = None, None try: outdir = os.path.dirname(last_project_settings_path) \ if not use_video_instead \ @@ -333,10 +332,6 @@ def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: yield { header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(), image_out: gr.update(value=frame, label=f"upscale {frame_idx}/{num_frames}", visible=True), - video_out: gr.update(visible=False), - process_button: gr.update(visible=False), - stop_button: gr.update(visible=True), - error_log: gr.update(visible=False), } if interrupt: break @@ -354,10 +349,6 @@ def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: yield { header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(), image_out: gr.update(value=frame, label=f"interpolate {frame_idx}/{num_frames}", visible=True), - video_out: gr.update(visible=False), - process_button: gr.update(visible=False), - stop_button: gr.update(visible=True), - error_log: gr.update(visible=False), } if interrupt: break @@ -369,6 +360,8 @@ def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: else: _, video_ext = os.path.splitext(video_to_postprocess) output_video = video_to_postprocess.replace(video_ext, f"{suffix}.mp4") + + yield { status: gr.update(label="Status", value="Compiling frames to MP4...", visible=True) } create_video_from_frames(outdir, output_video, fps=fps, reverse=reverse) except Exception as e: traceback.print_exc() @@ -380,19 +373,20 @@ def postprocess_video(fps: int, reverse: bool, interp_mode: str, interp_factor: video_out: gr.update(value=output_video, visible=True), process_button: gr.update(visible=True), stop_button: gr.update(visible=False), - error_log: gr.update(value=error, visible=bool(error)) + status: gr.update(label="Error", value=error, visible=bool(error)) } process_button.click( postprocess_video, inputs=[fps, reverse, frame_interp_mode, frame_interp_factor, upscale, use_video_instead, video_to_postprocess], - outputs=[header, image_out, video_out, process_button, stop_button, error_log] + outputs=[header, image_out, video_out, process_button, stop_button, status] ) - def stop_button_click(): + def stop(): global interrupt interrupt = True - stop_button.click(stop_button_click) + return { status: gr.update(label="Status", value="Stopping...", visible=True)} + stop_button.click(stop, outputs=[status]) def project_create(title, preset): @@ -556,7 +550,7 @@ def render_tab(): video_out = gr.Video(label="video", visible=False) button = gr.Button("Render") button_stop = gr.Button("Stop", visible=False) - error_log = gr.Textbox(label="Error", lines=3, visible=False) + status = gr.Textbox(lines=3, visible=False) def render(resume: bool, resume_from: int, *render_args): global interrupt, last_interp_factor, last_interp_mode, last_project_settings_path, last_upscale, project @@ -611,10 +605,9 @@ def render(resume: bool, resume_from: int, *render_args): yield { button: gr.update(visible=False), button_stop: gr.update(visible=True), - image_out: gr.update(visible=True, label=""), + image_out: gr.update(visible=True, label="", value=None), video_out: gr.update(visible=False), - header: gr.update(), - error_log: gr.update(visible=False), + status: gr.update(visible=False), } # delete frames from previous animation @@ -641,16 +634,9 @@ def render(resume: bool, resume_from: int, *render_args): if interrupt: break - # saving frames to project - #frame_uuid = project.put_image_asset(frame) - yield { - button: gr.update(visible=False), - button_stop: gr.update(visible=True), image_out: gr.update(value=frame, label=f"frame {frame_idx}/{args.max_frames}", visible=True), - video_out: gr.update(visible=False), header: gr.update(value=format_header_html()) if frame_idx % 12 == 0 else gr.update(), - error_log: gr.update(visible=False), } except ClassifierException as e: error = "Animation terminated early due to NSFW classifier." @@ -666,6 +652,9 @@ def render(resume: bool, resume_from: int, *render_args): last_project_settings_path = project_settings_path last_interp_factor, last_interp_mode, last_upscale = None, None, None output_video = project_settings_path.replace(".json", ".mp4") + yield { + status: gr.update(label="Status", value="Compiling frames to MP4...", visible=True), + } try: create_video_from_frames(outdir, output_video, fps=args.fps, reverse=args.reverse) except RuntimeError as e: @@ -679,20 +668,21 @@ def render(resume: bool, resume_from: int, *render_args): image_out: gr.update(visible=False), video_out: gr.update(value=output_video, visible=True), header: gr.update(value=format_header_html()), - error_log: gr.update(value=error, visible=bool(error)), + status: gr.update(label="Error", value=error, visible=bool(error)), } button.click( render, inputs=[resume_checkbox, resume_from_number] + list(controls.values()), - outputs=[button, button_stop, image_out, video_out, header, error_log] + outputs=[button, button_stop, image_out, video_out, header, status] ) # stop animation in progress def stop(): global interrupt interrupt = True - button_stop.click(stop) + return { status: gr.update(label="Status", value="Stopping...", visible=True) } + button_stop.click(stop, outputs=[status]) def ui_for_animation_settings(args: AnimationSettings): with gr.Row():