Skip to content

Commit

Permalink
Some simple camera controls
Browse files Browse the repository at this point in the history
  • Loading branch information
mhochsteger committed Oct 28, 2024
1 parent 5d3840d commit 53d51ba
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 9 deletions.
115 changes: 111 additions & 4 deletions webgpu/input_handler.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,135 @@
import js
import numpy as np

from .utils import create_proxy


class Transform:
def __init__(self):
self._mat = np.identity(4)
self._rot_mat = np.identity(4)
self._center = (0.5, 0.5, 0)
self._scale = 1

def translate(self, dx=0.0, dy=0.0, dz=0.0):
translation = np.array(
[[1, 0, 0, dx], [0, 1, 0, dy], [0, 0, 1, dz], [0, 0, 0, 1]]
)
self._mat = translation @ self._mat

def scale(self, s):
self._scale *= s

def rotate(self, ang_x, ang_y=0):
rx = np.radians(ang_x)
cx = np.cos(rx)
sx = np.sin(rx)

rotation_x = np.array(
[
[1, 0, 0, 0],
[0, cx, -sx, 0],
[0, sx, cx, 0],
[0, 0, 0, 1],
]
)

ry = np.radians(ang_y)
cy = np.cos(ry)
sy = np.sin(ry)
rotation_y = np.array(
[
[cy, 0, sy, 0],
[0, 1, 0, 0],
[-sy, 0, cy, 0],
[0, 0, 0, 1],
]
)

self._rot_mat = rotation_x @ rotation_y @ self._rot_mat

@property
def mat(self):
return self._mat @ self._rot_mat @ self._scale_mat @ self._center_mat

@property
def _center_mat(self):
cx, cy, cz = self._center
return np.array([[1, 0, 0, -cx], [0, 1, 0, -cy], [0, 0, 1, -cz], [0, 0, 0, 1]])

@property
def _scale_mat(self):
s = self._scale
return np.array([[s, 0, 0, 0], [0, s, 0, 0], [0, 0, s, 0], [0, 0, 0, 1]])


class InputHandler:
def __init__(self, canvas, uniforms, render_function=None):
self.canvas = canvas
self.uniforms = uniforms
self.render_function = render_function
self._is_moving = False
self._is_rotating = False

self._callbacks = {}
self.register_callbacks()
self.transform = Transform()

self.transform.scale(2)

self._update_uniforms()

def _update_uniforms(self):
near = 1
far = 100
proj_mat = np.array(
[
[near, 0, 0, 0],
[0, near, 0, 0],
[0, 0, -(near + far) / (far - near), -2 * near * far / (far - near)],
[0, 0, -1, 0],
]
)

view_mat = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, -3], [0, 0, 0, 1]])

def on_mousedown(self, _):
self._is_moving = True
mat = proj_mat @ view_mat @ self.transform.mat
mat = mat.transpose()
mat = mat.flatten()
for i in range(16):
self.uniforms.mat[i] = mat[i]

def _render(self):
self._update_uniforms()
if self.render_function:
js.requestAnimationFrame(self.render_function)

def on_mousedown(self, ev):
if ev.button == 0:
self._is_rotating = True
if ev.button == 1:
self._is_moving = True

def on_mouseup(self, _):
global _is_moving
self._is_moving = False
self._is_rotating = False
self._is_zooming = False

def on_mousewheel(self, ev):
self.transform.scale(1 - ev.deltaY / 1000)
self._render()

def on_mousemove(self, ev):
if self._is_rotating:
s = 0.3
self.transform.rotate(s * ev.movementY, s * ev.movementX)
self._render()
if self._is_moving:
self.uniforms.mat[12] += ev.movementX / self.canvas.width * 1.8
self.uniforms.mat[13] -= ev.movementY / self.canvas.height * 1.8
s = 0.01
self.transform.translate(s * ev.movementX, -s * ev.movementY)
self._render()

if self.render_function:
js.requestAnimationFrame(self.render_function)

Expand All @@ -46,6 +152,7 @@ def register_callbacks(self):
self.on("mousedown", self.on_mousedown)
self.on("mouseup", self.on_mouseup)
self.on("mousemove", self.on_mousemove)
self.on("wheel", self.on_mousewheel)

def __del__(self):
self.unregister_callbacks()
Expand Down
1 change: 1 addition & 0 deletions webgpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def render(time):
# )

render_function = create_proxy(render)
gpu.input_handler._update_uniforms()
gpu.input_handler.render_function = render_function

render_function.request_id = js.requestAnimationFrame(render_function)
Expand Down
17 changes: 12 additions & 5 deletions webgpu/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ struct VertexOutput1d {
@builtin(position) fragPosition: vec4<f32>,
@location(0) p: vec3<f32>,
@location(1) lam: f32,
@location(2) id: u32,
@location(2) @interpolate(flat) id: u32,
};

struct VertexOutput2d {
@builtin(position) fragPosition: vec4<f32>,
@location(0) p: vec3<f32>,
@location(1) lam: vec2<f32>,
@location(2) id: u32,
@location(2) @interpolate(flat) id: u32,
};

struct VertexOutput3d {
@builtin(position) fragPosition: vec4<f32>,
@location(0) p: vec3<f32>,
@location(1) lam: vec3<f32>,
@location(2) id: u32,
@location(2) @interpolate(flat) id: u32,
};

fn calcPosition(p: vec3<f32>) -> vec4<f32> {
Expand Down Expand Up @@ -114,12 +114,19 @@ fn mainVertexTrigP1Indexed(@builtin(vertex_index) vertexId: u32, @builtin(instan
}

@fragment
fn mainFragmentTrig(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) id: u32) -> @location(0) vec4<f32> {
fn mainFragmentTrig(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) @interpolate(flat) id: u32) -> @location(0) vec4<f32> {
checkClipping(p);
let value = evalTrig(id, 0u, lam);
return getColor(value);
}

@fragment
fn mainFragmentTrigMesh(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) @interpolate(flat) id: u32) -> @location(0) vec4<f32> {
checkClipping(p);
let value = id;
return vec4<f32>(0., 1.0, 0.0, 1.0);
}

@fragment
fn mainFragmentEdge(@location(0) p: vec3<f32>) -> @location(0) vec4<f32> {
checkClipping(p);
Expand All @@ -146,7 +153,7 @@ fn mainFragmentDeferred(@builtin(position) coord: vec4<f32>) -> @location(0) vec


@fragment
fn mainFragmentTrigToGBuffer(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) id: u32) -> @location(0) vec4<f32> {
fn mainFragmentTrigToGBuffer(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) @interpolate(flat) id: u32) -> @location(0) vec4<f32> {
checkClipping(p);
let value = evalTrig(id, 0u, lam);
return vec4<f32>(bitcast<f32>(id), lam, 0.0);
Expand Down

0 comments on commit 53d51ba

Please sign in to comment.