From b00f2e9815fab2d245526def67d49941bd335f0a Mon Sep 17 00:00:00 2001 From: peter942 Date: Sat, 23 Dec 2023 02:55:58 +0000 Subject: [PATCH 01/11] Refactor and reorganise --- ui_components/components/adjust_shot_page.py | 39 ++++ ui_components/components/animate_shot_page.py | 14 ++ .../explorer_page.py} | 74 +++--- .../components/frame_styling_page.py | 216 +++--------------- ui_components/components/shortlist_page.py | 21 ++ .../components/timeline_view_page.py | 25 ++ ui_components/setup.py | 117 ++++++---- .../widgets/frame_movement_widgets.py | 7 +- ui_components/widgets/frame_selector.py | 66 +++--- ui_components/widgets/shot_view.py | 83 +++---- ui_components/widgets/timeline_view.py | 14 +- .../widgets/variant_comparison_grid.py | 4 +- utils/common_utils.py | 40 ++++ utils/st_memory.py | 17 +- 14 files changed, 375 insertions(+), 362 deletions(-) create mode 100644 ui_components/components/adjust_shot_page.py create mode 100644 ui_components/components/animate_shot_page.py rename ui_components/{widgets/explorer_element.py => components/explorer_page.py} (92%) create mode 100644 ui_components/components/shortlist_page.py create mode 100644 ui_components/components/timeline_view_page.py diff --git a/ui_components/components/adjust_shot_page.py b/ui_components/components/adjust_shot_page.py new file mode 100644 index 00000000..c8db8f5f --- /dev/null +++ b/ui_components/components/adjust_shot_page.py @@ -0,0 +1,39 @@ +import streamlit as st +from ui_components.widgets.shot_view import shot_keyframe_element +from ui_components.components.explorer_page import gallery_image_view +from ui_components.components.explorer_page import generate_images_element +from ui_components.widgets.frame_selector import frame_selector_widget +from utils import st_memory + + + +def adjust_shot_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): + with h2: + frame_selector_widget(show=['shot_selector']) + + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") + + st.markdown("***") + + shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") + # with st.expander("📋 Explorer Shortlist",expanded=True): + shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"], + icons=['grid-3x3','airplane'], + menu_icon="cast", + default_index=st.session_state.get('shot_explorer_view', 0), + key="shot_explorer_view", orientation="horizontal", + styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#868c91"}}) + + st.markdown("***") + + if shot_explorer_view == "Shortlist": + project_setting = data_repo.get_project_setting(shot.project.uuid) + page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) + st.markdown("***") + gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=True, num_columns=4,view="individual_shot", shot=shot) + elif shot_explorer_view == "Explore": + project_setting = data_repo.get_project_setting(shot.project.uuid) + page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) + generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + st.markdown("***") + gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=False, num_columns=4,view="individual_shot", shot=shot) \ No newline at end of file diff --git a/ui_components/components/animate_shot_page.py b/ui_components/components/animate_shot_page.py new file mode 100644 index 00000000..e0033480 --- /dev/null +++ b/ui_components/components/animate_shot_page.py @@ -0,0 +1,14 @@ +import streamlit as st +from ui_components.widgets.frame_selector import frame_selector_widget +from ui_components.widgets.variant_comparison_grid import variant_comparison_grid +from ui_components.widgets.animation_style_element import animation_style_element + +def animate_shot_page(shot_uuid: str,h2,data_repo,shot,timing_list, project_settings): + + with h2: + frame_selector_widget(show=['shot_selector']) + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") + st.markdown("***") + variant_comparison_grid(st.session_state['shot_uuid'], stage="Shots") + with st.expander("🎬 Choose Animation Style & Create Variants", expanded=True): + animation_style_element(st.session_state['shot_uuid']) \ No newline at end of file diff --git a/ui_components/widgets/explorer_element.py b/ui_components/components/explorer_page.py similarity index 92% rename from ui_components/widgets/explorer_element.py rename to ui_components/components/explorer_page.py index c7ebccc6..f64cbb58 100644 --- a/ui_components/widgets/explorer_element.py +++ b/ui_components/components/explorer_page.py @@ -24,53 +24,45 @@ class InputImageStyling(ExtendedEnum): MAINTAIN_STRUCTURE = "Maintain Structure" +def columnn_selecter(): + f1, f2 = st.columns([1, 1]) + with f1: + st_memory.slider('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") + with f2: + st_memory.slider('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") -def explorer_element(project_uuid): - - st.markdown("***") - +def explorer_page(project_uuid): + data_repo = DataRepo() + project_setting = data_repo.get_project_setting(project_uuid) - project_setting = data_repo.get_project_setting(project_uuid) - - - f1, f2 = st.columns([1, 1]) - with f1: - num_columns = st_memory.slider('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") - with f2: - num_items_per_page = st_memory.slider('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") + st.markdown("***") + z1, z2, z3 = st.columns([0.25,2,0.25]) + with z2: + with st.expander("Prompt Settings", expanded=True): + generate_images_element(position='explorer', project_uuid=project_uuid, timing_uuid=None) + st.markdown("***") + columnn_selecter() + k1,k2 = st.columns([5,1]) + page_number = k1.radio("Select page:", options=range(1, project_setting.total_gallery_pages + 1), horizontal=True, key="main_gallery") + open_detailed_view_for_all = k2.toggle("Open detailed view for all:", key='main_gallery_toggle') st.markdown("***") + gallery_image_view(project_uuid, page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, False, st.session_state['num_columns_explorer'],view="explorer") - with st.sidebar: - - st.session_state['explorer_view'] = st_memory.menu( - '', - ["Explorations", "Shortlist"], - icons=['airplane', 'grid-3x3', "paint-bucket", 'pencil'], - menu_icon="cast", - default_index=0, - key="explorer_view_selector", - orientation="horizontal", - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#0068c9"}}, - ) - # tab1, tab2 = st.tabs(["Explorations", "Shortlist"]) - if st.session_state['explorer_view'] == "Explorations": - z1, z2, z3 = st.columns([0.25,2,0.25]) - with z2: - with st.expander("Prompt Settings", expanded=True): - generate_images_element(position='explorer', project_uuid=project_uuid, timing_uuid=None) - - k1,k2 = st.columns([5,1]) - page_number = k1.radio("Select page:", options=range(1, project_setting.total_gallery_pages + 1), horizontal=True, key="main_gallery") - open_detailed_view_for_all = k2.toggle("Open detailed view for all:", key='main_gallery_toggle') - gallery_image_view(project_uuid, page_number, num_items_per_page, open_detailed_view_for_all, False, num_columns,view="explorer") - elif st.session_state['explorer_view'] == "Shortlist": - k1,k2 = st.columns([5,1]) - shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") - with k2: - open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle') - gallery_image_view(project_uuid, shortlist_page_number, num_items_per_page, open_detailed_view_for_all, True, num_columns,view="shortlist") + +def shortlist_element(project_uuid): + data_repo = DataRepo() + project_setting = data_repo.get_project_setting(project_uuid) + columnn_selecter() + k1,k2 = st.columns([5,1]) + shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + with k2: + open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle') + st.markdown("***") + gallery_image_view(project_uuid, shortlist_page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, True, st.session_state['num_columns_explorer'],view="shortlist") + def generate_images_element(position='explorer', project_uuid=None, timing_uuid=None): diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index 3857290c..87cc2421 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -6,13 +6,13 @@ from ui_components.widgets.frame_selector import frame_selector_widget from ui_components.widgets.add_key_frame_element import add_key_frame, add_key_frame_element from ui_components.widgets.timeline_view import timeline_view -from ui_components.widgets.explorer_element import generate_images_element +from ui_components.components.explorer_page import generate_images_element from ui_components.widgets.animation_style_element import animation_style_element from ui_components.widgets.video_cropping_element import video_cropping_element from ui_components.widgets.inpainting_element import inpainting_element from ui_components.widgets.drawing_element import drawing_element from ui_components.widgets.sidebar_logger import sidebar_logger -from ui_components.widgets.explorer_element import explorer_element,gallery_image_view +# from ui_components.components.explorer_page import explorer_element,gallery_image_view from ui_components.widgets.variant_comparison_grid import variant_comparison_grid from ui_components.widgets.shot_view import shot_keyframe_element from utils import st_memory @@ -23,190 +23,40 @@ from utils.data_repo.data_repo import DataRepo -def frame_styling_page(shot_uuid: str): +def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): - data_repo = DataRepo() - shot = data_repo.get_shot_from_uuid(shot_uuid) - - timing_list = data_repo.get_timing_list_from_shot(shot_uuid) - - project_settings = data_repo.get_project_setting(shot.project.uuid) - - if "strength" not in st.session_state: - st.session_state['strength'] = DefaultProjectSettingParams.batch_strength - st.session_state['prompt_value'] = DefaultProjectSettingParams.batch_prompt - st.session_state['model'] = None - st.session_state['negative_prompt_value'] = DefaultProjectSettingParams.batch_negative_prompt - st.session_state['guidance_scale'] = DefaultProjectSettingParams.batch_guidance_scale - st.session_state['seed'] = DefaultProjectSettingParams.batch_seed - st.session_state['num_inference_steps'] = DefaultProjectSettingParams.batch_num_inference_steps - st.session_state['transformation_stage'] = DefaultProjectSettingParams.batch_transformation_stage - - if "current_frame_uuid" not in st.session_state and len(timing_list) > 0: - timing = data_repo.get_timing_list_from_shot(shot_uuid)[0] - st.session_state['current_frame_uuid'] = timing.uuid - st.session_state['current_frame_index'] = timing.aux_frame_index + 1 - if 'frame_styling_view_type' not in st.session_state: - st.session_state['frame_styling_view_type'] = "Individual" - st.session_state['frame_styling_view_type_index'] = 0 - - if st.session_state['change_view_type'] == True: - st.session_state['change_view_type'] = False - - if "explorer_view" not in st.session_state: - st.session_state['explorer_view'] = "Explorations" - st.session_state['explorer_view_index'] = 0 - - if "shot_view" not in st.session_state: - st.session_state['shot_view'] = "Animate Frames" - st.session_state['shot_view_index'] = 0 - - if "styling_view" not in st.session_state: - st.session_state['styling_view'] = "Generate" - st.session_state['styling_view_index'] = 0 - - if st.session_state['frame_styling_view_type'] == "Explorer": - st.markdown( - f"#### :red[{st.session_state['main_view_type']}] > **:green[{st.session_state['frame_styling_view_type']}]** > :orange[Explorer] > :blue[{st.session_state['explorer_view']}]") - elif st.session_state['frame_styling_view_type'] == "Timeline": - st.markdown( - f"#### :red[{st.session_state['main_view_type']}] > **:green[{st.session_state['frame_styling_view_type']}]** > :orange[{st.session_state['page']}]") - else: - if st.session_state['page'] == "Key Frames": - st.markdown( - f"#### :red[{st.session_state['main_view_type']}] > **:green[{st.session_state['frame_styling_view_type']}]** > :orange[{shot.name}] > :blue[{st.session_state['styling_view']}] > {shot.name} > #{st.session_state['current_frame_index']}") - else: - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > **:green[{st.session_state['frame_styling_view_type']}]** > :orange[{st.session_state['page']}] > :blue[{st.session_state['shot_view']}] > {shot.name}") - - project_settings = data_repo.get_project_setting(shot.project.uuid) - - if st.session_state['frame_styling_view_type'] == "Explorer": - - explorer_element(shot.project.uuid) - - # -------------------- INDIVIDUAL VIEW ---------------------- - elif st.session_state['frame_styling_view_type'] == "Individual": - - if st.session_state['page'] == CreativeProcessType.MOTION.value: - - with st.sidebar: - - if 'shot_view_manual_select' not in st.session_state: - st.session_state['shot_view_manual_select'] = None + with st.sidebar: + with h2: - if 'shot_view_index' not in st.session_state: - st.session_state['shot_view_index'] = 0 + frame_selector_widget(show=['shot_selector','frame_selector']) + + st.session_state['styling_view'] = st_memory.menu('',\ + ["Generate", "Crop/Move", "Inpainting","Scribbling"], \ + icons=['magic', 'crop', "paint-bucket", 'pencil'], \ + menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ + key="styling_view_selector", orientation="horizontal", \ + styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) - shot_views = ["Animate Frames", "Adjust Frames"] - # with st.sidebar: - st.session_state['shot_view'] = option_menu('', - shot_views, - icons=['film', 'crop', "paint-bucket", 'pencil'], - menu_icon="cast", default_index=st.session_state['shot_view_index'], - key="animation_view_selector", orientation="horizontal", - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#0068c9"}}, - manual_select=st.session_state['shot_view_manual_select']) - - if st.session_state['shot_view_manual_select'] != None: - st.session_state['shot_view_manual_select'] = None - - if shot_views.index(st.session_state['shot_view']) != st.session_state['shot_view_index']: - st.session_state['shot_view_index'] = shot_views.index(st.session_state['shot_view']) - st.rerun() - - - - - if st.session_state['shot_view'] == "Animate Frames": - variant_comparison_grid(shot_uuid, stage=CreativeProcessType.MOTION.value) - with st.expander("🎬 Choose Animation Style & Create Variants", expanded=True): - animation_style_element(shot_uuid) - - elif st.session_state['shot_view'] == "Adjust Frames": - st.markdown("***") - shot_keyframe_element(shot_uuid, 4, position="Individual") - # with st.expander("📋 Explorer Shortlist",expanded=True): - shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"], - icons=['grid-3x3','airplane'], - menu_icon="cast", - default_index=st.session_state.get('shot_explorer_view', 0), - key="shot_explorer_view", orientation="horizontal", - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#868c91"}}) - st.markdown("***") - if shot_explorer_view == "Shortlist": - project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=True, num_columns=4,view="individual_shot", shot=shot) - elif shot_explorer_view == "Explore": - project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) - st.markdown("***") - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=False, num_columns=4,view="individual_shot", shot=shot) - #with st.expander("🤏 Crop, Move & Rotate Image", expanded=True): - # video_cropping_element(shot_uuid) - - elif st.session_state['page'] == CreativeProcessType.STYLING.value: - - - - - - - with st.sidebar: - - st.session_state['styling_view'] = st_memory.menu('',\ - ["Generate", "Crop/Move", "Inpainting","Scribbling"], \ - icons=['magic', 'crop', "paint-bucket", 'pencil'], \ - menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ - key="styling_view_selector", orientation="horizontal", \ - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#0068c9"}}) - - - if st.session_state['styling_view'] == "Generate": - variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) - with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True): - generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) - - elif st.session_state['styling_view'] == "Crop/Move": - with st.expander("🤏 Crop, Move & Rotate", expanded=True): - cropping_selector_element(shot_uuid) - - elif st.session_state['styling_view'] == "Inpainting": - with st.expander("🌌 Inpainting", expanded=True): - inpainting_element(st.session_state['current_frame_uuid']) - - elif st.session_state['styling_view'] == "Scribbling": - with st.expander("📝 Draw On Image", expanded=True): - drawing_element(timing_list,project_settings, shot_uuid) - with st.sidebar: - frame_selector_widget() - - - # -------------------- TIMELINE VIEW -------------------------- - elif st.session_state['frame_styling_view_type'] == "Timeline": - if st.session_state['page'] == "Key Frames": - - with st.sidebar: - with st.expander("📋 Explorer Shortlist",expanded=True): - - if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): - project_setting = data_repo.get_project_setting(shot.project.uuid) - page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) - gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=10, open_detailed_view_for_all=False, shortlist=True, num_columns=2,view="sidebar") - - timeline_view(shot_uuid, "Key Frames") - elif st.session_state['page'] == "Shots": - timeline_view(shot_uuid, "Shots") - - # -------------------- SIDEBAR NAVIGATION -------------------------- - with st.sidebar: - - - with st.expander("🔍 Generation Log", expanded=True): - if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): - sidebar_logger(shot_uuid) - st.markdown("***") + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") + + if st.session_state['styling_view'] == "Generate": + variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) + with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True): + generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + + elif st.session_state['styling_view'] == "Crop/Move": + with st.expander("🤏 Crop, Move & Rotate", expanded=True): + cropping_selector_element(shot_uuid) + + elif st.session_state['styling_view'] == "Inpainting": + with st.expander("🌌 Inpainting", expanded=True): + inpainting_element(st.session_state['current_frame_uuid']) + + elif st.session_state['styling_view'] == "Scribbling": + with st.expander("📝 Draw On Image", expanded=True): + drawing_element(timing_list,project_settings, shot_uuid) + + \ No newline at end of file diff --git a/ui_components/components/shortlist_page.py b/ui_components/components/shortlist_page.py new file mode 100644 index 00000000..d7b35bf0 --- /dev/null +++ b/ui_components/components/shortlist_page.py @@ -0,0 +1,21 @@ +import streamlit as st +from ui_components.components.explorer_page import columnn_selecter,gallery_image_view +from utils.data_repo.data_repo import DataRepo +from utils import st_memory + + +def shortlist_page(project_uuid): + + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") + st.markdown("***") + + data_repo = DataRepo() + project_setting = data_repo.get_project_setting(project_uuid) + columnn_selecter() + k1,k2 = st.columns([5,1]) + shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") + with k2: + open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False) + st.markdown("***") + gallery_image_view(project_uuid, shortlist_page_number, st.session_state['num_items_per_page_explorer'], open_detailed_view_for_all, True, st.session_state['num_columns_explorer'],view="shortlist") + \ No newline at end of file diff --git a/ui_components/components/timeline_view_page.py b/ui_components/components/timeline_view_page.py new file mode 100644 index 00000000..edbe95cf --- /dev/null +++ b/ui_components/components/timeline_view_page.py @@ -0,0 +1,25 @@ + +import streamlit as st +from ui_components.constants import CreativeProcessType +from ui_components.widgets.timeline_view import timeline_view +from streamlit_option_menu import option_menu + +def timeline_view_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): + with st.sidebar: + + views = CreativeProcessType.value_list() + + if "view" not in st.session_state: + st.session_state["view"] = views[0] + st.session_state["manual_select"] = None + with h2: + st.session_state['view'] = option_menu(None, views, icons=['palette', 'camera-reels', "hourglass", 'stopwatch'], menu_icon="cast", orientation="vertical", key="secti2on_selector", styles={ + "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}, manual_select=st.session_state["manual_select"]) + + if st.session_state["manual_select"] != None: + st.session_state["manual_select"] = None + + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{st.session_state['view']}]") + + st.markdown("***") + timeline_view(st.session_state["shot_uuid"], st.session_state['view']) \ No newline at end of file diff --git a/ui_components/setup.py b/ui_components/setup.py index bdc33e62..7a98a100 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -2,20 +2,30 @@ import os from moviepy.editor import * from shared.constants import SERVER, ServerType - +# from ui_components.components.explorer_page import explorer_element,shortlist_element +from ui_components.widgets.timeline_view import timeline_view +from ui_components.widgets.sidebar_logger import sidebar_logger from ui_components.components.app_settings_page import app_settings_page from ui_components.components.custom_models_page import custom_models_page from ui_components.components.frame_styling_page import frame_styling_page +from ui_components.components.shortlist_page import shortlist_page +from ui_components.components.timeline_view_page import timeline_view_page +from ui_components.components.adjust_shot_page import adjust_shot_page +from ui_components.components.animate_shot_page import animate_shot_page +from ui_components.components.explorer_page import explorer_page + from ui_components.components.new_project_page import new_project_page from ui_components.components.project_settings_page import project_settings_page from ui_components.components.video_rendering_page import video_rendering_page from streamlit_option_menu import option_menu -from ui_components.constants import CreativeProcessType +from utils.common_utils import set_default_values + from ui_components.methods.common_methods import check_project_meta_data, update_app_setting_keys from ui_components.models import InternalAppSettingObject from utils.common_utils import create_working_assets, get_current_user, get_current_user_uuid, reset_project_state from utils import st_memory + from utils.data_repo.data_repo import DataRepo @@ -126,63 +136,70 @@ def setup_app_ui(): "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "red"}}) if st.session_state["main_view_type"] == "Creative Process": - with st.sidebar: - view_types = ["Explorer","Timeline","Individual"] - - if 'frame_styling_view_type_manual_select' not in st.session_state: - st.session_state['frame_styling_view_type_manual_select'] = 0 - st.session_state['frame_styling_view_type'] = "Explorer" - st.session_state['change_view_type'] = False - if 'change_view_type' not in st.session_state: - st.session_state['change_view_type'] = False + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) + timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) + project_settings = data_repo.get_project_setting(shot.project.uuid) + set_default_values(timing_list,shot.uuid, data_repo) - if st.session_state['change_view_type'] == True: - st.session_state['frame_styling_view_type_index'] = view_types.index( - st.session_state['frame_styling_view_type']) - else: - st.session_state['frame_styling_view_type_index'] = None + with st.sidebar: + creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Adjust Frame", "Animate Shot"] - st.session_state['frame_styling_view_type'] = option_menu( - None, - view_types, - icons=['compass', 'bookshelf','aspect-ratio', "hourglass", 'stopwatch'], - menu_icon="cast", - orientation="horizontal", - key="section-selecto1r", - styles={"nav-link": {"font-size": "15px", "margin":"0px", "--hover-color": "#eee"}, - "nav-link-selected": {"background-color": "green"}}, - manual_select=st.session_state['frame_styling_view_type_manual_select'] - ) - - if st.session_state['frame_styling_view_type_manual_select'] != None: - st.session_state['frame_styling_view_type_manual_select'] = None - - if 'page' not in st.session_state: - st.session_state["page"] = CreativeProcessType.value_list()[0] - st.session_state["manual_select"] = None + if 'creative_process_manual_select' not in st.session_state: + st.session_state['creative_process_manual_select'] = 0 + st.session_state['page'] = creative_process_pages[0] + + + h1,h2 = st.columns([1.5,1]) + with h1: + # view_types = ["Explorer","Timeline","Individual"] + creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Adjust Frame", "Animate Shot"] + st.session_state['page'] = option_menu( + None, + creative_process_pages, + icons=['compass', 'bookshelf','aspect-ratio', "hourglass", 'stopwatch'], + menu_icon="cast", + orientation="vertical", + key="section-selecto1r", + styles={"nav-link": {"font-size": "15px", "margin":"0px", "--hover-color": "#eee"}, + "nav-link-selected": {"background-color": "green"}}, + manual_select=st.session_state['creative_process_manual_select'] + ) + + if st.session_state['creative_process_manual_select'] != None: + st.session_state['creative_process_manual_select'] = None + + + if st.session_state['page'] == "Explore": + explorer_page(st.session_state["project_uuid"]) + + elif st.session_state['page'] == "Shortlist": + shortlist_page(st.session_state["project_uuid"]) + + elif st.session_state['page'] == "Timeline": + timeline_view_page(st.session_state["shot_uuid"],h2,data_repo,shot,timing_list, project_settings) + + elif st.session_state['page'] == "Adjust Frame": + frame_styling_page(st.session_state["shot_uuid"],h2,data_repo,shot,timing_list, project_settings) + + elif st.session_state['page'] == "Adjust Shot": + adjust_shot_page(st.session_state["shot_uuid"], h2,data_repo,shot,timing_list, project_settings) + + elif st.session_state['page'] == "Animate Shot": + animate_shot_page(st.session_state["shot_uuid"],h2,data_repo,shot,timing_list, project_settings) - if st.session_state['frame_styling_view_type'] != "Explorer": - pages = CreativeProcessType.value_list() - else: - pages = ["Key Frames"] + with st.sidebar: - if st.session_state['page'] != "Key Frames": - st.session_state["manual_select"] = 0 - st.session_state['page'] = "Key Frames" + with st.expander("🔍 Generation Log", expanded=True): + if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): + sidebar_logger(st.session_state["shot_uuid"]) + st.markdown("***") - if st.session_state["page"] not in pages: - st.session_state["page"] = pages[0] - st.session_state["manual_select"] = None - st.session_state['page'] = option_menu(None, pages, icons=['palette', 'camera-reels', "hourglass", 'stopwatch'], menu_icon="cast", orientation="horizontal", key="secti2on_selector", styles={ - "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}, manual_select=st.session_state["manual_select"]) - - if st.session_state["manual_select"] != None: - st.session_state["manual_select"] = None - frame_styling_page(st.session_state["shot_uuid"]) + # frame_styling_page(st.session_state["shot_uuid"]) elif st.session_state["main_view_type"] == "Tools & Settings": with st.sidebar: diff --git a/ui_components/widgets/frame_movement_widgets.py b/ui_components/widgets/frame_movement_widgets.py index 65ee150d..6c6b5ecb 100644 --- a/ui_components/widgets/frame_movement_widgets.py +++ b/ui_components/widgets/frame_movement_widgets.py @@ -168,12 +168,11 @@ def jump_to_single_frame_view_button(display_number, timing_list, src,uuid=None) if st.button(f"Jump to #{display_number}", key=f"{src}_{uuid}", use_container_width=True): st.session_state['prev_frame_index'] = st.session_state['current_frame_index'] = display_number - st.session_state['current_frame_uuid'] = timing_list[st.session_state['current_frame_index'] - 1].uuid - st.session_state['frame_styling_view_type'] = "Individual" - st.session_state['change_view_type'] = True + st.session_state['current_frame_uuid'] = timing_list[st.session_state['current_frame_index'] - 1].uuid st.session_state['frame_styling_view_type_manual_select'] = 2 st.session_state['shot_uuid'] = timing_list[st.session_state['current_frame_index'] - 1].shot.uuid st.session_state['prev_shot_index'] = st.session_state['current_shot_index'] = timing_list[st.session_state['current_frame_index'] - 1].shot.shot_idx - st.session_state["manual_select"] = 0 + st.session_state["creative_process_manual_select"] = 4 + st.session_state["styling_view_selector_manual_select"] = 0 st.session_state['page'] = "Key Frames" st.rerun() diff --git a/ui_components/widgets/frame_selector.py b/ui_components/widgets/frame_selector.py index f2a6e7ee..d05d67c2 100644 --- a/ui_components/widgets/frame_selector.py +++ b/ui_components/widgets/frame_selector.py @@ -10,21 +10,18 @@ -def frame_selector_widget(): +def frame_selector_widget(show: List[str]): data_repo = DataRepo() - time1, time2 = st.columns([1,1]) - st.markdown("***") timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) shot_list = data_repo.get_shot_list(shot.project.uuid) len_timing_list = len(timing_list) if len(timing_list) > 0 else 1.0 + if 'prev_shot_index' not in st.session_state: + st.session_state['prev_shot_index'] = shot.shot_idx - with time1: - if 'prev_shot_index' not in st.session_state: - st.session_state['prev_shot_index'] = shot.shot_idx - + if 'shot_selector' in show: shot_names = [s.name for s in shot_list] shot_name = st.selectbox('Shot name:', shot_names, key="current_shot_sidebar_selector",index=shot_names.index(shot.name)) # find shot index based on shot name @@ -33,36 +30,45 @@ def frame_selector_widget(): if shot_name != shot.name: st.session_state["shot_uuid"] = shot_list[shot_names.index(shot_name)].uuid st.rerun() - + if not ('current_shot_index' in st.session_state and st.session_state['current_shot_index']): st.session_state['current_shot_index'] = shot_names.index(shot_name) + 1 update_current_shot_index(st.session_state['current_shot_index']) + # st.write if frame_selector is present + + if 'frame_selector' in show: - + if st.session_state['page'] == "Key Frames": + if st.session_state['current_frame_index'] > len_timing_list: + update_current_frame_index(len_timing_list) - if st.session_state['page'] == "Key Frames": - if st.session_state['current_frame_index'] > len_timing_list: - update_current_frame_index(len_timing_list) - # st.progress(st.session_state['current_frame_index'] / len_timing_list) - elif st.session_state['page'] == "Shots": - if st.session_state['current_shot_index'] > len(shot_list): - update_current_shot_index(len(shot_list)) - # st.progress(st.session_state['current_shot_index'] / len(shot_list)) - if st.session_state['page'] == "Key Frames": + elif st.session_state['page'] == "Shots": + if st.session_state['current_shot_index'] > len(shot_list): + update_current_shot_index(len(shot_list)) - if len(timing_list): - with time2: - if 'prev_frame_index' not in st.session_state: - st.session_state['prev_frame_index'] = 1 + + if len(timing_list): + if 'prev_frame_index' not in st.session_state or st.session_state['prev_frame_index'] > len(timing_list): - st.session_state['current_frame_index'] = st.number_input(f"Key frame # (out of {len(timing_list)})", 1, - len(timing_list), value=st.session_state['prev_frame_index'], - step=1, key="current_frame_sidebar_selector") - - update_current_frame_index(st.session_state['current_frame_index']) + st.session_state['prev_frame_index'] = 1 + + st.session_state['current_frame_index'] = st.number_input(f"Key frame # (out of {len(timing_list)})", 1, + len(timing_list), value=st.session_state['prev_frame_index'], + step=1, key="current_frame_sidebar_selector") + + update_current_frame_index(st.session_state['current_frame_index']) else: - with time2: - st.error("No frames present") + st.error("No frames present") + +def frame_view(): + data_repo = DataRepo() + # time1, time2 = st.columns([1,1]) + st.markdown("***") + + timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) + shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) + if st.session_state['page'] == "Key Frames": + with st.expander(f"🖼️ Frame #{st.session_state['current_frame_index']} Details", expanded=True): if st_memory.toggle("Open", value=True, key="frame_toggle"): a1, a2 = st.columns([3,2]) @@ -124,7 +130,7 @@ def update_current_frame_index(index): st.session_state['current_frame_uuid'] = timing_list[index - 1].uuid st.session_state['reset_canvas'] = True st.session_state['frame_styling_view_type_index'] = 0 - st.session_state['frame_styling_view_type'] = "Individual View" + st.session_state['frame_styling_view_type'] = "Generate View" st.rerun() diff --git a/ui_components/widgets/shot_view.py b/ui_components/widgets/shot_view.py index 8824a66a..5e0eb472 100644 --- a/ui_components/widgets/shot_view.py +++ b/ui_components/widgets/shot_view.py @@ -24,46 +24,42 @@ def shot_keyframe_element(shot_uuid, items_per_row, position="Timeline", **kwarg if "open_shot" not in st.session_state: st.session_state["open_shot"] = None - - # st.markdown(f"### {shot.name}", expanded=True) + timing_list: List[InternalFrameTimingObject] = shot.timing_list - - + if position == "Timeline": - header_col_0, header_col_1, header_col_2, header_col_3, header_col_4= st.columns([1.75,1,2,0.25,0.25]) - - - + header_col_0, header_col_1, header_col_2, header_col_3 = st.columns([2,1,1.5,1.5]) + with header_col_0: - update_shot_name(shot.uuid) - footer_col_1, footer_col_2, _ = st.columns([0.35,0.35,1]) - with footer_col_1: - shot_adjustment_button(shot) - - with footer_col_2: - shot_animation_button(shot) - - - + update_shot_name(shot.uuid) + with header_col_1: update_shot_duration(shot.uuid) + with header_col_2: + st.write("") + shot_adjustment_button(shot, show_label=True) + with header_col_3: + st.write("") + shot_animation_button(shot, show_label=True) + else: - header_col_1,_ = st.columns([3,4]) - with header_col_1: - col2, col3, col4 = st.columns(3) - - with col2: - delete_frames_toggle = st_memory.toggle("Delete Frames", value=True, key="delete_frames_toggle") - copy_frame_toggle = st_memory.toggle("Copy Frame", value=True, key="copy_frame_toggle") - with col3: - move_frames_toggle = st_memory.toggle("Move Frames", value=True, key="move_frames_toggle") - replace_image_widget_toggle = st_memory.toggle("Replace Image", value=False, key="replace_image_widget_toggle") - - with col4: - change_shot_toggle = st_memory.toggle("Change Shot", value=False, key="change_shot_toggle") + + col1, col2, col3, col4, col5, _ = st.columns([1,1,1,1,1,3]) + + with col1: + delete_frames_toggle = st_memory.toggle("Delete Frames", value=True, key="delete_frames_toggle") + with col2: + copy_frame_toggle = st_memory.toggle("Copy Frame", value=True, key="copy_frame_toggle") + with col3: + move_frames_toggle = st_memory.toggle("Move Frames", value=True, key="move_frames_toggle") + with col4: + replace_image_widget_toggle = st_memory.toggle("Replace Image", value=False, key="replace_image_widget_toggle") + + with col5: + change_shot_toggle = st_memory.toggle("Change Shot", value=False, key="change_shot_toggle") st.markdown("***") @@ -75,9 +71,10 @@ def shot_keyframe_element(shot_uuid, items_per_row, position="Timeline", **kwarg if idx <= len(timing_list): with grid[j]: if idx == len(timing_list): - # if position != "Timeline": - st.info("**Add new frame(s) to shot**") - add_key_frame_section(shot_uuid, False) + if position != "Timeline": + + st.info("**Add new frame(s) to shot**") + add_key_frame_section(shot_uuid, False) else: timing = timing_list[idx] @@ -187,7 +184,7 @@ def duplicate_shot_button(shot_uuid): def delete_shot_button(shot_uuid): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) - confirm_delete = st.checkbox("This will delete all the frames & videos within",key=f"confirm_delete_{shot.uuid}") + confirm_delete = st.checkbox("Confirm deletion",key=f"confirm_delete_{shot.uuid}") help_text = "Check the box above to enable the delete button." if not confirm_delete else "This will this shot and all the frames and videos within." if st.button("Delete shot", disabled=(not confirm_delete), help=help_text, key=f"delete_btn_{shot.uuid}", use_container_width=True): if st.session_state['shot_uuid'] == str(shot.uuid): @@ -239,6 +236,8 @@ def shot_video_element(shot_uuid): shot_animation_button(shot) with st.expander("Details", expanded=False): + update_shot_name(shot.uuid) + update_shot_duration(shot.uuid) move_shot_buttons(shot, "side") delete_shot_button(shot.uuid) if shot.main_clip: @@ -270,19 +269,21 @@ def create_video_download_button(video_location, tag="temp"): key=tag + str(file_name), use_container_width=True ) -def shot_adjustment_button(shot): - if st.button("🔧", key=f"jump_to_shot_adjustment_{shot.uuid}", help=f"Shot adjustment view for '{shot.name}'", use_container_width=True): +def shot_adjustment_button(shot, show_label=False): + button_label = "Shot Adjustment 🔧" if show_label else "🔧" + if st.button(button_label, key=f"jump_to_shot_adjustment_{shot.uuid}", help=f"Shot adjustment view for '{shot.name}'", use_container_width=True): st.session_state["shot_uuid"] = shot.uuid - st.session_state["frame_styling_view_type_manual_select"] = 2 + st.session_state['creative_process_manual_select'] = 3 st.session_state["manual_select"] = 1 st.session_state['shot_view_manual_select'] = 1 st.session_state['shot_view_index'] = 1 st.rerun() -def shot_animation_button(shot): - if st.button("🎞️", key=f"jump_to_shot_animation_{shot.uuid}", help=f"Shot animation view for '{shot.name}'", use_container_width=True): +def shot_animation_button(shot, show_label=False): + button_label = "Shot Animation 🎞️" if show_label else "🎞️" + if st.button(button_label, key=f"jump_to_shot_animation_{shot.uuid}", help=f"Shot animation view for '{shot.name}'", use_container_width=True): st.session_state["shot_uuid"] = shot.uuid - st.session_state["frame_styling_view_type_manual_select"] = 2 + st.session_state['creative_process_manual_select'] = 5 st.session_state["manual_select"] = 1 st.session_state['shot_view_manual_select'] = 0 st.session_state['shot_view_index'] = 0 diff --git a/ui_components/widgets/timeline_view.py b/ui_components/widgets/timeline_view.py index 619a462a..2588108e 100644 --- a/ui_components/widgets/timeline_view.py +++ b/ui_components/widgets/timeline_view.py @@ -5,12 +5,11 @@ from utils import st_memory -def timeline_view(shot_uuid, stage): +def timeline_view(shot_uuid, stage, show_frame_upload=True): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) shot_list = data_repo.get_shot_list(shot.project.uuid) - - st.markdown("***") + _, header_col_2 = st.columns([5.5,1.5]) @@ -35,10 +34,11 @@ def timeline_view(shot_uuid, stage): shot_video_element(shot.uuid) if (idx + 1) % items_per_row == 0 or idx == len(shot_list) - 1: st.markdown("***") - if idx == len(shot_list) - 1: - with grid[(idx + 1) % items_per_row]: - st.markdown("### Add new shot") - add_new_shot_element(shot, data_repo) + if show_frame_upload == True: + if idx == len(shot_list) - 1: + with grid[(idx + 1) % items_per_row]: + st.markdown("### Add new shot") + add_new_shot_element(shot, data_repo) diff --git a/ui_components/widgets/variant_comparison_grid.py b/ui_components/widgets/variant_comparison_grid.py index 3c6dc22d..0efe1e16 100644 --- a/ui_components/widgets/variant_comparison_grid.py +++ b/ui_components/widgets/variant_comparison_grid.py @@ -29,13 +29,13 @@ def variant_comparison_grid(ele_uuid, stage=CreativeProcessType.MOTION.value): variants = shot.interpolated_clip_list timing_list = data_repo.get_timing_list_from_shot(shot.uuid) else: - timing_uuid = ele_uuid + timing_uuid = ele_uuid timing = data_repo.get_timing_from_uuid(timing_uuid) variants = timing.alternative_images_list shot_uuid = timing.shot.uuid timing_list ="" - st.markdown("***") + col1, col2, col3 = st.columns([1, 1,0.5]) items_to_show = col2.slider('Variants per page:', min_value=1, max_value=12, value=6) diff --git a/utils/common_utils.py b/utils/common_utils.py index ce3e78e5..98da8ded 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -10,6 +10,46 @@ from ui_components.models import InternalUserObject from utils.cache.cache import StCache from utils.data_repo.data_repo import DataRepo +from ui_components.constants import CreativeProcessType, DefaultProjectSettingParams, DefaultTimingStyleParams + +def set_default_values(timing_list, shot_uuid, data_repo): + + if "page" not in st.session_state: + st.session_state['strength'] + + if "strength" not in st.session_state: + st.session_state['strength'] = DefaultProjectSettingParams.batch_strength + st.session_state['prompt_value'] = DefaultProjectSettingParams.batch_prompt + st.session_state['model'] = None + st.session_state['negative_prompt_value'] = DefaultProjectSettingParams.batch_negative_prompt + st.session_state['guidance_scale'] = DefaultProjectSettingParams.batch_guidance_scale + st.session_state['seed'] = DefaultProjectSettingParams.batch_seed + st.session_state['num_inference_steps'] = DefaultProjectSettingParams.batch_num_inference_steps + st.session_state['transformation_stage'] = DefaultProjectSettingParams.batch_transformation_stage + + if "current_frame_uuid" not in st.session_state and len(timing_list) > 0: + timing = data_repo.get_timing_list_from_shot(shot_uuid)[0] + st.session_state['current_frame_uuid'] = timing.uuid + st.session_state['current_frame_index'] = timing.aux_frame_index + 1 + + if 'frame_styling_view_type' not in st.session_state: + st.session_state['frame_styling_view_type'] = "Generate" + st.session_state['frame_styling_view_type_index'] = 0 + + if st.session_state['change_view_type'] == True: + st.session_state['change_view_type'] = False + + if "explorer_view" not in st.session_state: + st.session_state['explorer_view'] = "Explorations" + st.session_state['explorer_view_index'] = 0 + + if "shot_view" not in st.session_state: + st.session_state['shot_view'] = "Animate Frames" + st.session_state['shot_view_index'] = 0 + + if "styling_view" not in st.session_state: + st.session_state['styling_view'] = "Generate" + st.session_state['styling_view_index'] = 0 def copy_sample_assets(project_uuid): import shutil diff --git a/utils/st_memory.py b/utils/st_memory.py index 9ccfe88f..55094eb3 100644 --- a/utils/st_memory.py +++ b/utils/st_memory.py @@ -117,17 +117,26 @@ def checkbox(label, value=True,key=None, help=None, on_change=None, disabled=Fal return selection -def menu(menu_title,options, icons=None, menu_icon=None, default_index=0, key=None, help=None, on_change=None, disabled=False, orientation="horizontal", default_value=0, styles=None): +def menu(menu_title, options, icons=None, menu_icon=None, default_index=0, key=None, help=None, on_change=None, disabled=False, orientation="horizontal", default_value=0, styles=None): if key not in st.session_state: st.session_state[key] = default_value - # st.write(styles) - selection = option_menu(menu_title,options=options, icons=icons, menu_icon=menu_icon, orientation=orientation, default_index=int(st.session_state[key]), styles=styles) + + # if {key}_manual_select doesn't exist, set it to None + manual_select_key = f'{key}_manual_select' + if manual_select_key not in st.session_state: + st.session_state[manual_select_key] = None + + selection = option_menu(menu_title, options=options, icons=icons, menu_icon=menu_icon, orientation=orientation, default_index=int(st.session_state[key]), styles=styles, manual_select=st.session_state[manual_select_key]) + + # if {key}_manual_select is not None, set it to None + if st.session_state[manual_select_key] is not None: + st.session_state[manual_select_key] = None if options.index(selection) != st.session_state[key]: st.session_state[key] = options.index(selection) st.rerun() - + return selection def text_area(label, value='', height=None, max_chars=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False, label_visibility="visible"): From 23f1b0347da062ee9a95b12e28767e3d20885686 Mon Sep 17 00:00:00 2001 From: peter942 Date: Sat, 23 Dec 2023 03:57:26 +0000 Subject: [PATCH 02/11] Small fix --- .../components/frame_styling_page.py | 68 +++++++++++-------- .../components/timeline_view_page.py | 10 +++ .../widgets/variant_comparison_grid.py | 36 +++++----- 3 files changed, 67 insertions(+), 47 deletions(-) diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index 87cc2421..59c439b8 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -25,38 +25,46 @@ def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): + if len(timing_list) == 0: + with h2: + frame_selector_widget(show=['shot_selector','frame_selector']) + + st.markdown("#### There are no frames present in this shot yet.") + + + + else: + with st.sidebar: + with h2: - with st.sidebar: - with h2: + frame_selector_widget(show=['shot_selector','frame_selector']) + + st.session_state['styling_view'] = st_memory.menu('',\ + ["Generate", "Crop/Move", "Inpainting","Scribbling"], \ + icons=['magic', 'crop', "paint-bucket", 'pencil'], \ + menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ + key="styling_view_selector", orientation="horizontal", \ + styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) - frame_selector_widget(show=['shot_selector','frame_selector']) - - st.session_state['styling_view'] = st_memory.menu('',\ - ["Generate", "Crop/Move", "Inpainting","Scribbling"], \ - icons=['magic', 'crop', "paint-bucket", 'pencil'], \ - menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ - key="styling_view_selector", orientation="horizontal", \ - styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") - st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") + + if st.session_state['styling_view'] == "Generate": + variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) + with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True): + generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) + + elif st.session_state['styling_view'] == "Crop/Move": + with st.expander("🤏 Crop, Move & Rotate", expanded=True): + cropping_selector_element(shot_uuid) - - if st.session_state['styling_view'] == "Generate": - variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) - with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True): - generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) - - elif st.session_state['styling_view'] == "Crop/Move": - with st.expander("🤏 Crop, Move & Rotate", expanded=True): - cropping_selector_element(shot_uuid) - - elif st.session_state['styling_view'] == "Inpainting": - with st.expander("🌌 Inpainting", expanded=True): - inpainting_element(st.session_state['current_frame_uuid']) - - elif st.session_state['styling_view'] == "Scribbling": - with st.expander("📝 Draw On Image", expanded=True): - drawing_element(timing_list,project_settings, shot_uuid) - - \ No newline at end of file + elif st.session_state['styling_view'] == "Inpainting": + with st.expander("🌌 Inpainting", expanded=True): + inpainting_element(st.session_state['current_frame_uuid']) + + elif st.session_state['styling_view'] == "Scribbling": + with st.expander("📝 Draw On Image", expanded=True): + drawing_element(timing_list,project_settings, shot_uuid) + + \ No newline at end of file diff --git a/ui_components/components/timeline_view_page.py b/ui_components/components/timeline_view_page.py index edbe95cf..ade647e7 100644 --- a/ui_components/components/timeline_view_page.py +++ b/ui_components/components/timeline_view_page.py @@ -2,7 +2,9 @@ import streamlit as st from ui_components.constants import CreativeProcessType from ui_components.widgets.timeline_view import timeline_view +from ui_components.components.explorer_page import gallery_image_view from streamlit_option_menu import option_menu +from utils import st_memory def timeline_view_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): with st.sidebar: @@ -12,6 +14,14 @@ def timeline_view_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_se if "view" not in st.session_state: st.session_state["view"] = views[0] st.session_state["manual_select"] = None + + with st.expander("📋 Explorer Shortlist",expanded=True): + + if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): + project_setting = data_repo.get_project_setting(shot.project.uuid) + page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True) + gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=10, open_detailed_view_for_all=False, shortlist=True, num_columns=2,view="sidebar") + with h2: st.session_state['view'] = option_menu(None, views, icons=['palette', 'camera-reels', "hourglass", 'stopwatch'], menu_icon="cast", orientation="vertical", key="secti2on_selector", styles={ "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}, manual_select=st.session_state["manual_select"]) diff --git a/ui_components/widgets/variant_comparison_grid.py b/ui_components/widgets/variant_comparison_grid.py index 0efe1e16..88708d2e 100644 --- a/ui_components/widgets/variant_comparison_grid.py +++ b/ui_components/widgets/variant_comparison_grid.py @@ -112,23 +112,6 @@ def variant_inference_detail_element(variant, stage, shot_uuid, timing_list="", data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) - st.markdown(f"Details:") - inf_data = fetch_inference_data(variant) - if 'image_prompt_list' in inf_data: - del inf_data['image_prompt_list'] - del inf_data['image_list'] - del inf_data['output_format'] - - st.write(inf_data) - - if stage != CreativeProcessType.MOTION.value: - h1, h2 = st.columns([1, 1]) - with h1: - st.markdown(f"Add to shortlist:") - add_variant_to_shortlist_element(variant, shot.project.uuid) - with h2: - add_variant_to_shot_element(variant, shot.project.uuid) - if stage == CreativeProcessType.MOTION.value: if st.button("Load up settings from this variant", key=f"{tag}_{variant.name}", help="This will enter the settings from this variant into the inputs below - you can also use them on other shots", use_container_width=True): print("Loading settings") @@ -147,6 +130,25 @@ def variant_inference_detail_element(variant, stage, shot_uuid, timing_list="", time.sleep(0.3) st.rerun() + st.markdown(f"Details:") + inf_data = fetch_inference_data(variant) + if 'image_prompt_list' in inf_data: + del inf_data['image_prompt_list'] + del inf_data['image_list'] + del inf_data['output_format'] + + st.write(inf_data) + + if stage != CreativeProcessType.MOTION.value: + h1, h2 = st.columns([1, 1]) + with h1: + st.markdown(f"Add to shortlist:") + add_variant_to_shortlist_element(variant, shot.project.uuid) + with h2: + add_variant_to_shot_element(variant, shot.project.uuid) + + + def prepare_values(inf_data, timing_list): settings = inf_data # Map interpolation_type to indices From 6589d4ff222cb25653d5763330d7c254aa123100 Mon Sep 17 00:00:00 2001 From: peter942 Date: Sat, 23 Dec 2023 04:25:16 +0000 Subject: [PATCH 03/11] update --- ui_components/widgets/animation_style_element.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ui_components/widgets/animation_style_element.py b/ui_components/widgets/animation_style_element.py index bf9c97c1..d96e3d8b 100644 --- a/ui_components/widgets/animation_style_element.py +++ b/ui_components/widgets/animation_style_element.py @@ -99,11 +99,11 @@ def animation_style_element(shot_uuid): st.session_state[f"frame_{idx+1}"] = idx * 16 # Default values in increments of 16 if idx == 0: # For the first frame, position is locked to 0 with columns[idx]: - frame_position = st_memory.number_input(f"{idx+1} frame Position", min_value=0, max_value=0, value=0, step=1, key=f"dynamic_frame_distribution_values_{idx+1}", disabled=True) + frame_position = st_memory.number_input(f"{idx+1} frame Position", min_value=0, max_value=0, value=0, step=1, key=f"dynamic_frame_distribution_values_{idx}", disabled=True) else: min_value = st.session_state[f"frame_{idx}"] + 1 with columns[idx]: - frame_position = st_memory.number_input(f"#{idx+1} position:", min_value=min_value, value=st.session_state[f"frame_{idx+1}"], step=1, key=f"dynamic_frame_distribution_values_{idx+1}") + frame_position = st_memory.number_input(f"#{idx+1} position:", min_value=min_value, value=st.session_state[f"frame_{idx+1}"], step=1, key=f"dynamic_frame_distribution_values_{idx}") # st.session_state[f"frame_{idx+1}"] = frame_position dynamic_frame_distribution_values.append(frame_position) @@ -545,4 +545,5 @@ def update_interpolation_settings(values=None, timing_list=None): default_values[f'dynamic_cn_strength_values_{idx}'] = (0.0,0.7) for key, default_value in default_values.items(): - st.session_state[key] = values.get(key, default_value) if values and values.get(key) is not None else default_value \ No newline at end of file + st.session_state[key] = values.get(key, default_value) if values and values.get(key) is not None else default_value + # print(f"{key}: {st.session_state[key]}") \ No newline at end of file From c58575d021f3905fc6df90e36a8a975ac44ef97d Mon Sep 17 00:00:00 2001 From: peter942 Date: Sat, 23 Dec 2023 04:45:13 +0000 Subject: [PATCH 04/11] Fix --- ui_components/components/frame_styling_page.py | 2 +- ui_components/widgets/sidebar_logger.py | 2 +- ui_components/widgets/timeline_view.py | 12 ++++++------ utils/common_utils.py | 4 +--- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index 59c439b8..c37cdf62 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -31,7 +31,7 @@ def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_se st.markdown("#### There are no frames present in this shot yet.") - + else: diff --git a/ui_components/widgets/sidebar_logger.py b/ui_components/widgets/sidebar_logger.py index 5e9fb9df..8f719908 100644 --- a/ui_components/widgets/sidebar_logger.py +++ b/ui_components/widgets/sidebar_logger.py @@ -108,7 +108,7 @@ def sidebar_logger(shot_uuid): st.session_state['main_view_type'] = "Creative Process" st.session_state['frame_styling_view_type_index'] = 0 st.session_state['frame_styling_view_type'] = "Explorer" - st.session_state['change_view_type'] = False + st.rerun() diff --git a/ui_components/widgets/timeline_view.py b/ui_components/widgets/timeline_view.py index 2588108e..95fc2d1f 100644 --- a/ui_components/widgets/timeline_view.py +++ b/ui_components/widgets/timeline_view.py @@ -5,7 +5,7 @@ from utils import st_memory -def timeline_view(shot_uuid, stage, show_frame_upload=True): +def timeline_view(shot_uuid, stage): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) shot_list = data_repo.get_shot_list(shot.project.uuid) @@ -34,11 +34,11 @@ def timeline_view(shot_uuid, stage, show_frame_upload=True): shot_video_element(shot.uuid) if (idx + 1) % items_per_row == 0 or idx == len(shot_list) - 1: st.markdown("***") - if show_frame_upload == True: - if idx == len(shot_list) - 1: - with grid[(idx + 1) % items_per_row]: - st.markdown("### Add new shot") - add_new_shot_element(shot, data_repo) + # if stage isn't + if idx == len(shot_list) - 1: + with grid[(idx + 1) % items_per_row]: + st.markdown("### Add new shot") + add_new_shot_element(shot, data_repo) diff --git a/utils/common_utils.py b/utils/common_utils.py index 98da8ded..62f3521d 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -15,7 +15,7 @@ def set_default_values(timing_list, shot_uuid, data_repo): if "page" not in st.session_state: - st.session_state['strength'] + st.session_state['page'] = "Explore" if "strength" not in st.session_state: st.session_state['strength'] = DefaultProjectSettingParams.batch_strength @@ -36,8 +36,6 @@ def set_default_values(timing_list, shot_uuid, data_repo): st.session_state['frame_styling_view_type'] = "Generate" st.session_state['frame_styling_view_type_index'] = 0 - if st.session_state['change_view_type'] == True: - st.session_state['change_view_type'] = False if "explorer_view" not in st.session_state: st.session_state['explorer_view'] = "Explorations" From d6a0b1c34008e746c32d4fd4ae0d710e6a55e3d9 Mon Sep 17 00:00:00 2001 From: piyushK52 Date: Sat, 23 Dec 2023 13:31:55 +0530 Subject: [PATCH 05/11] code refactoring --- ui_components/components/adjust_shot_page.py | 6 ++++- ui_components/components/animate_shot_page.py | 5 +++- .../components/frame_styling_page.py | 26 ++++++------------- .../components/timeline_view_page.py | 7 +++-- ui_components/setup.py | 23 ++++------------ ui_components/widgets/drawing_element.py | 13 ++++------ utils/cache/cache_methods.py | 7 +++++ utils/common_utils.py | 10 ++++--- 8 files changed, 45 insertions(+), 52 deletions(-) diff --git a/ui_components/components/adjust_shot_page.py b/ui_components/components/adjust_shot_page.py index c8db8f5f..9d4d7631 100644 --- a/ui_components/components/adjust_shot_page.py +++ b/ui_components/components/adjust_shot_page.py @@ -4,10 +4,14 @@ from ui_components.components.explorer_page import generate_images_element from ui_components.widgets.frame_selector import frame_selector_widget from utils import st_memory +from utils.data_repo.data_repo import DataRepo -def adjust_shot_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): +def adjust_shot_page(shot_uuid: str, h2): + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(shot_uuid) + with h2: frame_selector_widget(show=['shot_selector']) diff --git a/ui_components/components/animate_shot_page.py b/ui_components/components/animate_shot_page.py index e0033480..fad9c04b 100644 --- a/ui_components/components/animate_shot_page.py +++ b/ui_components/components/animate_shot_page.py @@ -2,8 +2,11 @@ from ui_components.widgets.frame_selector import frame_selector_widget from ui_components.widgets.variant_comparison_grid import variant_comparison_grid from ui_components.widgets.animation_style_element import animation_style_element +from utils.data_repo.data_repo import DataRepo -def animate_shot_page(shot_uuid: str,h2,data_repo,shot,timing_list, project_settings): +def animate_shot_page(shot_uuid: str, h2): + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(shot_uuid) with h2: frame_selector_widget(show=['shot_selector']) diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index c37cdf62..59ab4d30 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -1,39 +1,29 @@ import streamlit as st -from shared.constants import ViewType -from streamlit_option_menu import option_menu from ui_components.widgets.cropping_element import cropping_selector_element from ui_components.widgets.frame_selector import frame_selector_widget -from ui_components.widgets.add_key_frame_element import add_key_frame, add_key_frame_element -from ui_components.widgets.timeline_view import timeline_view from ui_components.components.explorer_page import generate_images_element -from ui_components.widgets.animation_style_element import animation_style_element -from ui_components.widgets.video_cropping_element import video_cropping_element from ui_components.widgets.inpainting_element import inpainting_element from ui_components.widgets.drawing_element import drawing_element -from ui_components.widgets.sidebar_logger import sidebar_logger -# from ui_components.components.explorer_page import explorer_element,gallery_image_view from ui_components.widgets.variant_comparison_grid import variant_comparison_grid -from ui_components.widgets.shot_view import shot_keyframe_element from utils import st_memory - -from ui_components.constants import CreativeProcessType, DefaultProjectSettingParams, DefaultTimingStyleParams - +from ui_components.constants import CreativeProcessType from utils.data_repo.data_repo import DataRepo -def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): +def frame_styling_page(shot_uuid: str, h2): + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(shot_uuid) + timing_list = data_repo.get_timing_list_from_shot(shot_uuid) + if len(timing_list) == 0: with h2: frame_selector_widget(show=['shot_selector','frame_selector']) - + st.markdown("#### There are no frames present in this shot yet.") - - - else: with st.sidebar: with h2: @@ -65,6 +55,6 @@ def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_se elif st.session_state['styling_view'] == "Scribbling": with st.expander("📝 Draw On Image", expanded=True): - drawing_element(timing_list,project_settings, shot_uuid) + drawing_element(shot_uuid) \ No newline at end of file diff --git a/ui_components/components/timeline_view_page.py b/ui_components/components/timeline_view_page.py index ade647e7..8799d2de 100644 --- a/ui_components/components/timeline_view_page.py +++ b/ui_components/components/timeline_view_page.py @@ -5,10 +5,13 @@ from ui_components.components.explorer_page import gallery_image_view from streamlit_option_menu import option_menu from utils import st_memory +from utils.data_repo.data_repo import DataRepo + +def timeline_view_page(shot_uuid: str, h2): + data_repo = DataRepo() + shot = data_repo.get_shot_from_uuid(shot_uuid) -def timeline_view_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings): with st.sidebar: - views = CreativeProcessType.value_list() if "view" not in st.session_state: diff --git a/ui_components/setup.py b/ui_components/setup.py index 7a98a100..78b0f381 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -136,21 +136,13 @@ def setup_app_ui(): "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "red"}}) if st.session_state["main_view_type"] == "Creative Process": - - data_repo = DataRepo() - shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) - timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) - project_settings = data_repo.get_project_setting(shot.project.uuid) - set_default_values(timing_list,shot.uuid, data_repo) + set_default_values(st.session_state["shot_uuid"]) with st.sidebar: - creative_process_pages = ["Explore", "Shortlist", "Timeline", "Adjust Shot", "Adjust Frame", "Animate Shot"] - if 'creative_process_manual_select' not in st.session_state: st.session_state['creative_process_manual_select'] = 0 st.session_state['page'] = creative_process_pages[0] - h1,h2 = st.columns([1.5,1]) with h1: @@ -179,28 +171,23 @@ def setup_app_ui(): shortlist_page(st.session_state["project_uuid"]) elif st.session_state['page'] == "Timeline": - timeline_view_page(st.session_state["shot_uuid"],h2,data_repo,shot,timing_list, project_settings) + timeline_view_page(st.session_state["shot_uuid"], h2) elif st.session_state['page'] == "Adjust Frame": - frame_styling_page(st.session_state["shot_uuid"],h2,data_repo,shot,timing_list, project_settings) + frame_styling_page(st.session_state["shot_uuid"], h2) elif st.session_state['page'] == "Adjust Shot": - adjust_shot_page(st.session_state["shot_uuid"], h2,data_repo,shot,timing_list, project_settings) + adjust_shot_page(st.session_state["shot_uuid"], h2) elif st.session_state['page'] == "Animate Shot": - animate_shot_page(st.session_state["shot_uuid"],h2,data_repo,shot,timing_list, project_settings) + animate_shot_page(st.session_state["shot_uuid"], h2) with st.sidebar: - with st.expander("🔍 Generation Log", expanded=True): if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): sidebar_logger(st.session_state["shot_uuid"]) st.markdown("***") - - - # frame_styling_page(st.session_state["shot_uuid"]) - elif st.session_state["main_view_type"] == "Tools & Settings": with st.sidebar: tool_pages = ["Query Logger", "Project Settings"] diff --git a/ui_components/widgets/drawing_element.py b/ui_components/widgets/drawing_element.py index 3a23eb8b..67f6add0 100644 --- a/ui_components/widgets/drawing_element.py +++ b/ui_components/widgets/drawing_element.py @@ -1,30 +1,27 @@ from io import BytesIO import uuid -import json import time import requests import streamlit as st from PIL import Image from streamlit_drawable_canvas import st_canvas -from ui_components.constants import WorkflowStageType from utils.data_repo.data_repo import DataRepo from ui_components.methods.common_methods import add_image_variant, extract_canny_lines, promote_image_variant from shared.constants import InternalFileType from ui_components.methods.file_methods import save_or_host_file -from utils import st_memory - - -def drawing_element(timing_details, project_settings, shot_uuid, stage=WorkflowStageType.STYLED.value): +def drawing_element(shot_uuid): data_repo = DataRepo() shot = data_repo.get_shot_from_uuid(shot_uuid) project_uuid = shot.project.uuid + project_settings = data_repo.get_project_setting(project_uuid) + timing_list = data_repo.get_timing_list_from_shot(shot_uuid) canvas1, canvas2 = st.columns([1, 1.5]) timing = data_repo.get_timing_from_uuid(st.session_state['current_frame_uuid']) - image_path = timing_details[st.session_state['current_frame_index'] - 1].primary_image_location + image_path = timing_list[st.session_state['current_frame_index'] - 1].primary_image_location with canvas1: width = int(project_settings.width) height = int(project_settings.height) @@ -90,7 +87,7 @@ def drawing_element(timing_details, project_settings, shot_uuid, stage=WorkflowS st.session_state['canny_image'] = None if st.button("Extract Canny From image"): - image_path = timing_details[st.session_state['current_frame_index'] - 1].primary_image_location + image_path = timing_list[st.session_state['current_frame_index'] - 1].primary_image_location canny_image = extract_canny_lines( image_path, project_uuid, low_threshold, high_threshold) st.session_state['canny_image'] = canny_image.uuid diff --git a/utils/cache/cache_methods.py b/utils/cache/cache_methods.py index 76f68ea8..aae7ff76 100644 --- a/utils/cache/cache_methods.py +++ b/utils/cache/cache_methods.py @@ -593,6 +593,13 @@ def _cache_get_shot_from_uuid(self, *args, **kwargs): original_func = getattr(cls, '_original_get_shot_from_uuid') shot = original_func(self, *args, **kwargs) + + if shot and not (shot_list and len(shot_list)): + original_func = getattr(cls, '_original_get_shot_list') + shot_list = original_func(self, shot.project.uuid) + if shot_list: + StCache.delete_all(CacheKey.SHOT.value) + StCache.add_all(shot_list, CacheKey.SHOT.value) return shot diff --git a/utils/common_utils.py b/utils/common_utils.py index 62f3521d..5c2a4240 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -10,9 +10,11 @@ from ui_components.models import InternalUserObject from utils.cache.cache import StCache from utils.data_repo.data_repo import DataRepo -from ui_components.constants import CreativeProcessType, DefaultProjectSettingParams, DefaultTimingStyleParams +from ui_components.constants import DefaultProjectSettingParams -def set_default_values(timing_list, shot_uuid, data_repo): +def set_default_values(shot_uuid): + data_repo = DataRepo() + timing_list = data_repo.get_timing_list_from_shot(shot_uuid) if "page" not in st.session_state: st.session_state['page'] = "Explore" @@ -28,7 +30,7 @@ def set_default_values(timing_list, shot_uuid, data_repo): st.session_state['transformation_stage'] = DefaultProjectSettingParams.batch_transformation_stage if "current_frame_uuid" not in st.session_state and len(timing_list) > 0: - timing = data_repo.get_timing_list_from_shot(shot_uuid)[0] + timing = timing_list[0] st.session_state['current_frame_uuid'] = timing.uuid st.session_state['current_frame_index'] = timing.aux_frame_index + 1 @@ -36,7 +38,6 @@ def set_default_values(timing_list, shot_uuid, data_repo): st.session_state['frame_styling_view_type'] = "Generate" st.session_state['frame_styling_view_type_index'] = 0 - if "explorer_view" not in st.session_state: st.session_state['explorer_view'] = "Explorations" st.session_state['explorer_view_index'] = 0 @@ -49,6 +50,7 @@ def set_default_values(timing_list, shot_uuid, data_repo): st.session_state['styling_view'] = "Generate" st.session_state['styling_view_index'] = 0 + def copy_sample_assets(project_uuid): import shutil From 206c7e41fa8516eadd4d8e42158e8a7f392d39ee Mon Sep 17 00:00:00 2001 From: piyushK52 Date: Sat, 23 Dec 2023 15:47:31 +0530 Subject: [PATCH 06/11] basic log list added --- ui_components/components/query_logger_page.py | 46 +++++++++++++++++++ ui_components/constants.py | 1 + ui_components/setup.py | 3 +- 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 ui_components/components/query_logger_page.py diff --git a/ui_components/components/query_logger_page.py b/ui_components/components/query_logger_page.py new file mode 100644 index 00000000..43a5556c --- /dev/null +++ b/ui_components/components/query_logger_page.py @@ -0,0 +1,46 @@ +import json +import streamlit as st +from ui_components.constants import DefaultTimingStyleParams +from utils.common_utils import get_current_user + +from utils.data_repo.data_repo import DataRepo + +def query_logger_page(): + st.header("Inference Log list") + + data_repo = DataRepo() + current_user = get_current_user() + b1, b2 = st.columns([1, 1]) + + total_log_table_pages = st.session_state['total_log_table_pages'] if 'total_log_table_pages' in st.session_state else DefaultTimingStyleParams.total_log_table_pages + page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1) + inference_log_list, total_page_count = data_repo.get_all_inference_log_list( + page=page_number, + data_per_page=100 + ) + + if total_log_table_pages != total_page_count: + st.session_state['total_log_table_pages'] = total_page_count + st.rerun() + + data = { + 'Project': [], + 'Prompt': [], + 'Model': [], + 'Inference time (sec)': [], + 'Cost ($)': [], + 'Status': [] + } + + for log in inference_log_list: + data['Project'].append(log.project.name) + prompt = json.loads(log.input_params).get('prompt', '') if log.input_params else '' + data['Prompt'].append(prompt) + model_name = json.loads(log.output_details).get('model_name', '') if log.output_details else '' + data['Model'].append(model_name) + data['Inference time (sec)'].append(round(log.total_inference_time, 3)) + data['Cost ($)'].append(round(log.total_inference_time * 0.004, 3)) + data['Status'].append(log.status) + + + st.table(data=data) \ No newline at end of file diff --git a/ui_components/constants.py b/ui_components/constants.py index ec2bb3cf..867b0ce5 100644 --- a/ui_components/constants.py +++ b/ui_components/constants.py @@ -32,6 +32,7 @@ class DefaultTimingStyleParams: animation_tool = AnimationToolType.G_FILM.value animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value model = None + total_log_table_pages = 1 class DefaultProjectSettingParams: batch_prompt = "" diff --git a/ui_components/setup.py b/ui_components/setup.py index 78b0f381..09401807 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -2,6 +2,7 @@ import os from moviepy.editor import * from shared.constants import SERVER, ServerType +from ui_components.components.query_logger_page import query_logger_page # from ui_components.components.explorer_page import explorer_element,shortlist_element from ui_components.widgets.timeline_view import timeline_view from ui_components.widgets.sidebar_logger import sidebar_logger @@ -199,7 +200,7 @@ def setup_app_ui(): st.session_state['page'] = option_menu(None, tool_pages, icons=['pencil', 'palette', "hourglass", 'stopwatch'], menu_icon="cast", orientation="horizontal", key="secti2on_selector", styles={ "nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "green"}}, manual_select=st.session_state["manual_select"]) if st.session_state["page"] == "Query Logger": - st.info("Query Logger will appear here.") + query_logger_page() if st.session_state["page"] == "Custom Models": custom_models_page(st.session_state["project_uuid"]) elif st.session_state["page"] == "Project Settings": From 0907886de53dc0d1b02e151813d839dfc89bf32f Mon Sep 17 00:00:00 2001 From: peter942 Date: Sat, 23 Dec 2023 19:37:53 +0000 Subject: [PATCH 07/11] Fix --- ui_components/components/adjust_shot_page.py | 5 ++++- ui_components/components/animate_shot_page.py | 5 ++++- ui_components/components/frame_styling_page.py | 4 +++- ui_components/widgets/animation_style_element.py | 4 ++-- ui_components/widgets/frame_selector.py | 1 - 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/ui_components/components/adjust_shot_page.py b/ui_components/components/adjust_shot_page.py index c8db8f5f..f62dc230 100644 --- a/ui_components/components/adjust_shot_page.py +++ b/ui_components/components/adjust_shot_page.py @@ -2,7 +2,7 @@ from ui_components.widgets.shot_view import shot_keyframe_element from ui_components.components.explorer_page import gallery_image_view from ui_components.components.explorer_page import generate_images_element -from ui_components.widgets.frame_selector import frame_selector_widget +from ui_components.widgets.frame_selector import frame_selector_widget, frame_view from utils import st_memory @@ -15,6 +15,9 @@ def adjust_shot_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_sett st.markdown("***") + with st.sidebar: + frame_view() + shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") # with st.expander("📋 Explorer Shortlist",expanded=True): shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"], diff --git a/ui_components/components/animate_shot_page.py b/ui_components/components/animate_shot_page.py index e0033480..7995d92a 100644 --- a/ui_components/components/animate_shot_page.py +++ b/ui_components/components/animate_shot_page.py @@ -1,5 +1,5 @@ import streamlit as st -from ui_components.widgets.frame_selector import frame_selector_widget +from ui_components.widgets.frame_selector import frame_selector_widget, frame_view from ui_components.widgets.variant_comparison_grid import variant_comparison_grid from ui_components.widgets.animation_style_element import animation_style_element @@ -7,6 +7,9 @@ def animate_shot_page(shot_uuid: str,h2,data_repo,shot,timing_list, project_sett with h2: frame_selector_widget(show=['shot_selector']) + with st.sidebar: + frame_view() + st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]") st.markdown("***") variant_comparison_grid(st.session_state['shot_uuid'], stage="Shots") diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index c37cdf62..8009f5df 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -3,7 +3,7 @@ from streamlit_option_menu import option_menu from ui_components.widgets.cropping_element import cropping_selector_element -from ui_components.widgets.frame_selector import frame_selector_widget +from ui_components.widgets.frame_selector import frame_selector_widget, frame_view from ui_components.widgets.add_key_frame_element import add_key_frame, add_key_frame_element from ui_components.widgets.timeline_view import timeline_view from ui_components.components.explorer_page import generate_images_element @@ -46,6 +46,8 @@ def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_se menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \ key="styling_view_selector", orientation="horizontal", \ styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) + + frame_view() st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") diff --git a/ui_components/widgets/animation_style_element.py b/ui_components/widgets/animation_style_element.py index d96e3d8b..7f338242 100644 --- a/ui_components/widgets/animation_style_element.py +++ b/ui_components/widgets/animation_style_element.py @@ -40,7 +40,7 @@ def animation_style_element(shot_uuid): type_of_frame_distribution = st_memory.radio("Type of key frame distribution:", options=["Linear", "Dynamic"], key="type_of_frame_distribution").lower() if type_of_frame_distribution == "linear": with setting_a_2: - linear_frame_distribution_value = st_memory.number_input("Frames per key frame:", min_value=8, max_value=36, value=16, step=1, key="frames_per_keyframe") + linear_frame_distribution_value = st_memory.number_input("Frames per key frame:", min_value=8, max_value=36, value=16, step=1, key="linear_frame_distribution_value") dynamic_frame_distribution_values = [] st.markdown("***") setting_b_1, setting_b_2 = st.columns([1, 1]) @@ -48,7 +48,7 @@ def animation_style_element(shot_uuid): type_of_key_frame_influence = st_memory.radio("Type of key frame length influence:", options=["Linear", "Dynamic"], key="type_of_key_frame_influence").lower() if type_of_key_frame_influence == "linear": with setting_b_2: - linear_key_frame_influence_value = st_memory.slider("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.01, key="length_of_key_frame_influence") + linear_key_frame_influence_value = st_memory.slider("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.01, key="linear_key_frame_influence_value") dynamic_key_frame_influence_values = [] st.markdown("***") diff --git a/ui_components/widgets/frame_selector.py b/ui_components/widgets/frame_selector.py index d05d67c2..c021ac9f 100644 --- a/ui_components/widgets/frame_selector.py +++ b/ui_components/widgets/frame_selector.py @@ -93,7 +93,6 @@ def frame_view(): st.markdown("---") delete_frame_button(st.session_state['current_frame_uuid']) - else: From d8336288e665c7ee6160e2d8331ed5ba9b6b6135 Mon Sep 17 00:00:00 2001 From: peter942 Date: Sun, 24 Dec 2023 04:36:39 +0000 Subject: [PATCH 08/11] Fixes --- ui_components/components/adjust_shot_page.py | 2 +- ui_components/components/explorer_page.py | 6 ++-- .../components/frame_styling_page.py | 8 +++-- ui_components/methods/file_methods.py | 32 ++++++++++++++++--- ui_components/methods/ml_methods.py | 2 +- .../widgets/animation_style_element.py | 4 +-- ui_components/widgets/frame_selector.py | 6 ++-- ui_components/widgets/inpainting_element.py | 2 +- utils/ml_processor/replicate/replicate.py | 2 +- 9 files changed, 45 insertions(+), 19 deletions(-) diff --git a/ui_components/components/adjust_shot_page.py b/ui_components/components/adjust_shot_page.py index dcb1eb1b..bb326d07 100644 --- a/ui_components/components/adjust_shot_page.py +++ b/ui_components/components/adjust_shot_page.py @@ -20,7 +20,7 @@ def adjust_shot_page(shot_uuid: str, h2): st.markdown("***") with st.sidebar: - frame_view() + frame_view(view='Video') shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual") # with st.expander("📋 Explorer Shortlist",expanded=True): diff --git a/ui_components/components/explorer_page.py b/ui_components/components/explorer_page.py index f64cbb58..2f160fd5 100644 --- a/ui_components/components/explorer_page.py +++ b/ui_components/components/explorer_page.py @@ -27,9 +27,9 @@ class InputImageStyling(ExtendedEnum): def columnn_selecter(): f1, f2 = st.columns([1, 1]) with f1: - st_memory.slider('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") + st_memory.number_input('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer") with f2: - st_memory.slider('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") + st_memory.number_input('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") def explorer_page(project_uuid): @@ -121,7 +121,7 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid= with b3: edge_pil_img = None - strength_of_current_image = st_memory.slider("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_current_image_key", help="This will determine how much of the current image will be kept in the final image.") + strength_of_current_image = st_memory.number_input("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_current_image_key", help="This will determine how much of the current image will be kept in the final image.") if type_of_transformation == InputImageStyling.EVOLVE_IMAGE.value: prompt_strength = round(1 - (strength_of_current_image / 100), 2) with c2: diff --git a/ui_components/components/frame_styling_page.py b/ui_components/components/frame_styling_page.py index 2439255e..4f252014 100644 --- a/ui_components/components/frame_styling_page.py +++ b/ui_components/components/frame_styling_page.py @@ -39,13 +39,15 @@ def frame_styling_page(shot_uuid: str, h2): key="styling_view_selector", orientation="horizontal", \ styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}}) - frame_view() + frame_view(view="Key Frame") st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]") - + variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) + + st.markdown("***") if st.session_state['styling_view'] == "Generate": - variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value) + with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True): generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid']) diff --git a/ui_components/methods/file_methods.py b/ui_components/methods/file_methods.py index 2438c6a7..7add1ce5 100644 --- a/ui_components/methods/file_methods.py +++ b/ui_components/methods/file_methods.py @@ -226,15 +226,39 @@ def load_from_env(key): def zip_images(image_locations, zip_filename='images.zip'): with zipfile.ZipFile(zip_filename, 'w') as zip_file: + # Sort the image_locations list to ensure they are added in numerical order + image_locations.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + + # Determine the number of digits to pad based on the highest number in the list + padding = len(str(len(image_locations))) + for idx, image_location in enumerate(image_locations): - # image_name = os.path.basename(image_location) - image_name = f"{idx}.png" + # Zero-padding the file index + image_name = f"{str(idx).zfill(padding)}.png" + if image_location.startswith('http'): + # Fetch the image over HTTP response = requests.get(image_location) image_data = response.content - zip_file.writestr(image_name, image_data) + # Convert to Image object to check and convert mode + img = Image.open(io.BytesIO(image_data)) + if img.mode != 'RGB': + img = img.convert('RGB') + # Save image data to zip using BytesIO + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + zip_file.writestr(image_name, img_byte_arr) else: - zip_file.write(image_location, image_name) + # Open the image file to check and convert mode + with Image.open(image_location) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + # Save image to zip + img_byte_arr = io.BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + zip_file.writestr(image_name, img_byte_arr.read()) return zip_filename diff --git a/ui_components/methods/ml_methods.py b/ui_components/methods/ml_methods.py index 3eae6890..b704796d 100644 --- a/ui_components/methods/ml_methods.py +++ b/ui_components/methods/ml_methods.py @@ -209,7 +209,7 @@ def inpainting(input_image: str, prompt, negative_prompt, timing_uuid, mask_in_p prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, - strength=1.0, + strength=0.99, queue_inference=QUEUE_INFERENCE_QUERIES ) diff --git a/ui_components/widgets/animation_style_element.py b/ui_components/widgets/animation_style_element.py index 7f338242..f510c5d1 100644 --- a/ui_components/widgets/animation_style_element.py +++ b/ui_components/widgets/animation_style_element.py @@ -48,7 +48,7 @@ def animation_style_element(shot_uuid): type_of_key_frame_influence = st_memory.radio("Type of key frame length influence:", options=["Linear", "Dynamic"], key="type_of_key_frame_influence").lower() if type_of_key_frame_influence == "linear": with setting_b_2: - linear_key_frame_influence_value = st_memory.slider("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.01, key="linear_key_frame_influence_value") + linear_key_frame_influence_value = st_memory.number_input("Length of key frame influence:", min_value=0.1, max_value=5.0, value=1.0, step=0.01, key="linear_key_frame_influence_value") dynamic_key_frame_influence_values = [] st.markdown("***") @@ -540,7 +540,7 @@ def update_interpolation_settings(values=None, timing_list=None): } for idx in range(0, len(timing_list)): - default_values[f'dynamic_frame_distribution_values_{idx}'] = (idx - 1) * 16 + default_values[f'dynamic_frame_distribution_values_{idx}'] = (idx ) * 16 default_values[f'dynamic_key_frame_influence_values_{idx}'] = 1.0 default_values[f'dynamic_cn_strength_values_{idx}'] = (0.0,0.7) diff --git a/ui_components/widgets/frame_selector.py b/ui_components/widgets/frame_selector.py index c021ac9f..a81dd7be 100644 --- a/ui_components/widgets/frame_selector.py +++ b/ui_components/widgets/frame_selector.py @@ -60,17 +60,17 @@ def frame_selector_widget(show: List[str]): else: st.error("No frames present") -def frame_view(): +def frame_view(view="Key Frame"): data_repo = DataRepo() # time1, time2 = st.columns([1,1]) st.markdown("***") timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) - if st.session_state['page'] == "Key Frames": + if view == "Key Frame": with st.expander(f"🖼️ Frame #{st.session_state['current_frame_index']} Details", expanded=True): - if st_memory.toggle("Open", value=True, key="frame_toggle"): + if st_memory.toggle("Open", value=True, key="frame_toggle"): a1, a2 = st.columns([3,2]) with a1: st.success(f"Main Key Frame:") diff --git a/ui_components/widgets/inpainting_element.py b/ui_components/widgets/inpainting_element.py index bb284ab6..26f338b4 100644 --- a/ui_components/widgets/inpainting_element.py +++ b/ui_components/widgets/inpainting_element.py @@ -218,7 +218,7 @@ def inpaint_in_black_space_element(cropped_img, project_uuid, stage=WorkflowStag st.markdown("##### Inpaint in black space:") - inpaint_prompt = st.text_area("Prompt", value=DefaultProjectSettingParams.batch_prompt) + inpaint_prompt = st.text_area("Prompt", value=st.session_state['explorer_base_prompt']) inpaint_negative_prompt = st.text_input( "Negative Prompt", value='edge,branches, frame, fractals, text' + DefaultProjectSettingParams.batch_negative_prompt) if 'precision_cropping_inpainted_image_uuid' not in st.session_state: diff --git a/utils/ml_processor/replicate/replicate.py b/utils/ml_processor/replicate/replicate.py index 284b43b2..d1b94c64 100644 --- a/utils/ml_processor/replicate/replicate.py +++ b/utils/ml_processor/replicate/replicate.py @@ -193,7 +193,7 @@ def inpainting(self, video_name, input_image, prompt, negative_prompt): input_image = open(input_image, "rb") start_time = time.time() - output = model.predict(mask=mask, image=input_image,prompt=prompt, invert_mask=True, negative_prompt=negative_prompt,num_inference_steps=25) + output = model.predict(mask=mask, image=input_image,prompt=prompt, invert_mask=True, negative_prompt=negative_prompt,num_inference_steps=25, strength=0.99) end_time = time.time() log = log_model_inference(model, end_time - start_time, prompt=prompt, invert_mask=True, negative_prompt=negative_prompt,num_inference_steps=25) self.update_usage_credits(end_time - start_time) From 60582fabc73618a0ae66b718422bf21b5b7e5e46 Mon Sep 17 00:00:00 2001 From: peter942 Date: Sun, 24 Dec 2023 06:07:43 +0000 Subject: [PATCH 09/11] Fixing zipping --- ui_components/methods/file_methods.py | 52 ++++++++++++++------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/ui_components/methods/file_methods.py b/ui_components/methods/file_methods.py index 7add1ce5..c8547b1e 100644 --- a/ui_components/methods/file_methods.py +++ b/ui_components/methods/file_methods.py @@ -224,45 +224,49 @@ def load_from_env(key): val = get_key(dotenv_path=ENV_FILE_PATH, key_to_get=key) return val -def zip_images(image_locations, zip_filename='images.zip'): - with zipfile.ZipFile(zip_filename, 'w') as zip_file: - # Sort the image_locations list to ensure they are added in numerical order - image_locations.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) +import zipfile +import os +import requests +from PIL import Image +from io import BytesIO - # Determine the number of digits to pad based on the highest number in the list - padding = len(str(len(image_locations))) +def zip_images(image_locations, zip_filename='images.zip'): + # Calculate the number of digits needed for padding + num_digits = len(str(len(image_locations) - 1)) + with zipfile.ZipFile(zip_filename, 'w') as zip_file: for idx, image_location in enumerate(image_locations): - # Zero-padding the file index - image_name = f"{str(idx).zfill(padding)}.png" + # Pad the index with zeros + padded_idx = str(idx).zfill(num_digits) + image_name = f"{padded_idx}.png" if image_location.startswith('http'): - # Fetch the image over HTTP response = requests.get(image_location) image_data = response.content - # Convert to Image object to check and convert mode - img = Image.open(io.BytesIO(image_data)) - if img.mode != 'RGB': - img = img.convert('RGB') - # Save image data to zip using BytesIO - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr = img_byte_arr.getvalue() - zip_file.writestr(image_name, img_byte_arr) + + # Open the image for inspection and possible conversion + with Image.open(BytesIO(image_data)) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + + # Save the potentially converted image to a byte stream + with BytesIO() as output: + img.save(output, format='PNG') + zip_file.writestr(image_name, output.getvalue()) else: - # Open the image file to check and convert mode + # For local files, open, possibly convert, and then add to zip with Image.open(image_location) as img: if img.mode != 'RGB': img = img.convert('RGB') - # Save image to zip - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr.seek(0) - zip_file.writestr(image_name, img_byte_arr.read()) + + img.save(image_name, format='PNG') + zip_file.write(image_name, image_name) + os.remove(image_name) # Clean up the temporary file return zip_filename + def create_duplicate_file(file: InternalFileObject, project_uuid=None) -> InternalFileObject: data_repo = DataRepo() From 00378efe4aa6b9d3eecc10847975a001fde3ce40 Mon Sep 17 00:00:00 2001 From: piyushK52 Date: Sun, 24 Dec 2023 17:47:12 +0530 Subject: [PATCH 10/11] speed fix --- app.py | 2 + ui_components/components/explorer_page.py | 62 +++++++------- ui_components/setup.py | 7 +- .../widgets/add_key_frame_element.py | 3 +- .../widgets/frame_movement_widgets.py | 5 +- utils/cache/cache.py | 5 ++ utils/cache/cache_methods.py | 81 +++++++++++++++++++ utils/common_utils.py | 10 ++- utils/data_repo/data_repo.py | 1 - 9 files changed, 140 insertions(+), 36 deletions(-) diff --git a/app.py b/app.py index 72c6098c..f0abffd5 100644 --- a/app.py +++ b/app.py @@ -94,6 +94,8 @@ def main(): from ui_components.setup import setup_app_ui setup_app_ui() + + st.session_state['maintain_state'] = False if __name__ == '__main__': try: diff --git a/ui_components/components/explorer_page.py b/ui_components/components/explorer_page.py index 2f160fd5..72a59094 100644 --- a/ui_components/components/explorer_page.py +++ b/ui_components/components/explorer_page.py @@ -4,6 +4,7 @@ from ui_components.methods.file_methods import generate_pil_image from ui_components.methods.ml_methods import query_llama2 from ui_components.widgets.add_key_frame_element import add_key_frame +from utils.common_utils import refresh_app from utils.constants import MLQueryObject from utils.data_repo.data_repo import DataRepo from shared.constants import QUEUE_INFERENCE_QUERIES, AIModelType, InferenceType, InternalFileTag, InternalFileType, SortOrder @@ -326,7 +327,8 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de st.success("Added To Shortlist") time.sleep(0.3) st.rerun() - + + # -------- inference details -------------- if gallery_image_list[i + j].inference_log: log = gallery_image_list[i + j].inference_log # data_repo.get_inference_log_from_uuid(gallery_image_list[i + j].inference_log.uuid) if log: @@ -337,38 +339,40 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de with st.expander("Prompt Details", expanded=open_detailed_view_for_all): st.info(f"**Prompt:** {prompt}\n\n**Model:** {model}") - if "last_shot_number" not in st.session_state: - st.session_state["last_shot_number"] = 0 - if view not in ["explorer", "shortlist"]: - if view == "individual_shot": - shot_name = shot.name - else: - shot_name = st.selectbox('Add to shot:', shot_names, key=f"current_shot_sidebar_selector_{gallery_image_list[i + j].uuid}",index=st.session_state["last_shot_number"]) - - if shot_name != "": - if shot_name == "**Create New Shot**": - shot_name = st.text_input("New shot name:", max_chars=40, key=f"shot_name_{gallery_image_list[i+j].uuid}") - if st.button("Create new shot", key=f"create_new_{gallery_image_list[i + j].uuid}", use_container_width=True): - new_shot = add_new_shot(project_uuid, name=shot_name) - add_key_frame(gallery_image_list[i + j], False, new_shot.uuid, len(data_repo.get_timing_list_from_shot(new_shot.uuid)), refresh_state=False) - # removing this from the gallery view - data_repo.update_file(gallery_image_list[i + j].uuid, tag="") - st.rerun() - - else: - if st.button(f"Add to shot", key=f"add_{gallery_image_list[i + j].uuid}", help="Promote this variant to the primary image", use_container_width=True): - shot_number = shot_names.index(shot_name) + 1 - st.session_state["last_shot_number"] = shot_number - 1 - shot_uuid = shot_list[shot_number - 2].uuid - - add_key_frame(gallery_image_list[i + j], False, shot_uuid, len(data_repo.get_timing_list_from_shot(shot_uuid)), refresh_state=False) - # removing this from the gallery view - data_repo.update_file(gallery_image_list[i + j].uuid, tag="") - st.rerun() else: st.warning("No inference data") else: st.warning("No data found") + + # ---------- add to shot btn --------------- + if "last_shot_number" not in st.session_state: + st.session_state["last_shot_number"] = 0 + if view not in ["explorer", "shortlist"]: + if view == "individual_shot": + shot_name = shot.name + else: + shot_name = st.selectbox('Add to shot:', shot_names, key=f"current_shot_sidebar_selector_{gallery_image_list[i + j].uuid}",index=st.session_state["last_shot_number"]) + + if shot_name != "": + if shot_name == "**Create New Shot**": + shot_name = st.text_input("New shot name:", max_chars=40, key=f"shot_name_{gallery_image_list[i+j].uuid}") + if st.button("Create new shot", key=f"create_new_{gallery_image_list[i + j].uuid}", use_container_width=True): + new_shot = add_new_shot(project_uuid, name=shot_name) + add_key_frame(gallery_image_list[i + j], False, new_shot.uuid, len(data_repo.get_timing_list_from_shot(new_shot.uuid)), refresh_state=False) + # removing this from the gallery view + data_repo.update_file(gallery_image_list[i + j].uuid, tag="") + st.rerun() + + else: + if st.button(f"Add to shot", key=f"add_{gallery_image_list[i + j].uuid}", help="Promote this variant to the primary image", use_container_width=True): + shot_number = shot_names.index(shot_name) + 1 + st.session_state["last_shot_number"] = shot_number - 1 + shot_uuid = shot_list[shot_number - 2].uuid + + add_key_frame(gallery_image_list[i + j], False, shot_uuid, len(data_repo.get_timing_list_from_shot(shot_uuid)), refresh_state=False) + # removing this from the gallery view + data_repo.update_file(gallery_image_list[i + j].uuid, tag="") + refresh_app(maintain_state=True) st.markdown("***") else: diff --git a/ui_components/setup.py b/ui_components/setup.py index 09401807..c58400c1 100644 --- a/ui_components/setup.py +++ b/ui_components/setup.py @@ -99,7 +99,12 @@ def setup_app_ui(): reset_project_state() st.session_state["project_uuid"] = project_list[selected_index].uuid - check_project_meta_data(st.session_state["project_uuid"]) + if 'maintain_state' not in st.session_state: + st.session_state["maintain_state"] = False + + if not st.session_state["maintain_state"]: + check_project_meta_data(st.session_state["project_uuid"]) + update_app_setting_keys() if 'shot_uuid' not in st.session_state: diff --git a/ui_components/widgets/add_key_frame_element.py b/ui_components/widgets/add_key_frame_element.py index f461efb2..cfd29ab8 100644 --- a/ui_components/widgets/add_key_frame_element.py +++ b/ui_components/widgets/add_key_frame_element.py @@ -7,6 +7,7 @@ from ui_components.widgets.image_zoom_widgets import zoom_inputs from utils import st_memory +from utils.common_utils import refresh_app from utils.data_repo.data_repo import DataRepo @@ -119,4 +120,4 @@ def add_key_frame(selected_image: Union[Image.Image, InternalFileObject], inheri st.session_state['section_index'] = 0 if refresh_state: - st.rerun() \ No newline at end of file + refresh_app(maintain_state=True) \ No newline at end of file diff --git a/ui_components/widgets/frame_movement_widgets.py b/ui_components/widgets/frame_movement_widgets.py index 6c6b5ecb..4a8351d8 100644 --- a/ui_components/widgets/frame_movement_widgets.py +++ b/ui_components/widgets/frame_movement_widgets.py @@ -3,6 +3,7 @@ from ui_components.constants import WorkflowStageType from ui_components.methods.common_methods import add_image_variant, promote_image_variant, save_and_promote_image from ui_components.models import InternalFrameTimingObject +from utils.common_utils import refresh_app from utils.constants import ImageStage from utils.data_repo.data_repo import DataRepo @@ -60,7 +61,7 @@ def move_frame_back_button(timing_uuid, orientation): arrow = "⬆️" if st.button(arrow, key=f"move_frame_back_{timing_uuid}", help="Move frame back", use_container_width=True): move_frame(direction, timing_uuid) - st.rerun() + refresh_app(maintain_state=True) def move_frame_forward_button(timing_uuid, orientation): @@ -72,7 +73,7 @@ def move_frame_forward_button(timing_uuid, orientation): if st.button(arrow, key=f"move_frame_forward_{timing_uuid}", help="Move frame forward", use_container_width=True): move_frame(direction, timing_uuid) - st.rerun() + refresh_app(maintain_state=True) def delete_frame_button(timing_uuid, show_label=False): diff --git a/utils/cache/cache.py b/utils/cache/cache.py index 851973d8..93593e40 100644 --- a/utils/cache/cache.py +++ b/utils/cache/cache.py @@ -11,6 +11,11 @@ class CacheKey(ExtendedEnum): LOGGED_USER = "logged_user" FILE = "file" SHOT = "shot" + # temp items (only cached for speed boost) + LOG = 'log' + LOG_PAGES = 'log_pages' + PROJECT = 'project' + USER = 'user' class StCache: diff --git a/utils/cache/cache_methods.py b/utils/cache/cache_methods.py index aae7ff76..de5af2ef 100644 --- a/utils/cache/cache_methods.py +++ b/utils/cache/cache_methods.py @@ -1,6 +1,7 @@ import uuid from shared.logging.logging import AppLogger from utils.cache.cache import CacheKey, StCache +import streamlit as st logger = AppLogger() @@ -721,4 +722,84 @@ def _cache_duplicate_shot(self, *args, **kwargs): setattr(cls, '_original_duplicate_shot', cls.duplicate_shot) setattr(cls, "duplicate_shot", _cache_duplicate_shot) + # ---------------------- APPROXIMATE METHODS --------------------- + ''' + these methods output whatever is last cached in them, irrespective of the input/query params + these are only used in cases when there are minor changes in the app and we want to maintain + the last state (check 'maintain_state' inside the refresh_app function) + ''' + def _cache_get_all_inference_log_list(self, *args, **kwargs): + if 'maintain_state' in st.session_state and st.session_state['maintain_state']: + log_list = StCache.get_all(CacheKey.LOG.value) + if log_list and len(log_list): + return log_list, st.session_state['log_pages_approx'] + + original_func = getattr(cls, '_original_get_all_inference_log_list') + output_log_list, total_pages = original_func(self, *args, **kwargs) + if output_log_list and len(output_log_list): + StCache.delete_all(CacheKey.LOG.value) + StCache.add_all(output_log_list, CacheKey.LOG.value) + st.session_state['log_pages_approx'] = total_pages + + return output_log_list, total_pages + + setattr(cls, '_original_get_all_inference_log_list', cls.get_all_inference_log_list) + setattr(cls, "get_all_inference_log_list", _cache_get_all_inference_log_list) + + def _cache_get_project_from_uuid(self, *args, **kwargs): + if 'maintain_state' in st.session_state and st.session_state['maintain_state']: + project_list = StCache.get_all(CacheKey.PROJECT.value) + if project_list and len(project_list): + for project in project_list: + if str(project.uuid) == str(args[0]): + return project + + original_func = getattr(cls, '_original_get_project_from_uuid') + output_project = original_func(self, *args, **kwargs) + if output_project: + StCache.add(output_project, CacheKey.PROJECT.value) + + return output_project + + setattr(cls, '_original_get_project_from_uuid', cls.get_project_from_uuid) + setattr(cls, "get_project_from_uuid", _cache_get_project_from_uuid) + + def _cache_get_all_project_list(self, *args, **kwargs): + if 'maintain_state' in st.session_state and st.session_state['maintain_state']: + project_list = StCache.get_all(CacheKey.PROJECT.value) + if project_list and len(project_list): + res = [] + for project in project_list: + if str(project.user_uuid) == str(kwargs['user_id']): + res.append(project) + + if len(res): + return res + + original_func = getattr(cls, '_original_get_all_project_list') + output_project_list = original_func(self, *args, **kwargs) + if output_project_list: + StCache.add_all(output_project_list, CacheKey.PROJECT.value) + + return output_project_list + + setattr(cls, '_original_get_all_project_list', cls.get_all_project_list) + setattr(cls, "get_all_project_list", _cache_get_all_project_list) + + def _cache_get_all_user_list(self, *args, **kwargs): + if 'maintain_state' in st.session_state and st.session_state['maintain_state']: + user_list = StCache.get_all(CacheKey.USER.value) + if user_list and len(user_list): + return user_list + + original_func = getattr(cls, '_original_get_all_user_list') + user_list = original_func(self, *args, **kwargs) + if user_list and len(user_list): + StCache.add_all(user_list, CacheKey.USER.value) + + return user_list + + setattr(cls, '_original_get_all_user_list', cls.get_all_user_list) + setattr(cls, "get_all_user_list", _cache_get_all_user_list) + return cls \ No newline at end of file diff --git a/utils/common_utils.py b/utils/common_utils.py index 5c2a4240..aa70e8ef 100644 --- a/utils/common_utils.py +++ b/utils/common_utils.py @@ -8,7 +8,7 @@ import json from shared.constants import SERVER, ServerType from ui_components.models import InternalUserObject -from utils.cache.cache import StCache +from utils.cache.cache import CacheKey, StCache from utils.data_repo.data_repo import DataRepo from ui_components.constants import DefaultProjectSettingParams @@ -171,7 +171,8 @@ def reset_project_state(): "seed", "promote_new_generation", "use_new_settings", - "shot_uuid" + "shot_uuid", + "maintain_state" ] for k in keys_to_delete: @@ -242,3 +243,8 @@ def release_lock(key): data_repo = DataRepo() data_repo.release_lock(key) return True + + +def refresh_app(maintain_state=False): + st.session_state['maintain_state'] = maintain_state + st.rerun() \ No newline at end of file diff --git a/utils/data_repo/data_repo.py b/utils/data_repo/data_repo.py index 406b5694..ec4c234a 100644 --- a/utils/data_repo/data_repo.py +++ b/utils/data_repo/data_repo.py @@ -7,7 +7,6 @@ from shared.logging.logging import AppLogger from ui_components.models import InferenceLogObject, InternalAIModelObject, InternalAppSettingObject, InternalBackupObject, InternalFrameTimingObject, InternalProjectObject, InternalFileObject, InternalSettingObject, InternalShotObject, InternalUserObject from utils.cache.cache_methods import cache_data -import wrapt from utils.data_repo.api_repo import APIRepo From 758d91c5f105d89e7838765d972e6b905bbc973b Mon Sep 17 00:00:00 2001 From: piyushK52 Date: Sun, 24 Dec 2023 22:32:28 +0530 Subject: [PATCH 11/11] dimension issue fix --- ui_components/components/explorer_page.py | 2 -- ui_components/widgets/shot_view.py | 3 ++- utils/ml_processor/replicate/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ui_components/components/explorer_page.py b/ui_components/components/explorer_page.py index 72a59094..efd1a581 100644 --- a/ui_components/components/explorer_page.py +++ b/ui_components/components/explorer_page.py @@ -33,9 +33,7 @@ def columnn_selecter(): st_memory.number_input('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer") def explorer_page(project_uuid): - data_repo = DataRepo() - project_setting = data_repo.get_project_setting(project_uuid) st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]") diff --git a/ui_components/widgets/shot_view.py b/ui_components/widgets/shot_view.py index 5e0eb472..68eae2f8 100644 --- a/ui_components/widgets/shot_view.py +++ b/ui_components/widgets/shot_view.py @@ -15,6 +15,7 @@ from ui_components.models import InternalFrameTimingObject, InternalShotObject from ui_components.widgets.add_key_frame_element import add_key_frame,add_key_frame_section from ui_components.widgets.frame_movement_widgets import change_frame_shot, delete_frame_button, jump_to_single_frame_view_button, move_frame_back_button, move_frame_forward_button, replace_image_widget +from utils.common_utils import refresh_app from utils.data_repo.data_repo import DataRepo from utils import st_memory @@ -312,7 +313,7 @@ def timeline_view_buttons(idx, shot_uuid, replace_image_widget_toggle, copy_fram if st.button("🔁", key=f"copy_frame_{timing_list[idx].uuid}", use_container_width=True): pil_image = generate_pil_image(timing_list[idx].primary_image.location) add_key_frame(pil_image, False, st.session_state['shot_uuid'], timing_list[idx].aux_frame_index+1, refresh_state=False) - st.rerun() + refresh_app(maintain_state=True) if delete_frames_toggle: with btn4: diff --git a/utils/ml_processor/replicate/utils.py b/utils/ml_processor/replicate/utils.py index dcbfa338..bce7b5e5 100644 --- a/utils/ml_processor/replicate/utils.py +++ b/utils/ml_processor/replicate/utils.py @@ -72,8 +72,8 @@ def get_model_params_from_query_obj(model, query_obj: MLQueryObject): data = { "prompt" : query_obj.prompt, "negative_prompt" : query_obj.negative_prompt, - "width" : max(768, query_obj.width), # 768 is the default for sdxl - "height" : max(768, query_obj.height), + "width" : 768 if query_obj.width == 512 else 1024, # 768 is the default for sdxl + "height" : 768 if query_obj.height == 512 else 1024, "prompt_strength": query_obj.strength, "mask": mask }