Skip to content

Commit

Permalink
Batch of small fixes and updates to Animation SDK (#235)
Browse files Browse the repository at this point in the history
* avoid 'prefill' border when 'inpaint_border' is disabled
* Provide user more feedback when stopping anim/post and while creating mp4 from frames.
* Ensure animation dimensions are a multiple of 64
* Switch from Colab form UI to `getpass` for entering API keys.

---------

Co-authored-by: Dmitrii Tochilkin <[email protected]>
  • Loading branch information
pharmapsychotic and kostarion authored May 18, 2023
1 parent 1e85abd commit 4a0a35b
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 83 deletions.
60 changes: 31 additions & 29 deletions nbs/animation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
"# Animation SDK example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "eeA1mYLdxr2j"
},
"outputs": [],
"source": [
"#@title Install the Stability SDK\n",
"%%capture captured --no-stderr\n",
"%pip install stability-sdk[anim]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -32,45 +46,40 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "zj56t6tc3prF"
},
"outputs": [],
"source": [
"%%capture\n",
"#@title Connect to the Stability API\n",
"\n",
"# install Stability Animation SDK for Python\n",
"%pip install stability-sdk[anim]\n",
"\n",
"import datetime\n",
"import getpass\n",
"import json\n",
"import os\n",
"import panel as pn\n",
"import param\n",
"import shutil\n",
"import sys\n",
"\n",
"from base64 import b64encode\n",
"from IPython import display\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"from types import SimpleNamespace\n",
"\n",
"from stability_sdk.api import Context\n",
"from stability_sdk.animation import AnimationArgs, Animator\n",
"from stability_sdk.utils import create_video_from_frames\n",
"\n",
"\n",
"# Enter your API key from dreamstudio.ai\n",
"# @markdown To get your API key visit https://dreamstudio.ai/account\n",
"STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n",
"STABILITY_KEY = \"\" #@param {type:\"string\"}\n",
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n",
"\n",
"# Connect to Stability API\n",
"api_context = Context(STABILITY_HOST, STABILITY_KEY)"
"context = Context(STABILITY_HOST, STABILITY_KEY)\n",
"\n",
"# Test the connection\n",
"context.get_user_info()\n",
"print(\"Connection successful!\")"
]
},
{
Expand All @@ -84,12 +93,10 @@
"source": [
"# @title Settings\n",
"\n",
"# @markdown Run this cell to reveal the settings UI. After entering values, move on to the next step.\n",
"# @markdown Run this cell to reveal the settings UI grouped across several tabs. After entering values, move on to the next step.\n",
"\n",
"# @markdown To reset values to default, simply re-run this cell.\n",
"\n",
"# @markdown NB: Settings are grouped across several tabs.\n",
"\n",
"show_documentation = True # @param {type:'boolean'}\n",
"\n",
"# #@markdown ####**Resume:**\n",
Expand Down Expand Up @@ -158,6 +165,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "_SudvbZG3prI"
Expand All @@ -168,7 +176,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {
"id": "FT9slDSw3prJ"
},
Expand Down Expand Up @@ -228,7 +236,7 @@
"print(f\"Saving animation frames to {out_dir}...\")\n",
"\n",
"animator = Animator(\n",
" api_context=api_context,\n",
" api_context=context,\n",
" animation_prompts=animation_prompts,\n",
" args=args,\n",
" out_dir=out_dir, \n",
Expand Down Expand Up @@ -271,13 +279,12 @@
],
"metadata": {
"colab": {
"collapsed_sections": [],
"provenance": []
},
"kernelspec": {
"display_name": "client",
"display_name": "venv",
"language": "python",
"name": "client"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -291,13 +298,8 @@
"pygments_lexer": "ipython3",
"version": "3.9.5"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 0
}
45 changes: 28 additions & 17 deletions nbs/animation_gradio.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,30 @@
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LUMF8i8BTwYH",
"outputId": "f61c635e-bc57-48ab-cc3d-5166286b158f"
"id": "enjwV3WW1yxL"
},
"outputs": [],
"source": [
"#@title Install the Stability SDK\n",
"%%capture captured --no-stderr\n",
"%pip install stability-sdk[anim_ui]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "LUMF8i8BTwYH"
},
"outputs": [],
"source": [
"#@title Mount Google Drive\n",
"import os\n",
"try:\n",
" from google.colab import drive\n",
" drive.mount('/content/gdrive')\n",
" outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n",
" os.makedirs(outputs_path, exist_ok=True)\n",
" !mkdir -p $outputs_path\n",
"except:\n",
" outputs_path = \".\"\n",
"print(f\"Animations will be saved to {outputs_path}\")"
Expand All @@ -44,18 +53,21 @@
},
"outputs": [],
"source": [
"#@title Install Animation SDK and connect to the Stability API\n",
"%pip install stability-sdk[anim_ui]\n",
"\n",
"#@title Connect to the Stability API\n",
"import getpass\n",
"from stability_sdk.api import Context\n",
"from stability_sdk.animation_ui import create_ui\n",
"\n",
"# Enter your API key from dreamstudio.ai\n",
"# @markdown To get your API key visit https://dreamstudio.ai/account\n",
"STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n",
"STABILITY_KEY = \"\" #@param {type:\"string\"}\n",
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n",
"\n",
"# Connect to Stability API\n",
"api_context = Context(STABILITY_HOST, STABILITY_KEY)"
"context = Context(STABILITY_HOST, STABILITY_KEY)\n",
"\n",
"# Test the connection\n",
"context.get_user_info()\n",
"print(\"Connection successful!\")"
]
},
{
Expand All @@ -70,8 +82,7 @@
"#@title Animation UI\n",
"show_ui_in_notebook = True #@param {type:\"boolean\"}\n",
"\n",
"ui = create_ui(api_context, outputs_path)\n",
"\n",
"ui = create_ui(context, outputs_path)\n",
"ui.queue(concurrency_count=2, max_size=2)\n",
"ui.launch(show_api=False, debug=True, inline=show_ui_in_notebook, height=768, share=True, show_error=True)"
]
Expand Down Expand Up @@ -106,5 +117,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
}
"nbformat_minor": 0
}
15 changes: 9 additions & 6 deletions src/stability_sdk/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import param
import random
import shutil
import subprocess

from collections import OrderedDict, deque
from dataclasses import dataclass, fields
Expand Down Expand Up @@ -805,10 +804,18 @@ def setup_animation(self, resume):
# select image generation model
self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model

# validate dimensions
if args.width % 64 != 0 or args.height % 64 != 0:
args.width, args.height = map(lambda x: x - x % 64, (args.width, args.height))
logger.warning(f"Adjusted dimensions to {args.width}x{args.height} to be multiples of 64.")

# validate border settings
if args.border == 'wrap' and args.animation_mode != '2D':
args.border = 'reflect'
logger.warning(f"Border 'wrap' is only supported in 2D mode, switching to '{args.border}'.")
if args.border == 'prefill' and args.animation_mode in ('2D', '3D warp') and not args.inpaint_border:
args.border = 'reflect'
logger.warning(f"Border 'prefill' is only supported when 'inpaint_border' is enabled, switching to '{args.border}'.")

# validate clip guidance setting against selected model and sampler
if args.clip_guidance.lower() != 'none':
Expand All @@ -817,11 +824,8 @@ def setup_animation(self, resume):
logger.warning(f"CLIP guidance is not supported by {unsupported}, disabling guidance.")
args.clip_guidance = 'None'

def curve_to_series(curve: str) -> List[float]:
return curve_from_cn_string(curve)

# expand key frame strings to per frame series
frame_args_dict = {f.name: curve_to_series(getattr(args, f.name)) for f in fields(FrameArgs)}
frame_args_dict = {f.name: curve_from_cn_string(getattr(args, f.name)) for f in fields(FrameArgs)}
self.frame_args = FrameArgs(**frame_args_dict)

# prepare sorted list of key frames
Expand Down Expand Up @@ -953,7 +957,6 @@ def transform_video(self, frame_idx) -> Optional[Image.Image]:
mask = masks[0]
self.prior_frames.extend(transformed_prior_frames)
self.video_prev_frame = video_next_frame
self.color_match_image = video_next_frame
return mask
return None

Expand Down
Loading

0 comments on commit 4a0a35b

Please sign in to comment.