diff --git a/backend/db_repo.py b/backend/db_repo.py index 68e70470..6f4e73d2 100644 --- a/backend/db_repo.py +++ b/backend/db_repo.py @@ -1029,7 +1029,7 @@ def update_app_setting(self, **kwargs): return InternalResponse({}, 'app_setting updated successfully', True) - def get_app_secrets_from_user_uuid(self, user_uuid): + def get_app_secrets_from_user_uuid(self, user_uuid, secret_access=None): if user_uuid: user: User = User.objects.filter(uuid=user_uuid, is_disabled=False).first() if not user: @@ -1412,7 +1412,11 @@ def release_lock(self, key): # shot def get_shot_from_number(self, project_uuid, shot_number=0): - shot: Shot = Shot.objects.filter(project_id=project_uuid, shot_idx=shot_number, is_disabled=False).first() + project = Project.objects.filter(uuid=project_uuid, is_disabled=False).first() + if not project: + return InternalResponse({}, 'invalid project uuid', False) + + shot: Shot = Shot.objects.filter(project_id=project.id, shot_idx=shot_number, is_disabled=False).first() if not shot: return InternalResponse({}, 'invalid shot number', False) diff --git a/banodoco_settings.py b/banodoco_settings.py index aba57f22..12e6811e 100644 --- a/banodoco_settings.py +++ b/banodoco_settings.py @@ -1,4 +1,4 @@ -import json +import copy import uuid import streamlit as st @@ -89,7 +89,7 @@ def create_new_project(user: InternalUserObject, project_name: str, width=512, h shot_data = { "project_uuid": project.uuid, "desc": "", - "duration": 2 + "duration": 10 } shot = data_repo.create_shot(**shot_data) @@ -99,7 +99,10 @@ def create_new_project(user: InternalUserObject, project_name: str, width=512, h sample_file_location = "sample_assets/sample_images/v.jpeg" img = Image.open(sample_file_location) img = img.resize((width, height)) - hosted_url = save_or_host_file(img, sample_file_location, mime_type='image/png', dim=(width, height)) + + unique_file_name = f"{str(uuid.uuid4())}.png" + file_location = f"videos/{project.uuid}/resources/prompt_images/{unique_file_name}" + hosted_url = save_or_host_file(img, file_location, mime_type='image/png', dim=(width, height)) file_data = { "name": str(uuid.uuid4()), "type": InternalFileType.IMAGE.value, @@ -110,7 +113,7 @@ def create_new_project(user: InternalUserObject, project_name: str, width=512, h if hosted_url: file_data.update({'hosted_url': hosted_url}) else: - file_data.update({'local_path': sample_file_location}) + file_data.update({'local_path': file_location}) source_image = data_repo.create_file(**file_data) @@ -148,11 +151,12 @@ def create_predefined_models(user): # create predefined models data = [] - for model in ML_MODEL_LIST: - if model['enabled']: - del model['enabled'] - model['user_id'] = user.uuid - data.append(model) + predefined_model_list = copy.deepcopy(ML_MODEL_LIST) + for m in predefined_model_list: + if 'enabled' in m and m['enabled']: + del m['enabled'] + m['user_id'] = user.uuid + data.append(m) # only creating pre-defined models for the first time available_models = data_repo.get_all_ai_model_list(\ diff --git a/sample_assets/sample_images/v.jpeg b/sample_assets/sample_images/v.jpeg index c19608b2..0ce960d9 100644 Binary files a/sample_assets/sample_images/v.jpeg and b/sample_assets/sample_images/v.jpeg differ diff --git a/shared/constants.py b/shared/constants.py index 39ff837c..3672cd81 100644 --- a/shared/constants.py +++ b/shared/constants.py @@ -116,4 +116,12 @@ class SortOrder(ExtendedEnum): ENCRYPTION_KEY = os.getenv('ENCRYPTION_KEY', 'J2684nBgNUYa_K0a6oBr5H8MpSRW0EJ52Qmq7jExE-w=') QUEUE_INFERENCE_QUERIES = True -HOSTED_BACKGROUND_RUNNER_MODE = os.getenv('HOSTED_BACKGROUND_RUNNER_MODE', False) \ No newline at end of file +HOSTED_BACKGROUND_RUNNER_MODE = os.getenv('HOSTED_BACKGROUND_RUNNER_MODE', False) + +if OFFLINE_MODE: + SECRET_ACCESS_TOKEN = os.getenv('SECRET_ACCESS_TOKEN', None) +else: + import boto3 + ssm = boto3.client("ssm", region_name="ap-south-1") + + SECRET_ACCESS_TOKEN = ssm.get_parameter(Name='/backend/banodoco/secret-access-token')['Parameter']['Value'] \ No newline at end of file diff --git a/ui_components/components/app_settings_page.py b/ui_components/components/app_settings_page.py index f1cc3cc6..6613c2a5 100644 --- a/ui_components/components/app_settings_page.py +++ b/ui_components/components/app_settings_page.py @@ -9,7 +9,6 @@ def app_settings_page(): data_repo = DataRepo() - app_secrets = data_repo.get_app_secrets_from_user_uuid() if SERVER == ServerType.DEVELOPMENT.value: st.subheader("Purchase Credits") diff --git a/ui_components/components/new_project_page.py b/ui_components/components/new_project_page.py index fbc2fb63..b079d38a 100644 --- a/ui_components/components/new_project_page.py +++ b/ui_components/components/new_project_page.py @@ -88,7 +88,7 @@ def new_project_page(): new_project_name = new_project_name.replace(" ", "_") current_user = data_repo.get_first_active_user() - new_project, shot = create_new_project(current_user, new_project_name, width, height, "Images", "Interpolation") + new_project, shot = create_new_project(current_user, new_project_name, width, height) new_timing = create_frame_inside_shot(shot.uuid, 0) if starting_image: @@ -100,7 +100,7 @@ def new_project_page(): # remvoing the initial frame which moved to the 1st position # (since creating new project also creates a frame) - shot = data_repo.get_shot_from_number(new_project.uuid, 0) + shot = data_repo.get_shot_from_number(new_project.uuid, 1) initial_frame = data_repo.get_timing_from_frame_number(shot.uuid, 0) data_repo.delete_timing_from_uuid(initial_frame.uuid) diff --git a/ui_components/components/project_settings_page.py b/ui_components/components/project_settings_page.py index a842db9d..598327c1 100644 --- a/ui_components/components/project_settings_page.py +++ b/ui_components/components/project_settings_page.py @@ -3,6 +3,7 @@ import os import time from ui_components.widgets.attach_audio_element import attach_audio_element +from PIL import Image from utils.data_repo.data_repo import DataRepo @@ -13,14 +14,23 @@ def project_settings_page(project_uuid): project_settings = data_repo.get_project_setting(project_uuid) attach_audio_element(project_uuid, True) + frame_sizes = ["512x512", "768x512", "512x768"] + current_size = f"{project_settings.width}x{project_settings.height}" + current_index = frame_sizes.index(current_size) if current_size in frame_sizes else 0 + with st.expander("Frame Size", expanded=True): - st.write("Current Size = ", - project_settings.width, "x", project_settings.height) - width = st.selectbox("Select video width", options=[ - "512", "683", "704", "1024"], key="video_width") - height = st.selectbox("Select video height", options=[ - "512", "704", "1024"], key="video_height") - if st.button("Save"): - data_repo.update_project_setting(project_uuid, width=width) - data_repo.update_project_setting(project_uuid, height=height) - st.rerun() + v1, v2, v3 = st.columns([4, 4, 2]) + with v1: + st.write("Current Size = ", project_settings.width, "x", project_settings.height) + + frame_size = st.radio("Select frame size:", options=frame_sizes, index=current_index, key="frame_size", horizontal=True) + width, height = map(int, frame_size.split('x')) + + + img = Image.new('RGB', (width, height), color = (73, 109, 137)) + st.image(img, width=70) + + if st.button("Save"): + data_repo.update_project_setting(project_uuid, width=width) + data_repo.update_project_setting(project_uuid, height=height) + st.experimental_rerun() \ No newline at end of file diff --git a/ui_components/methods/common_methods.py b/ui_components/methods/common_methods.py index 28b2235e..3100c8ae 100644 --- a/ui_components/methods/common_methods.py +++ b/ui_components/methods/common_methods.py @@ -12,7 +12,7 @@ from io import BytesIO import numpy as np import urllib3 -from shared.constants import SERVER, InferenceType, InternalFileTag, InternalFileType, ProjectMetaData, ServerType +from shared.constants import OFFLINE_MODE, SERVER, InferenceType, InternalFileTag, InternalFileType, ProjectMetaData, ServerType from pydub import AudioSegment from backend.models import InternalFileObject from shared.logging.constants import LoggingType @@ -734,6 +734,7 @@ def process_inference_output(**kwargs): inference_time = 0.0 inference_type = kwargs.get('inference_type') + log_uuid = None # ------------------- FRAME TIMING IMAGE INFERENCE ------------------- if inference_type == InferenceType.FRAME_TIMING_IMAGE_INFERENCE.value: output = kwargs.get('output') @@ -882,7 +883,7 @@ def process_inference_output(**kwargs): if inference_time: credits_used = round(inference_time * 0.004, 3) # make this more granular for different models - data_repo.update_usage_credits(-credits_used) + data_repo.update_usage_credits(-credits_used, log_uuid) return True @@ -929,7 +930,7 @@ def update_app_setting_keys(): data_repo = DataRepo() app_logger = AppLogger() - if True or SERVER == ServerType.DEVELOPMENT.value: + if OFFLINE_MODE: key = os.getenv('REPLICATE_KEY', None) else: import boto3 diff --git a/ui_components/methods/video_methods.py b/ui_components/methods/video_methods.py index 0e59c57e..506871b4 100644 --- a/ui_components/methods/video_methods.py +++ b/ui_components/methods/video_methods.py @@ -7,7 +7,8 @@ import uuid import ffmpeg import streamlit as st -from moviepy.editor import concatenate_videoclips, VideoFileClip, AudioFileClip +from moviepy.editor import concatenate_videoclips, concatenate_audioclips, VideoFileClip, AudioFileClip, CompositeVideoClip +from pydub import AudioSegment from shared.constants import InferenceType, InternalFileTag from shared.file_upload.s3 import is_s3_image_url @@ -123,6 +124,9 @@ def add_audio_to_video_slice(video_file, audio_bytes): os.rename("output_with_audio.mp4", video_location) def sync_audio_and_duration(video_file: InternalFileObject, shot_uuid, audio_sync_required=False): + ''' + audio_sync_required: this ensures that entire video clip is filled with proper audio + ''' from ui_components.methods.file_methods import convert_bytes_to_file, generate_temp_file data_repo = DataRepo() @@ -165,11 +169,32 @@ def sync_audio_and_duration(video_file: InternalFileObject, shot_uuid, audio_syn start_timestamp += round(shot_list[i - 1].duration, 2) trimmed_audio_clip = None - if audio_clip.duration >= video_clip.duration and start_timestamp + video_clip.duration <= audio_clip.duration: - trimmed_audio_clip = audio_clip.subclip(start_timestamp, start_timestamp + video_clip.duration) - video_clip = video_clip.set_audio(trimmed_audio_clip) + audio_len_overlap = 0 # length of audio that can be added to the video clip + if audio_clip.duration >= start_timestamp: + audio_len_overlap = round(min(video_clip.duration, audio_clip.duration - start_timestamp), 2) + + + # audio doesn't fit the video clip + if audio_len_overlap < video_clip.duration and audio_sync_required: + return None + + if audio_len_overlap: + trimmed_audio_clip = audio_clip.subclip(start_timestamp, start_timestamp + audio_len_overlap) + trimmed_audio_clip_duration = round(trimmed_audio_clip.duration, 2) + if trimmed_audio_clip_duration < video_clip.duration: + video_with_sound = video_clip.subclip(0, trimmed_audio_clip_duration) + video_with_sound = video_with_sound.copy() + video_without_sound = video_clip.subclip(trimmed_audio_clip_duration) + video_without_sound = video_without_sound.copy() + video_with_sound = video_with_sound.set_audio(trimmed_audio_clip) + video_clip = concatenate_videoclips([video_with_sound, video_without_sound]) + else: + video_clip = video_clip.set_audio(trimmed_audio_clip) else: - return None if audio_sync_required else output_video + for file in temp_file_list: + os.remove(file.name) + + return output_video # writing the video to the temp file output_temp_video_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") @@ -197,6 +222,8 @@ def sync_audio_and_duration(video_file: InternalFileObject, shot_uuid, audio_syn for file in temp_file_list: os.remove(file.name) + output_video = data_repo.get_file_from_uuid(output_video.uuid) + _ = data_repo.get_shot_list(shot.project.uuid, invalidate_cache=True) return output_video @@ -226,7 +253,7 @@ def render_video(final_video_name, project_uuid, file_tag=InternalFileTag.GENERA time.sleep(0.3) return False - shot_video = sync_audio_and_duration(shot.main_clip, shot.uuid, audio_sync_required=True) + shot_video = sync_audio_and_duration(shot.main_clip, shot.uuid, audio_sync_required=False) if not shot_video: st.error("Audio sync failed. Length mismatch") time.sleep(0.7) @@ -236,11 +263,11 @@ def render_video(final_video_name, project_uuid, file_tag=InternalFileTag.GENERA data_repo.add_interpolated_clip(shot.uuid, interpolated_clip_id=shot_video.uuid) temp_video_file = None - if shot.main_clip.hosted_url: - temp_video_file = generate_temp_file(shot.main_clip.hosted_url, '.mp4') + if shot_video.hosted_url: + temp_video_file = generate_temp_file(shot_video.hosted_url, '.mp4') temp_file_list.append(temp_video_file) - file_path = temp_video_file.name if temp_video_file else shot.main_clip.local_path + file_path = temp_video_file.name if temp_video_file else shot_video.local_path video_list.append(file_path) finalclip = concatenate_videoclips([VideoFileClip(v) for v in video_list]) diff --git a/ui_components/setup.py b/ui_components/setup.py index 38b5db22..b107c64d 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -172,7 +172,7 @@ def setup_app_ui(): elif st.session_state["main_view_type"] == "Tools & Settings": with st.sidebar: - tool_pages = ["Query Logger", "Custom Models", "Project Settings"] + tool_pages = ["Query Logger", "Project Settings"] if st.session_state["page"] not in tool_pages: st.session_state["page"] = tool_pages[0] diff --git a/ui_components/widgets/animation_style_element.py b/ui_components/widgets/animation_style_element.py index 51504dcf..f12625f1 100644 --- a/ui_components/widgets/animation_style_element.py +++ b/ui_components/widgets/animation_style_element.py @@ -13,42 +13,17 @@ import matplotlib.pyplot as plt def animation_style_element(shot_uuid): + motion_modules = AnimateDiffCheckpoint.get_name_list() variant_count = 1 current_animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value # setting a default value data_repo = DataRepo() - - # animation_type = st.radio("Animation Interpolation:", \ - # options=[AnimationStyleType.CREATIVE_INTERPOLATION.value, AnimationStyleType.IMAGE_TO_VIDEO.value], \ - # key="animation_tool", horizontal=True, disabled=True) - - - shot: InternalShotObject = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) st.session_state['project_uuid'] = str(shot.project.uuid) timing_list: List[InternalFrameTimingObject] = shot.timing_list - ''' - st.markdown("#### Keyframe Settings") - if timing_list and len(timing_list): - columns = st.columns(len(timing_list)) - disable_generate = False - help = "" - for idx, timing in enumerate(timing_list): - if timing.primary_image and timing.primary_image.location: - columns[idx].image(timing.primary_image.location, use_column_width=True) - b = timing.primary_image.inference_params - prompt = columns[idx].text_area(f"Prompt {idx+1}", value=(b['prompt'] if b else ""), key=f"prompt_{idx+1}") - # base_style_on_image = columns[idx].checkbox(f"Use base style image for prompt {idx+1}", key=f"base_style_image_{idx+1}",value=True) - else: - columns[idx].warning("No primary image present") - disable_generate = True - help = "You can't generate a video because one of your keyframes is missing an image." - else: - st.warning("No keyframes present") - - st.markdown("***") - ''' + + video_resolution = None settings = { @@ -56,79 +31,116 @@ def animation_style_element(shot_uuid): } - st.markdown("#### Keyframe Influence Settings") - d1, d2 = st.columns([1.5, 4]) - + st.markdown("#### Key Frame Settings") + d1, d2 = st.columns([1, 4]) + st.session_state['frame_position'] = 0 with d1: - setting_a_1, setting_a_2 = st.columns([1, 1]) + setting_a_1, setting_a_2, = st.columns([1, 1]) with setting_a_1: - type_of_frame_distribution = st_memory.radio("Type of Frame Distribution", options=["linear", "dynamic"], key="type_of_frame_distribution").lower() + type_of_frame_distribution = st_memory.radio("Type of key frame distribution:", options=["Linear", "Dynamic"], key="type_of_frame_distribution").lower() if type_of_frame_distribution == "linear": with setting_a_2: - linear_frame_distribution_value = st_memory.number_input("Frames per Keyframe", min_value=8, max_value=36, value=16, step=1, key="frames_per_keyframe") + linear_frame_distribution_value = st_memory.number_input("Frames per key frame:", min_value=8, max_value=36, value=16, step=1, key="frames_per_keyframe") dynamic_frame_distribution_values = [] + st.markdown("***") setting_b_1, setting_b_2 = st.columns([1, 1]) with setting_b_1: - type_of_key_frame_influence = st_memory.radio("Type of Keyframe Influence", options=["linear", "dynamic"], key="type_of_key_frame_influence").lower() + type_of_key_frame_influence = st_memory.radio("Type of key frame length influence:", options=["Linear", "Dynamic"], key="type_of_key_frame_influence").lower() if type_of_key_frame_influence == "linear": with setting_b_2: - linear_key_frame_influence_value = st_memory.slider("Length of Keyframe Influence", min_value=0.0, max_value=2.0, value=1.1, step=0.1, key="length_of_key_frame_influence") + linear_key_frame_influence_value = st_memory.slider("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.1, key="length_of_key_frame_influence") dynamic_key_frame_influence_values = [] - setting_c_1, setting_c_2 = st.columns([1, 1]) - with setting_c_1: - type_of_cn_strength_distribution = st_memory.radio("Type of CN Strength Distribution", options=["linear", "dynamic"], key="type_of_cn_strength_distribution").lower() + st.markdown("***") + + setting_d_1, setting_d_2 = st.columns([1, 1]) + + with setting_d_1: + type_of_cn_strength_distribution = st_memory.radio("Type of key frame strength control:", options=["Linear", "Dynamic"], key="type_of_cn_strength_distribution").lower() if type_of_cn_strength_distribution == "linear": - with setting_c_2: - linear_cn_strength_value = st_memory.slider("CN Strength", min_value=0.0, max_value=1.0, value=0.5, step=0.1, key="linear_cn_strength_value") + with setting_d_2: + linear_cn_strength_value = st_memory.slider("Range of strength:", min_value=0.0, max_value=1.0, value=(0.0,0.7), step=0.1, key="linear_cn_strength_value") dynamic_cn_strength_values = [] - - # length_of_key_frame_influence = st_memory.slider("Length of Keyframe Influence", min_value=0.0, max_value=2.0, value=1.1, step=0.1, key="length_of_key_frame_influence") - interpolation_style = st_memory.selectbox("Interpolation Style", options=["ease-in-out", "ease-in", "ease-out", "linear"], key="interpolation_style") - motion_scale = st_memory.slider("Motion Scale", min_value=0.0, max_value=2.0, value=1.1, step=0.1, key="motion_scale") + st.markdown("***") + footer1, _ = st.columns([2, 1]) + with footer1: + interpolation_style = 'ease-in-out' + motion_scale = st_memory.slider("Motion scale:", min_value=0.0, max_value=2.0, value=1.0, step=0.1, key="motion_scale") + + st.markdown("***") + if st.button("Reset to default settings", key="reset_animation_style"): + update_interpolation_settings(timing_list=timing_list) + st.rerun() with d2: - - columns = st.columns(max(5, len(timing_list))) + columns = st.columns(max(7, len(timing_list))) disable_generate = False help = "" dynamic_frame_distribution_values = [] dynamic_key_frame_influence_values = [] - dynamic_cn_strength_values = [] - - if type_of_frame_distribution == "dynamic" or type_of_key_frame_influence == "dynamic" or type_of_cn_strength_distribution == "dynamic": - for idx, timing in enumerate(timing_list): - if timing.primary_image and timing.primary_image.location: - columns[idx].info(f"Frame {idx+1}") - columns[idx].image(timing.primary_image.location, use_column_width=True) - b = timing.primary_image.inference_params - if type_of_frame_distribution == "dynamic": - linear_frame_distribution_value = 16 - if f"frame_{idx+1}" not in st.session_state: - st.session_state[f"frame_{idx+1}"] = idx * 16 # Default values in increments of 16 - if idx == 0: # For the first frame, position is locked to 0 - frame_position = columns[idx].number_input(f"Frame Position {idx+1}", min_value=0, max_value=0, value=0, step=1, key=f"dynamic_frame_distribution_values_{idx+1}", disabled=True) - else: - min_value = st.session_state[f"frame_{idx}"] + 1 - frame_position = columns[idx].number_input(f"Frame Position {idx+1}", min_value=min_value, value=st.session_state[f"frame_{idx+1}"], step=1, key=f"dynamic_frame_distribution_values_{idx+1}") - st.session_state[f"frame_{idx+1}"] = frame_position - dynamic_frame_distribution_values.append(frame_position) - if type_of_key_frame_influence == "dynamic": - linear_key_frame_influence_value = 1.1 - dynamic_key_frame_influence_individual_value = columns[idx].slider(f"Length of Keyframe Influence {idx+1}", min_value=0.0, max_value=5.0, value=(b['dynamic_key_frame_influence_values'] if b and 'dynamic_key_frame_influence_values' in b else 1.1), step=0.1, key=f"dynamic_key_frame_influence_values_{idx+1}") - dynamic_key_frame_influence_values.append(str(dynamic_key_frame_influence_individual_value)) - if type_of_cn_strength_distribution == "dynamic": - linear_cn_strength_value = 1 - dynamic_cn_strength_individual_value = columns[idx].slider(f"CN Strength {idx+1}", min_value=0.0, max_value=1.0, value=(b['dynamic_cn_strength_values'] if b and 'dynamic_cn_strength_values' in b else 0.5), step=0.1, key=f"dynamic_cn_strength_values_{idx+1}") - dynamic_cn_strength_values.append(str(dynamic_cn_strength_individual_value)) + dynamic_cn_strength_values = [] + mpl_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + + color_mapping = { + '#1f77b4': 'blue', '#ff7f0e': 'orange', '#2ca02c': 'green', + '#d62728': 'red', '#9467bd': 'purple', '#8c564b': 'brown', + '#e377c2': 'pink', '#7f7f7f': 'gray', '#bcbd22': 'olive', + '#17becf': 'cyan' + } + + streamlit_color_names = [color_mapping.get(color, 'black') for color in mpl_colors] + + for idx, timing in enumerate(timing_list): + # Only create markdown text for the current index + markdown_text = f'##### :{streamlit_color_names[idx]}[**Frame {idx + 1}** ___]' + + with columns[idx]: + st.markdown(markdown_text) + + if timing.primary_image and timing.primary_image.location: + columns[idx].image(timing.primary_image.location, use_column_width=True) + b = timing.primary_image.inference_params + if type_of_frame_distribution == "dynamic": + linear_frame_distribution_value = 16 + if f"frame_{idx+1}" not in st.session_state: + st.session_state[f"frame_{idx+1}"] = idx * 16 # Default values in increments of 16 + if idx == 0: # For the first frame, position is locked to 0 + with columns[idx]: + frame_position = st_memory.number_input(f"{idx+1} frame Position", min_value=0, max_value=0, value=0, step=1, key=f"dynamic_frame_distribution_values_{idx+1}", disabled=True) + else: + min_value = st.session_state[f"frame_{idx}"] + 1 + with columns[idx]: + frame_position = st_memory.number_input(f"#{idx+1} position:", min_value=min_value, value=st.session_state[f"frame_{idx+1}"], step=1, key=f"dynamic_frame_distribution_values_{idx+1}") + # st.session_state[f"frame_{idx+1}"] = frame_position + dynamic_frame_distribution_values.append(frame_position) + + if type_of_key_frame_influence == "dynamic": + linear_key_frame_influence_value = 1.1 + with columns[idx]: + dynamic_key_frame_influence_individual_value = st_memory.slider(f"#{idx+1} length of influence:", min_value=0.0, max_value=5.0, value=1.0, step=0.1, key=f"dynamic_key_frame_influence_values_{idx}") + dynamic_key_frame_influence_values.append(str(dynamic_key_frame_influence_individual_value)) + + if type_of_cn_strength_distribution == "dynamic": + linear_cn_strength_value = (0.0,1.0) + with columns[idx]: + help_texts = ["For the first frame, it'll start at the endpoint and decline to the starting point", + "For the final frame, it'll start at the starting point and end at the endpoint", + "For intermediate frames, it'll start at the starting point, peak in the middle at the endpoint, and decline to the starting point"] + label_texts = [f"#{idx+1} end -> start:", f"#{idx+1} start -> end:", f"#{idx+1} start -> peak:"] + help_text = help_texts[0] if idx == 0 else help_texts[1] if idx == len(timing_list) - 1 else help_texts[2] + label_text = label_texts[0] if idx == 0 else label_texts[1] if idx == len(timing_list) - 1 else label_texts[2] + dynamic_cn_strength_individual_value = st_memory.slider(label_text, min_value=0.0, max_value=1.0, value=(0.0,0.7), step=0.1, key=f"dynamic_cn_strength_values_{idx}",help=help_text) + dynamic_cn_strength_values.append(str(dynamic_cn_strength_individual_value)) # Convert lists to strings dynamic_frame_distribution_values = ",".join(map(str, dynamic_frame_distribution_values)) # Convert integers to strings before joining dynamic_key_frame_influence_values = ",".join(dynamic_key_frame_influence_values) dynamic_cn_strength_values = ",".join(dynamic_cn_strength_values) + # dynamic_start_and_endpoint_values = ",".join(dynamic_start_and_endpoint_values) + # st.write(dynamic_start_and_endpoint_values) - def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values): + def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values, allow_extension=True): if len(keyframe_positions) < 2 or len(keyframe_positions) != len(key_frame_influence_values): return [] @@ -140,13 +152,16 @@ def calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_v start_influence = position - range_size end_influence = position + range_size - start_influence = max(start_influence, keyframe_positions[i - 1] if i > 0 else 0) - end_influence = min(end_influence, keyframe_positions[i + 1] if i < len(keyframe_positions) - 1 else keyframe_positions[-1]) + # If extension beyond the adjacent keyframe is allowed, do not constrain the start and end influence. + if not allow_extension: + start_influence = max(start_influence, keyframe_positions[i - 1] if i > 0 else 0) + end_influence = min(end_influence, keyframe_positions[i + 1] if i < len(keyframe_positions) - 1 else keyframe_positions[-1]) influence_ranges.append((round(start_influence), round(end_influence))) return influence_ranges + def get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, images, linear_frame_distribution_value): if type_of_frame_distribution == "dynamic": return sorted([int(kf.strip()) for kf in dynamic_frame_distribution_values.split(',')]) @@ -159,34 +174,48 @@ def extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influ else: return [linear_key_frame_influence_value for _ in keyframe_positions] - def calculate_weights_and_plot(influence_ranges, interpolation, strengths): - plt.figure(figsize=(12, 6)) + def extract_start_and_endpoint_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value): + if type_of_key_frame_influence == "dynamic": + # If dynamic_key_frame_influence_values is a list of characters representing tuples, process it + if isinstance(dynamic_key_frame_influence_values[0], str) and dynamic_key_frame_influence_values[0] == "(": + # Join the characters to form a single string and evaluate to convert into a list of tuples + string_representation = ''.join(dynamic_key_frame_influence_values) + dynamic_values = eval(f'[{string_representation}]') + else: + # If it's already a list of tuples or a single tuple, use it directly + dynamic_values = dynamic_key_frame_influence_values if isinstance(dynamic_key_frame_influence_values, list) else [dynamic_key_frame_influence_values] + return dynamic_values + else: + # Return a list of tuples with the linear_key_frame_influence_value as a tuple repeated for each position + return [linear_key_frame_influence_value for _ in keyframe_positions] + + def calculate_weights(influence_ranges, interpolation, start_and_endpoint_strength, last_key_frame_position): + weights_list = [] + frame_numbers_list = [] - frame_names = [f'Frame {i+1}' for i in range(len(influence_ranges))] for i, (range_start, range_end) in enumerate(influence_ranges): - strength = float(strengths[i]) # Get the corresponding strength value + # Initialize variables if i == 0: - - strength_from = 1.0 - strength_to = 0.0 - revert_direction_at_midpoint = False - elif i == len(influence_ranges) - 1: - strength_from = 0.0 - strength_to = 1.0 - revert_direction_at_midpoint = False + strength_to, strength_from = start_and_endpoint_strength[i] if i < len(start_and_endpoint_strength) else (0.0, 1.0) else: - strength_from = 0.0 - strength_to = 1.0 - revert_direction_at_midpoint = True + strength_from, strength_to = start_and_endpoint_strength[i] if i < len(start_and_endpoint_strength) else (1.0, 0.0) + revert_direction_at_midpoint = (i != 0) and (i != len(influence_ranges) - 1) + + # if it's the first value, set influence range from 1.0 to 0.0 + if i == 0: + range_start = 0 + + # if it's the last value, set influence range to end at last_key_frame_position + if i == len(influence_ranges) - 1: + range_end = last_key_frame_position steps = range_end - range_start diff = strength_to - strength_from - if revert_direction_at_midpoint: - index = np.linspace(0, 1, steps // 2 + 1) - else: - index = np.linspace(0, 1, steps) + # Calculate index for interpolation + index = np.linspace(0, 1, steps // 2 + 1) if revert_direction_at_midpoint else np.linspace(0, 1, steps) + # Calculate weights based on interpolation type if interpolation == "linear": weights = np.linspace(strength_from, strength_to, len(index)) elif interpolation == "ease-in": @@ -195,35 +224,58 @@ def calculate_weights_and_plot(influence_ranges, interpolation, strengths): weights = diff * (1 - np.power(1 - index, 2)) + strength_from elif interpolation == "ease-in-out": weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from - - weights = weights.astype(float) * strength + # If it's a middle keyframe, mirror the weights if revert_direction_at_midpoint: - if steps % 2 == 0: - weights = np.concatenate([weights, weights[-1::-1]]) - else: - weights = np.concatenate([weights, weights[-2::-1]]) + weights = np.concatenate([weights, weights[::-1]]) + # Generate frame numbers frame_numbers = np.arange(range_start, range_start + len(weights)) + + # "Dropper" component: For keyframes with negative start, drop the weights + if range_start < 0 and i > 0: + drop_count = abs(range_start) + weights = weights[drop_count:] + frame_numbers = frame_numbers[drop_count:] + + # Dropper component: for keyframes a range_End is greater than last_key_frame_position, drop the weights + if range_end > last_key_frame_position and i < len(influence_ranges) - 1: + drop_count = range_end - last_key_frame_position + weights = weights[:-drop_count] + frame_numbers = frame_numbers[:-drop_count] + + weights_list.append(weights) + frame_numbers_list.append(frame_numbers) + + return weights_list, frame_numbers_list + + + def plot_weights(weights_list, frame_numbers_list, frame_names): + plt.figure(figsize=(12, 6)) + + for i, weights in enumerate(weights_list): + frame_numbers = frame_numbers_list[i] plt.plot(frame_numbers, weights, label=f'{frame_names[i]}') + # Plot settings plt.xlabel('Frame Number') plt.ylabel('Weight') - plt.title('Key Framing Influence Over Frames') plt.legend() plt.ylim(0, 1.0) plt.show() - - + st.set_option('deprecation.showPyplotGlobalUse', False) + st.pyplot() + keyframe_positions = get_keyframe_positions(type_of_frame_distribution, dynamic_frame_distribution_values, timing_list, linear_frame_distribution_value) - cn_strength_values = extract_keyframe_values(type_of_cn_strength_distribution, dynamic_cn_strength_values, keyframe_positions, linear_cn_strength_value) + last_key_frame_position = keyframe_positions[-1] + cn_strength_values = extract_start_and_endpoint_values(type_of_cn_strength_distribution, dynamic_cn_strength_values, keyframe_positions, linear_cn_strength_value) key_frame_influence_values = extract_keyframe_values(type_of_key_frame_influence, dynamic_key_frame_influence_values, keyframe_positions, linear_key_frame_influence_value) - influence_ranges = calculate_dynamic_influence_ranges(keyframe_positions,key_frame_influence_values) - - calculate_weights_and_plot(influence_ranges, interpolation_style, cn_strength_values) - st.set_option('deprecation.showPyplotGlobalUse', False) - st.pyplot() - + # start_and_endpoint_values = extract_start_and_endpoint_values(type_of_start_and_endpoint, dynamic_start_and_endpoint_values, keyframe_positions, linear_start_and_endpoint_value) + influence_ranges = calculate_dynamic_influence_ranges(keyframe_positions, key_frame_influence_values) + weights_list, frame_numbers_list = calculate_weights(influence_ranges, interpolation_style, cn_strength_values, last_key_frame_position) + frame_names = [f'Frame {i+1}' for i in range(len(influence_ranges))] + plot_weights(weights_list, frame_numbers_list, frame_names) + st.markdown("***") e1, e2 = st.columns([1, 1]) @@ -238,28 +290,13 @@ def calculate_weights_and_plot(influence_ranges, interpolation, strengths): ] # remove .safe tensors from the end of each model name - sd_model = st_memory.selectbox("Which model would you like to use?", options=sd_model_list, key="sd_model") - negative_prompt = st_memory.text_area("What would you like to avoid in the videos?", value="bad image, worst quality", key="negative_prompt") - ip_adapter_weight = st_memory.slider("How tightly would you like the style to adhere to the input images?", min_value=0.0, max_value=1.0, value=0.66, step=0.1, key="ip_adapter_weight") - soft_scaled_cn_weights_multipler = st_memory.slider("How much would you like to scale the CN weights?", min_value=0.0, max_value=10.0, value=0.85, step=0.1, key="soft_scaled_cn_weights_multipler") + sd_model = st_memory.selectbox("Which model would you like to use?", options=sd_model_list, key="sd_model_video") + negative_prompt = st_memory.text_area("What would you like to avoid in the videos?", value="bad image, worst quality", key="negative_prompt_video") + ip_adapter_weight = st_memory.slider("How tightly would you like the style to adhere to the input images?", min_value=0.0, max_value=1.0, value=0.66, step=0.1, key="ip_adapter_weight_video") + soft_scaled_cn_weights_multipler = st_memory.slider("How much would you like to scale the CN weights?", min_value=0.0, max_value=10.0, value=0.85, step=0.1, key="soft_scaled_cn_weights_multiple_video") - normalise_speed = st.checkbox("Normalise Speed", value=True, key="normalise_speed") + normalise_speed = True - # st.write(f"type_of_frame_distribution: {type_of_frame_distribution}") - # st.write(f"dynamic_frame_distribution_values: {dynamic_frame_distribution_values}") - # st.write(f"linear_frame_distribution_value: {linear_frame_distribution_value}") - # st.write(f"type_of_key_frame_influence: {type_of_key_frame_influence}") - # st.write(f"linear_key_frame_influence_value: {linear_key_frame_influence_value}") - # st.write(f"dynamic_key_frame_influence_values: {dynamic_key_frame_influence_values}") - # st.write(f"type_of_cn_strength_distribution: {type_of_cn_strength_distribution}") - # st.write(f"dynamic_cn_strength_values: {dynamic_cn_strength_values}") - # st.write(f"linear_cn_strength_value: {linear_cn_strength_value}") - # st.write(f"buffer: {buffer}") - # st.write(f"context_length: {context_length}") - # st.write(f"context_stride: {context_stride}") - # st.write(f"context_overlap: {context_overlap}") - - # TODO: add type of cn strength distribution project_settings = data_repo.get_project_setting(shot.project.uuid) width = project_settings.width height = project_settings.height @@ -277,8 +314,8 @@ def calculate_weights_and_plot(influence_ranges, interpolation, strengths): ip_adapter_model_weight=ip_adapter_weight, soft_scaled_cn_multiplier=soft_scaled_cn_weights_multipler, type_of_cn_strength_distribution=type_of_cn_strength_distribution, - linear_cn_strength_value=linear_cn_strength_value, - dynamic_cn_strength_values=dynamic_cn_strength_values, + linear_cn_strength_value=str(linear_cn_strength_value), + dynamic_cn_strength_values=str(dynamic_cn_strength_values), type_of_frame_distribution=type_of_frame_distribution, linear_frames_per_keyframe=linear_frame_distribution_value, dynamic_frames_per_keyframe=dynamic_frame_distribution_values, @@ -299,7 +336,7 @@ def calculate_weights_and_plot(influence_ranges, interpolation, strengths): if st.button("Generate Animation Clip", key="generate_animation_clip", disabled=disable_generate, help=help): vid_quality = "full" if video_resolution == "Full Resolution" else "preview" - st.write("Generating animation clip...") + st.success("Generating clip - see status in the generation log on the left.") positive_prompt = "" for idx, timing in enumerate(timing_list): @@ -310,7 +347,7 @@ def calculate_weights_and_plot(influence_ranges, interpolation, strengths): positive_prompt += ":" + frame_prompt if positive_prompt else frame_prompt else: st.error("Please generate primary images") - time.sleep(0.5) + time.sleep(0.7) st.rerun() settings.update( @@ -338,18 +375,16 @@ def calculate_weights_and_plot(influence_ranges, interpolation, strengths): st.info(f"Generating a video with {number_of_frames} frames in the cloud will cost c. ${cost_per_generation:.2f} USD.") elif where_to_generate == "Local": - h1, _ = st.columns([1,1]) + h1,h2 = st.columns([1,1]) with h1: st.info("You can run this locally in ComfyUI but you'll need at least 16GB VRAM. To get started, you can follow the instructions [here]() and download the workflow and images below.") - - # NOTE: this is a streamlit limitation (double btn click to download) - if st.button("Generate zip", key="download_workflow_and_images"): - zip_data = zip_shot_data(shot_uuid, settings) - st.download_button( - label="Download zip", - data=zip_data, - file_name='data.zip' - ) + if st.button("Generate zip", key="download_workflow_and_images"): + zip_data = zip_shot_data(shot_uuid, settings) + st.download_button( + label="Download zip", + data=zip_data, + file_name='data.zip' + ) def zip_shot_data(shot_uuid, settings): @@ -370,14 +405,14 @@ def zip_shot_data(shot_uuid, settings): prompt = b['prompt'] if b else "" frame_prompt = f"{idx * settings['linear_frames_per_keyframe']}:" + prompt + ("," if idx != len(shot.timing_list) - 1 else "") positive_prompt += frame_prompt - + settings['image_prompt_list'] = positive_prompt with zipfile.ZipFile(buffer, 'w') as zip_file: for idx, image_location in enumerate(image_locations): if not image_location: continue - + image_name = f"{idx}.png" if image_location.startswith('http'): response = requests.get(image_location) @@ -425,7 +460,7 @@ def create_workflow_json(image_locations, settings): output_format = settings['output_format'] soft_scaled_cn_multiplier = settings['soft_scaled_cn_multiplier'] stmfnet_multiplier = settings['stmfnet_multiplier'] - + if settings['type_of_frame_distribution'] == 'linear': batch_size = (len(image_locations) - 1) * settings['linear_frames_per_keyframe'] + int(buffer) else: @@ -439,13 +474,13 @@ def create_workflow_json(image_locations, settings): node['widgets_values'][-2] = int(img_height) node['widgets_values'][0] = ckpt node['widgets_values'][-1] = batch_size - + elif node['id'] == '187': json_data["widgets_values"][-2] = motion_scale elif node['id'] == '347': json_data["widgets_values"][0] = image_prompt_list - + elif node['id'] == '352': json_data["widgets_values"] = [negative_prompt] @@ -468,8 +503,32 @@ def create_workflow_json(image_locations, settings): elif node['id'] == '301': json_data["widgets_values"] = [ip_adapter_model_weight] - + elif node['id'] == '281': json_data["widgets_values"][3] = output_format - return json_data \ No newline at end of file + return json_data + + +def update_interpolation_settings(values=None, timing_list=None): + default_values = { + 'type_of_frame_distribution': 0, + 'frames_per_keyframe': 16, + 'type_of_key_frame_influence': 0, + 'length_of_key_frame_influence': 1.0, + 'type_of_cn_strength_distribution': 0, + 'linear_cn_strength_value': (0.0,0.7), + 'interpolation_style': 0, + 'motion_scale': 1.0, + 'negative_prompt_video': 'bad image, worst quality', + 'ip_adapter_weight_video': 0.66, + 'soft_scaled_cn_weights_multiple_video': 0.85 + } + + for idx in range(0, len(timing_list)): + default_values[f'dynamic_frame_distribution_values_{idx}'] = (idx - 1) * 16 + default_values[f'dynamic_key_frame_influence_values_{idx}'] = 1.0 + default_values[f'dynamic_cn_strength_values_{idx}'] = (0.0,0.7) + + for key, default_value in default_values.items(): + st.session_state[key] = values.get(key, default_value) if values and values.get(key) is not None else default_value \ No newline at end of file diff --git a/ui_components/widgets/variant_comparison_grid.py b/ui_components/widgets/variant_comparison_grid.py index 3f99f17e..b8f31ab0 100644 --- a/ui_components/widgets/variant_comparison_grid.py +++ b/ui_components/widgets/variant_comparison_grid.py @@ -8,6 +8,7 @@ from ui_components.methods.video_methods import sync_audio_and_duration from ui_components.models import InternalFileObject from ui_components.widgets.add_key_frame_element import add_key_frame +from ui_components.widgets.animation_style_element import update_interpolation_settings from utils.data_repo.data_repo import DataRepo @@ -18,17 +19,18 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): ''' data_repo = DataRepo() - timing_uuid, shot_uuid, project_uuid = None, None, None + timing_uuid, shot_uuid = None, None if stage == CreativeProcessType.MOTION.value: shot_uuid = ele_uuid shot = data_repo.get_shot_from_uuid(shot_uuid) variants = shot.interpolated_clip_list - project_uuid = shot.project.uuid + timing_list = data_repo.get_timing_list_from_shot(shot.uuid) else: timing_uuid = ele_uuid timing = data_repo.get_timing_from_uuid(timing_uuid) variants = timing.alternative_images_list - project_uuid = timing.shot.project.uuid + shot_uuid = timing.shot.uuid + timing_list ="" st.markdown("***") @@ -53,7 +55,6 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): timing.primary_variant_index) st.markdown("***") - cols = st.columns(num_columns) with cols[0]: if stage == CreativeProcessType.MOTION.value: @@ -61,8 +62,8 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): else: st.image(variants[current_variant].location, use_column_width=True) with st.expander("Inference details"): - st.markdown(f"Details:") - inference_detail_element(variants[current_variant]) + variant_inference_detail_element(variants[current_variant], stage, shot_uuid, timing_list) + st.success("**Main variant**") start = (page - 1) * items_to_show @@ -75,35 +76,17 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): if stage == CreativeProcessType.MOTION.value: st.video(variants[variant_index].location, format='mp4', start_time=0) if variants[variant_index] else st.error("No video present") else: - st.image(variants[variant_index].location, use_column_width=True) if variants[variant_index] else st.error("No image present") - - with st.expander("Inference details"): - st.markdown(f"Details:") - inference_detail_element(variants[variant_index]) - if stage != CreativeProcessType.MOTION.value: - h1, h2 = st.columns([1, 1]) - with h1: - st.markdown(f"Add to shortlist:") - add_variant_to_shortlist_element(variants[variant_index], project_uuid) - with h2: - add_variant_to_shot_element(variants[variant_index], project_uuid) + st.image(variants[variant_index].location, use_column_width=True) if variants[variant_index] else st.error("No image present") + with st.expander("Inference details"): + variant_inference_detail_element(variants[variant_index], stage, shot_uuid, timing_list) if st.button(f"Promote Variant #{variant_index + 1}", key=f"Promote Variant #{variant_index + 1} for {st.session_state['current_frame_index']}", help="Promote this variant to the primary image", use_container_width=True): if stage == CreativeProcessType.MOTION.value: promote_video_variant(shot.uuid, variants[variant_index].uuid) else: - promote_image_variant(timing.uuid, variant_index) - + promote_image_variant(timing.uuid, variant_index) st.rerun() - if stage == CreativeProcessType.MOTION.value: - if st.button("Sync audio/duration", key=f"{variants[variant_index].uuid}", help="Updates video length and the attached audio", use_container_width=True): - _ = sync_audio_and_duration(variants[variant_index], shot_uuid) - _ = data_repo.get_shot_list(project_uuid, invalidate_cache=True) - st.success("Video synced") - time.sleep(0.3) - st.rerun() - next_col += 1 if next_col >= num_columns: @@ -111,7 +94,81 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): next_col = 0 # Reset column counter -def inference_detail_element(file: InternalFileObject): +def variant_inference_detail_element(variant, stage, shot_uuid, timing_list=""): + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(shot_uuid) + + st.markdown(f"Details:") + inf_data = fetch_inference_data(variant) + if 'image_prompt_list' in inf_data: + del inf_data['image_prompt_list'] + del inf_data['image_list'] + del inf_data['output_format'] + + st.write(inf_data) + + if stage != CreativeProcessType.MOTION.value: + h1, h2 = st.columns([1, 1]) + with h1: + st.markdown(f"Add to shortlist:") + add_variant_to_shortlist_element(variant, shot.project.uuid) + with h2: + add_variant_to_shot_element(variant, shot.project.uuid) + + if stage == CreativeProcessType.MOTION.value: + if st.button("Load up settings from this variant", key=f"{variant.name}", help="This will enter the settings from this variant into the inputs below - you can also use them on other shots", use_container_width=True): + new_data = prepare_values(fetch_inference_data(variant), timing_list) + update_interpolation_settings(values=new_data, timing_list=timing_list) + st.success("Settings loaded - scroll down to run them.") + st.rerun() + if st.button("Sync audio/duration", key=f"{variant.uuid}", help="Updates video length and the attached audio", use_container_width=True): + data_repo = DataRepo() + _ = sync_audio_and_duration(variant, shot_uuid) + _ = data_repo.get_shot_list(shot.project.uuid, invalidate_cache=True) + st.success("Video synced") + time.sleep(0.3) + st.rerun() + + +def prepare_values(inf_data, timing_list): + settings = inf_data # Map interpolation_type to indices + interpolation_style_map = { + 'ease-in-out': 0, + 'ease-in': 1, + 'ease-out': 2, + 'linear': 3 + } + + values = { + 'type_of_frame_distribution': 1 if settings.get('type_of_frame_distribution') == 'dynamic' else 0, + 'frames_per_keyframe': settings.get('linear_frames_per_keyframe', None), + 'type_of_key_frame_influence': 1 if settings.get('type_of_key_frame_influence') == 'dynamic' else 0, + 'length_of_key_frame_influence': float(settings.get('linear_key_frame_influence_value')) if settings.get('linear_key_frame_influence_value') else None, + 'type_of_cn_strength_distribution': 1 if settings.get('type_of_cn_strength_distribution') == 'dynamic' else 0, + 'linear_cn_strength_value': float(settings.get('linear_cn_strength_value')) if settings.get('linear_cn_strength_value') else None, + 'interpolation_style': interpolation_style_map[settings.get('interpolation_type')] if settings.get('interpolation_type', 'ease-in-out') in interpolation_style_map else None, + 'motion_scale': settings.get('motion_scale', None), + 'negative_prompt_video': settings.get('negative_prompt', None), + 'ip_adapter_weight_video': settings.get('ip_adapter_model_weight', None), + 'soft_scaled_cn_weights_multiple_video': settings.get('soft_scaled_cn_multiplier', None) + } + + # Add dynamic values + dynamic_frame_distribution_values = settings['dynamic_frames_per_keyframe'].split(',') if settings['dynamic_frames_per_keyframe'] else [] + dynamic_key_frame_influence_values = settings['dynamic_key_frame_influence_value'].split(',') if settings['dynamic_key_frame_influence_value'] else [] + dynamic_cn_strength_values = settings['dynamic_cn_strength_values'].split(',') if settings['dynamic_cn_strength_values'] else [] + + min_length = min(len(timing_list), len(dynamic_frame_distribution_values), len(dynamic_key_frame_influence_values), len(dynamic_cn_strength_values)) + + for idx in range(1, min_length + 1): + values[f'dynamic_frame_distribution_values_{idx}'] = int(dynamic_frame_distribution_values[idx - 1]) if dynamic_frame_distribution_values[idx - 1] and dynamic_frame_distribution_values[idx - 1].strip() else None + values[f'dynamic_key_frame_influence_values_{idx}'] = float(dynamic_key_frame_influence_values[idx - 1]) if dynamic_key_frame_influence_values[idx - 1] and dynamic_key_frame_influence_values[idx - 1].strip() else None + values[f'dynamic_cn_strength_values_{idx}'] = float(dynamic_cn_strength_values[idx - 1]) if dynamic_cn_strength_values[idx - 1] and dynamic_cn_strength_values[idx - 1].strip() else None + + return values + + +def fetch_inference_data(file: InternalFileObject): if not file: return @@ -125,8 +182,8 @@ def inference_detail_element(file: InternalFileObject): del inf_data[data_type] inf_data = inf_data or not_found_msg - st.write(inf_data) + return inf_data def add_variant_to_shortlist_element(file: InternalFileObject, project_uuid): data_repo = DataRepo() @@ -138,7 +195,6 @@ def add_variant_to_shortlist_element(file: InternalFileObject, project_uuid): time.sleep(0.3) st.rerun() - def add_variant_to_shot_element(file: InternalFileObject, project_uuid): data_repo = DataRepo() diff --git a/utils/common_utils.py b/utils/common_utils.py index d6c8fd1a..0923c638 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -131,6 +131,7 @@ def reset_project_state(): "seed", "promote_new_generation", "use_new_settings", + "shot_uuid" ] for k in keys_to_delete: diff --git a/utils/data_repo/api_repo.py b/utils/data_repo/api_repo.py index f20d0a88..6838fc20 100644 --- a/utils/data_repo/api_repo.py +++ b/utils/data_repo/api_repo.py @@ -401,8 +401,8 @@ def get_app_setting_from_uuid(self, uuid=None): res = self.http_get(self.APP_SETTING_URL, params={'uuid': uuid}) return InternalResponse(res['payload'], 'success', res['status']) - def get_app_secrets_from_user_uuid(self, uuid=None): - res = self.http_get(self.APP_SECRET_URL) + def get_app_secrets_from_user_uuid(self, uuid=None, secret_access=None): + res = self.http_post(self.APP_SECRET_URL, data={'secret_access': secret_access}) return InternalResponse(res['payload'], 'success', res['status']) # TODO: complete this code diff --git a/utils/data_repo/data_repo.py b/utils/data_repo/data_repo.py index 479fddaf..ff48c75b 100644 --- a/utils/data_repo/data_repo.py +++ b/utils/data_repo/data_repo.py @@ -1,7 +1,7 @@ # this repo serves as a middlerware between API backend and the frontend import json import time -from shared.constants import InferenceParamType, InternalFileType, InternalResponse +from shared.constants import SECRET_ACCESS_TOKEN, InferenceParamType, InternalFileType, InternalResponse from shared.constants import SERVER, ServerType from shared.logging.constants import LoggingType from shared.logging.logging import AppLogger @@ -327,7 +327,8 @@ def get_app_secrets_from_user_uuid(self, uuid=None): if not uuid: uuid = get_current_user_uuid() - app_secrets = self.db_repo.get_app_secrets_from_user_uuid(uuid).data['data'] + app_secrets = self.db_repo.get_app_secrets_from_user_uuid(uuid, \ + secret_access=SECRET_ACCESS_TOKEN).data['data'] return app_secrets def get_all_app_setting_list(self): @@ -392,8 +393,13 @@ def restore_backup(self, uuid): return res.status # update user credits - updates the credit of the user calling the API - def update_usage_credits(self, credits_to_add): - user = self.update_user(user_id=None, credits_to_add=credits_to_add) + def update_usage_credits(self, credits_to_add, log_uuid=None): + user_id = None + if log_uuid: + log = self.get_inference_log_from_uuid(log_uuid) + user_id = log.project.user_uuid + + user = self.update_user(user_id=user_id, credits_to_add=credits_to_add) return user def generate_payment_link(self, amount): diff --git a/utils/ml_processor/replicate/constants.py b/utils/ml_processor/replicate/constants.py index 6b02d09c..4e163c5e 100644 --- a/utils/ml_processor/replicate/constants.py +++ b/utils/ml_processor/replicate/constants.py @@ -45,7 +45,7 @@ class REPLICATE_MODEL: epicrealism_v5 = ReplicateModel("pagebrain/epicrealism-v5", "222465e57e4d9812207f14133c9499d47d706ecc41a8bf400120285b2f030b42") sdxl_controlnet = ReplicateModel("lucataco/sdxl-controlnet", "db2ffdbdc7f6cb4d6dab512434679ee3366ae7ab84f89750f8947d5594b79a47") realistic_vision_v5_img2img = ReplicateModel("lucataco/realistic-vision-v5-img2img", "82bbb4595458d6be142450fc6d8c4d79c936b92bd184dd2d6dd71d0796159819") - ad_interpolation = ReplicateModel("piyushk52/ad_infinite", "cb1ec688474f38da9c7c4f598166957587f2802b6f2b9448f6d938ed22892bea") + ad_interpolation = ReplicateModel("peter942/steerable-motion", "443f102a9d608ec715635ff28640249d2e32b779f6b0b71e20c40c13e6ff6866") # addition 17/10/2023 llama_2_7b = ReplicateModel("meta/llama-2-7b", "527827021d8756c7ab79fde0abbfaac885c37a3ed5fe23c7465093f0878d55ef") diff --git a/utils/ml_processor/replicate/replicate.py b/utils/ml_processor/replicate/replicate.py index b3bc7273..284b43b2 100644 --- a/utils/ml_processor/replicate/replicate.py +++ b/utils/ml_processor/replicate/replicate.py @@ -125,7 +125,7 @@ def queue_prediction(self, replicate_model: ReplicateModel, **kwargs): # converting io buffers to base64 format for k, v in data['input'].items(): - if not isinstance(v, (int, str, list, dict, float)): + if not isinstance(v, (int, str, list, dict, float, tuple)): data['input'][k] = convert_file_to_base64(v) response = r.post(url, headers=headers, json=data)