diff --git a/deployment/InstantNGP/taichi_ngp/new_kernels.py b/deployment/InstantNGP/taichi_ngp/new_kernels.py deleted file mode 100644 index 46db9ae..0000000 --- a/deployment/InstantNGP/taichi_ngp/new_kernels.py +++ /dev/null @@ -1,18 +0,0 @@ -import taichi as ti -from taichi.math import vec3 - -@ti.kernel -def get_rays(pose: ti.types.ndarray(), - directions: ti.types.ndarray(), - rays_o: ti.types.ndarray(), - rays_d: ti.types.ndarray()): - #print(directions.shape) - for i in ti.ndrange(directions.shape[0]): - #for i in range(10): - c2w = pose[None] - mat_result = directions[i] @ c2w[:, :3].transpose() - ray_d = vec3(mat_result[0, 0], mat_result[0, 1], mat_result[0, 2]) - ray_o = c2w[:, 3] - - rays_o[i] = ray_o - rays_d[i] = ray_d diff --git a/deployment/InstantNGP/taichi_ngp/taichi_ngp.py b/deployment/InstantNGP/taichi_ngp/taichi_ngp.py index fc46231..cdc633e 100644 --- a/deployment/InstantNGP/taichi_ngp/taichi_ngp.py +++ b/deployment/InstantNGP/taichi_ngp/taichi_ngp.py @@ -5,28 +5,28 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), "../../../")) -from modules.intersection import ray_aabb_intersect +from modules.intersection import ray_aabb_intersect, get_rays_test_kernel import numpy as np import taichi as ti from matplotlib import pyplot as plt -from kernels import args, np_type, data_type,\ +from .utils import args, np_type, data_type,\ rotate_scale, reset,\ - ray_intersect, raymarching_test_kernel,\ - rearange_index, hash_encode,\ - sigma_rgb_layer, composite_test,\ + raymarching_test_kernel,\ + re_arrange_index, radiance_field, composite_test,\ re_order, fill_ndarray,\ init_current_index, rotate_scale,\ initialize, load_deployment_model,\ cascades, grid_size, scale, \ NGP_res, NGP_N_rays, NGP_min_samples -from new_kernels import get_rays -ti.init(arch=ti.vulkan, - enable_fallback=False, - debug=False, - kernel_profiler=False) +ti.init( + arch=ti.vulkan, + enable_fallback=False, + debug=False, + kernel_profiler=False +) ######################### # Compile for AOT files # @@ -89,11 +89,11 @@ def prepare_aot_files(model): m = ti.aot.Module( caps=['spirv_has_int8', 'spirv_has_int16', 'spirv_has_float16']) m.add_kernel(reset) - m.add_kernel(ray_intersect) + m.add_field(get_rays_test_kernel) + m.add_kernel(ray_aabb_intersect) m.add_kernel(raymarching_test_kernel) - m.add_kernel(rearange_index) - m.add_kernel(hash_encode) - m.add_kernel(sigma_rgb_layer) + m.add_kernel(re_arrange_index) + m.add_kernel(radiance_field) m.add_kernel(composite_test) m.add_kernel(re_order) m.add_kernel(fill_ndarray) @@ -123,16 +123,30 @@ def update_model_weights(model): -def run_inference(max_samples, - T_threshold, - dist_to_focus=0.8, - len_dis=0.0) -> Tuple[float, int, int]: +def run_inference( + max_samples, + T_threshold, +) -> Tuple[float, int, int]: samples = 0 #rotate_scale(NGP_pose, 0.5, 0.5, 0.0, 2.5) - reset(NGP_counter, NGP_alive_indices, NGP_opacity, NGP_rgb) - - get_rays(NGP_pose, NGP_directions, NGP_rays_o, NGP_rays_d) - ray_aabb_intersect(NGP_hits_t, NGP_rays_o, NGP_rays_d, scale) + reset( + NGP_counter, + NGP_alive_indices, + NGP_opacity, + NGP_rgb + ) + get_rays_test_kernel( + NGP_pose, + NGP_directions, + NGP_rays_o, + NGP_rays_d + ) + ray_aabb_intersect( + NGP_hits_t, + NGP_rays_o, + NGP_rays_d, + scale + ) while samples < max_samples: N_alive = NGP_counter[0] @@ -144,23 +158,61 @@ def run_inference(max_samples, samples += N_samples launch_model_total = N_alive * N_samples - raymarching_test_kernel(NGP_counter, NGP_density_bitfield, NGP_hits_t, - NGP_alive_indices, NGP_rays_o, NGP_rays_d, - NGP_current_index, NGP_xyzs, NGP_dirs, - NGP_deltas, NGP_ts, NGP_run_model_ind, - NGP_N_eff_samples, N_samples) - rearange_index(NGP_model_launch, NGP_padd_block_network, NGP_temp_hit, - NGP_run_model_ind, launch_model_total) - hash_encode(NGP_hash_embedding, NGP_model_launch, NGP_xyzs, NGP_dirs, - NGP_deltas, NGP_xyzs_embedding, NGP_temp_hit) - sigma_rgb_layer(NGP_sigma_weights, NGP_rgb_weights, NGP_model_launch, - NGP_padd_block_network, NGP_xyzs_embedding, NGP_dirs, - NGP_out_1, NGP_out_3, NGP_temp_hit) - - composite_test(NGP_counter, NGP_alive_indices, NGP_rgb, NGP_opacity, - NGP_current_index, NGP_deltas, NGP_ts, NGP_out_3, - NGP_out_1, NGP_N_eff_samples, N_samples, T_threshold) - re_order(NGP_counter, NGP_alive_indices, NGP_current_index, N_alive) + raymarching_test_kernel( + NGP_counter, + NGP_density_bitfield, + NGP_hits_t, + NGP_alive_indices, + NGP_rays_o, + NGP_rays_d, + NGP_current_index, + NGP_xyzs, + NGP_dirs, + NGP_deltas, + NGP_ts, + NGP_run_model_ind, + NGP_N_eff_samples, + N_samples, + ) + re_arrange_index( + NGP_model_launch, + NGP_padd_block_network, + NGP_temp_hit, + NGP_run_model_ind, + launch_model_total, + ) + radiance_field( + NGP_hash_embedding, + NGP_model_launch, + NGP_xyzs, + NGP_sigma_weights, + NGP_rgb_weights, + NGP_padd_block_network, + NGP_dirs, + NGP_out_1, + NGP_out_3, + NGP_temp_hit, + ) + composite_test( + NGP_counter, + NGP_alive_indices, + NGP_rgb, + NGP_opacity, + NGP_current_index, + NGP_deltas, + NGP_ts, + NGP_out_3, + NGP_out_1, + NGP_N_eff_samples, + N_samples, + T_threshold, + ) + re_order( + NGP_counter, + NGP_alive_indices, + NGP_current_index, + N_alive + ) return samples, N_alive, N_samples @@ -168,8 +220,10 @@ def run_inference(max_samples, def inference_local(n=1): for _ in range(n): - samples, N_alive, N_samples = run_inference(max_samples=100, T_threshold=1e-2) - + _, _, _ = run_inference( + max_samples=100, + T_threshold=1e-2 + ) ti.sync() # Show inferenced image @@ -177,7 +231,6 @@ def inference_local(n=1): plt.imshow((rgb_np * 255).astype(np.uint8)) plt.show() - if __name__ == '__main__': model = load_deployment_model(args.model_path) initialize() @@ -230,7 +283,7 @@ def inference_local(n=1): # model parameters sigma_layer1_base = 16 * 16 layer1_base = 32 * 16 - NGP_hash_embedding = ti.ndarray(dtype=data_type, shape=(17956864, )) + NGP_hash_embedding = ti.ndarray(dtype=data_type, shape=(11176096, )) NGP_sigma_weights = ti.ndarray(dtype=data_type, shape=(sigma_layer1_base + 16 * 16, )) NGP_rgb_weights = ti.ndarray(dtype=data_type, diff --git a/deployment/InstantNGP/taichi_ngp/kernels.py b/deployment/InstantNGP/taichi_ngp/utils.py similarity index 74% rename from deployment/InstantNGP/taichi_ngp/kernels.py rename to deployment/InstantNGP/taichi_ngp/utils.py index fc62d9d..1c85ac0 100644 --- a/deployment/InstantNGP/taichi_ngp/kernels.py +++ b/deployment/InstantNGP/taichi_ngp/utils.py @@ -1,8 +1,8 @@ -import taichi as ti import os -import numpy as np -import argparse import wget +import argparse +import numpy as np +import taichi as ti from taichi.math import uvec3 def parse_arguments(): @@ -222,7 +222,7 @@ def fill_ndarray( @ti.kernel -def rearange_index( +def re_arrange_index( NGP_model_launch: ti.types.ndarray(ti.i32, ndim=0), NGP_padd_block_network: ti.types.ndarray(ti.i32, ndim=0), NGP_temp_hit: ti.types.ndarray(ti.i32, ndim=1), @@ -233,7 +233,7 @@ def rearange_index( if NGP_run_model_ind[i]: index = ti.atomic_add(NGP_model_launch[None], 1) NGP_temp_hit[index] = i - + NGP_run_model_ind[i] = 0 NGP_model_launch[None] += 1 NGP_padd_block_network[None] = ( (NGP_model_launch[None] + block_dim - 1) // block_dim) * block_dim @@ -259,67 +259,25 @@ def re_order(counter: ti.types.ndarray(ti.i32, ndim=1), # Taichi NGP Kernels # ###################### # Most of these kernels are modified from the training code -@ti.func -def _ray_aabb_intersec(ray_o, ray_d): - inv_d = 1.0 / ray_d - - t_min = (NGP_center - NGP_half_size - ray_o) * inv_d - t_max = (NGP_center + NGP_half_size - ray_o) * inv_d - - _t1 = ti.min(t_min, t_max) - _t2 = ti.max(t_min, t_max) - t1 = _t1.max() - t2 = _t2.min() - - return tf_vec2(t1, t2) - - -@ti.kernel -def ray_intersect( - counter: ti.types.ndarray(ti.i32, ndim=1), - NGP_pose: ti.types.ndarray(dtype=ti.types.matrix(3, 4, - dtype=data_type), - ndim=0), - NGP_directions: ti.types.ndarray(dtype=ti.types.matrix( - 1, 3, dtype=data_type), - ndim=1), - NGP_hits_t: ti.types.ndarray(dtype=tf_vec2, ndim=1), - NGP_rays_o: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_rays_d: ti.types.ndarray(dtype=tf_vec3, ndim=1), -): - for i in ti.ndrange(counter[0]): - c2w = NGP_pose[None] - mat_result = NGP_directions[i] @ c2w[:, :3].transpose() - ray_d = tf_vec3(mat_result[0, 0], mat_result[0, 1], mat_result[0, 2]) - ray_o = c2w[:, 3] - - t1t2 = _ray_aabb_intersec(ray_o, ray_d) - - if t1t2[1] > 0.0: - NGP_hits_t[i][0] = data_type(ti.max(t1t2[0], NEAR_DISTANCE)) - NGP_hits_t[i][1] = t1t2[1] - - NGP_rays_o[i] = ray_o - NGP_rays_d[i] = ray_d # Modified from "modules/ray_march.py" @ti.kernel def raymarching_test_kernel( - counter: ti.types.ndarray(ti.i32, ndim=1), - NGP_density_bitfield: ti.types.ndarray(dtype=ti.u32, ndim=1), - NGP_hits_t: ti.types.ndarray(dtype=tf_vec2, ndim=1), - NGP_alive_indices: ti.types.ndarray(dtype=ti.i32, ndim=1), - NGP_rays_o: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_rays_d: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_current_index: ti.types.ndarray(dtype=ti.i32, ndim=0), - NGP_xyzs: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_dirs: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_deltas: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_ts: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_run_model_ind: ti.types.ndarray(dtype=ti.i32, ndim=1), - NGP_N_eff_samples: ti.types.ndarray(dtype=ti.i32, - ndim=1), N_samples: int): - + counter: ti.types.ndarray(ti.i32, ndim=1), + NGP_density_bitfield: ti.types.ndarray(dtype=ti.u32, ndim=1), + NGP_hits_t: ti.types.ndarray(dtype=tf_vec2, ndim=1), + NGP_alive_indices: ti.types.ndarray(dtype=ti.i32, ndim=1), + NGP_rays_o: ti.types.ndarray(dtype=tf_vec3, ndim=1), + NGP_rays_d: ti.types.ndarray(dtype=tf_vec3, ndim=1), + NGP_current_index: ti.types.ndarray(dtype=ti.i32, ndim=0), + NGP_xyzs: ti.types.ndarray(dtype=tf_vec3, ndim=1), + NGP_dirs: ti.types.ndarray(dtype=tf_vec3, ndim=1), + NGP_deltas: ti.types.ndarray(dtype=data_type, ndim=1), + NGP_ts: ti.types.ndarray(dtype=data_type, ndim=1), + NGP_run_model_ind: ti.types.ndarray(dtype=ti.i32, ndim=1), + NGP_N_eff_samples: ti.types.ndarray(dtype=ti.i32, ndim=1), + N_samples: int, +): for n in ti.ndrange(counter[0]): c_index = NGP_current_index[None] r = NGP_alive_indices[n * 2 + c_index] @@ -381,82 +339,18 @@ def raymarching_test_kernel( NGP_alive_indices[n * 2 + c_index] = -1 -# Modified from "modules/hash_encoder_deploy.py" -@ti.kernel -def hash_encode( - NGP_hash_embedding: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_model_launch: ti.types.ndarray(ti.i32, ndim=0), - NGP_xyzs: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_dirs: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_deltas: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_xyzs_embedding: ti.types.ndarray(dtype=data_type, ndim=2), - NGP_temp_hit: ti.types.ndarray(ti.i32, ndim=1), -): - for sn in ti.ndrange(NGP_model_launch[None]): - for level in ti.static(range(NGP_level)): - xyz = NGP_xyzs[NGP_temp_hit[sn]] + 0.5 - offset = NGP_offsets[level] * 4 - - init_val0 = tf_vec1(0.0) - init_val1 = tf_vec1(1.0) - local_feature_0 = init_val0[0] - local_feature_1 = init_val0[0] - local_feature_2 = init_val0[0] - local_feature_3 = init_val0[0] - - scale = NGP_base_res * ti.exp( - level * NGP_per_level_scales) - 1.0 - resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1 - - pos = xyz * scale + 0.5 - pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) - pos -= pos_grid_uint - - for idx in ti.static(range(8)): - w = init_val1[0] - pos_grid_local = uvec3(0) - - for d in ti.static(range(3)): - if (idx & (1 << d)) == 0: - pos_grid_local[d] = pos_grid_uint[d] - w *= data_type(1 - pos[d]) - else: - pos_grid_local[d] = pos_grid_uint[d] + 1 - w *= data_type(pos[d]) - - index = 0 - stride = 1 - for c_ in ti.static(range(3)): - index += pos_grid_local[c_] * stride - stride *= resolution - - local_feature_0 += data_type( - w * NGP_hash_embedding[offset + index * 4]) - local_feature_1 += data_type( - w * NGP_hash_embedding[offset + index * 4 + 1]) - local_feature_2 += data_type( - w * NGP_hash_embedding[offset + index * 4 + 2]) - local_feature_3 += data_type( - w * NGP_hash_embedding[offset + index * 4 + 3]) - - NGP_xyzs_embedding[sn, level * 4] = local_feature_0 - NGP_xyzs_embedding[sn, level * 4 + 1] = local_feature_1 - NGP_xyzs_embedding[sn, level * 4 + 2] = local_feature_2 - NGP_xyzs_embedding[sn, level * 4 + 3] = local_feature_3 - - -# Taichi implementation of MLP @ti.kernel -def sigma_rgb_layer( - NGP_sigma_weights: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_rgb_weights: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_model_launch: ti.types.ndarray(dtype=ti.i32, ndim=0), - NGP_padd_block_network: ti.types.ndarray(dtype=ti.i32, ndim=0), - NGP_xyzs_embedding: ti.types.ndarray(dtype=data_type, ndim=2), - NGP_dirs: ti.types.ndarray(dtype=tf_vec3, ndim=1), - NGP_out_1: ti.types.ndarray(dtype=data_type, ndim=1), - NGP_out_3: ti.types.ndarray(data_type, ndim=2), - NGP_temp_hit: ti.types.ndarray(ti.i32, ndim=1), +def radiance_field( + NGP_hash_embedding: ti.types.ndarray(dtype=data_type, ndim=1), + NGP_model_launch: ti.types.ndarray(ti.i32, ndim=0), + NGP_xyzs: ti.types.ndarray(dtype=tf_vec3, ndim=1), + NGP_sigma_weights: ti.types.ndarray(dtype=data_type, ndim=1), + NGP_rgb_weights: ti.types.ndarray(dtype=data_type, ndim=1), + NGP_padd_block_network: ti.types.ndarray(dtype=ti.i32, ndim=0), + NGP_dirs: ti.types.ndarray(dtype=tf_vec3, ndim=1), + NGP_out_1: ti.types.ndarray(dtype=data_type, ndim=1), + NGP_out_3: ti.types.ndarray(data_type, ndim=2), + NGP_temp_hit: ti.types.ndarray(ti.i32, ndim=1), ): ti.loop_config(block_dim=block_dim) # DO NOT REMOVE for sn in ti.ndrange(NGP_padd_block_network[None]): @@ -478,6 +372,58 @@ def sigma_rgb_layer( ti.simt.block.sync() if sn < did_launch_num: + xyzs_embedding = tf_vec32(0.0) + + for level in ti.static(range(NGP_level)): + xyz = NGP_xyzs[NGP_temp_hit[sn]] + 0.5 + offset = NGP_offsets[level] * 4 + + init_val0 = tf_vec1(0.0) + init_val1 = tf_vec1(1.0) + local_feature_0 = init_val0[0] + local_feature_1 = init_val0[0] + local_feature_2 = init_val0[0] + local_feature_3 = init_val0[0] + + scale = NGP_base_res * ti.exp( + level * NGP_per_level_scales) - 1.0 + resolution = ti.cast(ti.ceil(scale), ti.uint32) + 1 + + pos = xyz * scale + 0.5 + pos_grid_uint = ti.cast(ti.floor(pos), ti.uint32) + pos -= pos_grid_uint + + for idx in ti.static(range(8)): + w = init_val1[0] + pos_grid_local = uvec3(0) + + for d in ti.static(range(3)): + if (idx & (1 << d)) == 0: + pos_grid_local[d] = pos_grid_uint[d] + w *= data_type(1 - pos[d]) + else: + pos_grid_local[d] = pos_grid_uint[d] + 1 + w *= data_type(pos[d]) + + index = 0 + stride = 1 + for c_ in ti.static(range(3)): + index += pos_grid_local[c_] * stride + stride *= resolution + + local_feature_0 += data_type( + w * NGP_hash_embedding[offset + index * 4]) + local_feature_1 += data_type( + w * NGP_hash_embedding[offset + index * 4 + 1]) + local_feature_2 += data_type( + w * NGP_hash_embedding[offset + index * 4 + 2]) + local_feature_3 += data_type( + w * NGP_hash_embedding[offset + index * 4 + 3]) + + xyzs_embedding[level * 4] = local_feature_0 + xyzs_embedding[level * 4 + 1] = local_feature_1 + xyzs_embedding[level * 4 + 2] = local_feature_2 + xyzs_embedding[level * 4 + 3] = local_feature_3 s0 = init_val[0] s1 = init_val[0] @@ -490,8 +436,7 @@ def sigma_rgb_layer( for i in range(16): temp = init_val[0] for j in ti.static(range(16)): - temp += NGP_xyzs_embedding[sn, - j] * sigma_weight[i * 16 + j] + temp += xyzs_embedding[j] * sigma_weight[i * 16 + j] for j in ti.static(range(16)): sigma_output_val[j] += data_type(ti.max( diff --git a/deployment/InstantNGP/utils/app_fp32.cpp b/deployment/InstantNGP/utils/app_fp32.cpp index f1f5ce0..5175076 100644 --- a/deployment/InstantNGP/utils/app_fp32.cpp +++ b/deployment/InstantNGP/utils/app_fp32.cpp @@ -174,13 +174,17 @@ void App_nerf_f32::initialize(int img_width, int img_height, k_reset_[2] = opacity_; k_reset_[3] = rgb_; - k_ray_intersect_ = module_.get_kernel("ray_intersect"); - k_ray_intersect_[0] = counter_; - k_ray_intersect_[1] = pose_; - k_ray_intersect_[2] = directions_; - k_ray_intersect_[3] = hits_t_; - k_ray_intersect_[4] = rays_o_; - k_ray_intersect_[5] = rays_d_; + k_get_rays_ = module_.get_kernel("get_rays_test_kernel"); + k_get_rays_[0] = pose_; + k_get_rays_[1] = directions_; + k_get_rays_[2] = rays_o_; + k_get_rays_[3] = rays_d_; + + k_ray_intersect_ = module_.get_kernel("ray_aabb_intersect"); + k_ray_intersect_[0] = hits_t_; + k_ray_intersect_[1] = rays_o_; + k_ray_intersect_[2] = rays_d_; + k_ray_intersect_[3] = 0.5f; k_raymarching_test_kernel_ = module_.get_kernel("raymarching_test_kernel"); k_raymarching_test_kernel_[0] = counter_; @@ -197,31 +201,23 @@ void App_nerf_f32::initialize(int img_width, int img_height, k_raymarching_test_kernel_[11] = run_model_ind_; k_raymarching_test_kernel_[12] = N_eff_samples_; - k_rearange_index_ = module_.get_kernel("rearange_index"); + k_rearange_index_ = module_.get_kernel("re_arrange_index"); k_rearange_index_[0] = model_launch_; k_rearange_index_[1] = pad_block_network_; k_rearange_index_[2] = temp_hit_; k_rearange_index_[3] = run_model_ind_; - k_hash_encode_ = module_.get_kernel("hash_encode"); - k_hash_encode_[0] = hash_embedding_; - k_hash_encode_[1] = model_launch_; - k_hash_encode_[2] = xyzs_; - k_hash_encode_[3] = dirs_; - k_hash_encode_[4] = deltas_; - k_hash_encode_[5] = xyzs_embedding_; - k_hash_encode_[6] = temp_hit_; - - k_mlp_layer_ = module_.get_kernel("sigma_rgb_layer"); - k_mlp_layer_[0] = sigma_weights_; - k_mlp_layer_[1] = rgb_weights_; - k_mlp_layer_[2] = model_launch_; - k_mlp_layer_[3] = pad_block_network_; - k_mlp_layer_[4] = xyzs_embedding_; - k_mlp_layer_[5] = dirs_; - k_mlp_layer_[6] = out_1_; - k_mlp_layer_[7] = out_3_; - k_mlp_layer_[8] = temp_hit_; + k_radiance_field_ = module_.get_kernel("radiance_field"); + k_radiance_field_[0] = hash_embedding_; + k_radiance_field_[1] = model_launch_; + k_radiance_field_[2] = xyzs_; + k_radiance_field_[3] = sigma_weights_; + k_radiance_field_[4] = rgb_weights_; + k_radiance_field_[5] = pad_block_network_; + k_radiance_field_[6] = dirs_; + k_radiance_field_[7] = out_1_; + k_radiance_field_[8] = out_3_; + k_radiance_field_[9] = temp_hit_; k_composite_test_ = module_.get_kernel("composite_test"); k_composite_test_[0] = counter_; @@ -260,6 +256,7 @@ std::vector App_nerf_f32::run() { for (int n_time = 0; n_time < kRepeat; n_time += 1) { k_reset_.launch(); + k_get_rays_.launch(); k_ray_intersect_.launch(); samples = 0; @@ -283,9 +280,8 @@ std::vector App_nerf_f32::run() { k_rearange_index_[4] = launch_model_total; k_rearange_index_.launch(); - k_hash_encode_.launch(); + k_radiance_field_.launch(); - k_mlp_layer_.launch(); k_composite_test_[10] = N_samples; k_composite_test_.launch(); diff --git a/deployment/InstantNGP/utils/app_fp32.hpp b/deployment/InstantNGP/utils/app_fp32.hpp index 9a5d4c6..73b369b 100644 --- a/deployment/InstantNGP/utils/app_fp32.hpp +++ b/deployment/InstantNGP/utils/app_fp32.hpp @@ -7,11 +7,11 @@ class App_nerf_f32 { ti::Runtime runtime_; ti::AotModule module_; ti::Kernel k_reset_; + ti::Kernel k_get_rays_; ti::Kernel k_ray_intersect_; ti::Kernel k_raymarching_test_kernel_; ti::Kernel k_rearange_index_; - ti::Kernel k_hash_encode_; - ti::Kernel k_mlp_layer_; + ti::Kernel k_radiance_field_; ti::Kernel k_composite_test_; ti::Kernel k_re_order_; ti::Kernel k_fill_ndarray_; diff --git a/modules/intersection.py b/modules/intersection.py index f4ce45a..367aa28 100644 --- a/modules/intersection.py +++ b/modules/intersection.py @@ -1,9 +1,73 @@ -import taichi as ti import torch +import taichi as ti from taichi.math import vec3, vec2 - from .utils import NEAR_DISTANCE +mat3x4 = ti.types.matrix(3, 4, ti.f32) +mat1x3 = ti.types.matrix(1, 3, ti.f32) + + +@ti.func +def __get_rays(c2w, direction): + + mat_result = direction @ c2w[:, :3].transpose() + ray_d = vec3(mat_result[0, 0], mat_result[0, 1], mat_result[0, 2]) + ray_o = c2w[:, 3] + + return ray_o, ray_d + +@ti.kernel +def get_rays_test_kernel( + pose: ti.types.ndarray(dtype=mat3x4, ndim=0), + directions: ti.types.ndarray(dtype=mat1x3, ndim=1), + rays_o: ti.types.ndarray(dtype=vec3, ndim=1), + rays_d: ti.types.ndarray(dtype=vec3, ndim=1), +): + for i in ti.ndrange(directions.shape[0]): + c2w = pose[None] + direction = directions[i] + + ray_o, ray_d = __get_rays(c2w, direction) + + rays_o[i] = ray_o + rays_d[i] = ray_d + +@ti.kernel +def get_rays_train_kernel( + pose: ti.types.ndarray(dtype=mat3x4, ndim=1), + directions: ti.types.ndarray(dtype=mat1x3, ndim=1), + rays_o: ti.types.ndarray(dtype=vec3, ndim=1), + rays_d: ti.types.ndarray(dtype=vec3, ndim=1), +): + for i in ti.ndrange(directions.shape[0]): + c2w = pose[i] + direction = directions[i] + + ray_o, ray_d = __get_rays(c2w, direction) + + rays_o[i] = ray_o + rays_d[i] = ray_d + +def get_rays(directions, pose): + rays_o = torch.empty_like(directions) + rays_d = torch.empty_like(directions) + + if len(pose.shape) == 3: + get_rays_train_kernel( + pose, + directions.unsqueeze(1), + rays_o, + rays_d, + ) + else: + get_rays_test_kernel( + pose, + directions.unsqueeze(1), + rays_o, + rays_d, + ) + + return rays_o, rays_d @ti.kernel def ray_aabb_intersect( diff --git a/train.py b/train.py index d4fc1f7..88dae27 100644 --- a/train.py +++ b/train.py @@ -15,9 +15,9 @@ from gui import NGPGUI from opt import get_opts from datasets import dataset_dict -from datasets.ray_utils import get_rays from modules.networks import NGP +from modules.intersection import get_rays from modules.distortion import distortion_loss from modules.rendering import MAX_SAMPLES, render from modules.utils import depth2img, save_deployment_model