Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/staging'
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushK52 committed Dec 24, 2023
2 parents 51732c6 + 29814a3 commit e8f5ed9
Show file tree
Hide file tree
Showing 24 changed files with 297 additions and 114 deletions.
2 changes: 2 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def main():

from ui_components.setup import setup_app_ui
setup_app_ui()

st.session_state['maintain_state'] = False

if __name__ == '__main__':
try:
Expand Down
11 changes: 9 additions & 2 deletions ui_components/components/adjust_shot_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@
from ui_components.widgets.shot_view import shot_keyframe_element
from ui_components.components.explorer_page import gallery_image_view
from ui_components.components.explorer_page import generate_images_element
from ui_components.widgets.frame_selector import frame_selector_widget
from ui_components.widgets.frame_selector import frame_selector_widget, frame_view
from utils import st_memory
from utils.data_repo.data_repo import DataRepo



def adjust_shot_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings):
def adjust_shot_page(shot_uuid: str, h2):
data_repo = DataRepo()
shot = data_repo.get_shot_from_uuid(shot_uuid)

with h2:
frame_selector_widget(show=['shot_selector'])

st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]")

st.markdown("***")

with st.sidebar:
frame_view(view='Video')

shot_keyframe_element(st.session_state["shot_uuid"], 4, position="Individual")
# with st.expander("📋 Explorer Shortlist",expanded=True):
shot_explorer_view = st_memory.menu('',["Shortlist", "Explore"],
Expand Down
10 changes: 8 additions & 2 deletions ui_components/components/animate_shot_page.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import streamlit as st
from ui_components.widgets.frame_selector import frame_selector_widget
from ui_components.widgets.frame_selector import frame_selector_widget, frame_view
from ui_components.widgets.variant_comparison_grid import variant_comparison_grid
from ui_components.widgets.animation_style_element import animation_style_element
from utils.data_repo.data_repo import DataRepo

def animate_shot_page(shot_uuid: str,h2,data_repo,shot,timing_list, project_settings):
def animate_shot_page(shot_uuid: str, h2):
data_repo = DataRepo()
shot = data_repo.get_shot_from_uuid(shot_uuid)

with h2:
frame_selector_widget(show=['shot_selector'])
with st.sidebar:
frame_view()

st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}] > :orange[{shot.name}]")
st.markdown("***")
variant_comparison_grid(st.session_state['shot_uuid'], stage="Shots")
Expand Down
70 changes: 36 additions & 34 deletions ui_components/components/explorer_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ui_components.methods.file_methods import generate_pil_image
from ui_components.methods.ml_methods import query_llama2
from ui_components.widgets.add_key_frame_element import add_key_frame
from utils.common_utils import refresh_app
from utils.constants import MLQueryObject
from utils.data_repo.data_repo import DataRepo
from shared.constants import QUEUE_INFERENCE_QUERIES, AIModelType, InferenceType, InternalFileTag, InternalFileType, SortOrder
Expand All @@ -27,14 +28,12 @@ class InputImageStyling(ExtendedEnum):
def columnn_selecter():
f1, f2 = st.columns([1, 1])
with f1:
st_memory.slider('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer")
st_memory.number_input('Number of columns:', min_value=3, max_value=7, value=4,key="num_columns_explorer")
with f2:
st_memory.slider('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer")
st_memory.number_input('Items per page:', min_value=10, max_value=50, value=16, key="num_items_per_page_explorer")

def explorer_page(project_uuid):

data_repo = DataRepo()

project_setting = data_repo.get_project_setting(project_uuid)

st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['page']}]")
Expand Down Expand Up @@ -121,7 +120,7 @@ def generate_images_element(position='explorer', project_uuid=None, timing_uuid=

with b3:
edge_pil_img = None
strength_of_current_image = st_memory.slider("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_current_image_key", help="This will determine how much of the current image will be kept in the final image.")
strength_of_current_image = st_memory.number_input("What % of the current image would you like to keep?", min_value=0, max_value=100, value=50, step=1, key="strength_of_current_image_key", help="This will determine how much of the current image will be kept in the final image.")
if type_of_transformation == InputImageStyling.EVOLVE_IMAGE.value:
prompt_strength = round(1 - (strength_of_current_image / 100), 2)
with c2:
Expand Down Expand Up @@ -326,7 +325,8 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de
st.success("Added To Shortlist")
time.sleep(0.3)
st.rerun()


# -------- inference details --------------
if gallery_image_list[i + j].inference_log:
log = gallery_image_list[i + j].inference_log # data_repo.get_inference_log_from_uuid(gallery_image_list[i + j].inference_log.uuid)
if log:
Expand All @@ -337,38 +337,40 @@ def gallery_image_view(project_uuid,page_number=1,num_items_per_page=20, open_de
with st.expander("Prompt Details", expanded=open_detailed_view_for_all):
st.info(f"**Prompt:** {prompt}\n\n**Model:** {model}")

if "last_shot_number" not in st.session_state:
st.session_state["last_shot_number"] = 0
if view not in ["explorer", "shortlist"]:
if view == "individual_shot":
shot_name = shot.name
else:
shot_name = st.selectbox('Add to shot:', shot_names, key=f"current_shot_sidebar_selector_{gallery_image_list[i + j].uuid}",index=st.session_state["last_shot_number"])

if shot_name != "":
if shot_name == "**Create New Shot**":
shot_name = st.text_input("New shot name:", max_chars=40, key=f"shot_name_{gallery_image_list[i+j].uuid}")
if st.button("Create new shot", key=f"create_new_{gallery_image_list[i + j].uuid}", use_container_width=True):
new_shot = add_new_shot(project_uuid, name=shot_name)
add_key_frame(gallery_image_list[i + j], False, new_shot.uuid, len(data_repo.get_timing_list_from_shot(new_shot.uuid)), refresh_state=False)
# removing this from the gallery view
data_repo.update_file(gallery_image_list[i + j].uuid, tag="")
st.rerun()

else:
if st.button(f"Add to shot", key=f"add_{gallery_image_list[i + j].uuid}", help="Promote this variant to the primary image", use_container_width=True):
shot_number = shot_names.index(shot_name) + 1
st.session_state["last_shot_number"] = shot_number - 1
shot_uuid = shot_list[shot_number - 2].uuid

add_key_frame(gallery_image_list[i + j], False, shot_uuid, len(data_repo.get_timing_list_from_shot(shot_uuid)), refresh_state=False)
# removing this from the gallery view
data_repo.update_file(gallery_image_list[i + j].uuid, tag="")
st.rerun()
else:
st.warning("No inference data")
else:
st.warning("No data found")

# ---------- add to shot btn ---------------
if "last_shot_number" not in st.session_state:
st.session_state["last_shot_number"] = 0
if view not in ["explorer", "shortlist"]:
if view == "individual_shot":
shot_name = shot.name
else:
shot_name = st.selectbox('Add to shot:', shot_names, key=f"current_shot_sidebar_selector_{gallery_image_list[i + j].uuid}",index=st.session_state["last_shot_number"])

if shot_name != "":
if shot_name == "**Create New Shot**":
shot_name = st.text_input("New shot name:", max_chars=40, key=f"shot_name_{gallery_image_list[i+j].uuid}")
if st.button("Create new shot", key=f"create_new_{gallery_image_list[i + j].uuid}", use_container_width=True):
new_shot = add_new_shot(project_uuid, name=shot_name)
add_key_frame(gallery_image_list[i + j], False, new_shot.uuid, len(data_repo.get_timing_list_from_shot(new_shot.uuid)), refresh_state=False)
# removing this from the gallery view
data_repo.update_file(gallery_image_list[i + j].uuid, tag="")
st.rerun()

else:
if st.button(f"Add to shot", key=f"add_{gallery_image_list[i + j].uuid}", help="Promote this variant to the primary image", use_container_width=True):
shot_number = shot_names.index(shot_name) + 1
st.session_state["last_shot_number"] = shot_number - 1
shot_uuid = shot_list[shot_number - 2].uuid

add_key_frame(gallery_image_list[i + j], False, shot_uuid, len(data_repo.get_timing_list_from_shot(shot_uuid)), refresh_state=False)
# removing this from the gallery view
data_repo.update_file(gallery_image_list[i + j].uuid, tag="")
refresh_app(maintain_state=True)

st.markdown("***")
else:
Expand Down
34 changes: 15 additions & 19 deletions ui_components/components/frame_styling_page.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,31 @@
import streamlit as st
from shared.constants import ViewType
from streamlit_option_menu import option_menu

from ui_components.widgets.cropping_element import cropping_selector_element
from ui_components.widgets.frame_selector import frame_selector_widget
from ui_components.widgets.frame_selector import frame_selector_widget, frame_view
from ui_components.widgets.add_key_frame_element import add_key_frame, add_key_frame_element
from ui_components.widgets.timeline_view import timeline_view
from ui_components.components.explorer_page import generate_images_element
from ui_components.widgets.animation_style_element import animation_style_element
from ui_components.widgets.video_cropping_element import video_cropping_element
from ui_components.widgets.inpainting_element import inpainting_element
from ui_components.widgets.drawing_element import drawing_element
from ui_components.widgets.sidebar_logger import sidebar_logger
# from ui_components.components.explorer_page import explorer_element,gallery_image_view
from ui_components.widgets.variant_comparison_grid import variant_comparison_grid
from ui_components.widgets.shot_view import shot_keyframe_element
from utils import st_memory


from ui_components.constants import CreativeProcessType, DefaultProjectSettingParams, DefaultTimingStyleParams

from ui_components.constants import CreativeProcessType
from utils.data_repo.data_repo import DataRepo


def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings):
def frame_styling_page(shot_uuid: str, h2):
data_repo = DataRepo()
shot = data_repo.get_shot_from_uuid(shot_uuid)
timing_list = data_repo.get_timing_list_from_shot(shot_uuid)


if len(timing_list) == 0:
with h2:
frame_selector_widget(show=['shot_selector','frame_selector'])

st.markdown("#### There are no frames present in this shot yet.")




else:
with st.sidebar:
with h2:
Expand All @@ -46,12 +38,16 @@ def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_se
menu_icon="cast", default_index=st.session_state.get('styling_view_index', 0), \
key="styling_view_selector", orientation="horizontal", \
styles={"nav-link": {"font-size": "15px", "margin": "0px", "--hover-color": "#eee"}, "nav-link-selected": {"background-color": "orange"}})

frame_view(view="Key Frame")

st.markdown(f"#### :red[{st.session_state['main_view_type']}] > :green[{st.session_state['frame_styling_view_type']}] > :orange[{st.session_state['styling_view']}] > :blue[{shot.name} - #{st.session_state['current_frame_index']}]")


variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value)

st.markdown("***")
if st.session_state['styling_view'] == "Generate":
variant_comparison_grid(st.session_state['current_frame_uuid'], stage=CreativeProcessType.STYLING.value)

with st.expander("🛠️ Generate Variants + Prompt Settings", expanded=True):
generate_images_element(position='individual', project_uuid=shot.project.uuid, timing_uuid=st.session_state['current_frame_uuid'])

Expand All @@ -65,6 +61,6 @@ def frame_styling_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_se

elif st.session_state['styling_view'] == "Scribbling":
with st.expander("📝 Draw On Image", expanded=True):
drawing_element(timing_list,project_settings, shot_uuid)
drawing_element(shot_uuid)


46 changes: 46 additions & 0 deletions ui_components/components/query_logger_page.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
import streamlit as st
from ui_components.constants import DefaultTimingStyleParams
from utils.common_utils import get_current_user

from utils.data_repo.data_repo import DataRepo

def query_logger_page():
st.header("Inference Log list")

data_repo = DataRepo()
current_user = get_current_user()
b1, b2 = st.columns([1, 1])

total_log_table_pages = st.session_state['total_log_table_pages'] if 'total_log_table_pages' in st.session_state else DefaultTimingStyleParams.total_log_table_pages
page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1)
inference_log_list, total_page_count = data_repo.get_all_inference_log_list(
page=page_number,
data_per_page=100
)

if total_log_table_pages != total_page_count:
st.session_state['total_log_table_pages'] = total_page_count
st.rerun()

data = {
'Project': [],
'Prompt': [],
'Model': [],
'Inference time (sec)': [],
'Cost ($)': [],
'Status': []
}

for log in inference_log_list:
data['Project'].append(log.project.name)
prompt = json.loads(log.input_params).get('prompt', '') if log.input_params else ''
data['Prompt'].append(prompt)
model_name = json.loads(log.output_details).get('model_name', '') if log.output_details else ''
data['Model'].append(model_name)
data['Inference time (sec)'].append(round(log.total_inference_time, 3))
data['Cost ($)'].append(round(log.total_inference_time * 0.004, 3))
data['Status'].append(log.status)


st.table(data=data)
7 changes: 5 additions & 2 deletions ui_components/components/timeline_view_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from ui_components.components.explorer_page import gallery_image_view
from streamlit_option_menu import option_menu
from utils import st_memory
from utils.data_repo.data_repo import DataRepo

def timeline_view_page(shot_uuid: str, h2):
data_repo = DataRepo()
shot = data_repo.get_shot_from_uuid(shot_uuid)

def timeline_view_page(shot_uuid: str, h2,data_repo,shot,timing_list, project_settings):
with st.sidebar:

views = CreativeProcessType.value_list()

if "view" not in st.session_state:
Expand Down
1 change: 1 addition & 0 deletions ui_components/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DefaultTimingStyleParams:
animation_tool = AnimationToolType.G_FILM.value
animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value
model = None
total_log_table_pages = 1

class DefaultProjectSettingParams:
batch_prompt = ""
Expand Down
36 changes: 32 additions & 4 deletions ui_components/methods/file_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,49 @@ def load_from_env(key):
val = get_key(dotenv_path=ENV_FILE_PATH, key_to_get=key)
return val

import zipfile
import os
import requests
from PIL import Image
from io import BytesIO

def zip_images(image_locations, zip_filename='images.zip'):
# Calculate the number of digits needed for padding
num_digits = len(str(len(image_locations) - 1))

with zipfile.ZipFile(zip_filename, 'w') as zip_file:
for idx, image_location in enumerate(image_locations):
# image_name = os.path.basename(image_location)
image_name = f"{idx}.png"
# Pad the index with zeros
padded_idx = str(idx).zfill(num_digits)
image_name = f"{padded_idx}.png"

if image_location.startswith('http'):
response = requests.get(image_location)
image_data = response.content
zip_file.writestr(image_name, image_data)

# Open the image for inspection and possible conversion
with Image.open(BytesIO(image_data)) as img:
if img.mode != 'RGB':
img = img.convert('RGB')

# Save the potentially converted image to a byte stream
with BytesIO() as output:
img.save(output, format='PNG')
zip_file.writestr(image_name, output.getvalue())
else:
zip_file.write(image_location, image_name)
# For local files, open, possibly convert, and then add to zip
with Image.open(image_location) as img:
if img.mode != 'RGB':
img = img.convert('RGB')

img.save(image_name, format='PNG')
zip_file.write(image_name, image_name)
os.remove(image_name) # Clean up the temporary file

return zip_filename



def create_duplicate_file(file: InternalFileObject, project_uuid=None) -> InternalFileObject:
data_repo = DataRepo()

Expand Down
2 changes: 1 addition & 1 deletion ui_components/methods/ml_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def inpainting(input_image: str, prompt, negative_prompt, timing_uuid, mask_in_p
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
strength=1.0,
strength=0.99,
queue_inference=QUEUE_INFERENCE_QUERIES
)

Expand Down
Loading

0 comments on commit e8f5ed9

Please sign in to comment.