Skip to content

Commit

Permalink
refactor: Move camera setup logic to Camera class
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Jun 15, 2024
1 parent 3f08f91 commit f3eedb1
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 81 deletions.
44 changes: 44 additions & 0 deletions camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,50 @@ def __init__(self, distance=100, max_distance=500, focal_length=100):
self.max_distance = max_distance
self.focal_length = focal_length

def setup_camera_and_cards(self, image_slices, depths, sensor_width=35.0):
"""
Set up the camera matrix and the card corners in 3D space.
Args:
image_slices (list): A list of image slices.
depths (list): A list of threshold depths for each image slice.
sensor_width (float, optional): The width of the camera sensor. Defaults to 35.0.
Returns:
tuple: A tuple containing the camera matrix and a list of card corners in 3D space.
"""
num_slices = len(image_slices)
image_height, image_width, _ = image_slices[0].shape

# Calculate the focal length in pixels
focal_length_px = (image_width * self.focal_length) / sensor_width

# Set up the camera intrinsic parameters
camera_matrix = np.array([[focal_length_px, 0, image_width / 2],
[0, focal_length_px, image_height / 2],
[0, 0, 1]], dtype=np.float32)

# Set up the card corners in 3D space
card_corners_3d_list = []
# The thresholds start with 0 and end with 255. We want the closest card to be at 0.
for i in range(num_slices):
z = self.max_distance * ((255 - depths[i]) / 255.0)

# Calculate the 3D points of the card corners
card_width = (image_width * (z + self.camera_distance)) / focal_length_px
card_height = (image_height * (z + self.camera_distance)) / focal_length_px

card_corners_3d = np.array([
[-card_width / 2, -card_height / 2, z],
[card_width / 2, -card_height / 2, z],
[card_width / 2, card_height / 2, z],
[-card_width / 2, card_height / 2, z]
], dtype=np.float32)
card_corners_3d_list.append(card_corners_3d)

return camera_matrix, card_corners_3d_list


def to_json(self):
return {
'position': self._camera_position.tolist(),
Expand Down
17 changes: 5 additions & 12 deletions components.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from controller import AppState, CompositeMode
from utils import to_image_url, filename_add_version, find_square_bounding_box
from inpainting import patch_image, create_inpainting_pipeline
from segmentation import setup_camera_and_cards, render_view, remove_mask_from_alpha
from segmentation import render_view, remove_mask_from_alpha
from stabilityai import StabilityAI


Expand Down Expand Up @@ -1502,12 +1502,9 @@ def make_navigation_callbacks(app):
Input(C.NAV_ZOOM_IN, 'n_clicks'),
Input(C.NAV_ZOOM_OUT, 'n_clicks'),
State(C.STORE_APPSTATE_FILENAME, 'data'),
State(C.SLIDER_CAMERA_DISTANCE, 'value'),
State(C.SLIDER_FOCAL_LENGTH, 'value'),
State(C.SLIDER_MAX_DISTANCE, 'value'),
State(C.LOGS_DATA, 'data'),
prevent_initial_call=True)
def navigate_image(reset, up, down, left, right, zoom_in, zoom_out, filename, camera_distance, focal_length, max_distance, logs):
def navigate_image(reset, up, down, left, right, zoom_in, zoom_out, filename, logs):
if filename is None:
raise PreventUpdate()

Expand All @@ -1526,7 +1523,7 @@ def navigate_image(reset, up, down, left, right, zoom_in, zoom_out, filename, ca

if nav_clicked == C.NAV_RESET:
camera_position = np.array(
[0, 0, -camera_distance], dtype=np.float32)
[0, 0, -state.camera.camera_distance], dtype=np.float32)
else:
# Move the camera position based on the navigation button clicked
# The distance should be configurable
Expand All @@ -1542,13 +1539,9 @@ def navigate_image(reset, up, down, left, right, zoom_in, zoom_out, filename, ca
camera_position += switch[nav_clicked]

state.camera.camera_position = camera_position
state.camera.camera_distance = camera_distance
state.camera.focal_length = focal_length
state.camera.max_distance = max_distance

camera_matrix, card_corners_3d_list = setup_camera_and_cards(
state.image_slices, state.image_depths,
state.camera.camera_distance, state.camera.max_distance, state.camera.focal_length)
camera_matrix, card_corners_3d_list = state.camera.setup_camera_and_cards(
state.image_slices, state.image_depths)

image = render_view(state.image_slices, camera_matrix,
card_corners_3d_list, camera_position)
Expand Down
55 changes: 3 additions & 52 deletions segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# for exporting a 3d scene
from gltf import export_gltf
from camera import Camera


def generate_depth_map(image, model: DepthEstimationModel, progress_callback=None):
Expand Down Expand Up @@ -139,53 +140,6 @@ def create_slice_from_mask(image, mask, num_expand=50):
return masked_image


def setup_camera_and_cards(image_slices, depths, camera_distance=100.0, max_distance=100.0, focal_length=100.0, sensor_width=35.0):
"""
Set up the camera intrinsic parameters and the card corners in 3D space.
Args:
image_slices (list): A list of image slices.
depths (list): A list of threshold depths for each image slice.
camera_distance (float, optional): The distance between the camera and the cards. Defaults to 100.0.
max_distance (float, optional): The maximum distance for the cards. Defaults to 100.0.
focal_length (float, optional): The focal length of the camera. Defaults to 100.0.
sensor_width (float, optional): The width of the camera sensor. Defaults to 35.0.
Returns:
tuple: A tuple containing the camera matrix and a list of card corners in 3D space.
"""
num_slices = len(image_slices)
image_height, image_width, _ = image_slices[0].shape

# Calculate the focal length in pixels
focal_length_px = (image_width * focal_length) / sensor_width

# Set up the camera intrinsic parameters
camera_matrix = np.array([[focal_length_px, 0, image_width / 2],
[0, focal_length_px, image_height / 2],
[0, 0, 1]], dtype=np.float32)

# Set up the card corners in 3D space
card_corners_3d_list = []
# The thresholds start with 0 and end with 255. We want the closest card to be at 0.
for i in range(num_slices):
z = max_distance * ((255 - depths[i]) / 255.0)

# Calculate the 3D points of the card corners
card_width = (image_width * (z + camera_distance)) / focal_length_px
card_height = (image_height * (z + camera_distance)) / focal_length_px

card_corners_3d = np.array([
[-card_width / 2, -card_height / 2, z],
[card_width / 2, -card_height / 2, z],
[card_width / 2, card_height / 2, z],
[-card_width / 2, card_height / 2, z]
], dtype=np.float32)
card_corners_3d_list.append(card_corners_3d)

return camera_matrix, card_corners_3d_list


def render_view(image_slices, camera_matrix, card_corners_3d_list, camera_position):
"""
Render the current view of the camera.
Expand Down Expand Up @@ -384,11 +338,8 @@ def process_image(image_path, output_path, num_slices=5,
image_slices.append(slice_image)

# Set up the camera and cards
camera_distance = 100.0
max_distance = 500.0
focal_length = 100.0
camera_matrix, card_corners_3d_list = setup_camera_and_cards(
image_slices, thresholds[1:], camera_distance, max_distance, focal_length)
camera = Camera(100.0, 500.0, 100.0)
camera_matrix, card_corners_3d_list = camera.setup_camera_and_cards(image_slices, thresholds[1:])

# Render the initial view
camera_position = np.array([0, 0, -100], dtype=np.float32)
Expand Down
16 changes: 9 additions & 7 deletions test_webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
update_threshold_values, click_event,
copy_to_clipboard, export_state_as_gltf, slice_upload, update_slices)
from controller import AppState
from segmentation import setup_camera_and_cards
from utils import to_image_url
from camera import Camera
import constants as C
Expand Down Expand Up @@ -250,18 +249,21 @@ def setUp(self):
self.state.depth_filename.return_value = self.mock_depth_file
self.mock_depth_file.exists.return_value = True

self.camera = Camera(10, 100, 50)

self.state.upscaled_filename.return_value = Path("upscaled_file.png")
self.state.image_slices_filenames = [
Path(f"slice_{i}.png") for i in range(3)]
self.state.MODEL_FILE = "model.gltf"
self.state.camera = self.camera

@patch("webui.generate_depth_map")
@patch("webui.postprocess_depth_map")
@patch("webui.export_gltf")
def test_export_state_as_gltf(self, mock_export_gltf, mock_postprocess_depth_map, mock_generate_depth_map):
# Test case 1: Displacement scale is 0
camera_matrix, card_corners_3d_list = setup_camera_and_cards(
self.state.image_slices, self.state.image_depths, 10, 100, 50)
camera_matrix, card_corners_3d_list = self.camera.setup_camera_and_cards(
self.state.image_slices, self.state.image_depths)
mock_export_gltf.return_value = Path("output.gltf")

result = export_state_as_gltf(
Expand Down Expand Up @@ -294,8 +296,8 @@ def test_export_state_as_gltf(self, mock_export_gltf, mock_postprocess_depth_map
def test_export_state_as_gltf_with_displacement(
self, mock_export_gltf, mock_postprocess_depth_map, mock_generate_depth_map, mock_image_fromarray):
# Test case 2: Displacement scale is greater than 0
camera_matrix, card_corners_3d_list = setup_camera_and_cards(
self.state.image_slices, self.state.image_depths, 10, 100, 50)
camera_matrix, card_corners_3d_list = self.camera.setup_camera_and_cards(
self.state.image_slices, self.state.image_depths)

mock_export_gltf.return_value = Path("output.gltf")

Expand Down Expand Up @@ -339,8 +341,8 @@ def test_export_state_as_gltf_with_displacement(
@patch("webui.export_gltf")
def test_export_state_as_gltf_with_upscaled(self, mock_export_gltf):
# Test case 3: Upscaled slices exist
camera_matrix, card_corners_3d_list = setup_camera_and_cards(
self.state.image_slices, self.state.image_depths, 10, 100, 50)
camera_matrix, card_corners_3d_list = self.camera.setup_camera_and_cards(
self.state.image_slices, self.state.image_depths)

# Pretend the upscaled file exists
mock_upscaled_file = MagicMock()
Expand Down
16 changes: 6 additions & 10 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
analyze_depth_histogram,
generate_image_slices,
create_slice_from_mask,
setup_camera_and_cards,
export_gltf,
blend_with_alpha,
remove_mask_from_alpha,
Expand Down Expand Up @@ -1214,9 +1213,8 @@ def export_state_as_gltf(
state, filename,
camera,
displacement_scale, modelname='midas', support_dof=False):
camera_matrix, card_corners_3d_list = setup_camera_and_cards(
state.image_slices,
state.image_depths, camera.camera_distance, camera.max_distance, camera.focal_length)
camera_matrix, card_corners_3d_list = state.camera.setup_camera_and_cards(
state.image_slices, state.image_depths)

depth_filenames = []
if displacement_scale > 0:
Expand Down Expand Up @@ -1331,20 +1329,18 @@ def slice_upload(contents, filename, logs):
Input(C.BTN_EXPORT_ANIMATION, 'n_clicks'),
State(C.STORE_APPSTATE_FILENAME, 'data'),
State(C.SLIDER_NUM_FRAMES, 'value'),
State(C.SLIDER_CAMERA_DISTANCE, 'value'),
State(C.SLIDER_MAX_DISTANCE, 'value'),
State(C.SLIDER_FOCAL_LENGTH, 'value'),
State(C.LOGS_DATA, 'data'),
running=[(Output(C.BTN_EXPORT_ANIMATION, 'disabled'), True, False)],
prevent_initial_call=True)
def export_animation(n_clicks, filename, num_frames, camera_distance, max_distance, focal_length, logs):
def export_animation(n_clicks, filename, num_frames, logs):
if n_clicks is None or filename is None:
raise PreventUpdate()

state = AppState.from_cache(filename)

camera_matrix, card_corners_3d_list = setup_camera_and_cards(
state.image_slices, state.image_depths, camera_distance, max_distance, focal_length)
camera_distance = state.camera.camera_distance
camera_matrix, card_corners_3d_list = state.camera.setup_camera_and_cards(
state.image_slices, state.image_depths)

# Render the initial view
camera_position = np.array([0, 0, -camera_distance], dtype=np.float32)
Expand Down

0 comments on commit f3eedb1

Please sign in to comment.