Skip to content

Commit

Permalink
Merge pull request #58 from banodoco/green-head
Browse files Browse the repository at this point in the history
Green head
  • Loading branch information
piyushK52 authored Dec 22, 2023
2 parents e600224 + 849d5bd commit c74a521
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 104 deletions.
9 changes: 5 additions & 4 deletions backend/db_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,7 @@ def get_timing_list_from_project(self, project_uuid=None):
return InternalResponse(payload, 'timing list fetched', True)

def get_timing_list_from_shot(self, shot_uuid):

shot: Shot = Shot.objects.filter(uuid=shot_uuid, is_disabled=False).first()
if not shot:
return InternalResponse({}, 'invalid shot', False)
Expand Down Expand Up @@ -847,13 +848,13 @@ def create_timing(self, **kwargs):

attributes._data['canny_image_id'] = canny_image.id

if 'primay_image_id' in attributes.data:
if attributes.data['primay_image_id'] != None:
primay_image: InternalFileObject = InternalFileObject.objects.filter(uuid=attributes.data['primay_image_id'], is_disabled=False).first()
if 'primary_image_id' in attributes.data:
if attributes.data['primary_image_id'] != None:
primay_image: InternalFileObject = InternalFileObject.objects.filter(uuid=attributes.data['primary_image_id'], is_disabled=False).first()
if not primay_image:
return InternalResponse({}, 'invalid primary image uuid', False)

attributes._data['primay_image_id'] = primay_image.id
attributes._data['primary_image_id'] = primay_image.id

timing = Timing.objects.create(**attributes.data)
payload = {
Expand Down
2 changes: 1 addition & 1 deletion backend/serializers/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class CreateTimingDao(serializers.Serializer):
mask_id = serializers.CharField(max_length=100, required=False)
canny_image_id = serializers.CharField(max_length=100, required=False)
shot_id = serializers.CharField(max_length=100)
primary_image = serializers.CharField(max_length=100, required=False)
primary_image_id = serializers.CharField(max_length=100, required=False)
alternative_images = serializers.CharField(max_length=100, required=False)
notes = serializers.CharField(max_length=1024, required=False)
aux_frame_index = serializers.IntegerField(required=False)
Expand Down
44 changes: 33 additions & 11 deletions ui_components/components/frame_styling_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@


def frame_styling_page(shot_uuid: str):

data_repo = DataRepo()
shot = data_repo.get_shot_from_uuid(shot_uuid)

timing_list = data_repo.get_timing_list_from_shot(shot_uuid)

project_settings = data_repo.get_project_setting(shot.project.uuid)

if "strength" not in st.session_state:
Expand All @@ -39,7 +42,7 @@ def frame_styling_page(shot_uuid: str):
st.session_state['num_inference_steps'] = DefaultProjectSettingParams.batch_num_inference_steps
st.session_state['transformation_stage'] = DefaultProjectSettingParams.batch_transformation_stage

if "current_frame_uuid" not in st.session_state:
if "current_frame_uuid" not in st.session_state and len(timing_list) > 0:
timing = data_repo.get_timing_list_from_shot(shot_uuid)[0]
st.session_state['current_frame_uuid'] = timing.uuid
st.session_state['current_frame_index'] = timing.aux_frame_index + 1
Expand Down Expand Up @@ -89,6 +92,7 @@ def frame_styling_page(shot_uuid: str):
if st.session_state['page'] == CreativeProcessType.MOTION.value:

with st.sidebar:

if 'shot_view_manual_select' not in st.session_state:
st.session_state['shot_view_manual_select'] = None

Expand Down Expand Up @@ -123,27 +127,43 @@ def frame_styling_page(shot_uuid: str):
elif st.session_state['shot_view'] == "Adjust Frames":
st.markdown("***")
shot_keyframe_element(shot_uuid, 4, position="Individual")
with st.expander("📋 Explorer Shortlist",expanded=True):

# with st.expander("📋 Explorer Shortlist",expanded=True):
shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"],
icons=['grid-3x3','airplane'],
menu_icon="cast",
default_index=st.session_state.get('shot_explorer_view', 0),
key="shot_explorer_view", orientation="horizontal",
styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#868c91"}})
st.markdown("***")
if shot_explorer_view == "Shortlist":
project_setting = data_repo.get_project_setting(shot.project.uuid)
page_number = st.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True)

page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True)
gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=True, num_columns=4,view="individual_shot", shot=shot)
elif shot_explorer_view == "Explore":
project_setting = data_repo.get_project_setting(shot.project.uuid)
page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True)
generate_images_element(position='explorer', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid'])
st.markdown("***")
gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=8, open_detailed_view_for_all=False, shortlist=False, num_columns=4,view="individual_shot", shot=shot)
#with st.expander("🤏 Crop, Move & Rotate Image", expanded=True):
# video_cropping_element(shot_uuid)

elif st.session_state['page'] == CreativeProcessType.STYLING.value:




with st.sidebar:


with st.sidebar:

st.session_state['styling_view'] = st_memory.menu('',\
["Generate", "Crop/Move", "Inpainting","Scribbling"], \
icons=['magic', 'crop', "paint-bucket", 'pencil'], \
menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \
key="styling_view_selector", orientation="horizontal", \
styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "#0068c9"}})
frame_selector_widget()


if st.session_state['styling_view'] == "Generate":
variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value)
Expand All @@ -161,18 +181,20 @@ def frame_styling_page(shot_uuid: str):
elif st.session_state['styling_view'] == "Scribbling":
with st.expander("📝 Draw On Image", expanded=True):
drawing_element(timing_list,project_settings, shot_uuid)
with st.sidebar:
frame_selector_widget()



# -------------------- TIMELINE VIEW --------------------------
elif st.session_state['frame_styling_view_type'] == "Timeline":
if st.session_state['page'] == "Key Frames":

with st.sidebar:
with st.expander("📋 Explorer Shortlist",expanded=True):

if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"):
project_setting = data_repo.get_project_setting(shot.project.uuid)
page_number = st.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True)
page_number = st.radio("Select page:", options=range(1, project_setting.total_shortlist_gallery_pages + 1), horizontal=True)
gallery_image_view(shot.project.uuid, page_number=page_number, num_items_per_page=10, open_detailed_view_for_all=False, shortlist=True, num_columns=2,view="sidebar")

timeline_view(shot_uuid, "Key Frames")
Expand All @@ -182,7 +204,7 @@ def frame_styling_page(shot_uuid: str):
# -------------------- SIDEBAR NAVIGATION --------------------------
with st.sidebar:


with st.expander("🔍 Generation Log", expanded=True):
if st_memory.toggle("Open", value=True, key="generaton_log_toggle"):
sidebar_logger(shot_uuid)
Expand Down
3 changes: 2 additions & 1 deletion ui_components/methods/common_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,8 @@ def process_inference_output(**kwargs):
name=filename,
type=InternalFileType.IMAGE.value,
hosted_url=output[0] if isinstance(output, list) else output,
inference_log_id=log.uuid
inference_log_id=log.uuid,
project_id=timing.shot.project.uuid,
)

add_image_variant(output_file.uuid, timing_uuid)
Expand Down
6 changes: 4 additions & 2 deletions ui_components/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def setup_app_ui():
update_app_setting_keys()

if 'shot_uuid' not in st.session_state:
shot_list = data_repo.get_shot_list(st.session_state["project_uuid"])
shot_list = data_repo.get_shot_list(st.session_state["project_uuid"])
st.session_state['shot_uuid'] = shot_list[0].uuid

# print uuids of shots

if "current_frame_index" not in st.session_state:
st.session_state['current_frame_index'] = 1
Expand Down Expand Up @@ -179,7 +181,7 @@ def setup_app_ui():

if st.session_state["manual_select"] != None:
st.session_state["manual_select"] = None

frame_styling_page(st.session_state["shot_uuid"])

elif st.session_state["main_view_type"] == "Tools & Settings":
Expand Down
29 changes: 20 additions & 9 deletions ui_components/widgets/add_key_frame_element.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
from typing import Union
import streamlit as st
from shared.constants import AnimationStyleType
from ui_components.constants import CreativeProcessType, WorkflowStageType
from ui_components.models import InternalFileObject
from ui_components.models import InternalFileObject, InternalFrameTimingObject
from ui_components.widgets.image_zoom_widgets import zoom_inputs

from utils import st_memory
Expand All @@ -11,7 +12,7 @@

from utils.constants import ImageStage
from ui_components.methods.file_methods import generate_pil_image,save_or_host_file
from ui_components.methods.common_methods import apply_image_transformations, clone_styling_settings, create_frame_inside_shot, save_uploaded_image
from ui_components.methods.common_methods import add_image_variant, apply_image_transformations, clone_styling_settings, create_frame_inside_shot, save_new_image, save_uploaded_image
from PIL import Image


Expand Down Expand Up @@ -83,20 +84,30 @@ def add_key_frame(selected_image: Union[Image.Image, InternalFileObject], inheri
len_shot_timing_list = len(timing_list) if len(timing_list) > 0 else 0
target_frame_position = len_shot_timing_list if target_frame_position is None else target_frame_position
target_aux_frame_index = min(len(timing_list), target_frame_position)
_ = create_frame_inside_shot(shot_uuid, target_aux_frame_index)

timing_list = data_repo.get_timing_list_from_shot(shot_uuid)
# updating the newly created frame timing

save_uploaded_image(selected_image, shot.project.uuid, timing_list[target_aux_frame_index].uuid, WorkflowStageType.SOURCE.value)
save_uploaded_image(selected_image, shot.project.uuid, timing_list[target_aux_frame_index].uuid, WorkflowStageType.STYLED.value)
if isinstance(selected_image, InternalFileObject):
saved_image = selected_image
else:
saved_image = save_new_image(selected_image, shot.project.uuid)

timing_data = {
"shot_id": shot_uuid,
"animation_style": AnimationStyleType.CREATIVE_INTERPOLATION.value,
"aux_frame_index": target_aux_frame_index,
"source_image_id": saved_image.uuid,
"primary_image_id": saved_image.uuid,
}
timing: InternalFrameTimingObject = data_repo.create_timing(**timing_data)

add_image_variant(saved_image.uuid, timing.uuid)

timing_list = data_repo.get_timing_list_from_shot(shot_uuid)
if update_cur_frame_idx:
# this part of code updates current_frame_index when a new keyframe is added
if inherit_styling_settings == "Yes" and st.session_state['current_frame_index']:
clone_styling_settings(st.session_state['current_frame_index'] - 1, timing_list[target_aux_frame_index-1].uuid)

if len(timing_list) == 1:
if len(timing_list) <= 1:
st.session_state['current_frame_index'] = 1
st.session_state['current_frame_uuid'] = timing_list[0].uuid
else:
Expand Down
Loading

0 comments on commit c74a521

Please sign in to comment.