- Neural Implicit Representation for Building Digital Twins of
- Unknown Articulated Objects
-
-
-
- Yijia Weng, Bowen Wen, Jonathan Tremblay, Valts Blukis
-
- Dieter Fox, Leonidas Guibas, Stan Birchfield
-
-
-
-
-
-
-
- Abstract
-
-
-
- We tackle the problem of building digital twins of unknown articulated objects from two RGBD scans of the object at different articulation states. We decompose the problem into two stages, each addressing distinct aspects. Our method first reconstructs object-level shape at each state, then recovers the underlying articulation model including part segmentation and joint articulations that associate the two states. By explicitly modeling point-level correspondences and exploiting cues from images, 3D reconstructions, and kinematics, our method yields more accurate and stable results compared to prior work. It also handles more than one movable part and does not rely on any object shape or structure priors.
-
-
-
-
-
-
-
-
-
-
- Demo: Interacting with the Multi-Part Digital Twin in Simulation
-
-
-
-
-
- Our reconstructed digital twin can be readily imported into simulation environments and interacted with. The above video shows an example interaction sequence in Issac Gym, played at 4x speed.
-
-
-
-
-
-
-
- Qualitative Results on PARIS Two-Part Object Dataset
-
-
-
-
-
- Visualization of reconstruction results on PARIS Two-Part Object Dataset from PARIS, PARIS* (PARIS augmented with depth supervision), and our method. We run each method 10 times with different random initializations and show typical trials with performance closest to the average.
-
-
-
-
-
-
-
- Qualitative Results on SAPIEN Multi-Part Objects
-
-
-
-
-
- Visualization of reconstruction results on SAPIEN Multi-Part Objects from PARIS*-m (PARIS augmented with depth supervision and extended to handle multiple movable parts) and our method. We run each method 10 times with different random initializations and show typical trials with performance closest to the average.
-
-
-
-
-
-
-
-
-
- Method Overview
-
-
-
-
-
- Given multi-view RGB-D observations of an articulated object at two different joint configurations, we aim to reconstruct its per-part geometry and articulation model.
- We decompose this problem into two stages with distinct focuses. Our method performs per-state object-level reconstruction in the first stage, then recovers the articulation model in the second stage. We parameterize the articulation model with a part segmentation field and per-part rigid transformations, from which we explicitly derive a point correspondence field that associates the two reconstructions. The point correspondence field can be effectively supervised with a set of losses, including consistency loss, matching loss, and collision loss, leveraging cues from 3D reconstruction, image feature matching, and kinematics.
-
-
-
-
-
-
-
-
-
-
- Acknowledgements
-
-
-
- The website template was borrowed from Michaël Gharbi.
-
-
-
-
-
-
diff --git a/js/app.js b/js/app.js
deleted file mode 100644
index bbfd90d..0000000
--- a/js/app.js
+++ /dev/null
@@ -1,40 +0,0 @@
-
-$(document).ready(function() {
- var editor = CodeMirror.fromTextArea(document.getElementById("bibtex"), {
- lineNumbers: false,
- lineWrapping: true,
- readOnly:true
- });
- $(function () {
- $('[data-toggle="tooltip"]').tooltip()
- });
-
-
-// var frameNumber = 0, // start video at frame 0
-// // lower numbers = faster playback
-// playbackConst = 500,
-// // get page height from video duration
-// setHeight = document.getElementById("main"),
-// // select video element
-// vid = document.getElementById('v0');
-// // var vid = $('#v0')[0]; // jquery option
-
-
-
-
-// // Use requestAnimationFrame for smooth playback
-// function scrollPlay(){
-// var frameNumber = window.pageYOffset/playbackConst;
-// vid.currentTime = frameNumber;
-// window.requestAnimationFrame(scrollPlay);
-// console.log('scroll');
-// }
-
-// // dynamically set the page height according to video length
-// vid.addEventListener('loadedmetadata', function() {
-// setHeight.style.height = Math.floor(vid.duration) * playbackConst + "px";
-// });
-
-
-// window.requestAnimationFrame(scrollPlay);
-});
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..a5ccf54
--- /dev/null
+++ b/main.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import argparse
+import os
+import imageio.v2 as imageio
+import cv2
+import numpy as np
+import json
+from tqdm import tqdm
+import copy
+import ruamel.yaml
+yaml = ruamel.yaml.YAML()
+import time
+import open3d as o3d
+import multiprocessing
+from os.path import join as pjoin
+from model import ArtiModel
+from utils.py_utils import set_seed, set_logging_file, list_to_array
+from utils.geometry_utils import compute_scene_bounds, mask_and_normalize_data
+
+try:
+ multiprocessing.set_start_method('spawn')
+except:
+ pass
+
+
+def run(cfg_dir, data_dir, save_dir, test_only=False, no_wandb=False, ckpt_path=None,
+ num_parts=2, denoise=False):
+
+ with open(pjoin(cfg_dir, 'config.yml'), 'r') as f:
+ cfg = yaml.load(f)
+
+ cfg['data_dir'] = data_dir
+ cfg['save_dir'] = save_dir
+ cfg['slot_num'] = num_parts
+
+ os.makedirs(save_dir, exist_ok=True)
+ set_logging_file(pjoin(save_dir, 'log.txt'))
+
+ folders = save_dir.split('/')
+ while folders[0] != 'runs':
+ folders = folders[1:]
+ exp_name = '/'.join(folders[1:])
+
+ cfg['bounding_box'] = np.array(cfg['bounding_box']).reshape(2, 3)
+
+ if 'random_seed' not in cfg:
+ cfg['random_seed'] = int(time.time()) % 200003
+
+ set_seed(cfg['random_seed'])
+
+ K = np.loadtxt(pjoin(data_dir, 'cam_K.txt')).reshape(3, 3)
+
+ keyframes = yaml.load(open(pjoin(data_dir, 'init_keyframes.yml'), 'r'))
+ keys = list(keyframes.keys())
+
+ frame_ids = []
+ timesteps = []
+ cam_in_obs = []
+ for k in keys:
+ cam_in_ob = np.array(keyframes[k]['cam_in_ob']).reshape(4, 4)
+ cam_in_obs.append(cam_in_ob)
+ timesteps.append(float(keyframes[k]['time']))
+ frame_ids.append(k.replace('frame_', ''))
+ cam_in_obs = np.array(cam_in_obs)
+ timesteps = np.array(timesteps)
+
+ max_timestep = np.max(timesteps) + 1
+ normalized_timesteps = timesteps / max_timestep
+
+ frame_names, rgbs, depths, masks = [], [], [], []
+
+ rgb_dir = pjoin(data_dir, 'color_segmented')
+ for frame_id in frame_ids:
+ rgb_file = pjoin(rgb_dir, f'{frame_id}.png')
+ rgb = imageio.imread(rgb_file)
+ rgb_wh = rgb.shape[:2]
+ depth = cv2.imread(rgb_file.replace('color_segmented', 'depth_filtered'), -1) / 1e3
+ depth_wh = depth.shape[:2]
+ if rgb_wh[0] != depth_wh[0] or rgb_wh[1] != depth_wh[1]:
+ depth = cv2.resize(depth, (rgb_wh[1], rgb_wh[0]), interpolation=cv2.INTER_NEAREST)
+
+ mask = cv2.imread(rgb_file.replace('color_segmented', 'mask'), -1)
+ if len(mask.shape) == 3:
+ mask = mask[..., 0]
+
+ frame_names.append(rgb_file)
+ rgbs.append(rgb)
+ depths.append(depth)
+ masks.append(mask)
+
+ glcam_in_obs = cam_in_obs
+
+ scene_normalization_path = pjoin(data_dir, 'scene_normalization.npz')
+ if os.path.exists(scene_normalization_path):
+ scene_info = np.load(scene_normalization_path, allow_pickle=True)
+ sc_factor, translation = scene_info['sc_factor'], scene_info['translation']
+ pcd_normalized = scene_info['pcd_normalized']
+ pcd_normalized = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pcd_normalized.astype(np.float64)))
+ else:
+ sc_factor, translation, pcd_real_scale, pcd_normalized = compute_scene_bounds(frame_names, glcam_in_obs, K,
+ use_mask=True,
+ base_dir=save_dir,
+ rgbs=np.array(rgbs),
+ depths=np.array(depths),
+ masks=np.array(masks),
+ cluster=denoise, eps=0.01,
+ min_samples=5,
+ sc_factor=None,
+ translation_cvcam=None)
+ np.savez_compressed(scene_normalization_path, sc_factor=sc_factor, translation=translation,
+ pcd_normalized=np.asarray(pcd_normalized.points))
+
+ cfg['sc_factor'] = float(sc_factor)
+ cfg['translation'] = translation
+
+ print("sc factor", sc_factor, 'translation', translation)
+
+ rgbs, depths, masks, poses = mask_and_normalize_data(np.array(rgbs), depths=np.array(depths),
+ masks=np.array(masks),
+ poses=glcam_in_obs,
+ sc_factor=cfg['sc_factor'],
+ translation=cfg['translation'])
+ cfg['sampled_frame_ids'] = np.arange(len(rgbs))
+
+ nerf = ArtiModel(cfg, frame_names=frame_names, images=rgbs, depths=depths, masks=masks,
+ poses=poses, timesteps=normalized_timesteps, K=K, build_octree_pcd=pcd_normalized,
+ use_wandb=not no_wandb, exp_name=exp_name, max_timestep=int(max_timestep),
+ test_only=test_only)
+ nerf.max_timestep = max_timestep
+
+ if ckpt_path is not None:
+ assert os.path.exists(ckpt_path)
+ else:
+ if os.path.exists(pjoin(save_dir, 'ckpt', 'model_latest.pth')):
+ ckpt_path = pjoin(save_dir, 'ckpt', 'model_latest.pth')
+
+ if ckpt_path is not None:
+ nerf.load_weights(ckpt_path)
+
+ if test_only:
+ nerf.test()
+ else:
+ nerf.initialize_correspondence()
+ corr_path = pjoin(data_dir, cfg['correspondence_name'])
+ if os.path.exists(corr_path):
+ for filename in tqdm(os.listdir(corr_path)):
+ if filename.endswith('npz'):
+ cur_corr = np.load(pjoin(corr_path, filename), allow_pickle=True)['data']
+ try:
+ cur_corr = list(cur_corr)
+ nerf.load_correspondence(cur_corr)
+ except:
+ pass
+ nerf.finalize_correspondence()
+
+ print("Start training")
+ nerf.train()
+
+ with open(pjoin(save_dir, 'config.yml'), 'w') as ff:
+ tmp = copy.deepcopy(cfg)
+ for k in tmp.keys():
+ if isinstance(tmp[k], np.ndarray):
+ tmp[k] = tmp[k].tolist()
+ yaml.dump(tmp, ff)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_dir', type=str)
+ parser.add_argument('--cfg_dir', type=str)
+ parser.add_argument('--save_dir', type=str, default=None)
+ parser.add_argument('--ckpt_path', type=str, default=None)
+ parser.add_argument('--num_parts', type=int, default=2)
+ parser.add_argument('--test_only', action='store_true')
+ parser.add_argument('--no_wandb', action='store_true')
+ parser.add_argument('--denoise', action='store_true')
+ args = parser.parse_args()
+
+ save_dir = args.save_dir if args.save_dir is not None else args.cfg_dir
+ run(args.cfg_dir, args.data_dir, save_dir, test_only=args.test_only, no_wandb=args.no_wandb,
+ ckpt_path=args.ckpt_path, num_parts=args.num_parts, denoise=args.denoise)
diff --git a/img/qual_multi_part.mp4 b/media/sim_interaction_4x.gif
similarity index 56%
rename from img/qual_multi_part.mp4
rename to media/sim_interaction_4x.gif
index 5c3ef8c..976186c 100644
Binary files a/img/qual_multi_part.mp4 and b/media/sim_interaction_4x.gif differ
diff --git a/img/teaser.png b/media/teaser.png
similarity index 100%
rename from img/teaser.png
rename to media/teaser.png
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..940a22d
--- /dev/null
+++ b/model.py
@@ -0,0 +1,1840 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+import os.path
+import cv2
+import torch
+import numpy as np
+import imageio, trimesh
+import json
+import logging
+import wandb
+import common
+import copy
+from itertools import permutations
+from tqdm import tqdm
+from os.path import join as pjoin
+from utils.nerf_utils import ray_box_intersection_batch, \
+ get_sdf_loss, get_camera_rays_np, get_pixel_coords_np, to8b
+from utils.geometry_utils import to_homo, transform_pts, OctreeManager, get_voxel_pts, \
+ DepthFuser, VoxelVisibility, VoxelSDF, sdf_voxel_from_mesh
+from network import PartArticulationNet, SHEncoder, GridEncoder, FeatureVolume, NeRFSmall
+from utils.articulation_utils import save_axis_mesh, interpret_transforms, eval_axis_and_state, read_gt as read_axis_gt
+from eval.eval_mesh import eval_CD, cluster_meshes
+
+"""
+train_loop(batch of rays): call [render], compute losses
+- train_loop_forward: similar
+
+render: call [batchify_rays], split the results into ['rgb_map'] and others
+batchify_rays: call [render_rays] in chunks, concat the results
+render_rays: sample points, run [run_network] or [run_network_for_forward_only]
+"""
+
+
+def inverse_transform(transform):
+ rot = transform['rot']
+ trans = transform['trans']
+ return {'rot': rot.T, 'trans': -np.matmul(rot.T, trans.reshape(3, 1)).reshape(-1)}
+
+
+def batchify(fn, chunk):
+ """Constructs a version of 'fn' that applies to smaller batches.
+ """
+ if chunk is None:
+ return fn
+
+ def ret(inputs):
+ return torch.cat([fn(inputs[i:i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
+
+ return ret
+
+
+def compute_near_far_and_filter_rays(cam_in_world, rays, cfg):
+ '''
+ @cam_in_world: in normalized space
+ @rays: (...,D) in camera
+ Return:
+ (-1,D+2) with near far
+ '''
+ D = rays.shape[-1]
+ rays = rays.reshape(-1, D)
+ dirs_unit = rays[:, :3] / np.linalg.norm(rays[:, :3], axis=-1).reshape(-1, 1)
+ dirs = (cam_in_world[:3, :3] @ rays[:, :3].T).T
+ origins = (cam_in_world @ to_homo(np.zeros(dirs.shape)).T).T[:, :3]
+ bounds = np.array(cfg['bounding_box']).reshape(2, 3)
+ tmin, tmax = ray_box_intersection_batch(origins, dirs, bounds)
+ tmin = tmin.data.cpu().numpy()
+ tmax = tmax.data.cpu().numpy()
+ ishit = tmin >= 0
+ near = (dirs_unit * tmin.reshape(-1, 1))[:, 2]
+ far = (dirs_unit * tmax.reshape(-1, 1))[:, 2]
+ good_rays = rays[ishit]
+ near = near[ishit]
+ far = far[ishit]
+ near = np.abs(near)
+ far = np.abs(far)
+ good_rays = np.concatenate((good_rays, near.reshape(-1, 1), far.reshape(-1, 1)), axis=-1) # (N,8+2)
+
+ return good_rays
+
+
+@torch.no_grad()
+def sample_rays_uniform(N_samples, near, far, lindisp=False, perturb=True):
+ '''
+ @near: (N_ray,1)
+ '''
+ N_ray = near.shape[0]
+ t_vals = torch.linspace(0., 1., steps=N_samples, device=near.device).reshape(1, -1)
+ if not lindisp:
+ z_vals = near * (1. - t_vals) + far * (t_vals)
+ else:
+ z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals)) # (N_ray,N_sample)
+
+ if perturb > 0.:
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ upper = torch.cat([mids, z_vals[..., -1:]], -1)
+ lower = torch.cat([z_vals[..., :1], mids], -1)
+ t_rand = torch.rand(z_vals.shape, device=far.device)
+ z_vals = lower + (upper - lower) * t_rand
+ z_vals = torch.clip(z_vals, near, far)
+
+ return z_vals.reshape(N_ray, N_samples)
+
+
+class DataLoader:
+ def __init__(self, rays, batch_size):
+ self.rays = rays
+ self.batch_size = batch_size
+ self.pos = 0
+ self.ids = torch.randperm(len(self.rays))
+
+ def __next__(self):
+ if self.pos + self.batch_size < len(self.ids):
+ self.batch_ray_ids = self.ids[self.pos:self.pos + self.batch_size]
+ out = self.rays[self.batch_ray_ids]
+ self.pos += self.batch_size
+ return out.cuda()
+
+ self.ids = torch.randperm(len(self.rays))
+ self.pos = self.batch_size
+ self.batch_ray_ids = self.ids[:self.batch_size]
+ return self.rays[self.batch_ray_ids].cuda()
+
+
+class IndexDataLoader:
+ def __init__(self, indices, batch_size):
+ self.indices = indices
+ self.batch_size = batch_size
+ self.pos = 0
+ self.ids = torch.randperm(len(self.indices))
+
+ def __next__(self):
+ if self.pos + self.batch_size < len(self.ids):
+ out = self.indices[self.ids[self.pos:self.pos + self.batch_size]]
+ self.pos += self.batch_size
+ return out
+
+ self.ids = torch.randperm(len(self.indices))
+ self.pos = self.batch_size
+ return self.indices[self.ids[:self.batch_size]]
+
+
+class ArtiModel:
+ def __init__(self, cfg, frame_names, images, depths, masks, poses, timesteps, K,
+ build_octree_pcd=None, use_wandb=True, exp_name=None, max_timestep=0,
+ test_only=False):
+ '''
+ normal_maps: use None
+ poses: opengl convention, camera pose w.r.t. object(object frame normalized to [-1, 1] or [0, 1]); z- forward, y up;
+ K: cam intrinsics
+
+ '''
+ self.cfg = cfg
+ self.frame_names = frame_names
+ self.frame_name2id = {'_'.join(frame_name.split('/')[-1].split('.')[0].split('_')[-2:]): id for id, frame_name in enumerate(frame_names)}
+ self.images = images
+ self.depths = depths
+ self.masks = masks
+ self.poses = poses
+ self.timesteps = timesteps
+ self.all_timesteps = np.unique(timesteps)
+ self.all_timesteps.sort()
+ self.all_timesteps = torch.tensor(self.all_timesteps).cuda()
+ self.max_timestep = max_timestep
+ self.cnc_timesteps = {'init': 0.0, 'last': (self.max_timestep - 1.0) / self.max_timestep}
+ assert self.cnc_timesteps['last'] == self.all_timesteps[-1]
+ self.K = K.copy()
+
+ self.load_gt()
+
+ self.build_octree_pts = np.asarray(build_octree_pcd.points).copy() # Make it pickable
+
+ self.save_dir = self.cfg['save_dir']
+
+ self.H, self.W = self.images[0].shape[:2]
+ self.tensor_K = torch.tensor(self.K, device='cuda:0', dtype=torch.float32)
+
+ self.octree_m = None
+ if self.cfg['use_octree']:
+ self.build_octree()
+
+ self.create_nerf()
+ self.create_optimizer()
+
+ self.amp_scaler = torch.cuda.amp.GradScaler(enabled=self.cfg['amp'])
+
+ self.total_step = self.cfg['n_step']
+ self.global_step = 0
+ self.freeze_recon_step = self.cfg['freeze_recon_step']
+
+ self.c2w_array = torch.tensor(poses).float().cuda()
+
+ if not test_only:
+ rays_ = {cnc_name: [] for cnc_name in self.cnc_timesteps}
+ num_rays_ = {cnc_name: 0 for cnc_name in self.cnc_timesteps}
+ pixel_to_ray_id = {cnc_name: {} for cnc_name in self.cnc_timesteps}
+ for frame_i in tqdm(range(len(self.timesteps))):
+ for cnc_name in self.cnc_timesteps:
+ if self.timesteps[frame_i] == self.cnc_timesteps[cnc_name]:
+ frame_rays, frame_pixel_to_ray_id = self.make_frame_rays(frame_i)
+ rays_[cnc_name].append(frame_rays)
+ frame_pixel_to_ray_id[np.where(frame_pixel_to_ray_id >= 0)] += num_rays_[cnc_name]
+ pixel_to_ray_id[cnc_name][frame_i] = frame_pixel_to_ray_id
+ num_rays_[cnc_name] += len(frame_rays)
+ self.pixel_to_ray_id = pixel_to_ray_id
+
+ rays_dict = {}
+ for cnc_name in self.cnc_timesteps:
+ rays_dict[cnc_name] = np.concatenate(rays_[cnc_name], axis=0)
+
+ for cnc_name in self.cnc_timesteps:
+ rays_dict[cnc_name] = torch.tensor(rays_dict[cnc_name], dtype=torch.float).cuda()
+
+ self.rays_dict = rays_dict
+
+ self.data_loader = {cnc_name: DataLoader(rays=self.rays_dict[cnc_name], batch_size=self.cfg['N_rand'])
+ for cnc_name in self.rays_dict}
+
+ self.loss_weights = {key: torch.tensor(value).float().cuda() for key, value in self.cfg['loss_weights'].items()}
+ self.loss_schedule = {} if 'loss_schedule' not in self.cfg else self.cfg['loss_schedule']
+
+ self.use_wandb = use_wandb and not test_only
+ if self.use_wandb:
+ wandb.init(project='art-nerf', name=exp_name)
+ wandb.init(config=self.cfg)
+
+ self.depth_fuser = {}
+ for cnc_name in self.cnc_timesteps:
+ cur_frame_idx = np.where(self.timesteps == self.cnc_timesteps[cnc_name])
+ self.depth_fuser[cnc_name] = DepthFuser(self.tensor_K, self.c2w_array[cur_frame_idx],
+ self.depths[cur_frame_idx].squeeze(-1),
+ self.masks[cur_frame_idx].squeeze(-1),
+ self.get_truncation(),
+ near=self.cfg['near'] * self.cfg['sc_factor'],
+ far=self.cfg['far'] * self.cfg['sc_factor'])
+ self.load_visibility_grid()
+
+ def load_visibility_grid(self):
+ self.visibility_grid = {}
+ for cnc_name in self.cnc_timesteps:
+ visibility_path = pjoin(self.cfg['data_dir'], f'{cnc_name}_visibility.npz')
+ if os.path.exists(visibility_path):
+ visibility = np.load(visibility_path, allow_pickle=True)['data']
+ else:
+ query_pts = get_voxel_pts(self.cfg['sdf_voxel_size'])
+ old_shape = tuple(query_pts.shape[:3])
+ query_pts = torch.tensor(query_pts.astype(np.float32).reshape(-1, 3)).float().cuda()
+
+ if self.octree_m is not None:
+ vox_size = self.cfg['octree_raytracing_voxel_size'] * self.cfg['sc_factor']
+ level = int(np.floor(np.log2(2.0 / vox_size)))
+
+ chunk = 160000
+ all_valid = []
+ for i in range(0, query_pts.shape[0], chunk):
+ cur_pts = query_pts[i: i + chunk]
+ center_ids = self.octree_m.get_center_ids(cur_pts, level)
+ valid = center_ids >= 0
+ all_valid.append(valid)
+ valid = torch.cat(all_valid, dim=0)
+ else:
+ valid = torch.ones(len(query_pts), dtype=bool).cuda()
+
+ flat = query_pts[valid]
+ chunk = 160000
+ observed = []
+ for i in range(0, flat.shape[0], chunk):
+ observed.append(self.depth_fuser[cnc_name].query(flat[i:i + chunk]))
+
+ observed = torch.cat(observed, dim=0)
+
+ visibility = np.zeros(len(query_pts), dtype=bool)
+ visibility[valid.cpu().numpy()] = observed.cpu().numpy()
+
+ np.savez_compressed(visibility_path, data=visibility.reshape(old_shape))
+
+ visibility = visibility.reshape(old_shape)
+
+ self.visibility_grid[cnc_name] = VoxelVisibility(visibility)
+
+ def initialize_correspondence(self):
+ self.correspondence = {cnc_name: [] for cnc_name in self.cnc_timesteps}
+ self.corr_src_id_slice = 0
+ self.corr_tgt_frame_slice = 1
+ self.corr_tgt_pixel_silce = [2, 3]
+
+ def load_correspondence(self, corr_list, downsample=10):
+
+ def rev_pixel(pixel):
+ return pixel * np.array([1, -1]).reshape(1, 2) + np.array([0, self.H - 1]).reshape(1, 2)
+
+ for corr in corr_list:
+ for order in [1, -1]:
+ src_name, tgt_name = list(corr.keys())[::order]
+ src_pixel, tgt_pixel = corr[src_name], corr[tgt_name] # smaller coords are at the top - the same index to use for images
+ src_pixel = rev_pixel(src_pixel)
+ tgt_pixel = rev_pixel(tgt_pixel)
+ # however, the coords here are - smaller at the bottom
+ cnc_name = {0: 'init', 1: 'last'}[int(src_name.split('_')[0])]
+ if src_name not in self.frame_name2id or tgt_name not in self.frame_name2id:
+ continue
+ src_frame_id = self.frame_name2id[src_name]
+ tgt_frame_id = self.frame_name2id[tgt_name]
+ src_idx = src_pixel[:, 1] * self.W + src_pixel[:, 0]
+ src_ray_ids = self.pixel_to_ray_id[cnc_name][src_frame_id][src_idx].reshape(-1, 1)
+ valid_idx = np.where(src_ray_ids >= 0)[0]
+ target_length = max(500, len(valid_idx) // downsample)
+ final_idx = np.random.permutation(valid_idx)[:target_length]
+ tgt_frame_ids = np.ones_like(src_ray_ids) * tgt_frame_id
+ cur_corr = np.concatenate([src_ray_ids, tgt_frame_ids, tgt_pixel], axis=-1)
+ self.correspondence[cnc_name].append(cur_corr[final_idx])
+
+ def finalize_correspondence(self):
+ self.correspondence = {cnc_name: None if len(corr_list) == 0 else np.concatenate(corr_list, axis=0) for cnc_name, corr_list in self.correspondence.items()}
+
+ upper_limit = self.H * self.W * len(self.frame_names) * 5
+ self.correspondence = {cnc_name: None if corr is None else torch.tensor(corr[np.random.permutation(len(corr))[:upper_limit]]).cuda() for cnc_name, corr in self.correspondence.items()}
+ self.corr_loader = {cnc_name: None if self.correspondence[cnc_name] is None else DataLoader(rays=self.correspondence[cnc_name], batch_size=self.cfg['N_rand']) for cnc_name in self.correspondence}
+
+ def plot_loss(self, loss_dict, step):
+ if self.use_wandb:
+ wandb.log(loss_dict, step=step)
+
+ def create_nerf(self, device=torch.device("cuda")):
+
+ models = {}
+ for cnc_name in self.cnc_timesteps:
+ embed_fn = GridEncoder(input_dim=3, n_levels=self.cfg['num_levels'],
+ log2_hashmap_size=self.cfg['log2_hashmap_size'],
+ desired_resolution=self.cfg['finest_res'], base_resolution=self.cfg['base_res'],
+ level_dim=self.cfg['feature_grid_dim'])
+ embed_fn = embed_fn.to(device)
+ input_ch = embed_fn.out_dim
+ models[f'{cnc_name}_embed_fn'] = embed_fn
+
+ embeddirs_fn = SHEncoder(self.cfg['multires_views'])
+ input_ch_views = embeddirs_fn.out_dim
+ models[f'{cnc_name}_embeddirs_fn'] = embeddirs_fn
+
+ model = NeRFSmall(num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64,
+ input_ch=input_ch, input_ch_views=input_ch_views).to(device)
+ model = model.to(device)
+ models[f'{cnc_name}_model'] = model
+
+ embed_bwdflow_fn = FeatureVolume(out_dim=self.cfg['feature_vol_dim'], res=self.cfg['feature_vol_res'], num_dim=3)
+ embed_bwdflow_fn = embed_bwdflow_fn.to(device)
+ models[f'{cnc_name}_embed_bwdflow_fn'] = embed_bwdflow_fn
+
+ embed_fwdflow_fn = FeatureVolume(out_dim=self.cfg['feature_vol_dim'], res=self.cfg['feature_vol_res'], num_dim=3)
+ embed_fwdflow_fn = embed_fwdflow_fn.to(device)
+ models[f'{cnc_name}_embed_fwdflow_fn'] = embed_fwdflow_fn
+
+ fwdflow_ch = self.cfg['feature_vol_dim']
+
+ if self.cfg['share_motion'] and cnc_name == 'last':
+ inv_transform = lambda: models['init_deformation_model'].get_raw_slot_transform()
+ else:
+ inv_transform = None
+ deformation_model = PartArticulationNet(device=device, feat_dim=fwdflow_ch,
+ slot_num=self.cfg['slot_num'],
+ slot_hard=self.cfg['slot_hard'],
+ gt_transform=None,
+ inv_transform=inv_transform,
+ fix_base=self.cfg.get('fix_base', True),
+ gt_joint_types=None if not self.cfg['use_gt_joint_type'] else self.gt_joint_types)
+
+ deformation_model = deformation_model.to(device)
+
+ models[f'{cnc_name}_deformation_model'] = deformation_model
+
+ self.models = models
+ print(models)
+
+ def make_frame_rays(self, frame_id):
+
+ def get_last_ray_slice_idx(rays, num):
+ if num == 1:
+ return rays.shape[-1] - 1
+ else:
+ return list(range(rays.shape[-1] - num, rays.shape[-1]))
+ mask = self.masks[frame_id, ..., 0].copy()
+
+ rays = get_camera_rays_np(self.H, self.W,
+ self.K) # [self.H, self.W, 3] We create rays frame-by-frame to save memory
+ self.ray_dir_slice = get_last_ray_slice_idx(rays, 3)
+
+ rays = np.concatenate([rays, frame_id * np.ones(self.depths[frame_id].shape)], -1) # [H, W, 18]
+ self.ray_frame_id_slice = get_last_ray_slice_idx(rays, 1)
+
+ rays = np.concatenate([rays, self.depths[frame_id]], -1) # [H, W, 7]
+ self.ray_depth_slice = get_last_ray_slice_idx(rays, 1)
+
+ ray_types = np.zeros((self.H, self.W, 1)) # 0 is good; 1 is invalid depth (uncertain)
+ invalid_depth = ((self.depths[frame_id, ..., 0] < self.cfg['near'] * self.cfg['sc_factor']) | (
+ self.depths[frame_id, ..., 0] > self.cfg['far'] * self.cfg['sc_factor'])) & (mask > 0)
+ ray_types[invalid_depth] = 1
+ rays = np.concatenate((rays, ray_types), axis=-1) # 19
+ self.ray_type_slice = get_last_ray_slice_idx(rays, 1)
+
+ rays = np.concatenate([rays, get_pixel_coords_np(self.H, self.W, self.K)], axis=-1)
+ self.ray_coords_slice = get_last_ray_slice_idx(rays, 2)
+
+ rays = np.concatenate([rays, self.images[frame_id]], -1) # [H, W, 6]
+ self.ray_rgb_slice = get_last_ray_slice_idx(rays, 3)
+
+ rays = np.concatenate([rays, self.masks[frame_id] > 0], -1) # [H, W, 8]
+ self.ray_mask_slice = get_last_ray_slice_idx(rays, 1)
+
+ rays = np.concatenate([rays, self.timesteps[frame_id] * np.ones(self.depths[frame_id].shape)], -1) # 20
+ self.ray_time_slice = get_last_ray_slice_idx(rays, 1)
+
+ n = rays.shape[-1]
+
+ dilate = 60
+ kernel = np.ones((dilate, dilate), np.uint8)
+ mask = cv2.dilate(mask, kernel, iterations=1)
+
+ if self.cfg['rays_valid_depth_only']:
+ mask[invalid_depth] = 0
+
+ vs, us = np.where(mask > 0)
+ cur_rays = rays[vs, us].reshape(-1, n)
+ cur_rays = cur_rays[cur_rays[:, self.ray_type_slice] == 0]
+ cur_rays = compute_near_far_and_filter_rays(self.poses[frame_id], cur_rays, self.cfg)
+
+ self.ray_near_slice, self.ray_far_slice = get_last_ray_slice_idx(rays, 2)
+
+ if self.cfg['use_octree']:
+ rays_o_world = (self.poses[frame_id] @ to_homo(np.zeros((len(cur_rays), 3))).T).T[:, :3]
+ rays_o_world = torch.from_numpy(rays_o_world).cuda().float()
+ rays_unit_d_cam = cur_rays[:, :3] / np.linalg.norm(cur_rays[:, :3], axis=-1).reshape(-1, 1)
+ rays_d_world = (self.poses[frame_id][:3, :3] @ rays_unit_d_cam.T).T
+ rays_d_world = torch.from_numpy(rays_d_world).cuda().float()
+
+ vox_size = self.cfg['octree_raytracing_voxel_size'] * self.cfg['sc_factor']
+ level = int(np.floor(np.log2(2.0 / vox_size)))
+ near, far, _, ray_depths_in_out = self.octree_m.ray_trace(rays_o_world, rays_d_world, level=level)
+ near = near.cpu().numpy()
+ valid = (near > 0).reshape(-1)
+ cur_rays = cur_rays[valid]
+
+ cur_ray_coords = cur_rays[:, self.ray_coords_slice] # [N, 2], x in [0, W - 1], y in [0, H - 1]
+ coords = (cur_ray_coords[:, 1] * self.W + cur_ray_coords[:, 0]).astype(np.int32)
+ pixel_to_ray_id = np.ones(self.H * self.W) * -1
+ pixel_to_ray_id[coords] = np.arange(len(coords))
+
+ return cur_rays, pixel_to_ray_id
+
+ def build_octree(self):
+
+ pts = torch.tensor(self.build_octree_pts).cuda().float() # Must be within [-1,1]
+ octree_smallest_voxel_size = self.cfg['octree_smallest_voxel_size'] * self.cfg['sc_factor']
+ finest_n_voxels = 2.0 / octree_smallest_voxel_size
+ max_level = int(np.ceil(np.log2(finest_n_voxels)))
+ octree_smallest_voxel_size = 2.0 / (2 ** max_level)
+
+ dilate_radius = int(np.ceil(self.cfg['octree_dilate_size'] / self.cfg['octree_smallest_voxel_size']))
+ dilate_radius = max(1, dilate_radius)
+ logging.info(f"Octree voxel dilate_radius:{dilate_radius}")
+ shifts = []
+ for dx in [-1, 0, 1]:
+ for dy in [-1, 0, 1]:
+ for dz in [-1, 0, 1]:
+ shifts.append([dx, dy, dz])
+ shifts = torch.tensor(shifts).cuda().long() # (27,3)
+ coords = torch.floor((pts + 1) / octree_smallest_voxel_size).long() # (N,3)
+ dilated_coords = coords.detach().clone()
+ for iter in range(dilate_radius):
+ dilated_coords = (dilated_coords[None].expand(shifts.shape[0], -1, -1) + shifts[:, None]).reshape(-1, 3)
+ dilated_coords = torch.unique(dilated_coords, dim=0)
+ pts = (dilated_coords + 0.5) * octree_smallest_voxel_size - 1
+ pts = torch.clip(pts, -1, 1)
+
+ assert pts.min() >= -1 and pts.max() <= 1
+ self.octree_m = OctreeManager(pts, max_level)
+
+ def create_optimizer(self):
+ params = []
+ for k in self.models:
+ if self.models[k] is not None:
+ params += list(self.models[k].parameters())
+
+ param_groups = [{'name': 'basic', 'params': params, 'lr': self.cfg['lrate']}]
+
+ self.optimizer = torch.optim.Adam(param_groups, betas=(0.9, 0.999), weight_decay=0, eps=1e-15)
+ self.param_groups_init = copy.deepcopy(self.optimizer.param_groups)
+
+ def load_weights(self, ckpt_path):
+ print('Reloading from', ckpt_path)
+ ckpt = torch.load(ckpt_path)
+ for key in self.models:
+ self.models[key].load_state_dict(ckpt[key])
+
+ if 'octree' in ckpt:
+ self.octree_m = OctreeManager(octree=ckpt['octree'])
+ self.optimizer.load_state_dict(ckpt['optimizer'])
+ self.global_step = ckpt['global_step']
+
+ if self.global_step >= self.freeze_recon_step:
+ self.freeze_recon()
+
+ def freeze_recon(self):
+ print("----------------freeze recon--------------")
+ for cnc_name in self.cnc_timesteps:
+ for suffix in ['model', 'embed_fn', 'embeddirs_fn']:
+ model_key = f'{cnc_name}_{suffix}'
+ if model_key in self.models:
+ for param in self.models[model_key].parameters():
+ param.requires_grad = False
+
+ def save_weights(self, output_path):
+ data = {
+ 'global_step': self.global_step,
+ 'optimizer': self.optimizer.state_dict(),
+ }
+ for key in self.models:
+ data[key] = self.models[key].state_dict()
+
+ if self.octree_m is not None:
+ data['octree'] = self.octree_m.octree
+
+ output_dir = os.path.dirname(output_path)
+ os.makedirs(output_dir, exist_ok=True)
+ torch.save(data, output_path)
+ print('Saved checkpoints at', output_path)
+ latest_path = pjoin(output_dir, 'model_latest.pth')
+ if latest_path != output_path:
+ os.system(f'cp {output_path} {latest_path}')
+
+ def schedule_lr(self):
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ init_lr = self.param_groups_init[i]['lr']
+ new_lrate = init_lr * (self.cfg['decay_rate'] ** (float(self.global_step) / self.total_step))
+ param_group['lr'] = new_lrate
+
+ def load_gt(self):
+ gt_path = pjoin(self.cfg['data_dir'], 'gt')
+ if not os.path.exists(gt_path):
+ print("No gt")
+ self.gt_dict = None
+ return
+
+ gt_joint_list = read_axis_gt(pjoin(gt_path, 'trans.json'))
+ gt_rot_list = [gt_joint['rotation'] for gt_joint in gt_joint_list]
+ gt_trans_list = [gt_joint['translation'] for gt_joint in gt_joint_list]
+
+ self.gt_joint_types = [gt_joint['type'] for gt_joint in gt_joint_list]
+
+ gt_dict = {'joint': gt_joint_list, 'rot': gt_rot_list, 'trans': gt_trans_list}
+ num_joints = len(gt_joint_list)
+
+ for gt_name in ('start', 'end'):
+ if len(gt_joint_list) > 1:
+ gt_meshes = [pjoin(gt_path, gt_name, f'{gt_name}_{mid}rotate.ply')
+ for mid in ['', 'static_'] + [f'dynamic_{i}_' for i in range(num_joints)]]
+ gt_w, gt_s, gt_d = gt_meshes[0], gt_meshes[1], gt_meshes[2:]
+ else:
+ gt_w, gt_s, gt_d = [pjoin(gt_path, gt_name, f'{gt_name}_{mid}rotate.ply')
+ for mid in ['', 'static_', 'dynamic_']]
+ gt_d = [gt_d]
+
+ gt_dict[f'mesh_{gt_name}'] = {'s': gt_s, 'd': gt_d, 'w': gt_w}
+
+ self.gt_dict = gt_dict
+
+ def get_truncation(self):
+ truncation = self.cfg['trunc']
+ truncation *= self.cfg['sc_factor']
+ return truncation
+
+ def query_full_sdf(self, cnc_name, queries):
+ sdf = self.recon_sdf_dict[cnc_name].query(queries.reshape(-1, 3)) / self.get_truncation()
+ sdf = sdf.reshape(queries.shape[:-1])
+ return sdf
+
+ def query_visibility(self, cnc_name, queries):
+ visibility = self.visibility_grid[cnc_name].query(queries.reshape(-1, 3))
+ visibility = visibility.reshape(queries.shape[:-1])
+ return visibility
+
+ def backward_flow(self, cnc_name, pts, valid_samples, training=True):
+
+ if valid_samples is None:
+ valid_samples = torch.ones((len(pts)), dtype=torch.bool, device=pts.device)
+
+ inputs_flat = pts # torch.cat([pts, timesteps], dim=-1)
+
+ embedded_bwdflow = torch.zeros((inputs_flat.shape[0], self.models[f'{cnc_name}_embed_bwdflow_fn'].out_dim),
+ device=inputs_flat.device)
+
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ embedded_bwdflow[valid_samples] = self.models[f'{cnc_name}_embed_bwdflow_fn'](
+ inputs_flat[valid_samples]).to(embedded_bwdflow.dtype)
+
+ embedded_bwdflow = embedded_bwdflow.float()
+
+ canonical_pts = []
+ bwd_attn_hard, bwd_attn_soft = [], []
+ raw_cnc, raw_slot_attn, raw_slot_sdf = [], [], []
+ all_max_attn = []
+ all_total_occ = []
+ all_non_max_occ = []
+ empty_slot_mask = []
+ canonical_pts_cand = []
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ chunk = self.cfg['netchunk']
+ for i in range(0, embedded_bwdflow.shape[0], chunk):
+ out = self.models[f'{cnc_name}_deformation_model'].back_deform(pts[i: i + chunk], embedded_bwdflow[i: i + chunk])
+ xyz_cnc = out['xyz_cnc'] # [N, S, 3]
+ num_pts, num_slots = xyz_cnc.shape[:2]
+ xyz_cnc = xyz_cnc.reshape(-1, 3)
+ raw_cnc.append(xyz_cnc)
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ embedded_fwd_cnc = self.models[f'{cnc_name}_embed_fwdflow_fn'](xyz_cnc.float()).float()
+ fwd_attn_hard, fwd_attn_raw = self.models[f'{cnc_name}_deformation_model'].forw_attn(xyz_cnc, embedded_fwd_cnc, training=training) # [N * S, S]
+
+ def pick_slot_attn(fwd_attn):
+ fwd_attn = fwd_attn.reshape(num_pts, num_slots, num_slots)
+ fwd_attn = fwd_attn[
+ torch.arange(num_pts).to(fwd_attn.device).long().reshape(-1, 1).repeat(1, num_slots), # [N, S]
+ torch.arange(num_slots).to(fwd_attn.device).long().reshape(1, -1).repeat(num_pts, 1), # [N, S]
+ torch.arange(num_slots).to(fwd_attn.device).long().reshape(1, -1).repeat(num_pts, 1)] # [N, S]
+ return fwd_attn
+
+ fwd_attn_hard = pick_slot_attn(fwd_attn_hard)
+ fwd_attn_raw = pick_slot_attn(fwd_attn_raw)
+
+ # [2] candidates, [2, 2] <-- diagonal --> [S], fwd_attn_soft
+ # point 0: prob(point 0 belongs to slot 0) prob(point 0 belongs to slot 1)
+ # point 1: prob(point 1 belongs to slot 0) prob(point 1 belongs to slot 1)
+
+ raw_slot_attn.append(fwd_attn_raw) # for future analysis
+
+ sdf = self.query_full_sdf(cnc_name, xyz_cnc.float())
+ weights_from_sdf = self.get_occ_from_full_sdf(sdf)
+
+ weights_from_sdf = weights_from_sdf.reshape(num_pts, num_slots)
+ raw_slot_sdf.append(weights_from_sdf)
+
+ dots = fwd_attn_hard * weights_from_sdf # * weights_from_sdf # [N, S]
+ total_occ = torch.sum(dots, dim=-1)
+ non_max_occ = total_occ - torch.max(dots, dim=-1)[0]
+ dots = torch.cat([dots, torch.ones_like(dots[:, :1]) * self.cfg['empty_slot_weight']], dim=-1)
+
+ # let the stochasticity only happen in forward pass; just take their results (attn_hard), and run straight-through argmax
+ attn = dots / torch.sum(dots, dim=1, keepdim=True)
+ max_attn, index = attn.max(dim=1, keepdim=True) # [N]
+ y_hard = torch.zeros_like(attn, memory_format=torch.legacy_contiguous_format).scatter_(1, index, 1.0)
+ attn_hard = y_hard - attn.detach() + attn
+ attn_raw = attn
+
+ # make all indices other than the max index have a small value
+ all_max_attn.append(max_attn.reshape(-1))
+ all_total_occ.append(total_occ)
+ all_non_max_occ.append(non_max_occ)
+
+ xyz_base = torch.cat([xyz_cnc.reshape(num_pts, num_slots, 3), pts[i: i + chunk].reshape(-1, 1, 3)], dim=1)
+
+ chosen_cnc = (attn_hard.unsqueeze(-1) * xyz_base).sum(dim=1)
+ bwd_attn_hard.append(attn_hard[:, :-1])
+ bwd_attn_soft.append(attn_raw[:, :-1])
+ canonical_pts.append(chosen_cnc)
+ empty_slot_mask.append(attn_hard[:, -1])
+ canonical_pts_cand.append(xyz_cnc.reshape(num_pts, num_slots, 3))
+
+ canonical_pts = torch.cat(canonical_pts, dim=0).float()
+ if len(bwd_attn_hard) > 0 and bwd_attn_hard[0] is not None:
+ bwd_attn_hard = torch.cat(bwd_attn_hard, dim=0).float()
+ bwd_attn_soft = torch.cat(bwd_attn_soft, dim=0).float()
+ else:
+ bwd_attn_hard, bwd_attn_soft = None, None
+ if len(raw_cnc) > 0:
+ raw_cnc = torch.cat(raw_cnc, dim=0)
+ raw_slot_attn = torch.cat(raw_slot_attn, dim=0)
+ raw_slot_sdf = torch.cat(raw_slot_sdf, dim=0)
+ empty_slot_mask = torch.cat(empty_slot_mask, dim=0)
+ all_max_attn = torch.cat(all_max_attn, dim=0)
+ all_total_occ = torch.cat(all_total_occ, dim=0)
+ all_non_max_occ = torch.cat(all_non_max_occ, dim=0)
+ canonical_pts_cand = torch.cat(canonical_pts_cand, dim=0)
+ else:
+ raw_cnc, raw_slot_attn, raw_slot_sdf, empty_slot_mask, canonical_pts_cand = None, None, None, None, None
+
+ ret_dict = {'canonical_pts': canonical_pts, 'canonical_pts_cand': canonical_pts_cand,
+ 'bwd_attn_hard': bwd_attn_hard, 'bwd_attn_soft': bwd_attn_soft,
+ 'raw_cnc': raw_cnc, 'raw_slot_attn': raw_slot_attn, 'raw_slot_sdf': raw_slot_sdf,
+ 'empty_slot_mask': empty_slot_mask, 'max_attn': all_max_attn, 'total_occ': all_total_occ, 'non_max_occ': all_non_max_occ}
+
+ return ret_dict
+
+ def forward_flow(self, cnc_name, pts, valid_samples, training=True):
+
+ if valid_samples is None:
+ valid_samples = torch.ones((len(pts)), dtype=torch.bool, device=pts.device)
+
+ inputs_flat = pts
+ embedded_fwdflow = torch.zeros((inputs_flat.shape[0], self.models[f'{cnc_name}_embed_fwdflow_fn'].out_dim),
+ device=inputs_flat.device)
+
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ embedded_fwdflow[valid_samples] = self.models[f'{cnc_name}_embed_fwdflow_fn'](
+ inputs_flat[valid_samples]).to(embedded_fwdflow.dtype)
+
+ embedded_fwdflow = embedded_fwdflow.float()
+
+ world_pts = []
+ world_pts_cand = []
+ fwd_attn_hard, fwd_attn_soft = [], []
+ fwd_rot, fwd_trans = [], []
+ fwd_rot_cand = []
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ chunk = self.cfg['netchunk']
+ for i in range(0, embedded_fwdflow.shape[0], chunk):
+ out = self.models[f'{cnc_name}_deformation_model'].forw_deform(pts[i: i + chunk],
+ embedded_fwdflow[i: i + chunk],
+ training=training, gt_attn=None)
+ world_pts.append(out['world_pts'])
+ world_pts_cand.append(out['world_pts_cand'])
+ fwd_attn_hard.append(out['attn_hard']) # [N, S]
+ fwd_attn_soft.append(out['attn_soft'])
+ fwd_rot.append(out['rotation'])
+ fwd_trans.append(out['translation'])
+ fwd_rot_cand.append(out['rotation_cand'])
+ world_pts = torch.cat(world_pts, dim=0).float()
+ world_pts_cand = torch.cat(world_pts_cand, dim=0).float()
+ if fwd_attn_hard[0] is not None:
+ fwd_attn_hard = torch.cat(fwd_attn_hard, dim=0).float()
+ fwd_attn_soft = torch.cat(fwd_attn_soft, dim=0).float()
+ fwd_rot = torch.cat(fwd_rot, dim=0).float()
+ fwd_rot_cand = torch.cat(fwd_rot_cand, dim=0).float()
+ fwd_trans = torch.cat(fwd_trans, dim=0).float()
+ else:
+ fwd_attn_hard, fwd_attn_soft = None, None
+ return {'world_pts': world_pts, 'world_pts_cand': world_pts_cand,
+ 'fwd_attn_hard': fwd_attn_hard, 'fwd_attn_soft': fwd_attn_soft,
+ 'fwd_rot': fwd_rot, 'fwd_trans': fwd_trans, 'fwd_rot_cand': fwd_rot_cand,
+ 'cnc_features': embedded_fwdflow}
+
+ def project_to_pixel(self, cam_pts):
+ projection = torch.matmul(self.tensor_K[:2, :2],
+ (cam_pts[..., :2] /
+ torch.clip(-cam_pts[..., 2:3], min=1e-8)).transpose(0, 1)) + self.tensor_K[:2, 2:3]
+ projection = projection.transpose(0, 1)
+ return projection
+
+ def get_canonical_pts_from_world_pts(self, cnc_name, world_pts, timesteps, valid_samples):
+ ret = {}
+ first_mask = (timesteps == self.cnc_timesteps[cnc_name]).float()
+ if first_mask.mean() == 1:
+ canonical_pts = world_pts
+ else:
+ backward_flow = self.backward_flow(cnc_name, world_pts, valid_samples)
+ canonical_pts = backward_flow['canonical_pts']
+ for key in ['bwd_attn_soft', 'bwd_attn_hard', 'raw_cnc', 'raw_slot_attn', 'raw_slot_sdf', 'empty_slot_mask',
+ 'max_attn', 'total_occ', 'non_max_occ', 'canonical_pts_cand']:
+ if key in backward_flow and backward_flow[key] is not None:
+ ret[key] = backward_flow[key]
+ canonical_pts = first_mask * world_pts + (1 - first_mask) * canonical_pts
+
+ ret['canonical_pts'] = canonical_pts
+
+ return ret
+
+ def get_world_pts_from_canonical_pts(self, cnc_name, canonical_pts, timesteps, valid_samples, training=True):
+ ret = {}
+ first_mask = (timesteps == self.cnc_timesteps[cnc_name]).float()
+
+ forward_flow = self.forward_flow(cnc_name, canonical_pts, valid_samples, training=training)
+ world_pts = forward_flow['world_pts']
+ world_pts_cand = forward_flow['world_pts_cand']
+
+ world_pts = first_mask * canonical_pts + (1 - first_mask) * world_pts
+ world_pts_cand = first_mask.unsqueeze(1) * canonical_pts.unsqueeze(1) + (1 - first_mask.unsqueeze(1)) * world_pts_cand
+ for key in ['fwd_attn_soft', 'fwd_attn_hard', 'cnc_features', 'fwd_rot', 'fwd_trans', 'fwd_rot_cand']:
+ ret[key] = forward_flow[key]
+
+ ret.update({'world_pts': world_pts, 'world_pts_cand': world_pts_cand})
+
+ return ret
+
+ def summarize_loss(self, loss_dict):
+ loss = torch.tensor(0.).cuda()
+ for loss_name, weight in self.loss_weights.items():
+ if weight > 0 and loss_name in loss_dict:
+ if loss_name in self.loss_schedule and self.loss_schedule[loss_name] > self.global_step:
+ continue
+ loss += loss_dict[loss_name] * weight
+ return loss
+
+ def train_epilogue(self, cnc_name, loss_dict):
+ loss = self.summarize_loss(loss_dict)
+
+ if (self.global_step + 1) % self.cfg['i_print'] == 0:
+ msg = f"Iter: {self.global_step + 1}, {cnc_name}, "
+ metrics = {
+ 'loss': loss.item(),
+ }
+ metrics.update({loss_name: loss_dict[loss_name].item() for loss_name in loss_dict
+ if loss_name.startswith(cnc_name) or loss_name.startswith('self')})
+
+ for k in metrics.keys():
+ msg += f"{k}: {metrics[k]:.7f}, "
+ msg += "\n"
+ logging.info(msg)
+
+ if (self.global_step + 1) % self.cfg['i_wandb'] == 0 and self.use_wandb:
+ self.plot_loss({'total_loss': loss.item()}, self.global_step)
+ self.plot_loss({'lr': self.optimizer.state_dict()['param_groups'][0]['lr']},
+ self.global_step)
+ self.plot_loss(loss_dict, self.global_step)
+
+ if loss.requires_grad:
+ self.optimizer.zero_grad()
+
+ self.amp_scaler.scale(loss).backward()
+
+ self.amp_scaler.step(self.optimizer)
+ self.amp_scaler.update()
+
+ if (self.global_step + 1) % 10 == 0:
+ self.schedule_lr()
+
+ if (self.global_step + 1) % self.cfg['i_weights'] == 0 and cnc_name == 'last':
+ self.save_weights(output_path=os.path.join(self.save_dir, 'ckpt', f'model_{self.global_step + 1:07d}.pth'))
+
+ if (self.global_step + 1) % self.cfg['i_mesh'] == 0:
+ self.export_canonical(cnc_name, per_part=self.global_step >= self.freeze_recon_step)
+
+ if (self.global_step + 1) % self.cfg['i_img'] == 0 and self.global_step < self.freeze_recon_step:
+ ids = torch.unique(self.rays_dict[cnc_name][:, self.ray_frame_id_slice]).data.cpu().numpy().astype(int).tolist()
+ ids.sort()
+ ids = ids[::10][:10]
+
+ os.makedirs(pjoin(self.save_dir, 'step_img'), exist_ok=True)
+ dir = f"{self.save_dir}/step_img/step_{self.global_step + 1:07d}_{cnc_name}"
+ os.makedirs(dir, exist_ok=True)
+ for frame_idx in ids:
+ rgb, depth, ray_mask, gt_rgb, gt_depth, _ = self.render_images(cnc_name, frame_idx)
+ mask_vis = (rgb * 255 * 0.2 + ray_mask * 0.8).astype(np.uint8)
+ mask_vis = np.clip(mask_vis, 0, 255)
+ rgb = np.concatenate((rgb, gt_rgb), axis=1)
+ far = self.cfg['far'] * self.cfg['sc_factor']
+ gt_depth = np.clip(gt_depth, self.cfg['near'] * self.cfg['sc_factor'], far)
+ depth_vis = np.concatenate((to8b(depth / far), to8b(gt_depth / far)), axis=1)
+ depth_vis = np.tile(depth_vis[..., None], (1, 1, 3))
+ row = np.concatenate((to8b(rgb), depth_vis, mask_vis), axis=1)
+ img_name = self.frame_names[frame_idx].split('/')[-1].split('.')[-2]
+ imageio.imwrite(pjoin(dir, f'{img_name}.png'), row.astype(np.uint8))
+
+ def train_render_loop(self, cnc_name, batch):
+ target_s = batch[:, self.ray_rgb_slice] # Color (N,3)
+ target_d = batch[:, self.ray_depth_slice] # Normalized scale (N)
+
+ extras = self.render_rays(cnc_name=cnc_name, ray_batch=batch,
+ depth=target_d, lindisp=False, perturb=True)
+ loss_dict = {}
+
+ valid_samples = extras['valid_samples'] # (N_ray,N_samples)
+ N_rays, N_samples = valid_samples.shape
+
+ rgb = extras['rgb_map']
+
+ valid_samples = extras['valid_samples'] # (N_ray,N_samples)
+ z_vals = extras['z_vals'] # [N_rand, N_samples + N_importance]
+
+ sdf = extras['raw'][..., -1]
+
+ valid_rays = (valid_samples > 0).any(dim=-1).bool().reshape(N_rays) & (batch[:, self.ray_type_slice] == 0)
+ valid_sample_weights = valid_samples * valid_rays.view(-1, 1)
+
+ rgb_loss = (((rgb - target_s) ** 2 * valid_rays.view(-1, 1))).mean(dim=-1)
+
+ loss_dict['self_rgb'] = rgb_loss.mean()
+
+ truncation = self.get_truncation()
+
+ empty_loss, fs_loss, sdf_loss, front_mask, sdf_mask = get_sdf_loss(z_vals, target_d.reshape(-1, 1).expand(-1, N_samples),
+ sdf, truncation, self.cfg, return_mask=True,
+ rays_d=batch[:, self.ray_dir_slice])
+
+ for loss, loss_name in zip((fs_loss, empty_loss, sdf_loss), ('freespace', 'empty', 'sdf')):
+ loss = (loss * valid_sample_weights).mean(dim=-1)
+ loss_dict[f'self_{loss_name}'] = loss.mean()
+
+ return loss_dict
+
+ def forward_consistency(self, cnc_name, cnc_pts, cnc_viewdirs=None, valid_samples=None):
+
+ other_cnc_name = [name for name in self.cnc_timesteps if name != cnc_name][0]
+ target_timesteps = torch.ones_like(cnc_pts[..., :1]) * self.cnc_timesteps[other_cnc_name]
+ target_pts_dict = self.get_world_pts_from_canonical_pts(cnc_name, cnc_pts, target_timesteps, valid_samples.reshape(-1))
+ target_pts = target_pts_dict['world_pts']
+ target_pts_cand = target_pts_dict['world_pts_cand']
+
+ attn = target_pts_dict['fwd_attn_hard']
+ num_slots = target_pts_cand.shape[1]
+
+ target_rot = target_pts_dict['fwd_rot'] # [N, 3, 3]
+ target_rot_cand = target_pts_dict['fwd_rot_cand']
+
+ num_pts = len(target_pts)
+ target_pts_all = torch.cat([target_pts, target_pts_cand.reshape(-1, 3)], dim=0)
+
+ valid_samples_cand = valid_samples.unsqueeze(1).repeat(1, num_slots)
+ valid_samples_all = torch.cat([valid_samples, valid_samples_cand.reshape(-1)], dim=0)
+
+ if cnc_viewdirs is not None: # [N, 3]
+ target_viewdirs = torch.matmul(target_rot, cnc_viewdirs.unsqueeze(-1)).squeeze(-1) # [N, 3]
+ target_viewdirs_cand = torch.matmul(target_rot_cand.unsqueeze(0), cnc_viewdirs.unsqueeze(1).unsqueeze(-1)).squeeze(-1) # [N, P, 3]
+ target_viewdirs_all = torch.cat([target_viewdirs, target_viewdirs_cand.reshape(-1, 3)], dim=0)
+ sdf_only = False
+ else:
+ target_viewdirs_all = None
+ sdf_only = True
+
+ target_outputs_all, __ = self.query_object_field(other_cnc_name, target_pts_all, valid_samples_all, viewdirs=target_viewdirs_all,
+ sdf_only=sdf_only)
+ target_outputs, target_outputs_cand = target_outputs_all[:num_pts], target_outputs_all[num_pts:].reshape(-1, num_slots, target_outputs_all.shape[-1])
+
+ target_outputs_post = (attn.unsqueeze(-1) * target_outputs_cand).sum(dim=1)
+ target_outputs_post = target_outputs_post.reshape(target_outputs.shape)
+
+ target_outputs = target_outputs_post
+
+ target_full_sdf_cand = self.query_full_sdf(other_cnc_name, target_pts_cand.reshape(-1, 3).float())
+ target_full_sdf_cand = target_full_sdf_cand.reshape(-1, num_slots)
+ target_full_sdf_post = (attn * target_full_sdf_cand).sum(dim=1)
+ target_full_sdf_post = target_full_sdf_post.reshape(-1)
+ warped_sdf = target_full_sdf_post
+ target_vis_cand = self.query_visibility(other_cnc_name, target_pts_cand.reshape(-1, 3).float()).float()
+ warped_vis = (attn * target_vis_cand.reshape(-1, num_slots)).sum(dim=1).reshape(-1)
+
+ ret_dict = {'warped_sdf': warped_sdf, 'fwd_attn': target_pts_dict['fwd_attn_soft'], 'warped_vis': warped_vis}
+
+ if cnc_viewdirs is not None:
+ ret_dict['warped_rgb'] = target_outputs[..., :3]
+ return ret_dict
+
+ def backward_consistency(self, cnc_name, world_pts, valid_samples): # world_pts (N, 3)
+ world_pts = world_pts.detach()
+ other_cnc_name = [name for name in self.cnc_timesteps if name != cnc_name][0]
+ target_timesteps = torch.ones_like(world_pts[..., :1]) * self.cnc_timesteps[other_cnc_name]
+
+ canonical_pts_dict = self.get_canonical_pts_from_world_pts(cnc_name, world_pts, target_timesteps, valid_samples)
+
+ ret_dict = {f'bwd_{key}': canonical_pts_dict[key] for key in ['max_attn', 'non_max_occ', 'total_occ']}
+ ret_dict['bwd_attn'] = canonical_pts_dict['bwd_attn_soft']
+
+ return ret_dict
+
+ def compute_forward_losses(self, self_dict, forward_dict):
+ loss_dict = {}
+
+ weights_from_sdf = self_dict['weights']
+
+ #--------------Consistency------------
+
+ self_sdf = self_dict['sdf']
+ warped_sdf = forward_dict['warped_sdf']
+ self_vis = self_dict['visibility']
+ warped_vis = forward_dict['warped_vis']
+
+ vis_weight = self_vis * warped_vis
+ vis_discount = self.cfg.get('vis_discount', 1.)
+ vis_weight = (1 - vis_weight) + vis_weight * vis_discount
+ weights = weights_from_sdf * vis_weight # * self_vis * warped_vis
+ weights = weights / (weights_from_sdf.sum() + 1e-6)
+ cns_sdf = ((warped_sdf - self_sdf.detach()).abs() * weights).sum()
+ loss_dict[f'cns_sdf'] = cns_sdf
+
+ if 'rgb' in self_dict and 'warped_rgb' in forward_dict:
+
+ self_rgb = self_dict['rgb']
+ warped_rgb = forward_dict['warped_rgb']
+ cns_rgb = (((warped_rgb - self_rgb.detach()) ** 2).mean(dim=-1) * weights).sum()
+ loss_dict['cns_rgb'] = cns_rgb
+
+ return loss_dict
+
+ def compute_backward_losses(self, self_dict, backward_dict):
+ loss_dict = {}
+
+ total_occ = backward_dict['bwd_total_occ']
+ loss_dict['collision_occ'] = (torch.relu(total_occ - 1) ** 2).mean()
+
+ if 'occ' in self_dict:
+ occ = self_dict['occ']
+
+ vis_weight = self_dict['visibility'].float()
+
+ vis_discount = self.cfg.get('vis_discount', 1.)
+ vis_weight = (1 - vis_weight) + vis_weight * vis_discount
+
+ loss_dict['cns_occ'] = (((total_occ - occ) ** 2) * vis_weight).mean()
+
+ return loss_dict
+
+ def train_ray_loop(self, cnc_name, batch):
+ target_d = batch[:, self.ray_depth_slice] # Normalized scale (N)
+
+ sample_dict = self.sample_rays(cnc_name, batch, lindisp=False, perturb=True, depth=target_d)
+
+ cnc_pts, cnc_viewdirs, valid_samples = [sample_dict[key].reshape((-1, ) + sample_dict[key].shape[2:])
+ for key in ['cnc_pts', 'cnc_viewdirs', 'valid_samples']]
+
+ full_sdf = self.query_full_sdf(cnc_name, cnc_pts)
+ visibility = self.query_visibility(cnc_name, cnc_pts)
+ weights_from_sdf = self.get_weights_from_full_sdf(full_sdf)
+
+ self_outputs, _ = self.query_object_field(cnc_name, cnc_pts, valid_samples, cnc_viewdirs)
+ self_rgb = self_outputs[..., :3]
+ self_sdf = full_sdf
+ self_dict = {'rgb': self_rgb, 'sdf': self_sdf, 'weights': weights_from_sdf, 'visibility': visibility}
+ self_dict['occ'] = self.get_occ_from_full_sdf(full_sdf)
+
+ loss_dict = {}
+
+ forward_dict = self.forward_consistency(cnc_name, cnc_pts, cnc_viewdirs, valid_samples)
+ loss_dict.update(self.compute_forward_losses(self_dict, forward_dict))
+
+ backward_dict = self.backward_consistency(cnc_name, cnc_pts, valid_samples)
+ loss_dict.update(self.compute_backward_losses(self_dict, backward_dict))
+
+ return loss_dict
+
+ def sample_occ(self, cnc_name):
+ canonical_pts = torch.rand(self.cfg['occ_sample_space'], 3).cuda() * 2 - 1.
+ return canonical_pts
+
+ def train_occ_loop(self, cnc_name):
+ cnc_pts = self.sample_occ(cnc_name)
+
+ full_sdf = self.query_full_sdf(cnc_name, cnc_pts)
+
+ weights_from_sdf = self.get_weights_from_full_sdf(full_sdf)
+
+ valid_samples = torch.ones(len(cnc_pts), dtype=torch.bool, device=cnc_pts.device)
+
+ visibility = self.query_visibility(cnc_name, cnc_pts)
+
+ self_sdf = full_sdf
+ self_dict = {'sdf': self_sdf, 'weights': weights_from_sdf, 'visibility': visibility}
+ self_dict['occ'] = self.get_occ_from_full_sdf(full_sdf)
+
+ loss_dict = {}
+
+ forward_dict = self.forward_consistency(cnc_name, cnc_pts, cnc_viewdirs=None, valid_samples=valid_samples)
+ loss_dict.update(self.compute_forward_losses(self_dict, forward_dict))
+
+ backward_dict = self.backward_consistency(cnc_name, cnc_pts, valid_samples)
+ loss_dict.update(self.compute_backward_losses(self_dict, backward_dict))
+
+ return loss_dict
+
+ def train_corr_loop(self, cnc_name, corr_batch):
+ src_ray_id = corr_batch[:, self.corr_src_id_slice].long()
+ tgt_frame_id = corr_batch[:, self.corr_tgt_frame_slice].long()
+ tgt_gt_pixel = corr_batch[:, self.corr_tgt_pixel_silce]
+
+ batch = self.rays_dict[cnc_name][src_ray_id]
+ target_d = batch[:, self.ray_depth_slice] # Normalized scale (N)
+
+ sample_dict = self.sample_rays(cnc_name=cnc_name, ray_batch=batch, lindisp=False,
+ perturb=True, depth=target_d)
+ tgt_pred = self.run_network_corr(cnc_name=cnc_name, sample_dict=sample_dict, depth=target_d)
+
+ valid_rays = tgt_pred['valid']
+ tgt_world_pts = tgt_pred['target_world_pts']
+ valid_rays = valid_rays & (batch[:, self.ray_type_slice] == 0)
+
+ tgt_tf = self.c2w_array[tgt_frame_id]
+
+ tgt_cam_pts = (tgt_tf[:, :3, :3].transpose(-1, -2) @ (tgt_world_pts.unsqueeze(-1) - tgt_tf[:, :3, 3:])).squeeze(-1)
+ tgt_pixel = self.project_to_pixel(tgt_cam_pts)
+
+ corr_diff = (tgt_pixel - tgt_gt_pixel) * valid_rays.unsqueeze(-1)
+ corr_loss = torch.abs(corr_diff).sum() / valid_rays.sum() / self.H
+
+ loss_dict = {}
+
+ loss_dict['corr'] = corr_loss
+ loss_dict[f'{cnc_name}_corr'] = corr_loss
+
+ return loss_dict
+
+ def train_recon(self):
+ start_step = self.global_step
+
+ for iter in range(start_step, self.freeze_recon_step):
+ for cnc_name in self.cnc_timesteps:
+ batch = next(self.data_loader[cnc_name])
+ loss_dict = self.train_render_loop(cnc_name, batch.cuda())
+
+ self.train_epilogue(cnc_name, loss_dict)
+
+ self.global_step += 1
+
+ self.freeze_recon()
+
+ def train_arti(self):
+ start_step = self.global_step
+
+ start_corr = -1 if 'corr' not in self.loss_schedule else self.loss_schedule['corr']
+
+ for iter in range(start_step, self.total_step):
+ for cnc_name in self.cnc_timesteps:
+ batch = next(self.data_loader[cnc_name])
+ loss_dict = {}
+
+ if 'ray' in self.cfg['train_modes']:
+ ray_loss = self.train_ray_loop(cnc_name, batch.cuda())
+ loss_dict.update({f'ray_{key}': value for key, value in ray_loss.items()})
+ loss_dict.update({f'{cnc_name}_ray_{key}': value for key, value in ray_loss.items()})
+
+ if 'occ' in self.cfg['train_modes']:
+ occ_loss = self.train_occ_loop(cnc_name)
+ loss_dict.update({f'occ_{key}': value for key, value in occ_loss.items()})
+ loss_dict.update({f'{cnc_name}_occ_{key}': value for key, value in occ_loss.items()})
+
+ if iter % self.cfg['train_corr_n_step'] == 0 and iter > start_corr:
+ corr_batch = next(self.corr_loader[cnc_name])
+ loss_dict.update(self.train_corr_loop(cnc_name, corr_batch))
+
+ self.train_epilogue(cnc_name, loss_dict)
+
+ self.global_step += 1
+
+ def load_recon(self):
+ recon_path = pjoin(self.save_dir, 'recon')
+ os.makedirs(recon_path, exist_ok=True)
+
+ sdf_dict = {}
+
+ nerf_scale = self.cfg['sc_factor']
+ nerf_trans = self.cfg['translation'].reshape(1, 3)
+
+ for cnc_name in ['init', 'last']:
+ recon_mesh_path = pjoin(recon_path, f'{cnc_name}_all_clustered.obj')
+ if not os.path.exists(recon_mesh_path):
+ with torch.no_grad():
+ mesh = self.extract_canonical_info(cnc_name, isolevel=0.0,
+ voxel_size=self.cfg['mesh_resolution'], disable_octree=False,
+ per_part=False)
+
+ vertices_raw = mesh.vertices.copy()
+ vertices_raw = vertices_raw / nerf_scale - nerf_trans
+ ori_mesh = trimesh.Trimesh(vertices_raw, mesh.faces, vertex_colors=mesh.visual.vertex_colors, process=False)
+
+ ori_mesh.export(pjoin(recon_path, f'{cnc_name}_all.obj'))
+
+ cluster_meshes(pjoin(recon_path, f'{cnc_name}_all.obj'), [], None)
+
+ sdf_voxel_size = self.cfg['sdf_voxel_size']
+ sdf_path = pjoin(recon_path, f'{cnc_name}_{sdf_voxel_size:.3f}.npz')
+ if os.path.exists(sdf_path):
+ sdf = np.load(sdf_path, allow_pickle=True)['data']
+ else:
+ mesh = trimesh.load(recon_mesh_path)
+ mesh.vertices = (mesh.vertices + nerf_trans) * nerf_scale
+ print('mesh vertices range', np.asarray(mesh.vertices).min(axis=0),
+ np.asarray(mesh.vertices).max(axis=0))
+ pts, sdf = sdf_voxel_from_mesh(mesh, sdf_voxel_size)
+ np.savez_compressed(sdf_path, data=sdf)
+
+ voxel_sdf = VoxelSDF(sdf)
+
+ sdf_dict[cnc_name] = voxel_sdf
+
+ self.recon_sdf_dict = sdf_dict
+
+ def train(self):
+ if self.global_step < self.freeze_recon_step:
+ self.train_recon()
+
+ self.freeze_recon()
+
+ self.load_recon()
+
+ self.train_arti()
+
+ def test(self):
+ assert self.global_step >= self.freeze_recon_step, 'Stage 1 reconstruction incomplete.'
+
+ self.load_recon()
+
+ for cnc_name in self.cnc_timesteps:
+ self.export_canonical(cnc_name, per_part=True)
+
+ @torch.no_grad()
+ def sample_rays_uniform_occupied_voxels(self, rays_d, depths_in_out, lindisp=False, perturb=False,
+ depths=None, N_samples=None):
+ N_rays = rays_d.shape[0]
+ N_intersect = depths_in_out.shape[1]
+ dirs = rays_d / rays_d.norm(dim=-1, keepdim=True)
+
+ z_in_out = depths_in_out.cuda() * torch.abs(dirs[..., 2]).reshape(N_rays, 1, 1).cuda()
+
+ if depths is not None:
+ depths = depths.reshape(-1, 1)
+ trunc = self.get_truncation()
+ valid = (depths >= self.cfg['near'] * self.cfg['sc_factor']) & (
+ depths <= self.cfg['far'] * self.cfg['sc_factor']).expand(-1, N_intersect)
+ valid = valid & (z_in_out > 0).all(dim=-1) # (N_ray, N_intersect)
+ # upper_bound = 10.0 if not near_depth else 0.
+ z_in_out[valid] = torch.clip(z_in_out[valid],
+ min=torch.zeros_like(z_in_out[valid]),
+ max=torch.ones_like(z_in_out[valid]) * (
+ depths.reshape(-1, 1, 1).expand(-1, N_intersect, 2)[
+ valid] + trunc))
+
+ depths_lens = z_in_out[:, :, 1] - z_in_out[:, :, 0] # (N_ray,N_intersect)
+ z_vals_continous = sample_rays_uniform(N_samples,
+ torch.zeros((N_rays, 1), device=z_in_out.device).reshape(-1, 1),
+ depths_lens.sum(dim=-1).reshape(-1, 1), lindisp=lindisp,
+ perturb=perturb) # (N_ray,N_sample)
+
+ N_samples = z_vals_continous.shape[1]
+ z_vals = torch.zeros((N_rays, N_samples), dtype=torch.float, device=rays_d.device)
+ z_vals = common.sampleRaysUniformOccupiedVoxels(z_in_out.contiguous(), z_vals_continous.contiguous(), z_vals)
+ z_vals = z_vals.float().to(rays_d.device) # (N_ray,N_sample)
+
+ return z_vals, z_vals_continous
+
+ def sample_rays(self, cnc_name, ray_batch, lindisp=False, perturb=False, depth=None):
+
+ N_rays = ray_batch.shape[0]
+
+ rays_d = ray_batch[:, self.ray_dir_slice] # in world frame
+ rays_o = torch.zeros_like(rays_d)
+ viewdirs = rays_d / rays_d.norm(dim=-1, keepdim=True)
+
+ frame_ids = ray_batch[:, self.ray_frame_id_slice].long()
+
+ tf = self.c2w_array[frame_ids]
+
+ rays_o_w = transform_pts(rays_o, tf)
+ viewdirs_w = (tf[:, :3, :3] @ viewdirs[:, None].permute(0, 2, 1))[:, :3, 0]
+ voxel_size = self.cfg['octree_raytracing_voxel_size'] * self.cfg['sc_factor']
+ level = int(np.floor(np.log2(2.0 / voxel_size)))
+ near, far, _, depths_in_out = self.octree_m.ray_trace(rays_o_w, viewdirs_w, level=level, debug=0)
+ z_vals, _ = self.sample_rays_uniform_occupied_voxels(rays_d=viewdirs,
+ depths_in_out=depths_in_out, lindisp=lindisp,
+ perturb=perturb, depths=depth,
+ N_samples=self.cfg['N_samples'])
+
+ if self.cfg['N_samples_around_depth'] > 0 and depth is not None:
+ valid_depth_mask = (depth >= self.cfg['near'] * self.cfg['sc_factor']) & (
+ depth <= self.cfg['far'] * self.cfg['sc_factor'])
+ valid_depth_mask = valid_depth_mask.reshape(-1)
+ trunc = self.get_truncation()
+ near_depth = depth[valid_depth_mask] - trunc
+ far_depth = depth[valid_depth_mask] + trunc * self.cfg['neg_trunc_ratio']
+ z_vals_around_depth = torch.zeros((N_rays, self.cfg['N_samples_around_depth']),
+ device=ray_batch.device).float()
+ # if torch.sum(inside_mask)>0:
+ z_vals_around_depth[valid_depth_mask] = sample_rays_uniform(self.cfg['N_samples_around_depth'],
+ near_depth.reshape(-1, 1),
+ far_depth.reshape(-1, 1), lindisp=lindisp,
+ perturb=perturb)
+ invalid_depth_mask = valid_depth_mask == 0
+
+ if invalid_depth_mask.any() and self.cfg['use_octree']:
+ z_vals_invalid, _ = self.sample_rays_uniform_occupied_voxels(rays_d=viewdirs[invalid_depth_mask],
+ depths_in_out=depths_in_out[
+ invalid_depth_mask], lindisp=lindisp,
+ perturb=perturb, depths=None,
+ N_samples=self.cfg[
+ 'N_samples_around_depth'])
+ z_vals_around_depth[invalid_depth_mask] = z_vals_invalid
+ else:
+ z_vals_around_depth[invalid_depth_mask] = sample_rays_uniform(self.cfg['N_samples_around_depth'],
+ near[invalid_depth_mask].reshape(-1, 1),
+ far[invalid_depth_mask].reshape(-1, 1),
+ lindisp=lindisp, perturb=perturb)
+
+ z_vals = torch.cat((z_vals, z_vals_around_depth), dim=-1)
+ valid_samples = torch.ones(z_vals.shape, dtype=torch.bool,
+ device=ray_batch.device) # During pose update if ray out of box, it becomes invalid
+
+ pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3]
+
+ N_ray, N_sample = pts.shape[:2]
+
+ tf_flat = tf[:, None].expand(-1, N_sample, -1, -1).reshape(-1, 4, 4)
+ cnc_pts = transform_pts(torch.reshape(pts, (-1, 3)), tf_flat)
+
+ valid_samples = valid_samples.bool() & (torch.abs(cnc_pts) <= 1).all(dim=-1).view(N_ray, N_sample).bool()
+ cnc_pts = cnc_pts.reshape(N_ray, N_sample, 3)
+
+ cnc_viewdirs = (tf[..., :3, :3] @ viewdirs[..., None])[..., 0] # (N_ray, 3)
+ cnc_viewdirs = cnc_viewdirs.unsqueeze(1).repeat(1, N_sample, 1)
+
+ return {'cnc_pts': cnc_pts, 'z_vals': z_vals, 'tf': tf, 'valid_samples': valid_samples,
+ 'cnc_viewdirs': cnc_viewdirs}
+
+ def render_rays(self, cnc_name, ray_batch, lindisp=False, perturb=False, depth=None):
+ sample_dict = self.sample_rays(cnc_name, ray_batch, lindisp=lindisp, perturb=perturb, depth=depth)
+
+ cnc_pts, cnc_viewdirs, valid_samples, z_vals = [sample_dict[key]
+ for key in ['cnc_pts', 'cnc_viewdirs', 'valid_samples', 'z_vals']]
+
+ cur_dict = self.run_network_render(cnc_name, cnc_pts, cnc_viewdirs,
+ valid_samples=valid_samples) # [N_rays, N_samples, 4]
+
+ raw, valid_samples = cur_dict['raw'], cur_dict['valid_samples']
+
+ rgb_map, weights = self.raw2outputs(raw, z_vals, valid_samples=valid_samples, depth=depth)
+
+ cur_dict.update({'rgb_map': rgb_map, 'weights': weights, 'z_vals': z_vals})
+
+ return cur_dict
+
+ def get_ray_sample_weights_from_depth(self, z_vals, depth=None, pred_sdf=None, valid_samples=None, truncation=None):
+ '''
+ z_vals, valid_samples: [N_rays, N_samples]
+ depth: [N_rays]
+ '''
+ if truncation is None:
+ truncation = self.get_truncation()
+
+ if depth is not None:
+ depth = depth.view(-1, 1)
+ sdf_from_depth = (depth - z_vals) / truncation
+ sdf = sdf_from_depth
+ else:
+ sdf = pred_sdf
+
+ weights = torch.sigmoid(sdf * self.cfg['sdf_lambda']) * torch.sigmoid(-sdf * self.cfg['sdf_lambda'])
+ mask = (sdf > -self.cfg['neg_trunc_ratio']) & (sdf < 1)
+ weights = weights * mask
+
+ if depth is not None:
+ invalid = (depth > self.cfg['far'] * self.cfg['sc_factor']).reshape(-1)
+ weights[invalid] = 0
+
+ if valid_samples is not None:
+ weights[valid_samples == 0] = 0
+
+ weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-10)
+
+ return weights
+
+ def raw2outputs(self, raw, z_vals, valid_samples=None, depth=None):
+
+ rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
+ sdf = raw[..., -1]
+
+ weights = self.get_ray_sample_weights_from_depth(z_vals, depth=depth, pred_sdf=sdf, valid_samples=valid_samples)
+
+ rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3]
+
+ return rgb_map, weights
+
+ def get_weights_from_full_sdf(self, ori_sdf):
+ trunc = self.cfg['sdf_weight_trunc']
+ neg_clamp = self.cfg['sdf_weight_neg_clamp']
+ sdf = torch.clamp_min(ori_sdf / trunc, -neg_clamp)
+ weights = torch.sigmoid(sdf * self.cfg['sdf_lambda']) * torch.sigmoid(-sdf * self.cfg['sdf_lambda'])
+ cut_off = self.cfg['sdf_weight_cutoff']
+ if cut_off > 0:
+ mask = ori_sdf < cut_off
+ weights = weights * mask
+
+ return weights
+
+ def get_occ_from_full_sdf(self, sdf):
+ thresh = self.cfg['sdf_to_occ_thresh']
+ return torch.clamp_min(1 - torch.relu(sdf / thresh + 0.5), 0) # SDF < -0.005 --> 1, SDF > 0.005 --> 0, [-0.005, 0.005] occ 1 --> 0
+
+ def render_frame(self, cnc_name, rays, depth=None, lindisp=False, perturb=False):
+ """Render rays in chunks
+ """
+
+ all_ret = []
+ chunk = self.cfg['chunk']
+ for i in range(0, rays.shape[0], chunk):
+ ret = self.render_rays(cnc_name, rays[i:i + chunk],
+ depth=None if depth is None else depth[i: i + chunk],
+ lindisp=lindisp, perturb=perturb)
+ all_ret.append({key: value for key, value in ret.items() if value is not None})
+
+ def merge_list(l):
+ if isinstance(l[0], dict): # merge a list of dicts
+ return {key: merge_list(list(x[key] for x in l)) for key in l[0]}
+ elif isinstance(l[0], list):
+ return [merge_list(list(x[i] for x in l)) for i in range(len(l[0]))]
+ elif isinstance(l[0], torch.Tensor):
+ return torch.cat(l, 0)
+
+ all_ret = merge_list(all_ret)
+
+ return all_ret
+
+ def query_object_field(self, cnc_name, pts, valid_samples, viewdirs=None,
+ sdf_only=False, empty_slot_mask=None):
+ embedded = torch.zeros((pts.shape[0], self.models[f'{cnc_name}_embed_fn'].out_dim), device=pts.device)
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ embedded[valid_samples] = self.models[f'{cnc_name}_embed_fn'](pts[valid_samples]).to(embedded.dtype)
+ embedded = embedded.float()
+
+ # Add view directions
+ if not sdf_only and self.models[f'{cnc_name}_embeddirs_fn'] is not None:
+ embedded_dirs = self.models[f'{cnc_name}_embeddirs_fn'](viewdirs)
+ embedded = torch.cat([embedded, embedded_dirs], -1)
+
+ outputs = []
+ with torch.cuda.amp.autocast(enabled=self.cfg['amp']):
+ chunk = self.cfg['netchunk']
+ for i in range(0, embedded.shape[0], chunk):
+ if sdf_only:
+ out = self.models[f'{cnc_name}_model'].forward_sdf(embedded[i:i + chunk]).reshape(-1, 1)
+ else:
+ out = self.models[f'{cnc_name}_model'](embedded[i:i + chunk])
+ outputs.append(out)
+ outputs = torch.cat(outputs, dim=0).float()
+
+ if empty_slot_mask is not None:
+ empty_outputs = torch.zeros_like(outputs)
+ empty_outputs[..., -1] = 1
+ empty_slot_mask = empty_slot_mask.float().reshape(-1, 1)
+
+ outputs = (1 - empty_slot_mask) * outputs + empty_slot_mask * empty_outputs
+
+ return outputs, valid_samples
+
+ def run_network_corr(self, cnc_name, sample_dict, depth):
+
+ canonical_pts = sample_dict['cnc_pts'].reshape(-1, 3)
+ valid_samples = sample_dict['valid_samples']
+
+ other_cnc_name = [name for name in self.cnc_timesteps if name != cnc_name][0]
+ target_timesteps = torch.ones_like(canonical_pts[..., :1]) * self.cnc_timesteps[other_cnc_name]
+ target_pts_dict = self.get_world_pts_from_canonical_pts(cnc_name, canonical_pts, target_timesteps,
+ valid_samples.reshape(-1))
+ target_world_pts = target_pts_dict['world_pts']
+
+ # per-point 3D correspondence --> sum over the ray
+
+ weights = self.get_ray_sample_weights_from_depth(sample_dict['z_vals'],
+ depth=depth, valid_samples=sample_dict['valid_samples'])
+ sumw = weights.sum(dim=-1)
+ non_zero_weight = (sumw > 0)
+ target_world_pts = (target_world_pts.reshape(len(weights), -1, 3) * weights.unsqueeze(-1)).sum(dim=-2)
+ return {'target_world_pts': target_world_pts, 'valid': non_zero_weight}
+
+ def run_network_render(self, cnc_name, cnc_pts, cnc_viewdirs, valid_samples=None):
+ # only render the current frame; does not go through motion
+ # cnc_pts: [#rays, #samples, 3], cnc_viewdirs: [#rays, #samples, 3]
+ old_shape = cnc_pts.shape[:-1]
+
+ outputs, cur_valid_samples = self.query_object_field(cnc_name, cnc_pts.reshape(-1, 3),
+ valid_samples.reshape(-1),
+ cnc_viewdirs.reshape(-1, 3))
+
+ ret_dict = {'raw': outputs.reshape(old_shape + (outputs.shape[-1], )),
+ 'valid_samples': cur_valid_samples.reshape(old_shape)}
+
+ return ret_dict
+
+ def run_network_density(self, cnc_name, inputs, timestep):
+ inputs_flat = inputs.reshape(-1, inputs.shape[-1])
+
+ inputs_flat = torch.clip(inputs_flat, -1, 1)
+ valid_samples = torch.ones((len(inputs_flat)), device=inputs.device).bool()
+ empty_slot_mask = None
+
+ if timestep != self.cnc_timesteps[cnc_name]:
+
+ timesteps = torch.ones_like(inputs_flat[..., :1]) * timestep
+
+ if not inputs_flat.requires_grad:
+ inputs_flat.requires_grad = True
+
+ input_pts = inputs_flat
+ canonical_pts_dict = self.backward_flow(cnc_name, input_pts, timesteps, valid_samples, training=False)
+ inputs_flat = canonical_pts_dict['canonical_pts']
+ inputs_flat = torch.clip(inputs_flat, -1, 1)
+ empty_slot_mask = torch.zeros_like(inputs_flat[..., 0]) if 'empty_slot_mask' not in canonical_pts_dict or canonical_pts_dict['empty_slot_mask'] is None else canonical_pts_dict['empty_slot_mask']
+
+ outputs, valid_samples = self.query_object_field(cnc_name, inputs_flat, valid_samples,
+ sdf_only=True,
+ empty_slot_mask=empty_slot_mask)
+
+ return outputs, valid_samples
+
+ def run_network_canonical_info(self, cnc_name, inputs):
+
+ inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
+
+ inputs_flat = torch.clip(inputs_flat, -1, 1)
+ valid_samples = torch.ones((len(inputs_flat)), device=inputs.device).bool()
+
+ canonical_pts = inputs_flat
+ target_pts_dict = self.get_world_pts_from_canonical_pts(cnc_name, canonical_pts, torch.zeros_like(canonical_pts[..., :1]),
+ valid_samples.reshape(-1), training=False)
+
+ ret = {key: target_pts_dict[key] for key in ['fwd_attn_hard']}
+
+ return ret
+
+ def export_canonical(self, cnc_name, per_part=True, eval=True):
+ os.makedirs(pjoin(self.save_dir, 'results'), exist_ok=True)
+ result_path = pjoin(self.save_dir, 'results', f'step_{self.global_step + 1:07d}')
+ os.makedirs(result_path, exist_ok=True)
+
+ with torch.no_grad():
+ recon = self.extract_canonical_info(cnc_name, isolevel=0.0,
+ voxel_size=self.cfg['mesh_resolution'], disable_octree=False,
+ per_part=per_part)
+
+ if per_part:
+ all_mesh, part_meshes, part_rot, part_trans = recon
+ else:
+ all_mesh = recon
+
+ nerf_scale = self.cfg['sc_factor']
+ nerf_trans = self.cfg['translation'].reshape(1, 3)
+ inv_scale = 1.0 / nerf_scale
+ inv_trans = (-nerf_scale * nerf_trans).reshape(1, 3)
+
+ if per_part:
+ all_rot = part_rot.cpu().numpy().swapaxes(-1, -2)
+ all_trans = part_trans.cpu().numpy()
+ num_joints = len(all_rot) - 1
+
+ all_trans = (all_trans + inv_trans - np.matmul(all_rot, inv_trans[..., np.newaxis])[..., 0]) * inv_scale
+
+ pred_meshes = [pjoin(result_path, f'{cnc_name}_{suffix}.obj')
+ for suffix in ['all'] + [f'part_{i}' for i in range(num_joints + 1)]]
+ else:
+ pred_meshes = [pjoin(result_path, f'{cnc_name}_all.obj')]
+ part_meshes = []
+
+ gt_name = {'init': 'start', 'last': 'end'}[cnc_name]
+
+ for mesh, mesh_path in zip([all_mesh] + part_meshes, pred_meshes):
+ if mesh is not None:
+ vertices_raw = mesh.vertices.copy()
+ vertices_raw = inv_scale * (vertices_raw + inv_trans.reshape(1, 3))
+ ori_mesh = trimesh.Trimesh(vertices_raw, mesh.faces, vertex_colors=mesh.visual.vertex_colors, process=False)
+
+ ori_mesh.export(mesh_path)
+
+ mesh_name = '_'.join(mesh_path.split('/')[-1].split('.')[0].split('_')[1:])
+ if mesh_name.startswith('part'):
+ part_i = int(mesh_name.split('_')[-1])
+ cur_rot, cur_trans = all_rot[part_i], all_trans[part_i] # [3, 3] and [3]
+ vertices_raw = ori_mesh.vertices.copy()
+ vertices_raw = (np.matmul(cur_rot,
+ vertices_raw.transpose(1, 0)) + cur_trans.reshape(3, 1)).transpose(1, 0)
+ forward_mesh = trimesh.Trimesh(vertices_raw, ori_mesh.faces, process=False)
+ new_mesh_path = mesh_path.replace(mesh_name, f'{mesh_name}_forward')
+ forward_mesh.export(new_mesh_path)
+
+ all_metric_dict = {}
+
+ if not per_part:
+ pred_w = pred_meshes[0]
+ if self.gt_dict is not None and eval:
+ gt_w = self.gt_dict[f'mesh_{gt_name}']['w']
+ s, d_list, w = eval_CD(None, [], pred_w, None, [], gt_w)
+ all_metric_dict.update({'chamfer_whole': w * 1000})
+ else:
+ cluster_meshes(None, [], pred_w)
+
+ else:
+
+ pred_w, pred_s, pred_d_list = pred_meshes[0], pred_meshes[1], pred_meshes[2:]
+ if self.gt_dict is not None:
+ gt_w, gt_s, gt_d_list = list(self.gt_dict[f'mesh_{gt_name}'][key] for key in ['w', 's', 'd'])
+ else:
+ eval = False
+ cluster_meshes(pred_s, pred_d_list, pred_w)
+
+ base_rot, base_trans = all_rot[0], all_trans[0] # only works for one-level below base
+
+ all_perm = permutations(range(num_joints))
+
+ if not eval:
+ all_perm = []
+ for joint_type in ['prismatic', 'revolute']:
+ joint_info_list = []
+ for pred_i in range(num_joints):
+ part_rot, part_trans = all_rot[pred_i + 1], all_trans[pred_i + 1]
+
+ joint_info, rel_rot, rel_trans = interpret_transforms(base_rot, base_trans, part_rot, part_trans,
+ joint_type=joint_type)
+ joint_info.update({'rotation': rel_rot, 'translation': rel_trans})
+ save_axis_mesh(joint_info['axis_direction'], joint_info['axis_position'],
+ pjoin(result_path, f'{cnc_name}_axis_{pred_i}_{joint_type}.obj'))
+
+ joint_info_list.append(joint_info)
+
+ with open(pjoin(result_path, f'{cnc_name}_{joint_type}_motion.json'), 'w') as f:
+ json_joint_info = [
+ {key: value if not isinstance(value, np.ndarray) else value.tolist() for key, value in
+ joint_info.items()} for joint_info in joint_info_list]
+ json.dump(json_joint_info, f)
+
+ for p, perm in enumerate(all_perm):
+ perm_str = 'p' + ''.join(list(map(str, perm))) + '_'
+ if num_joints == 1:
+ perm_str = ''
+
+ joint_info_list = []
+ joint_metric = {}
+
+ for gt_i, pred_i in enumerate(perm):
+ joint_str = f'_{gt_i}' if len(perm) > 1 else ''
+ gt_joint = {key: value for key, value in self.gt_dict['joint'][gt_i].items()}
+ part_rot, part_trans = all_rot[pred_i + 1], all_trans[pred_i + 1]
+
+ joint_type = 'revolute' if gt_joint['joint_type'] == 'rotate' else 'prismatic'
+
+ joint_info, rel_rot, rel_trans = interpret_transforms(base_rot, base_trans, part_rot, part_trans, joint_type=joint_type)
+ joint_info.update({'rotation': rel_rot, 'translation': rel_trans})
+
+ joint_info_list.append(joint_info)
+
+ save_axis_mesh(joint_info['axis_direction'], joint_info['axis_position'],
+ pjoin(result_path, f'{cnc_name}_axis_{pred_i}_{joint_type}.obj'))
+
+ angle, distance, theta_diff = eval_axis_and_state(joint_info, gt_joint, joint_type=joint_type,
+ reverse=gt_name == 'end')
+ joint_metric.update({f'axis_angle{joint_str}': angle,
+ f'axis_dist{joint_str}': distance * 10,
+ f'theta_diff{joint_str}': theta_diff})
+
+ try:
+ s, d_list, w = eval_CD(pred_s, [pred_d_list[pred_i] for pred_i in perm], pred_w, gt_s, gt_d_list, gt_w)
+ except:
+ s, d_list, w = -1, [-1 for _ in range(num_joints)], -1
+
+ metric_dict = {}
+ metric_dict.update(joint_metric)
+ metric_dict.update({'chamfer_static': s * 1000})
+ if len(d_list) > 1:
+ metric_dict.update({f'chamfer_dynamic_{i}': d_list[i] * 1000 for i in range(len(d_list))})
+ else:
+ metric_dict['chamfer_dynamic'] = d_list[0] * 1000
+ metric_dict.update({'chamfer_whole': w * 1000})
+
+ with open(pjoin(result_path, f'{cnc_name}_{perm_str}motion.json'), 'w') as f:
+ json_joint_info = [{key: value if not isinstance(value, np.ndarray) else value.tolist() for key, value in
+ joint_info.items()} for joint_info in joint_info_list]
+ json.dump(json_joint_info, f)
+
+ self.plot_loss({f'{cnc_name}_{perm_str}{key}': value for key, value in metric_dict.items()}, self.global_step)
+
+ with open(pjoin(result_path, f'{cnc_name}_{perm_str}metrics.json'), 'w') as f:
+ json.dump(metric_dict, f)
+
+ all_metric_dict.update({f'{perm_str}{key}': value for key, value in metric_dict.items()})
+
+ write_mode = 'w' if cnc_name == 'init' and p == 0 else 'a'
+ with open(pjoin(result_path, 'all_metrics'), write_mode) as f:
+ for metric in metric_dict:
+ metric_name = f'{cnc_name} {perm_str}{metric}'
+ print(f'{metric_name : >20}: {metric_dict[metric]:.6f}', file=f)
+
+ msg = f"{cnc_name}, iter: {self.global_step}, "
+ for k in all_metric_dict.keys():
+ msg += f"{k}: {all_metric_dict[k]:.4f}, "
+ msg += "\n"
+ logging.info(msg)
+
+
+ @torch.no_grad()
+ def extract_canonical_info(self, cnc_name, voxel_size=0.003, isolevel=0.0, disable_octree=False, per_part=True):
+ # Query network on dense 3d grid of points
+ voxel_size *= self.cfg['sc_factor'] # in "network space"
+ voxel_size = max(voxel_size, 0.004)
+
+ bounds = np.array(self.cfg['bounding_box']).reshape(2, 3)
+ x_min, x_max = bounds[0, 0], bounds[1, 0]
+ y_min, y_max = bounds[0, 1], bounds[1, 1]
+ z_min, z_max = bounds[0, 2], bounds[1, 2]
+ tx = np.arange(x_min + 0.5 * voxel_size, x_max, voxel_size)
+ ty = np.arange(y_min + 0.5 * voxel_size, y_max, voxel_size)
+ tz = np.arange(z_min + 0.5 * voxel_size, z_max, voxel_size)
+ N = len(tx)
+ query_pts = torch.tensor(
+ np.stack(np.meshgrid(tx, ty, tz, indexing='ij'), -1).astype(np.float32).reshape(-1, 3)).float().cuda()
+
+ if self.octree_m is not None and not disable_octree:
+ vox_size = self.cfg['octree_raytracing_voxel_size'] * self.cfg['sc_factor']
+ level = int(np.floor(np.log2(2.0 / vox_size)))
+
+ chunk = 160000
+ all_valid = []
+ for i in range(0, query_pts.shape[0], chunk):
+ cur_pts = query_pts[i: i + chunk]
+ center_ids = self.octree_m.get_center_ids(cur_pts, level)
+ valid = center_ids >= 0
+ all_valid.append(valid)
+ valid = torch.cat(all_valid, dim=0)
+ else:
+ valid = torch.ones(len(query_pts), dtype=bool).cuda()
+
+ flat = query_pts[valid]
+
+ sigma = []
+ cnc_features, attn_hard = [], []
+ chunk = self.cfg['netchunk']
+ for i in range(0, flat.shape[0], chunk):
+ inputs = flat[i:i + chunk]
+ with torch.no_grad():
+ outputs, valid_samplh_resos = self.run_network_density(cnc_name=cnc_name, inputs=inputs, timestep=self.cnc_timesteps[cnc_name])
+ slot_info = self.run_network_canonical_info(cnc_name=cnc_name, inputs=inputs)
+ attn_hard.append(slot_info['fwd_attn_hard'])
+ sigma.append(outputs)
+ sigma = torch.cat(sigma, dim=0)
+
+ observed = []
+ chunk = 120000
+
+ for i in range(0, flat.shape[0], chunk):
+ observed.append(self.depth_fuser[cnc_name].query(flat[i:i + chunk]))
+
+ observed = torch.cat(observed, dim=0)
+ sigma[~observed] = 1.
+
+ sigma_ = torch.ones((N ** 3)).float()
+ sigma_[valid.cpu()] = sigma.reshape(-1).cpu()
+ sigma = sigma_.reshape(N, N, N).data.numpy()
+
+ valid = valid.cpu().numpy()
+ valid = np.where(valid > 0)[0]
+
+ def get_mesh(sigma):
+
+ from skimage import measure
+ try:
+ vertices, triangles, normals, values = measure.marching_cubes(sigma, isolevel)
+ except Exception as e:
+ logging.info(f"ERROR Marching Cubes {e}")
+ return None
+
+ # Rescale and translate
+ voxel_size_ndc = np.array([tx[-1] - tx[0], ty[-1] - ty[0], tz[-1] - tz[0]]) / np.array(
+ [[tx.shape[0] - 1, ty.shape[0] - 1, tz.shape[0] - 1]])
+ offset = np.array([tx[0], ty[0], tz[0]])
+ vertices[:, :3] = voxel_size_ndc.reshape(1, 3) * vertices[:, :3] + offset.reshape(1, 3)
+
+ # Create mesh
+ mesh = trimesh.Trimesh(vertices, triangles, process=False)
+ return mesh
+
+ all_mesh = get_mesh(sigma)
+
+ if not per_part:
+ return all_mesh
+
+ attn_hard = torch.cat(attn_hard, dim=0)
+ rot, trans = self.models[f'{cnc_name}_deformation_model'].get_slot_motions()
+
+ if all_mesh is not None:
+ all_vertices = all_mesh.vertices
+ flat = torch.tensor(all_vertices).float().cuda()
+
+ chunk = self.cfg['netchunk']
+ vert_attn_hard = []
+ for i in range(0, flat.shape[0], chunk):
+ inputs = flat[i:i + chunk]
+ with torch.no_grad():
+ slot_info = self.run_network_canonical_info(cnc_name, inputs)
+ vert_attn_hard.append(slot_info['fwd_attn_hard'])
+ vert_attn_hard = torch.cat(vert_attn_hard, dim=0)
+ visual = np.zeros((len(vert_attn_hard), 3))
+ color_list = np.array(((0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0), (1, 0, 1), (0, 1, 1), (1, 1, 1), (0.5, 0.5, 0.5)))
+ for i in range(vert_attn_hard.shape[-1]):
+ visual[torch.where(vert_attn_hard[:, i] == 1)[0].cpu().numpy()] = color_list[i]
+ visual = trimesh.visual.ColorVisuals(mesh=all_mesh, face_colors=None,
+ vertex_colors=(visual * 255).astype(np.uint8))
+ all_mesh.visual = visual
+
+ part_meshes = []
+ for i in range(attn_hard.shape[-1]):
+ part_sigma = np.ones_like(sigma.reshape(-1))
+ idx = torch.where(attn_hard[:, i] == 1)[0].cpu().numpy()
+ part_sigma[valid[idx]] = sigma.reshape(-1)[valid[idx]]
+ part_meshes.append(get_mesh(part_sigma.reshape(sigma.shape)))
+
+ return all_mesh, part_meshes, rot, trans
+
+ def render_images(self, cnc_name, img_i, cur_rays=None):
+ if cur_rays is None:
+ frame_ids = self.rays_dict[cnc_name][:, self.ray_frame_id_slice].cuda()
+ cur_rays = self.rays_dict[cnc_name][frame_ids == img_i].cuda()
+ gt_depth = cur_rays[:, self.ray_depth_slice]
+ gt_rgb = cur_rays[:, self.ray_rgb_slice].cpu()
+ ray_type = cur_rays[:, self.ray_type_slice].data.cpu().numpy()
+
+ ori_chunk = self.cfg['chunk']
+ self.cfg['chunk'] = copy.deepcopy(self.cfg['render_chunk'])
+ with torch.no_grad():
+ extras = self.render_frame(cnc_name=cnc_name, rays=cur_rays, lindisp=False, perturb=False,
+ depth=gt_depth)
+ self.cfg['chunk'] = ori_chunk
+
+ sdf = extras['raw'][..., -1] # full_sdf: will just use the partial one learned with color
+ z_vals = extras['z_vals']
+ signs = sdf[:, 1:] * sdf[:, :-1]
+ empty_rays = (signs > 0).all(dim=-1)
+ mask = signs < 0
+ inds = torch.argmax(mask.float(), axis=1)
+ inds = inds[..., None]
+ depth = torch.gather(z_vals, dim=1, index=inds)
+ depth[empty_rays] = self.cfg['far'] * self.cfg['sc_factor']
+ depth = depth[..., None].data.cpu().numpy()
+
+ rgb = extras['rgb_map']
+ rgb = rgb.data.cpu().numpy()
+
+ rgb_full = np.zeros((self.H, self.W, 3), dtype=float)
+ depth_full = np.zeros((self.H, self.W), dtype=float)
+ ray_mask_full = np.zeros((self.H, self.W, 3), dtype=np.uint8)
+ X = cur_rays[:, self.ray_dir_slice].data.cpu().numpy()
+ X[:, [1, 2]] = -X[:, [1, 2]]
+ projected = (self.K @ X.T).T
+ uvs = projected / projected[:, 2].reshape(-1, 1)
+ uvs = uvs.round().astype(int)
+ uvs_good = uvs[ray_type == 0]
+ ray_mask_full[uvs_good[:, 1], uvs_good[:, 0]] = [255, 0, 0]
+ uvs_uncertain = uvs[ray_type == 1]
+ ray_mask_full[uvs_uncertain[:, 1], uvs_uncertain[:, 0]] = [0, 255, 0]
+ rgb_full[uvs[:, 1], uvs[:, 0]] = rgb.reshape(-1, 3)
+ depth_full[uvs[:, 1], uvs[:, 0]] = depth.reshape(-1)
+ gt_rgb_full = np.zeros((self.H, self.W, 3), dtype=float)
+ gt_rgb_full[uvs[:, 1], uvs[:, 0]] = gt_rgb.reshape(-1, 3).data.cpu().numpy()
+ gt_depth_full = np.zeros((self.H, self.W), dtype=float)
+ gt_depth_full[uvs[:, 1], uvs[:, 0]] = gt_depth.reshape(-1).data.cpu().numpy()
+
+ return rgb_full, depth_full, ray_mask_full, gt_rgb_full, gt_depth_full, extras
diff --git a/mycuda/bindings.cpp b/mycuda/bindings.cpp
new file mode 100644
index 0000000..677d940
--- /dev/null
+++ b/mycuda/bindings.cpp
@@ -0,0 +1,7 @@
+#include
+#include "common.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("sampleRaysUniformOccupiedVoxels", &sampleRaysUniformOccupiedVoxels);
+ m.def("postprocessOctreeRayTracing", &postprocessOctreeRayTracing);
+}
\ No newline at end of file
diff --git a/mycuda/common.cu b/mycuda/common.cu
new file mode 100644
index 0000000..49123f1
--- /dev/null
+++ b/mycuda/common.cu
@@ -0,0 +1,166 @@
+/*
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+*/
+
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "common.h"
+#include "Eigen/Dense"
+
+
+
+/**
+ * @brief
+ *
+ * @tparam scalar_t
+ * @param z_sampled
+ * @param z_in_out
+ * @param z_vals
+ * @return __global__
+ */
+template
+__global__ void sample_rays_uniform_occupied_voxels_kernel(const torch::PackedTensorAccessor32 z_sampled, const torch::PackedTensorAccessor32 z_in_out, torch::PackedTensorAccessor32 z_vals)
+{
+ const int i_ray = blockIdx.x * blockDim.x + threadIdx.x;
+ const int i_sample = blockIdx.y * blockDim.y + threadIdx.y;
+ if (i_ray>=z_sampled.size(0)) return;
+ if (i_sample>=z_sampled.size(1)) return;
+
+ int i_box = 0;
+ float z_remain = z_sampled[i_ray][i_sample];
+ auto z_in_out_cur_ray = z_in_out[i_ray];
+ const float eps = 1e-4;
+ const int max_n_box = z_in_out.size(1);
+
+ if (z_in_out_cur_ray[0][0]==0) return;
+
+ while (1)
+ {
+ if (i_box>=max_n_box)
+ {
+ if (z_remain<=eps)
+ {
+ z_vals[i_ray][i_sample] = z_in_out_cur_ray[max_n_box-1][1];
+ }
+ else
+ {
+ printf("ERROR sample_rays_uniform_occupied_voxels_kernel: z_remain=%f, i_ray=%d, i_sample=%d, i_box=%d, z_in_out_cur_ray=(%f,%f)\n",z_remain,i_ray,i_sample,i_box,z_in_out_cur_ray[i_box][0],z_in_out_cur_ray[i_box][1]);
+ for (int i=0;i=1)
+ {
+ z_vals[i_ray][i_sample] = z_in_out_cur_ray[i_box-1][1];
+ return;
+ }
+ else
+ {
+ printf("ERROR sample_rays_uniform_occupied_voxels_kernel: z_remain=%f, i_ray=%d, i_sample=%d, i_box=%d, z_in_out_cur_ray=(%f,%f)\n",z_remain,i_ray,i_sample,i_box,z_in_out_cur_ray[i_box][0],z_in_out_cur_ray[i_box][1]);
+ for (int i=0;i<<<{divCeil(N_rays,threadx),divCeil(N_samples,thready)}, {threadx,thready}>>>(z_sampled.packed_accessor32(),z_in_out.packed_accessor32(),z_vals.packed_accessor32());
+ }));
+
+ return z_vals;
+}
+
+
+template
+__global__ void postprocessOctreeRayTracingKernel(const torch::PackedTensorAccessor32 ray_index, const torch::PackedTensorAccessor32 depth_in_out, const torch::PackedTensorAccessor32 unique_intersect_ray_ids, const torch::PackedTensorAccessor32 start_poss, torch::PackedTensorAccessor32 depths_in_out_padded)
+{
+ const int unique_id_pos = blockIdx.x * blockDim.x + threadIdx.x;
+ if (unique_id_pos>=unique_intersect_ray_ids.size(0)) return;
+ const int i_ray = unique_intersect_ray_ids[unique_id_pos];
+
+ int i_intersect = 0;
+ auto cur_depths_in_out_padded = depths_in_out_padded[i_ray];
+ for (int i=start_poss[unique_id_pos];idepth_in_out[i][1]) continue;
+ if (abs(depth_in_out[i][1]-depth_in_out[i][0])<1e-4) continue;
+
+ cur_depths_in_out_padded[i_intersect][0] = depth_in_out[i][0];
+ cur_depths_in_out_padded[i_intersect][1] = depth_in_out[i][1];
+ i_intersect++;
+
+ }
+}
+
+at::Tensor postprocessOctreeRayTracing(const at::Tensor ray_index, const at::Tensor depth_in_out, const at::Tensor unique_intersect_ray_ids, const at::Tensor start_poss, const int max_intersections, const int N_rays)
+{
+ CHECK_INPUT(ray_index);
+ CHECK_INPUT(depth_in_out);
+ CHECK_INPUT(start_poss);
+
+ const int n_unique_ids = unique_intersect_ray_ids.sizes()[0];
+ at::Tensor depths_in_out_padded = at::zeros({N_rays,max_intersections,2}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, 0).requires_grad(false));
+ dim3 threads = {256};
+ dim3 blocks = {divCeil(n_unique_ids,threads.x)};
+ AT_DISPATCH_FLOATING_TYPES(depth_in_out.type(), "postprocessOctreeRayTracingKernel", ([&]
+ {
+ postprocessOctreeRayTracingKernel<<>>(ray_index.packed_accessor32(), depth_in_out.packed_accessor32(), unique_intersect_ray_ids.packed_accessor32(), start_poss.packed_accessor32(), depths_in_out_padded.packed_accessor32());
+ }));
+
+ return depths_in_out_padded;
+}
+
diff --git a/mycuda/common.h b/mycuda/common.h
new file mode 100644
index 0000000..d37f45a
--- /dev/null
+++ b/mycuda/common.h
@@ -0,0 +1,29 @@
+/*
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+*/
+
+
+#pragma once
+
+#include
+#include
+
+#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+
+inline int divCeil(int a, int b)
+{
+ return (a+b-1)/b;
+};
+
+
+at::Tensor sampleRaysUniformOccupiedVoxels(const at::Tensor z_in_out, const at::Tensor z_sampled, at::Tensor z_vals);
+at::Tensor postprocessOctreeRayTracing(const at::Tensor ray_index, const at::Tensor depth_in_out, const at::Tensor unique_intersect_ray_ids, const at::Tensor start_poss, const int max_intersections, const int N_rays);
diff --git a/mycuda/setup.py b/mycuda/setup.py
new file mode 100644
index 0000000..e43cb1b
--- /dev/null
+++ b/mycuda/setup.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+from setuptools import setup
+import os,sys
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+from torch.utils.cpp_extension import load
+
+code_dir = os.path.dirname(os.path.realpath(__file__))
+
+
+nvcc_flags = ['-Xcompiler', '-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__']
+c_flags = ['-O3', '-std=c++14']
+
+setup(
+ name='common',
+ extra_cflags=c_flags,
+ extra_cuda_cflags=nvcc_flags,
+ ext_modules=[
+ CUDAExtension('common', [
+ 'bindings.cpp',
+ 'common.cu',
+ ],extra_compile_args={'gcc': c_flags, 'nvcc': nvcc_flags}),
+ CUDAExtension('gridencoder', [
+ f"{code_dir}/torch_ngp_grid_encoder/gridencoder.cu",
+ f"{code_dir}/torch_ngp_grid_encoder/bindings.cpp",
+ ],extra_compile_args={'gcc': c_flags, 'nvcc': nvcc_flags}),
+ ],
+ include_dirs=[
+ "/usr/local/include/eigen3",
+ "/usr/include/eigen3",
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+})
diff --git a/mycuda/torch_ngp_grid_encoder/LICENSE b/mycuda/torch_ngp_grid_encoder/LICENSE
new file mode 100644
index 0000000..56f524a
--- /dev/null
+++ b/mycuda/torch_ngp_grid_encoder/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 hawkey
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/mycuda/torch_ngp_grid_encoder/bindings.cpp b/mycuda/torch_ngp_grid_encoder/bindings.cpp
new file mode 100644
index 0000000..98cf755
--- /dev/null
+++ b/mycuda/torch_ngp_grid_encoder/bindings.cpp
@@ -0,0 +1,18 @@
+/*
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+*/
+
+#include
+
+#include "gridencoder.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
+ m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
+}
\ No newline at end of file
diff --git a/mycuda/torch_ngp_grid_encoder/grid.py b/mycuda/torch_ngp_grid_encoder/grid.py
new file mode 100644
index 0000000..2077ee1
--- /dev/null
+++ b/mycuda/torch_ngp_grid_encoder/grid.py
@@ -0,0 +1,171 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import numpy as np
+import os,sys,pdb
+code_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(code_dir)
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_bwd, custom_fwd
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+from torch.utils.cpp_extension import load
+import os,sys
+import gridencoder
+
+
+_gridtype_to_id = {
+ 'hash': 0,
+ 'tiled': 1,
+}
+
+class _grid_encode(Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False):
+ # inputs: [B, D], float in [0, 1]
+ # embeddings: [sO, C], float
+ # offsets: [L + 1], int
+ # RETURN: [B, F], float
+
+ inputs = inputs.contiguous()
+
+ B, D = inputs.shape # batch size, coord dim
+ L = offsets.shape[0] - 1 # level
+ C = embeddings.shape[1] # embedding dim for each level
+ S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
+ H = base_resolution # base resolution
+
+ # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
+ # if C % 2 != 0, force float, since half for atomicAdd is very slow.
+ if torch.is_autocast_enabled() and C % 2 == 0:
+ embeddings = embeddings.to(torch.half)
+
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
+
+ if calc_grad_inputs:
+ dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype)
+ else:
+ dy_dx = torch.empty(1, device=inputs.device, dtype=embeddings.dtype) # placeholder... TODO: a better way?
+
+ gridencoder.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners)
+
+ # permute back to [B, L * C]
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
+
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
+ ctx.dims = [B, D, C, L, S, H, gridtype]
+ ctx.calc_grad_inputs = calc_grad_inputs
+ ctx.align_corners = align_corners
+
+ return outputs
+
+ @staticmethod
+ #@once_differentiable
+ @custom_bwd
+ def backward(ctx, grad):
+
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
+ B, D, C, L, S, H, gridtype = ctx.dims
+ calc_grad_inputs = ctx.calc_grad_inputs
+ align_corners = ctx.align_corners
+
+ # grad: [B, L * C] --> [L, B, C]
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
+
+ grad_embeddings = torch.zeros_like(embeddings)
+
+ if calc_grad_inputs:
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
+ else:
+ grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
+
+ gridencoder.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners)
+
+ if calc_grad_inputs:
+ grad_inputs = grad_inputs.to(inputs.dtype)
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
+ else:
+ return None, grad_embeddings, None, None, None, None, None, None
+
+
+grid_encode = _grid_encode.apply
+
+
+# https://github.com/ashawkey/torch-ngp/blob/main/gridencoder/grid.py
+class GridEncoder(nn.Module):
+ def __init__(self, input_dim=3, n_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False):
+ super().__init__()
+
+ per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (n_levels - 1))
+
+ self.input_dim = input_dim # coord dims, 2 or 3
+ self.n_levels = n_levels # num levels, each level multiply resolution by 2
+ self.level_dim = level_dim # encode channels per level
+ self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
+ self.log2_hashmap_size = log2_hashmap_size
+ self.base_resolution = base_resolution
+ self.out_dim = n_levels * level_dim
+ self.gridtype = gridtype
+ self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
+ self.align_corners = align_corners
+
+ # allocate parameters
+ offsets = []
+ offset = 0
+ self.max_params = 2 ** log2_hashmap_size
+ for i in range(n_levels):
+ resolution = int(np.ceil(base_resolution * per_level_scale ** i))
+ params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number
+ params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
+ print(f"level {i}, resolution: {resolution}")
+ offsets.append(offset)
+ offset += params_in_level
+ offsets.append(offset)
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
+ self.register_buffer('offsets', offsets)
+
+ self.n_params = offsets[-1] * level_dim
+
+ # parameters
+ self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
+ # self.embeddings = nn.Embedding(offset, level_dim, sparse=True)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ std = 1e-4
+ self.embeddings.data.uniform_(-std, std)
+ # self.embeddings.weight.data.uniform_(-std, std)
+
+
+ def __repr__(self):
+ return f"GridEncoder: input_dim={self.input_dim} n_levels={self.n_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.n_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
+
+
+ def forward(self, inputs, bound=1):
+ # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
+ # return: [..., num_levels * level_dim]
+
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
+
+ #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
+
+ prefix_shape = list(inputs.shape[:-1])
+ inputs = inputs.view(-1, self.input_dim)
+
+ outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners)
+ outputs = outputs.view(prefix_shape + [self.out_dim])
+
+ #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
+
+ return outputs
\ No newline at end of file
diff --git a/mycuda/torch_ngp_grid_encoder/gridencoder.cu b/mycuda/torch_ngp_grid_encoder/gridencoder.cu
new file mode 100644
index 0000000..8f5ae49
--- /dev/null
+++ b/mycuda/torch_ngp_grid_encoder/gridencoder.cu
@@ -0,0 +1,491 @@
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+#include
+#include
+
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
+#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
+
+
+// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
+static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
+ // requires CUDA >= 10 and ARCH >= 70
+ // this is very slow compared to float or __half2, and never used.
+ //return atomicAdd(reinterpret_cast<__half*>(address), val);
+}
+
+
+template
+static inline __host__ __device__ T div_round_up(T val, T divisor) {
+ return (val + divisor - 1) / divisor;
+}
+
+
+template
+__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
+
+ // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
+ // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
+ // coordinates.
+ constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
+
+ uint32_t result = 0;
+ #pragma unroll
+ for (uint32_t i = 0; i < D; ++i) {
+ result ^= pos_grid[i] * primes[i];
+ }
+
+ return result;
+}
+
+
+template
+__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
+ uint32_t stride = 1;
+ uint32_t index = 0;
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
+ index += pos_grid[d] * stride;
+ stride *= align_corners ? resolution: (resolution + 1);
+ }
+
+ // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
+ // gridtype: 0 == hash, 1 == tiled
+ if (gridtype == 0 && stride > hashmap_size) {
+ index = fast_hash(pos_grid);
+ }
+
+ return (index % hashmap_size) * C + ch;
+}
+
+
+/**
+ * @brief
+ *
+ * @tparam scalar_t
+ * @tparam D : point dimension usually 3D
+ * @tparam C : embedding dim 1/2/4
+ * @param inputs
+ * @param grid
+ * @param offsets
+ * @param outputs
+ * @param B
+ * @param L
+ * @param S
+ * @param H
+ * @param calc_grad_inputs
+ * @param dy_dx
+ * @param gridtype
+ * @param align_corners
+ * @return __global__
+ */
+template
+__global__ void kernel_grid(
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ outputs,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const bool calc_grad_inputs,
+ scalar_t * __restrict__ dy_dx,
+ const uint32_t gridtype,
+ const bool align_corners
+) {
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; // batch id
+
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+
+ grid += (uint32_t)offsets[level] * C;
+ inputs += b * D;
+ outputs += level * B * C + b * C;
+
+ bool flag_oob = false;
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ flag_oob = true;
+ }
+ }
+
+ if (flag_oob) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = 0;
+ }
+ if (calc_grad_inputs) {
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[d * C + ch] = 0;
+ }
+ }
+ }
+ return;
+ }
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ pos[d] -= (float)pos_grid[d];
+ }
+
+ scalar_t results[C] = {0}; // temp results in register
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); // Hashing
+
+ // writing to register (fast)
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results[ch] += w * grid[index + ch];
+ }
+
+ }
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ outputs[ch] = results[ch];
+ }
+
+ if (calc_grad_inputs) {
+
+ dy_dx += b * D * L * C + level * D * C; // B L D C
+
+ #pragma unroll
+ for (uint32_t gd = 0; gd < D; gd++) {
+
+ scalar_t results_grad[C] = {0};
+
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
+ float w = scale;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
+
+ if ((idx & (1 << nd)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ pos_grid_local[gd] = pos_grid[gd];
+ uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+ pos_grid_local[gd] = pos_grid[gd] + 1;
+ uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
+ }
+ }
+
+ #pragma unroll
+ for (uint32_t ch = 0; ch < C; ch++) {
+ dy_dx[gd * C + ch] = results_grad[ch];
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_grid_backward(
+ const scalar_t * __restrict__ grad,
+ const float * __restrict__ inputs,
+ const scalar_t * __restrict__ grid,
+ const int * __restrict__ offsets,
+ scalar_t * __restrict__ grad_grid,
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
+ const uint32_t gridtype,
+ const bool align_corners
+) {
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
+ if (b >= B) return;
+
+ const uint32_t level = blockIdx.y;
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
+
+ // locate
+ grad_grid += offsets[level] * C;
+ inputs += b * D;
+ grad += level * B * C + b * C + ch; // L, B, C
+
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
+ const float scale = exp2f(level * S) * H - 1.0f;
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
+
+ // check input range (should be in [0, 1])
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if (inputs[d] < 0 || inputs[d] > 1) {
+ return; // grad is init as 0, so we simply return.
+ }
+ }
+
+ // calculate coordinate
+ float pos[D];
+ uint32_t pos_grid[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
+ pos_grid[d] = floorf(pos[d]);
+ pos[d] -= (float)pos_grid[d];
+ }
+
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ grad_cur[c] = grad[c];
+ }
+
+ // interpolate
+ #pragma unroll
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
+ float w = 1;
+ uint32_t pos_grid_local[D];
+
+ #pragma unroll
+ for (uint32_t d = 0; d < D; d++) {
+ if ((idx & (1 << d)) == 0) {
+ w *= 1 - pos[d];
+ pos_grid_local[d] = pos_grid[d];
+ } else {
+ w *= pos[d];
+ pos_grid_local[d] = pos_grid[d] + 1;
+ }
+ }
+
+ uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
+
+ // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
+ // TODO: use float which is better than __half, if N_C % 2 != 0
+ if (std::is_same::value && N_C % 2 == 0) {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c += 2) {
+ // process two __half at once (by interpreting as a __half2)
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
+ atomicAdd((__half2*)&grad_grid[index + c], v);
+ }
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
+ } else {
+ #pragma unroll
+ for (uint32_t c = 0; c < N_C; c++) {
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
+ }
+ }
+ }
+}
+
+
+template
+__global__ void kernel_input_backward(
+ const scalar_t * __restrict__ grad,
+ const scalar_t * __restrict__ dy_dx,
+ scalar_t * __restrict__ grad_inputs,
+ uint32_t B, uint32_t L
+) {
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
+ if (t >= B * D) return;
+
+ const uint32_t b = t / D;
+ const uint32_t d = t - b * D;
+
+ dy_dx += b * L * D * C;
+
+ scalar_t result = 0;
+
+ # pragma unroll
+ for (int l = 0; l < L; l++) {
+ # pragma unroll
+ for (int ch = 0; ch < C; ch++) {
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
+ }
+ }
+
+ grad_inputs[t] = result;
+}
+
+
+template
+void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
+ static constexpr uint32_t N_THREAD = 512;
+ const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
+ switch (C) {
+ case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
+// H: base resolution
+// dy_dx: [B, L * D * C]
+template
+void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
+ switch (D) {
+ case 1: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
+ }
+
+}
+
+template
+void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
+ static constexpr uint32_t N_THREAD = 256;
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
+ const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
+ switch (C) {
+ case 1:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 2:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 4:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ case 8:
+ kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
+ if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L);
+ break;
+ default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
+ }
+}
+
+
+// grad: [L, B, C], float
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// grad_embeddings: [sO, C]
+// H: base resolution
+template
+void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
+ switch (D) {
+ case 1: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
+ case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
+ default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
+ }
+}
+
+
+
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) {
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(outputs);
+ CHECK_CUDA(dy_dx);
+
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(outputs);
+ CHECK_CONTIGUOUS(dy_dx);
+
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(outputs);
+ CHECK_IS_FLOATING(dy_dx);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
+ grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), gridtype, align_corners);
+ }));
+}
+
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) {
+ CHECK_CUDA(grad);
+ CHECK_CUDA(inputs);
+ CHECK_CUDA(embeddings);
+ CHECK_CUDA(offsets);
+ CHECK_CUDA(grad_embeddings);
+ CHECK_CUDA(dy_dx);
+ CHECK_CUDA(grad_inputs);
+
+ CHECK_CONTIGUOUS(grad);
+ CHECK_CONTIGUOUS(inputs);
+ CHECK_CONTIGUOUS(embeddings);
+ CHECK_CONTIGUOUS(offsets);
+ CHECK_CONTIGUOUS(grad_embeddings);
+ CHECK_CONTIGUOUS(dy_dx);
+ CHECK_CONTIGUOUS(grad_inputs);
+
+ CHECK_IS_FLOATING(grad);
+ CHECK_IS_FLOATING(inputs);
+ CHECK_IS_FLOATING(embeddings);
+ CHECK_IS_INT(offsets);
+ CHECK_IS_FLOATING(grad_embeddings);
+ CHECK_IS_FLOATING(dy_dx);
+ CHECK_IS_FLOATING(grad_inputs);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "grid_encode_backward", ([&] {
+ grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), grad_inputs.data_ptr(), gridtype, align_corners);
+ }));
+
+}
diff --git a/mycuda/torch_ngp_grid_encoder/gridencoder.h b/mycuda/torch_ngp_grid_encoder/gridencoder.h
new file mode 100644
index 0000000..b093e78
--- /dev/null
+++ b/mycuda/torch_ngp_grid_encoder/gridencoder.h
@@ -0,0 +1,15 @@
+#ifndef _HASH_ENCODE_H
+#define _HASH_ENCODE_H
+
+#include
+#include
+
+// inputs: [B, D], float, in [0, 1]
+// embeddings: [sO, C], float
+// offsets: [L + 1], uint32_t
+// outputs: [B, L * C], float
+// H: base resolution
+void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners);
+void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners);
+
+#endif
\ No newline at end of file
diff --git a/network.py b/network.py
new file mode 100644
index 0000000..460c5a8
--- /dev/null
+++ b/network.py
@@ -0,0 +1,404 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mycuda.torch_ngp_grid_encoder.grid import GridEncoder
+
+
+class SHEncoder(nn.Module):
+ '''Spherical encoding
+ '''
+
+ def __init__(self, input_dim=3, degree=4):
+
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.degree = degree
+
+ assert self.input_dim == 3
+ assert self.degree >= 1 and self.degree <= 5
+
+ self.out_dim = degree ** 2
+
+ self.C0 = 0.28209479177387814
+ self.C1 = 0.4886025119029199
+ self.C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+ ]
+ self.C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+ ]
+ self.C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761
+ ]
+
+ def forward(self, input, **kwargs):
+
+ result = torch.empty((*input.shape[:-1], self.out_dim), dtype=input.dtype, device=input.device)
+ x, y, z = input.unbind(-1)
+
+ result[..., 0] = self.C0
+ if self.degree > 1:
+ result[..., 1] = -self.C1 * y
+ result[..., 2] = self.C1 * z
+ result[..., 3] = -self.C1 * x
+ if self.degree > 2:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result[..., 4] = self.C2[0] * xy
+ result[..., 5] = self.C2[1] * yz
+ result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy)
+ # result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting...
+ result[..., 7] = self.C2[3] * xz
+ result[..., 8] = self.C2[4] * (xx - yy)
+ if self.degree > 3:
+ result[..., 9] = self.C3[0] * y * (3 * xx - yy)
+ result[..., 10] = self.C3[1] * xy * z
+ result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy)
+ result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy)
+ result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy)
+ result[..., 14] = self.C3[5] * z * (xx - yy)
+ result[..., 15] = self.C3[6] * x * (xx - 3 * yy)
+ if self.degree > 4:
+ result[..., 16] = self.C4[0] * xy * (xx - yy)
+ result[..., 17] = self.C4[1] * yz * (3 * xx - yy)
+ result[..., 18] = self.C4[2] * xy * (7 * zz - 1)
+ result[..., 19] = self.C4[3] * yz * (7 * zz - 3)
+ result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3)
+ result[..., 21] = self.C4[5] * xz * (7 * zz - 3)
+ result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1)
+ result[..., 23] = self.C4[7] * xz * (xx - 3 * yy)
+ result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
+
+ return result
+
+
+class FeatureVolume(nn.Module):
+ def __init__(self, out_dim, res, num_dim=3):
+ super().__init__()
+ self.grid = torch.nn.Parameter(torch.zeros([1, out_dim] + [res] * num_dim,
+ dtype=torch.float32))
+ self.out_dim = out_dim
+
+ def forward(self, pts):
+ feat = F.grid_sample(self.grid, pts[None, None, None, :, :], mode='bilinear',
+ align_corners=True) # [1, C, 1, 1, N]
+ return feat[0, :, 0, 0, :].permute(1, 0) # [N, C]
+
+
+class NeRFSmall(nn.Module):
+ def __init__(self, num_layers=3, hidden_dim=64, geo_feat_dim=15, num_layers_color=4, hidden_dim_color=64,
+ input_ch=3, input_ch_views=3):
+ super(NeRFSmall, self).__init__()
+
+ self.input_ch = input_ch
+ self.input_ch_views = input_ch_views
+
+ # sigma network
+ self.num_layers = num_layers
+ self.hidden_dim = hidden_dim
+ self.geo_feat_dim = geo_feat_dim
+
+ sigma_net = []
+ for l in range(num_layers):
+ if l == 0:
+ in_dim = self.input_ch
+ else:
+ in_dim = hidden_dim
+
+ if l == num_layers - 1:
+ out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color
+ else:
+ out_dim = hidden_dim
+
+ sigma_net.append(nn.Linear(in_dim, out_dim, bias=True))
+ if l != num_layers - 1:
+ sigma_net.append(nn.ReLU(inplace=True))
+
+ self.sigma_net = nn.Sequential(*sigma_net)
+ torch.nn.init.constant_(self.sigma_net[-1].bias, -1.) # Encourage last layer predict positive SDF
+
+ # color network
+ self.num_layers_color = num_layers_color
+ self.hidden_dim_color = hidden_dim_color
+
+ color_net = []
+ for l in range(num_layers_color):
+ if l == 0:
+ in_dim = self.input_ch_views + self.geo_feat_dim
+ else:
+ in_dim = hidden_dim
+
+ if l == num_layers_color - 1:
+ out_dim = 3 # 3 rgb
+ else:
+ out_dim = hidden_dim
+
+ color_net.append(nn.Linear(in_dim, out_dim, bias=True))
+ if l != num_layers_color - 1:
+ color_net.append(nn.ReLU(inplace=True))
+
+ self.color_net = nn.Sequential(*color_net)
+
+ def forward_sdf(self, x):
+ '''
+ @x: embedded positions
+ '''
+ h = self.sigma_net(x)
+ sigma, geo_feat = h[..., 0], h[..., 1:]
+ return sigma
+
+ def forward(self, x):
+ x = x.float()
+ input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
+
+ # sigma
+ h = input_pts
+ h = self.sigma_net(h)
+
+ sigma, geo_feat = h[..., 0], h[..., 1:]
+
+ # color
+ h = torch.cat([input_views, geo_feat], dim=-1)
+ color = self.color_net(h)
+
+ outputs = torch.cat([color, sigma.unsqueeze(dim=-1)], -1)
+
+ return outputs
+
+
+def rotat_from_6d(ortho6d):
+ def normalize_vector(v, return_mag=False):
+ batch = v.shape[0]
+ v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
+ v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda()))
+ v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
+ v = v / v_mag
+ if (return_mag == True):
+ return v, v_mag[:, 0]
+ else:
+ return v
+
+ def cross_product(u, v):
+ batch = u.shape[0]
+ i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
+ j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
+ k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
+ out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) # batch*3
+ return out
+
+ x_raw = ortho6d[:, 0:3] # batch*3 100
+ y_raw = ortho6d[:, 3:6] # batch*3
+ x = normalize_vector(x_raw) # batch*3 100
+ z = cross_product(x, y_raw) # batch*3
+ z = normalize_vector(z) # batch*3
+ y = cross_product(z, x) # batch*3
+ x = x.view(-1, 3, 1)
+ y = y.view(-1, 3, 1)
+ z = z.view(-1, 3, 1)
+ matrix = torch.cat((x, y, z), 2) # batch*3*3
+ return matrix
+
+
+class GumbelAttn:
+ def __init__(self, device, dtype, tau):
+ self.gumbel_dist = torch.distributions.gumbel.Gumbel(
+ torch.tensor(0., device=device, dtype=dtype),
+ torch.tensor(1., device=device, dtype=dtype))
+ self.tau = tau
+
+ def process_dots(self, dots, training=True):
+ if training:
+ gumbels = self.gumbel_dist.sample(dots.shape)
+ # gumbels = (torch.log(dots.softmax(dim=1) + 1e-7) + gumbels) / self.tau # ~Gumbel(logits,tau)
+ gumbels = (dots + gumbels) / self.tau
+ else:
+ gumbels = dots
+ return gumbels
+
+
+def compute_attn(logits, style, gumbel_module, training=True):
+ raw_attn = logits.softmax(dim=1)
+
+ if 'gumbel' in style and training:
+ gumbels = gumbel_module.process_dots(logits, training=training)
+ attn = gumbels.softmax(dim=1) # [1, S, N]
+ else:
+ attn = raw_attn
+
+ if 'soft' in style and training:
+ return attn, raw_attn
+ else:
+ index = attn.max(dim=1, keepdim=True)[1]
+ y_hard = torch.zeros_like(attn, memory_format=torch.legacy_contiguous_format).scatter_(1, index, 1.0)
+ attn_hard = y_hard - attn.detach() + attn
+ return attn_hard, raw_attn
+
+
+class PartArticulationNet(torch.nn.Module):
+ def __init__(self, device, feat_dim=20, num_layers=3, hidden_dim=64,
+ slot_name='motion_xyz', slot_num=16, slot_hard='hard',
+ gt_transform=None, inv_transform=None, fix_base=True, gt_joint_types=None):
+ super(PartArticulationNet, self).__init__()
+ self.device = device
+
+ self.slot_name = slot_name
+ self.slot_num = slot_num
+ self.slot_hard = slot_hard
+
+ self.gumbel_module = GumbelAttn(device=device, dtype=torch.float, tau=1.0)
+
+ # classification network
+ self.num_layers = num_layers
+ self.hidden_dim = hidden_dim
+
+ if 'xyz' in slot_name:
+ feat_dim += 3
+
+ self.net = self.create_layer(feat_dim, hidden_dim, num_layers, slot_num)
+ self.net.to(device)
+
+ self.gt_joint_types = gt_joint_types # will use the same parameters, but interpret them based on the joint types
+
+ self.rotation = nn.Parameter(torch.Tensor([1, 0, 0, 0, 1, 0]).unsqueeze(0).repeat(slot_num, 1))
+ self.translation = nn.Parameter(torch.zeros(slot_num, 3))
+ self.inv_transform = inv_transform
+
+ self.gt_transform = gt_transform
+
+ self.fix_base = fix_base
+
+ def create_layer(self, dimin, width, layers, dimout):
+ if layers == 1:
+ return nn.Sequential(nn.Linear(dimin, dimout))
+ else:
+ return nn.Sequential(
+ nn.Linear(dimin, width), nn.ReLU(inplace=True),
+ *[nn.Sequential(nn.Linear(width, width), nn.ReLU(inplace=True)) for _ in range(layers - 2)],
+ nn.Linear(width, dimout), )
+
+ def get_raw_slot_transform(self): # note that the output R is transposed!!
+ if self.inv_transform is None:
+ rotat_back = self.rotation
+ trans_back = self.translation
+
+ if self.fix_base:
+ rotat_static = torch.tensor([[1, 0, 0], [0, 1, 0]], dtype=rotat_back.dtype,
+ device=rotat_back.device).reshape(1, 6)
+ trans_static = torch.tensor([0, 0, 0], dtype=trans_back.dtype, device=trans_back.device).reshape(1, 3)
+
+ if self.gt_transform is None:
+ rotat_other = rotat_back[1:]
+ trans_other = trans_back[1:]
+ else:
+ rotat_other = torch.tensor(self.gt_transform['rot'], dtype=rotat_back.dtype,
+ device=rotat_back.device)[:2].reshape(1, 6)
+ trans_other = torch.tensor(self.gt_transform['trans'], dtype=trans_back.dtype,
+ device=trans_back.device).reshape(1, 3)
+
+ rotat_all, trans_all = torch.cat([rotat_static, rotat_other], dim=0), \
+ torch.cat([trans_static, trans_other], dim=0)
+ else:
+ rotat_all, trans_all = rotat_back, trans_back
+
+ rotat_all = rotat_from_6d(rotat_all)
+
+ if self.gt_joint_types is not None:
+ revolute_rotation = rotat_all
+
+ mat = torch.eye(3).reshape(1, 3, 3).to(rotat_all.device) - rotat_all
+ U, S, Vh = torch.linalg.svd(mat.detach())
+ U = U * (S.unsqueeze(-2) > 0)
+ revolute_translation = torch.matmul(U, torch.matmul(U.transpose(-1, -2),
+ trans_all.unsqueeze(-1))).squeeze(-1)
+ prismatic_rotation = torch.eye(3).reshape(1, 3, 3).repeat(self.slot_num, 1, 1).to(rotat_all.device)
+ prismatic_translation = trans_all
+
+ revolute_mask = torch.tensor([joint_type == 'revolute' for joint_type in self.gt_joint_types],
+ dtype=revolute_rotation.dtype, device=revolute_rotation.device)
+
+ rotat_all = revolute_mask.reshape(-1, 1, 1) * revolute_rotation + \
+ (1 - revolute_mask.reshape(-1, 1, 1)) * prismatic_rotation
+ trans_all = revolute_mask.reshape(-1, 1) * revolute_translation + \
+ (1 - revolute_mask.reshape(-1, 1)) * prismatic_translation
+
+ return rotat_all, trans_all
+ else:
+ inv_rot, inv_trans = self.inv_transform() # [S, 3, 3], [S, 3]
+ rot = inv_rot.transpose(-1, -2)
+ trans = -torch.matmul(inv_rot, inv_trans.unsqueeze(-1)).squeeze(-1)
+ return rot, trans
+
+ def back_deform(self, xyz_smp, xyz_smp_embedded, training=True):
+ num_points = len(xyz_smp)
+ slot_rotat, slot_trans = self.get_raw_slot_transform() # [S, xx]
+ slot_rotat = slot_rotat.unsqueeze(0).repeat(num_points, 1, 1, 1)
+ xyz = xyz_smp.unsqueeze(1).repeat(1, self.slot_num, 1) # [N, S, 3]
+ xyz_cnc = torch.einsum('nscd,nsde->nsce', slot_rotat, (xyz - slot_trans).unsqueeze(-1)).squeeze(-1) # [N, S, 3]
+ return {'xyz_cnc': xyz_cnc, 'inv': True}
+
+ def forw_attn(self, xyz_cnc, xyz_cnc_embedded, training=True):
+ feat_forw = xyz_cnc_embedded
+ if 'xyz' in self.slot_name:
+ feat_forw = torch.cat([feat_forw, xyz_cnc], dim=-1)
+ logits = self.net(feat_forw)
+ attn_hard, attn_soft = compute_attn(logits, self.slot_hard, self.gumbel_module, training=training)
+ attn_hard = attn_hard
+ attn_soft = attn_soft
+ return attn_hard, attn_soft
+
+ def forw_deform(self, xyz_cnc, xyz_cnc_embedded, training=True, gt_attn=None):
+ attn_hard, attn_soft = self.forw_attn(xyz_cnc, xyz_cnc_embedded, training=training)
+ if gt_attn is not None:
+ attn_hard, attn_soft = gt_attn, gt_attn
+
+ rotat_forw_cand, trans_forw_cand = self.get_raw_slot_transform()
+
+ rotat_forw = (attn_hard.unsqueeze(-1).unsqueeze(-1) * rotat_forw_cand.unsqueeze(0)).sum(dim=1)
+
+ trans_forw = (attn_hard.unsqueeze(-1) * trans_forw_cand.unsqueeze(0)).sum(dim=1)
+
+ xyz_smp_pred = torch.einsum('bcd,bde->bce', rotat_forw.permute(0, 2, 1), xyz_cnc.unsqueeze(-1)).squeeze(-1)
+ xyz_smp_pred = xyz_smp_pred + trans_forw
+
+ xyz_smp_pred_cand = torch.matmul(rotat_forw_cand.permute(0, 2, 1).unsqueeze(0),
+ xyz_cnc.unsqueeze(1).unsqueeze(-1)).squeeze(-1) # 1, S, 3, 3; N, 1, 3, 1 --> N, S, 3
+ xyz_smp_pred_cand = xyz_smp_pred_cand + trans_forw_cand.unsqueeze(0)
+
+ return {'attn_hard': attn_hard, 'attn_soft': attn_soft, # [N, S]
+ 'world_pts': xyz_smp_pred, 'world_pts_cand': xyz_smp_pred_cand, # [N, S, 3]
+ 'rotation': rotat_forw.permute(0, 2, 1), 'translation': trans_forw,
+ 'rotation_cand': rotat_forw_cand.permute(0, 2, 1)}
+
+ def get_slot_motions(self): # attn_hard: [N, S], feat_forw: [N, S, C], timesteps: [T, 1]
+ rot, trans = self.get_raw_slot_transform()
+ return rot, trans
+
+
diff --git a/preproc/gen_correspondence.py b/preproc/gen_correspondence.py
new file mode 100644
index 0000000..8bc7219
--- /dev/null
+++ b/preproc/gen_correspondence.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import sys
+from os.path import join as pjoin
+base_dir = os.path.dirname(__file__)
+sys.path.insert(0, base_dir)
+sys.path.insert(0, pjoin(base_dir, '..'))
+import numpy as np
+import argparse
+import yaml
+import cv2
+import imageio
+from PIL import Image
+from tqdm import tqdm
+
+from utils.py_utils import list_to_array
+from loftr_wrapper import LoftrRunner
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_path', type=str, default=None)
+ parser.add_argument('--src_name', type=str, default=None)
+ parser.add_argument('--tgt_name', type=str, default=None)
+ parser.add_argument('--output_path', type=str, default=None)
+ parser.add_argument('--top_k', type=int, default=-1)
+ parser.add_argument('--batch', action='store_true')
+ parser.add_argument('--cat', type=str, default=None)
+
+ return parser.parse_args()
+
+
+def read_img_dict(folder, name, poses=None):
+ rgb_subfolder = 'color_raw'
+ if not os.path.exists(pjoin(folder, rgb_subfolder)):
+ rgb_subfolder = 'color_segmented'
+ img_dict = {}
+ for key, subfolder in (('rgb', rgb_subfolder), ('mask', 'mask')):
+ img_name = pjoin(folder, subfolder, f'{name}.png')
+ img = np.array(Image.open(img_name))
+ if key == 'mask' and len(img.shape) == 3:
+ img = img[..., 0]
+ img_dict[key] = img
+ if poses is None:
+ poses = yaml.safe_load(open(pjoin(folder, 'init_keyframes.yml'), 'r'))
+ img_dict['cam2world'] = list_to_array(poses[f'frame_{name}']['cam_in_ob']).reshape(4, 4)
+
+ return img_dict, poses
+
+
+def run_source_img(loftr, folder, src_name, tgt_names=None, top_k=-1,
+ filter_level_list=['no_filter'], visualize=False, vis_path='test_loftr'):
+
+ src_time = int(src_name.split('_')[0])
+
+ K = np.loadtxt(pjoin(folder, 'cam_K.txt')).reshape(3, 3)
+
+ src_dict, poses = read_img_dict(folder, src_name)
+
+ src_position = src_dict['cam2world'][:3, 3]
+
+ if tgt_names is None:
+ all_tgt = []
+ for frame_name in poses:
+ if poses[frame_name]['time'] != src_time:
+ tgt_position = list_to_array(poses[frame_name]['cam_in_ob']).reshape(4, 4)[:3, 3]
+ all_tgt.append(('_'.join(frame_name.split('_')[-2:]), ((src_position - tgt_position) ** 2).sum()))
+ all_tgt.sort(key=lambda x: x[1])
+ if top_k <= 0:
+ top_k = len(all_tgt)
+ tgt_names = [pair[0] for pair in all_tgt[:top_k]]
+
+ tgt_dicts = [read_img_dict(folder, tgt_name, poses)[0] for tgt_name in tgt_names]
+
+ all_corr = compute_correspondence(loftr, src_dict, tgt_dicts, K, filter_level_list=filter_level_list,
+ visualize=visualize, vis_path=vis_path)
+
+ results = {filter_level: [{src_name: tgt_corr[0], tgt_name: tgt_corr[1]}
+ for tgt_name, tgt_corr in zip(tgt_names, corr)]
+ for filter_level, corr in all_corr.items()}
+
+ return results
+
+
+def draw_corr(rgbA, rgbB, corrA, corrB, output_name):
+ vis = np.concatenate([rgbA, rgbB], axis=1)
+ radius = 2
+ for i in range(len(corrA)):
+ uvA = corrA[i]
+ uvB = corrB[i].copy()
+ uvB[0] += rgbA.shape[1]
+ color = tuple(np.random.randint(0, 255, size=(3)).tolist())
+ vis = cv2.circle(vis, uvA, radius=radius, color=color, thickness=1)
+ vis = cv2.circle(vis, uvB, radius=radius, color=color, thickness=1)
+ vis = cv2.line(vis, uvA, uvB, color=color, thickness=1, lineType=cv2.LINE_AA)
+ imageio.imwrite(f'{output_name}.png', vis.astype(np.uint8))
+
+
+def compute_correspondence(loftr, src_dict, tgt_dicts, K, visualize=False, vis_path='test_loftr',
+ filter_level_list=['no_filter']):
+ src_mask = src_dict['mask']
+ src_rgb = src_dict['rgb']
+ img_h, img_w = src_rgb.shape[:2]
+
+ # src_pixels = np.stack(np.where(src_seg > 0), axis=-1)
+ fx, _, cx = K[0]
+ _, fy, cy = K[1]
+
+ tgt_rgbs = np.stack([tgt_dict['rgb'] for tgt_dict in tgt_dicts], axis=0)
+ corres = loftr.predict(np.tile(src_rgb[np.newaxis], (len(tgt_rgbs), 1, 1, 1)), tgt_rgbs)
+
+ def get_valid_mask(mask, coords):
+ valid = np.logical_and(np.logical_and(coords[..., 0] >= 0, coords[..., 0] < img_w),
+ np.logical_and(coords[..., 1] >= 0, coords[..., 1] < img_h))
+ valid = np.logical_and(valid, mask[coords[..., 1], coords[..., 0]])
+ return valid
+
+ os.makedirs('debug', exist_ok=True)
+
+ filtered_corr = {key: [] for key in filter_level_list}
+
+ for i, tgt_dict in enumerate(tgt_dicts):
+ tgt_mask, tgt_rgb = tgt_dict['mask'], tgt_dict['rgb']
+ cur_corres = corres[i]
+ src_coords = cur_corres[:, :2].round().astype(int)
+ tgt_coords = cur_corres[:, 2:4].round().astype(int)
+ valid_mask = np.logical_and(get_valid_mask(src_mask, src_coords),
+ get_valid_mask(tgt_mask, tgt_coords))
+
+ loftr_total = len(valid_mask)
+ valid_total = sum(valid_mask)
+
+ src_coords = src_coords[np.where(valid_mask)[0]]
+ tgt_coords = tgt_coords[np.where(valid_mask)[0]]
+
+ if 'no_filter' in filter_level_list:
+ filtered_corr['no_filter'].append((src_coords, tgt_coords))
+
+ if visualize:
+ draw_corr(src_rgb, tgt_rgb, src_coords, tgt_coords, pjoin(vis_path, f'{i}_1_no_filter_{valid_total}_of_{loftr_total}'))
+
+ return filtered_corr
+
+
+def main(args):
+
+ loftr = LoftrRunner()
+ if args.batch:
+ paris_path = args.data_path
+ for folder in os.listdir(paris_path):
+ if os.path.isdir(pjoin(paris_path, folder)):
+ if args.cat is not None and not folder.startswith(args.cat):
+ continue
+ print(f'Running for folder {folder}')
+ run_folder(loftr, pjoin(paris_path, folder),
+ src_name=None, tgt_name=None, output_path=pjoin(paris_path, folder, 'correspondence_loftr'),
+ top_k=args.top_k)
+ else:
+ run_folder(loftr, args.data_path, args.src_name, args.tgt_name, args.output_path, args.top_k)
+
+
+def run_folder(loftr, folder, src_name, tgt_name, output_path, top_k):
+ if src_name is not None:
+ src_names = [src_name]
+ else:
+ all_frames = yaml.safe_load(open(pjoin(folder, 'init_keyframes.yml'), 'r'))
+ src_names = ['_'.join(frame_name.split('_')[-2:]) for frame_name in all_frames]
+
+ if tgt_name is not None:
+ tgt_names = [tgt_name]
+ save_name = f'tgt_{tgt_name}'
+ else:
+ tgt_names = None
+ save_name = f'tgt_all'
+ if top_k >= 0:
+ save_name = f'{save_name}_top{top_k}'
+
+ os.makedirs(output_path, exist_ok=True)
+
+ pbar = tqdm(src_names)
+ for src_name in pbar:
+ pbar.set_description(src_name)
+ results = run_source_img(loftr, folder, src_name, tgt_names, top_k=top_k, visualize=False)
+ for filter_level in results:
+ os.makedirs(pjoin(output_path, filter_level), exist_ok=True)
+ np.savez_compressed(pjoin(output_path, filter_level, f'src_{src_name}_{save_name}.npz'),
+ data=results[filter_level])
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/preproc/loftr_wrapper.py b/preproc/loftr_wrapper.py
new file mode 100644
index 0000000..b73a125
--- /dev/null
+++ b/preproc/loftr_wrapper.py
@@ -0,0 +1,91 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+import os
+import torchvision
+import torch
+import logging
+import numpy as np
+import sys
+from os.path import join as pjoin
+base_dir = os.path.dirname(__file__)
+sys.path.insert(0, base_dir)
+sys.path.insert(0, pjoin(base_dir, '..'))
+from external.LoFTR.src.loftr import default_cfg, LoFTR
+
+state_dict_path = pjoin(base_dir, '../external/LoFTR/weights/outdoor_ds.ckpt')
+
+class LoftrRunner:
+ def __init__(self):
+ default_cfg['match_coarse']['thr'] = 0.2
+ print("default_cfg", default_cfg)
+ self.matcher = LoFTR(config=default_cfg)
+ self.matcher.load_state_dict(torch.load(state_dict_path)['state_dict'])
+ self.matcher = self.matcher.eval().cuda()
+
+ @torch.no_grad()
+ def predict(self, rgbAs: np.ndarray, rgbBs: np.ndarray):
+ '''
+ @rgbAs: (N,H,W,C)
+ '''
+ image0 = torch.from_numpy(rgbAs).permute(0, 3, 1, 2).float().cuda()
+ image1 = torch.from_numpy(rgbBs).permute(0, 3, 1, 2).float().cuda()
+ if image0.shape[1] == 3:
+ image0 = torchvision.transforms.functional.rgb_to_grayscale(image0)
+ image1 = torchvision.transforms.functional.rgb_to_grayscale(image1)
+ image0 = image0 / 255.0
+ image1 = image1 / 255.0
+ default_value = image0[0, 0, 0, 0]
+ last_data = {'image0': image0, 'image1': image1,
+ # 'mask0': (image0[:, 0] != default_value).float(),
+ # 'mask1': (image1[:, 0] != default_value).float()
+ }
+ logging.info(f"image0: {last_data['image0'].shape}")
+
+ batch_size = 4
+ ret_keys = ['mkpts0_f', 'mkpts1_f', 'mconf', 'm_bids']
+ with torch.cuda.amp.autocast(enabled=True):
+ i_b = 0
+ for b in range(0, len(last_data['image0']), batch_size):
+ tmp = {'image0': last_data['image0'][b:b + batch_size],
+ 'image1': last_data['image1'][b:b + batch_size]}
+ with torch.no_grad():
+ self.matcher(tmp)
+ tmp['m_bids'] += i_b
+ for k in ret_keys:
+ if k not in last_data:
+ last_data[k] = []
+ last_data[k].append(tmp[k])
+ i_b += len(tmp['image0'])
+
+ logging.info("net forward")
+
+ for k in ret_keys:
+ last_data[k] = torch.cat(last_data[k], dim=0)
+
+ mkpts0 = last_data['mkpts0_f'].cpu().numpy()
+ mkpts1 = last_data['mkpts1_f'].cpu().numpy()
+ mconf = last_data['mconf'].cpu().numpy()
+ pair_ids = last_data['m_bids'].cpu().numpy()
+ logging.info(f"mconf, {mconf.min()} {mconf.max()}")
+ logging.info(f'pair_ids {pair_ids.shape}')
+ corres = np.concatenate((mkpts0.reshape(-1, 2), mkpts1.reshape(-1, 2), mconf.reshape(-1, 1)), axis=-1).reshape(
+ -1, 5).astype(np.float32)
+
+ logging.info(f'corres: {corres.shape}')
+ corres_tmp = []
+ for i in range(len(rgbAs)):
+ cur_corres = corres[pair_ids == i]
+ corres_tmp.append(cur_corres)
+ corres = corres_tmp
+
+ del last_data, image0, image1
+ torch.cuda.empty_cache()
+
+ return corres
diff --git a/readme.md b/readme.md
new file mode 100644
index 0000000..5a13d60
--- /dev/null
+++ b/readme.md
@@ -0,0 +1,127 @@
+# Neural Implicit Representation for Building Digital Twins of Unknown Articulated Objects (CVPR 2024)
+
+Yijia Weng, Bowen Wen, Jonathan Tremblay, Valts Blukis, Dieter Fox, Leonidas Guibas, Stan Birchfield
+
+[project](https://nvlabs.github.io/DigitalTwinArt/) | [arxiv](https://arxiv.org/abs/2404.01440)
+
+
+
+
+
+
+
+We tackle the problem of building digital twins of unknown articulated objects from two RGBD scans of the object at different articulation states. We decompose the problem into two stages, each addressing distinct aspects. Our method first reconstructs object-level shape at each state, then recovers the underlying articulation model including part segmentation and joint articulations that associate the two states. By explicitly modeling point-level correspondences and exploiting cues from images, 3D reconstructions, and kinematics, our method yields more accurate and stable results compared to prior work. It also handles more than one movable part and does not rely on any object shape or structure priors.
+
+## Installation
+
+### Clone Repository
+
+```bash
+git clone --recursive git@github.com:HalfSummer11/artnerf.git
+```
+
+### Install Dependencies
+
+We recommend using [Anaconda](https://www.anaconda.com/) to manage dependencies. Run `./env.sh` to install all dependencies in a dedicated conda environment `artnerf`. Our code has been tested on Ubuntu 20.04, CUDA 11.3.
+
+ ```bash
+ bash env.sh
+ ```
+
+### Download Data
+
+
+Please download [our preprocessed, depth-augmented version](https://drive.google.com/file/d/1e_HWjw4usNHAkXkg6_o3QvfLD-YPR-1W/view?usp=drive_link) of [PARIS two-part object dataset](https://github.com/3dlg-hcvc/paris?tab=readme-ov-file#data) and [our synthetic multi-part object dataset](https://drive.google.com/file/d/186EskU7WtLU8CMgwY2swRUn5AB_fQ-qK/view?usp=drive_link) and unzip them into `./data` under project root.
+
+```bash
+artnerf
+├── data
+│ ├── paris
+│ │ ├── [instance]
+│ │ └── ...
+│ └── multi-part
+│ ├── [instance]
+│ └── ...
+
+```
+
+
+## Training and Evaluation
+
+### Generate 2D pixel matches with LoFTR
+
+We provide 2D pixel matches generated with LoFTR under each instance folder `${dataset}/${instance}/correspondence_loftr/no_filter`. You could also generate them from scratch.
+
+To start with, follow the instruction in [LoFTR](https://github.com/zju3dv/LoFTR?tab=readme-ov-file#installation) and download their model weights into `external/LoFTR/weights`.
+
+Run the following to generate pixel matches for `${dataset}/${instance}`.
+
+
+```bash
+cd preproc
+python gen_correspondence.py \\
+ --data_path ../data/${dataset}/${instance} \\
+ --output_path ../data/${dataset}/${instance}/correspondence_loftr \\
+ --top_k=30
+```
+
+
+### Training
+
+Run the following to reconstruct an articulated object from `${dataset}/${instance}`.
+```bash
+python main.py \\
+ --data_dir data/${dataset}/${instance} \\
+ --cfg_dir config/release/ \\
+ --num_parts ${num_parts} \\ # 2 for paris two-part objects, 3 for synthetic multi-part objects
+ --save_dir runs/exp_${dataset}_${instance}
+ # --denoise # enable denoising for real instances including paris/real_storage and paris/real_fridge
+```
+Checkpoints and reconstructions will be written to `runs/exp_${instance}`. Final results can be found in `runs/exp_${instance}/results/step_0004000`, including reconstructed part-level meshes (e.g. `init_part_0.obj`), axis meshes (e.g. `init_axis_0_revolute.obj`), and quantitative evaluations in `all_metrics`.
+
+
+### Inference
+
+You could also download [our pretrained checkpoints](https://drive.google.com/file/d/15Nn2fo13URJ9IUzQXRQ4sWsum7ZSzhmA/view?usp=drive_link) and unzip them into `./runs` under project root.
+
+Run the following to generate a reconstruction from the checkpoint.
+
+```bash
+python main.py \\
+ --data_dir data/${dataset}/${instance} \\
+ --cfg_dir config/release/ \\
+ --num_parts ${num_parts} \\ # 2 for paris two-part objects, 3 for synthetic multi-part objects
+ --save_dir runs/pretrained_${dataset}_${instance} \\
+ --ckpt_path runs/pretrained_${dataset}_${instance}/ckpt/model_latest.ckpt \\
+ --test_only
+ # --denoise # enable denoising for real instances including paris/real_storage and paris/real_fridge
+```
+
+### Evaluation
+
+Evaluations are automatically run after each reconstruction. Results are written to both the terminal and `all_metrics` in the corresponding results folder. You may also run the following to evaluate both the geometry and joint parameters of the predicted reconstruction (e.g. in `${exp_dir}/results/step_0004000`) offline.
+
+```bash
+python eval/eval_results.py --gt_path data/${dataset}/${instance}/gt --pred_path runs/${exp_dir}/results/step_${step_cnt}
+```
+
+
+## Citation
+
+
+```
+@inproceedings{weng2024neural,
+ title={Neural Implicit Representation for Building Digital Twins of Unknown Articulated Objects},
+ author={Yijia Weng and Bowen Wen and Jonathan Tremblay and Valts Blukis and Dieter Fox and Leonidas Guibas and Stan Birchfield},
+ booktitle={CVPR},
+ year={2024}
+}
+```
+
+
+## Acknowledgements
+
+This implementation is based on the following repositories. We thank the authors for open sourcing their great works!
+
++ [[CVPR 2023] BundleSDF: Neural 6-DoF Tracking and 3D Reconstruction of Unknown Objects](https://github.com/NVlabs/BundleSDF)
++ [[ICCV 2023] PARIS: Part-level Reconstruction and Motion Analysis for Articulated Objects](https://github.com/3dlg-hcvc/paris)
diff --git a/utils/articulation_utils.py b/utils/articulation_utils.py
new file mode 100644
index 0000000..a6350d3
--- /dev/null
+++ b/utils/articulation_utils.py
@@ -0,0 +1,303 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import os
+import sys
+from os.path import join as pjoin
+base_path = os.path.dirname(__file__)
+sys.path.insert(0, pjoin(base_path, '..'))
+sys.path.insert(0, base_path)
+import json
+from scipy.spatial.transform import Rotation as R
+import yaml
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.spatial.transform import Rotation
+import trimesh
+import open3d as o3d
+
+from argparse import ArgumentParser
+from utils.py_utils import list_to_array
+
+
+def interpret_transforms(base_R, base_t, R, t, joint_type='revolute'):
+ """
+ base_R, base_t, R, t are all from canonical to world
+ rewrite the transformation = global transformation (base_R, base_t) {R' part + t'} --> s.t. R' and t' happens in canonical space
+ R', t':
+ - revolute: R'p + t' = R'(p - a) + a, R' --> axis-theta representation; axis goes through a = (I - R')^{-1}t'
+ - prismatic: R' = I, t' = l * axis_direction
+ """
+ R = np.matmul(base_R.T, R)
+ t = np.matmul(base_R.T, (t - base_t).reshape(3, 1)).reshape(-1)
+
+ if joint_type == 'revolute':
+ rotvec = Rotation.from_matrix(R).as_rotvec()
+ theta = np.linalg.norm(rotvec, axis=-1)
+ axis_direction = rotvec / max(theta, (theta < 1e-8))
+ try:
+ axis_position = np.matmul(np.linalg.inv(np.eye(3) - R), t.reshape(3, 1)).reshape(-1)
+ except: # TO DO find the best solution
+ axis_position = np.zeros(3)
+ axis_position += axis_direction * np.dot(axis_direction, -axis_position)
+ joint_info = {'axis_position': axis_position,
+ 'axis_direction': axis_direction,
+ 'theta': np.rad2deg(theta),
+ 'rotation': R, 'translation': t}
+
+ elif joint_type == 'prismatic':
+ theta = np.linalg.norm(t)
+ axis_direction = t / max(theta, (theta < 1e-8))
+ joint_info = {'axis_direction': axis_direction, 'axis_position': np.zeros(3), 'theta': theta,
+ 'rotation': R, 'translation': t}
+
+ return joint_info, R, t
+
+
+def read_pts_from_obj_file(obj_file):
+ tm = trimesh.load(obj_file)
+ return np.array(tm.vertices)
+
+
+def transform_pts_with_rt(pts, rot, trans):
+ return (np.matmul(rot, pts.T) + trans.reshape(3, 1)).T
+
+
+def read_gt(gt_path):
+ with open(gt_path, 'r') as f:
+ info = json.load(f)
+
+ all_trans_info = info['trans_info']
+ if isinstance(all_trans_info, dict):
+ all_trans_info = [all_trans_info]
+ ret_list = []
+ for trans_info in all_trans_info:
+ axis = trans_info['axis']
+ axis_o, axis_d = np.array(axis['o']), np.array(axis['d'])
+ axis_type = trans_info['type']
+ l, r = trans_info[axis_type]['l'], trans_info[axis_type]['r']
+
+ if axis_type == 'rotate':
+ rotvec = axis_d * np.deg2rad(r - l)
+ rot = R.from_rotvec(rotvec).as_matrix()
+ trans = np.matmul(np.eye(3) - rot, axis_o.reshape(3, 1)).reshape(-1)
+ joint_type = 'revolute'
+ else:
+ rot = np.eye(3)
+ trans = (r - l) * axis_d
+ joint_type = 'prismatic'
+ ret_list.append({'axis_position': axis_o, 'axis_direction': axis_d, 'theta': r - l, 'joint_type': axis_type, 'rotation': rot, 'translation': trans,
+ 'type': joint_type})
+ return ret_list
+
+
+def line_distance(a_o, a_d, b_o, b_d):
+ normal = np.cross(a_d, b_d)
+ normal_length = np.linalg.norm(normal)
+ if normal_length < 1e-6: # parallel
+ return np.linalg.norm(np.cross(b_o - a_o, a_d))
+ else:
+ return np.abs(np.dot(normal, a_o - b_o)) / normal_length
+
+
+def eval_axis_and_state(axis_a, axis_b, joint_type='revolute', reverse=False):
+ a_d, b_d = axis_a['axis_direction'], axis_b['axis_direction']
+
+ angle = np.rad2deg(np.arccos(np.dot(a_d, b_d) / np.linalg.norm(a_d) / np.linalg.norm(b_d)))
+ angle = min(angle, 180 - angle)
+
+ if joint_type == 'revolute':
+ a_o, b_o = axis_a['axis_position'], axis_b['axis_position']
+ distance = line_distance(a_o, a_d, b_o, b_d)
+
+ a_r, b_r = axis_a['rotation'], axis_b['rotation']
+ if reverse:
+ a_r = a_r.T
+
+ r_diff = np.matmul(a_r, b_r.T)
+ state = np.rad2deg(np.arccos(np.clip((np.trace(r_diff) - 1.0) * 0.5, a_min=-1, a_max=1)))
+ else:
+ distance = 0
+ a_t, b_t = axis_a['translation'], axis_b['translation']
+ if reverse:
+ a_t = -a_t
+
+ state = np.linalg.norm(a_t - b_t)
+
+ return angle, distance, state
+
+
+def geodesic_distance(pred_R, gt_R):
+ '''
+ q is the output from the network (rotation from t=0.5 to t=1)
+ gt_R is the GT rotation from t=0 to t=1
+ '''
+ pred_R, gt_R = pred_R.cpu(), gt_R.cpu()
+ R_diff = torch.matmul(pred_R, gt_R.T)
+ cos_angle = torch.clip((torch.trace(R_diff) - 1.0) * 0.5, min=-1., max=1.)
+ angle = torch.rad2deg(torch.arccos(cos_angle))
+ return angle
+
+
+def axis_metrics(motion, gt):
+ # pred axis
+ pred_axis_d = motion['axis_d'].cpu().squeeze(0)
+ pred_axis_o = motion['axis_o'].cpu().squeeze(0)
+ # gt axis
+ gt_axis_d = gt['axis_d']
+ gt_axis_o = gt['axis_o']
+ # angular difference between two vectors
+ cos_theta = torch.dot(pred_axis_d, gt_axis_d) / (torch.norm(pred_axis_d) * torch.norm(gt_axis_d))
+ ang_err = torch.rad2deg(torch.acos(torch.abs(cos_theta)))
+ # positonal difference between two axis lines
+ w = gt_axis_o - pred_axis_o
+ cross = torch.cross(pred_axis_d, gt_axis_d)
+ if (cross == torch.zeros(3)).sum().item() == 3:
+ pos_err = torch.tensor(0)
+ else:
+ pos_err = torch.abs(torch.sum(w * cross)) / torch.norm(cross)
+ return ang_err, pos_err
+
+
+def translational_error(motion, gt):
+ dist_half = motion['dist'].cpu()
+ dist = dist_half * 2.
+ gt_dist = gt['dist']
+
+ axis_d = F.normalize(motion['axis_d'].cpu().squeeze(0), p=2, dim=0)
+ gt_axis_d = F.normalize(gt['axis_d'].cpu(), p=2, dim=0)
+
+ err = torch.sqrt(((dist * axis_d - gt_dist * gt_axis_d) ** 2).sum())
+ return err
+
+
+def eval_step_info(path, step_name, gt_path, base_joint=0):
+ # gt_joint, gt_rot, gt_trans = read_gt(gt_path)
+ gt_joint_list = read_gt(gt_path)
+ gt_joint = gt_joint_list[0]
+
+ step_transform = np.load(pjoin(path, 'step_transform', f'{step_name}_part_pose.npz'), allow_pickle=True)['data'].item()
+
+ cfg = yaml.safe_load(open(pjoin(path, 'config_nerf.yml'), 'r'))
+ nerf_scale = cfg['sc_factor']
+ nerf_trans = np.array(cfg['translation'])
+ inv_scale = 1.0 / nerf_scale
+ inv_trans = (-nerf_scale * nerf_trans).reshape(1, 3)
+
+ rot = step_transform['rotation'].swapaxes(-1, -2) # somehow R^T is saved..
+ trans = step_transform['translation']
+ trans = (trans + inv_trans - np.matmul(rot, inv_trans[..., np.newaxis])[..., 0]) * inv_scale
+
+ num_parts = len(rot)
+
+ joint_dict = {}
+ for joint_i in range(num_parts):
+ if joint_i == base_joint:
+ continue
+ cur_joint, cur_rot, cur_trans = interpret_transforms(rot[base_joint], trans[base_joint], rot[joint_i], trans[joint_i])
+ print('cur joint', cur_joint)
+ print('gt joint', gt_joint)
+ angle, distance, theta_diff = eval_axis_and_state(cur_joint, gt_joint)
+ print('axis diff angle', angle, 'distance', distance)
+ print('joint state diff angle', theta_diff)
+
+ # cur_joint_tensor = {'axis_d': torch.tensor(cur_joint['axis_direction']).reshape(1, 3),
+ # 'axis_o': torch.tensor(cur_joint['axis_position']).reshape(1, 3)}
+ # gt_joint_tensor = {'axis_d': torch.tensor(gt_joint['axis_direction']),
+ # 'axis_o': torch.tensor(gt_joint['axis_position'])}
+ # print(axis_metrics(cur_joint_tensor, gt_joint_tensor))
+ """
+ cur_pts = transform_pts_with_rt(part_canon_pts[joint_i], cur_rot, cur_trans)
+ center = np.mean(cur_pts, axis=0)
+ if 'axis_position' in cur_joint:
+ p = cur_joint['axis_position']
+ u = cur_joint['axis_direction']
+ cur_joint['axis_position'] = p + (np.dot(center, u) - np.dot(p, u)) * u
+ cur_pt_dict[joint_i] = cur_pts
+ """
+ joint_dict[joint_i] = cur_joint
+
+
+def eval_result(path, step_name, gt_path):
+ gt_joint_list = read_gt(gt_path)
+ gt_joint = gt_joint_list[0]
+
+ for cnc in ['init', 'last']:
+ if cnc == 'last':
+ gt_joint['theta'] *= -1.
+ with open(pjoin(path, 'results', f'{step_name}', f'{cnc}_motion.json'), 'r') as f:
+ axis_info = json.load(f)
+
+ axis_info = {key: list_to_array(value) for key, value in axis_info.items()}
+ angle, distance, theta_diff = eval_axis_and_state(axis_info, gt_joint)
+ print('axis diff angle', angle, 'distance', distance)
+ print('joint state diff angle', theta_diff)
+
+def get_rotation_axis_angle(k, theta):
+ '''
+ Rodrigues' rotation formula
+ args:
+ * k: direction unit vector of the axis to rotate about
+ * theta: the (radian) angle to rotate with
+ return:
+ * 3x3 rotation matrix
+ '''
+ if np.linalg.norm(k) == 0.:
+ return np.eye(3)
+ k = k / np.linalg.norm(k)
+ kx, ky, kz = k[0], k[1], k[2]
+ cos, sin = np.cos(theta), np.sin(theta)
+ R = np.zeros((3, 3))
+ R[0, 0] = cos + (kx**2) * (1 - cos)
+ R[0, 1] = kx * ky * (1 - cos) - kz * sin
+ R[0, 2] = kx * kz * (1 - cos) + ky * sin
+ R[1, 0] = kx * ky * (1 - cos) + kz * sin
+ R[1, 1] = cos + (ky**2) * (1 - cos)
+ R[1, 2] = ky * kz * (1 - cos) - kx * sin
+ R[2, 0] = kx * kz * (1 - cos) - ky * sin
+ R[2, 1] = ky * kz * (1 - cos) + kx * sin
+ R[2, 2] = cos + (kz**2) * (1 - cos)
+ return R
+
+
+def save_axis_mesh(k, center, filepath):
+ '''support rotate only for now'''
+ axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.02, cone_radius=0.04, cylinder_height=1.0, cone_height=0.08)
+ arrow = np.array([0., 0., 1.], dtype=np.float32)
+ n = np.cross(arrow, k)
+ rad = np.arccos(np.dot(arrow, k))
+ R_arrow = get_rotation_axis_angle(n, rad)
+ axis.rotate(R_arrow, center=(0, 0, 0))
+ axis.translate(center[:3])
+ o3d.io.write_triangle_mesh(filepath, axis)
+
+
+
+def main(args):
+ if args.name is None:
+ eval_result(args.path, f'step_{args.step:07d}', args.gt_path)
+ else:
+ eval_step_info(args.path, f'step_{args.step:07d}_{args.name}', args.gt_path)
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument('--path', type=str)
+ parser.add_argument('--step', default=3500, type=int)
+ parser.add_argument('--base_joint', default=0, type=int)
+ parser.add_argument('--name', type=str, default=None)
+ parser.add_argument('--gt_path', type=str)
+
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
+
diff --git a/utils/geometry_utils.py b/utils/geometry_utils.py
new file mode 100644
index 0000000..c6dc7f2
--- /dev/null
+++ b/utils/geometry_utils.py
@@ -0,0 +1,486 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+import torch
+import open3d as o3d
+import numpy as np
+import ruamel.yaml
+import os
+import logging
+import copy
+import joblib
+from PIL import Image
+import cv2
+from sklearn.cluster import DBSCAN
+
+os.environ['PYOPENGL_PLATFORM'] = 'egl'
+from mesh_to_sdf import mesh_to_sdf
+import torch.nn.functional as F
+import time
+import trimesh
+
+yaml = ruamel.yaml.YAML()
+try:
+ import kaolin
+except Exception as e:
+ print(f"Import kaolin failed, {e}")
+try:
+ from mycuda import common
+except:
+ pass
+
+glcam_in_cvcam = np.array([[1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [0, 0, 0, 1]])
+
+
+def toOpen3dCloud(points, colors=None, normals=None):
+ cloud = o3d.geometry.PointCloud()
+ cloud.points = o3d.utility.Vector3dVector(points.astype(np.float64))
+ if colors is not None:
+ if colors.max() > 1:
+ colors = colors / 255.0
+ cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64))
+ if normals is not None:
+ cloud.normals = o3d.utility.Vector3dVector(normals.astype(np.float64))
+ return cloud
+
+
+def depth2xyzmap(depth, K):
+ invalid_mask = (depth < 0.1)
+ H, W = depth.shape[:2]
+ vs, us = np.meshgrid(np.arange(0, H), np.arange(0, W), sparse=False, indexing='ij')
+ vs = vs.reshape(-1)
+ us = us.reshape(-1)
+ zs = depth.reshape(-1)
+ xs = (us - K[0, 2]) * zs / K[0, 0]
+ ys = (vs - K[1, 2]) * zs / K[1, 1]
+ pts = np.stack((xs.reshape(-1), ys.reshape(-1), zs.reshape(-1)), 1) # (N,3)
+ xyz_map = pts.reshape(H, W, 3).astype(np.float32)
+ xyz_map[invalid_mask] = 0
+ return xyz_map.astype(np.float32)
+
+
+def to_homo(pts):
+ '''
+ @pts: (N,3 or 2) will homogeneliaze the last dimension
+ '''
+ assert len(pts.shape) == 2, f'pts.shape: {pts.shape}'
+ homo = np.concatenate((pts, np.ones((pts.shape[0], 1))), axis=-1)
+ return homo
+
+
+def transform_pts(pts, tf):
+ """Transform 2d or 3d points
+ @pts: (...,3)
+ """
+ return (tf[..., :-1, :-1] @ pts[..., None] + tf[..., :-1, -1:])[..., 0]
+
+
+class OctreeManager:
+ def __init__(self, pts=None, max_level=None, octree=None):
+ if octree is None:
+ pts_quantized = kaolin.ops.spc.quantize_points(pts.contiguous(), level=max_level)
+ self.octree = kaolin.ops.spc.unbatched_points_to_octree(pts_quantized, max_level, sorted=False)
+ else:
+ self.octree = octree
+ lengths = torch.tensor([len(self.octree)], dtype=torch.int32).cpu()
+ self.max_level, self.pyramids, self.exsum = kaolin.ops.spc.scan_octrees(self.octree, lengths)
+ self.n_level = self.max_level + 1
+ self.point_hierarchies = kaolin.ops.spc.generate_points(self.octree, self.pyramids, self.exsum)
+ self.point_hierarchy_dual, self.pyramid_dual = kaolin.ops.spc.unbatched_make_dual(self.point_hierarchies,
+ self.pyramids[0])
+ self.trinkets, self.pointers_to_parent = kaolin.ops.spc.unbatched_make_trinkets(self.point_hierarchies,
+ self.pyramids[0],
+ self.point_hierarchy_dual,
+ self.pyramid_dual)
+ self.n_vox = len(self.point_hierarchies)
+ self.n_corners = len(self.point_hierarchy_dual)
+
+ def get_level_corner_quantized_points(self, level):
+ start = self.pyramid_dual[..., 1, level]
+ num = self.pyramid_dual[..., 0, level]
+ return self.point_hierarchy_dual[start:start + num]
+
+ def get_level_quantized_points(self, level):
+ start = self.pyramids[..., 1, level]
+ num = self.pyramids[..., 0, level]
+ return self.pyramids[start:start + num]
+
+ def get_trilinear_coeffs(self, x, level):
+ quantized = kaolin.ops.spc.quantize_points(x, level)
+ coeffs = kaolin.ops.spc.coords_to_trilinear_coeffs(x, quantized, level) # (N,8)
+ return coeffs
+
+ def get_center_ids(self, x, level):
+ pidx = kaolin.ops.spc.unbatched_query(self.octree, self.exsum, x, level, with_parents=False)
+ return pidx
+
+ def get_corners_ids(self, x, level):
+ pidx = kaolin.ops.spc.unbatched_query(self.octree, self.exsum, x, level, with_parents=False)
+ corner_ids = self.trinkets[pidx]
+ is_valid = torch.ones(len(x)).bool().to(x.device)
+ bad_ids = (pidx < 0).nonzero()[:, 0]
+ is_valid[bad_ids] = 0
+
+ return corner_ids, is_valid
+
+ def trilinear_interpolate(self, x, level, feat):
+ '''
+ @feat: (N_feature of current level, D)
+ '''
+ ############!NOTE direct API call cannot back prop well
+ # pidx = kaolin.ops.spc.unbatched_query(self.octree, self.exsum, x, level, with_parents=False)
+ # x = x.unsqueeze(0)
+ # interpolated = kaolin.ops.spc.unbatched_interpolate_trilinear(coords=x,pidx=pidx.int(),point_hierarchy=self.point_hierarchies,trinkets=self.trinkets, feats=feat, level=level)[0]
+ ##################
+
+ coeffs = self.get_trilinear_coeffs(x, level) # (N,8)
+ corner_ids, is_valid = self.get_corners_ids(x, level)
+ # if corner_ids.max()>=feat.shape[0]:
+ # pdb.set_trace()
+
+ corner_feat = feat[corner_ids[is_valid].long()] # (N,8,D)
+ out = torch.zeros((len(x), feat.shape[-1]), device=x.device).float()
+ out[is_valid] = torch.sum(coeffs[..., None][is_valid] * corner_feat, dim=1) # (N,D)
+
+ # corner_feat = feat[corner_ids.long()] #(N,8,D)
+ # out = torch.sum(coeffs[...,None]*corner_feat, dim=1) #(N,D)
+
+ return out, is_valid
+
+ def draw_boxes(self, level, outfile='/home/bowen/debug/corners.ply'):
+ centers = kaolin.ops.spc.unbatched_get_level_points(self.point_hierarchies.reshape(-1, 3),
+ self.pyramids.reshape(2, -1), level)
+ pts = (centers.float() + 0.5) / (2 ** level) * 2 - 1 # Normalize to [-1,1]
+ pcd = toOpen3dCloud(pts.data.cpu().numpy())
+ o3d.io.write_point_cloud(outfile.replace("corners", "centers"), pcd)
+
+ corners = kaolin.ops.spc.unbatched_get_level_points(self.point_hierarchy_dual, self.pyramid_dual, level)
+ pts = corners.float() / (2 ** level) * 2 - 1 # Normalize to [-1,1]
+ pcd = toOpen3dCloud(pts.data.cpu().numpy())
+ o3d.io.write_point_cloud(outfile, pcd)
+
+ def ray_trace(self, rays_o, rays_d, level, debug=False):
+ """Octree is in normalized [-1,1] world coordinate frame
+ 'rays_o': ray origin in normalized world coordinate system
+ 'rays_d': (N,3) unit length ray direction in normalized world coordinate system
+ 'octree': spc
+ @voxel_size: in the scale of [-1,1] space
+ Return:
+ ray_depths_in_out: traveling times, NOT the Z value
+ """
+ from mycuda import common
+
+ # Avoid corner cases. issuse in kaolin: https://github.com/NVIDIAGameWorks/kaolin/issues/490 and https://github.com/NVIDIAGameWorks/kaolin/pull/634
+ # rays_o = rays_o.clone() + 1e-7
+
+ ray_index, rays_pid, depth_in_out = kaolin.render.spc.unbatched_raytrace(self.octree, self.point_hierarchies,
+ self.pyramids[0], self.exsum, rays_o,
+ rays_d, level=level, return_depth=True,
+ with_exit=True)
+ if ray_index.size()[0] == 0:
+ ray_depths_in_out = torch.zeros((rays_o.shape[0], 1, 2))
+ rays_pid = -torch.ones_like(rays_o[:, :1])
+ rays_near = torch.zeros_like(rays_o[:, :1])
+ rays_far = torch.zeros_like(rays_o[:, :1])
+ return rays_near, rays_far, rays_pid, ray_depths_in_out
+
+ intersected_ray_ids, counts = torch.unique_consecutive(ray_index, return_counts=True)
+ max_intersections = counts.max().item()
+ start_poss = torch.cat([torch.tensor([0], device=counts.device), torch.cumsum(counts[:-1], dim=0)], dim=0)
+
+ ray_depths_in_out = common.postprocessOctreeRayTracing(ray_index.long().contiguous(), depth_in_out.contiguous(),
+ intersected_ray_ids.long().contiguous(),
+ start_poss.long().contiguous(), max_intersections,
+ rays_o.shape[0])
+
+ rays_far = ray_depths_in_out[:, :, 1].max(dim=-1)[0].reshape(-1, 1)
+ rays_near = ray_depths_in_out[:, 0, 0].reshape(-1, 1)
+
+ return rays_near, rays_far, rays_pid, ray_depths_in_out
+
+
+def find_biggest_cluster(pts, eps=0.005, min_samples=5):
+ dbscan = DBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1)
+ dbscan.fit(pts)
+ ids, cnts = np.unique(dbscan.labels_, return_counts=True)
+ best_id = ids[cnts.argsort()[-1]]
+ keep_mask = dbscan.labels_ == best_id
+ pts_cluster = pts[keep_mask]
+ return pts_cluster, keep_mask
+
+
+def compute_translation_scales(pts, max_dim=2, cluster=True, eps=0.005, min_samples=5):
+ if cluster:
+ pts, keep_mask = find_biggest_cluster(pts, eps, min_samples)
+ else:
+ keep_mask = np.ones((len(pts)), dtype=bool)
+ max_xyz = pts.max(axis=0)
+ min_xyz = pts.min(axis=0)
+ center = (max_xyz + min_xyz) / 2
+ sc_factor = max_dim / (max_xyz - min_xyz).max() # Normalize to [-1,1]
+ sc_factor *= 0.9
+ translation_cvcam = -center
+ return translation_cvcam, sc_factor, keep_mask
+
+
+def compute_scene_bounds_worker(color_file, K, glcam_in_world, use_mask, rgb=None, depth=None, mask=None):
+ if rgb is None:
+ depth_file = color_file.replace('images', 'depth_filtered')
+ mask_file = color_file.replace('images', 'masks')
+ rgb = np.array(Image.open(color_file))[..., :3]
+ depth = cv2.imread(depth_file, -1) / 1e3
+ xyz_map = depth2xyzmap(depth, K)
+ valid = depth >= 0.1
+ if use_mask:
+ if mask is None:
+ mask = cv2.imread(mask_file, -1)
+ valid = valid & (mask > 0)
+ pts = xyz_map[valid].reshape(-1, 3)
+ if len(pts) == 0:
+ return None
+ colors = rgb[valid].reshape(-1, 3)
+
+ pcd = toOpen3dCloud(pts, colors)
+
+ pcd = pcd.voxel_down_sample(0.01)
+ new_pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=30, std_ratio=2.0)
+ cam_in_world = glcam_in_world @ glcam_in_cvcam
+ new_pcd.transform(cam_in_world)
+
+ return np.asarray(new_pcd.points).copy(), np.asarray(new_pcd.colors).copy()
+
+
+def compute_scene_bounds(color_files, glcam_in_worlds, K, use_mask=True, base_dir=None, rgbs=None, depths=None,
+ masks=None, cluster=True, translation_cvcam=None, sc_factor=None, eps=0.06, min_samples=1):
+ # assert color_files is None or rgbs is None
+
+ if base_dir is None:
+ base_dir = os.path.dirname(color_files[0]) + '/../'
+ os.makedirs(base_dir, exist_ok=True)
+
+ args = []
+ if rgbs is not None:
+ for i in range(len(rgbs)):
+ args.append((color_files[i], K, glcam_in_worlds[i], use_mask, rgbs[i], depths[i], masks[i]))
+ else:
+ for i in range(len(color_files)):
+ args.append((color_files[i], K, glcam_in_worlds[i], use_mask))
+
+ logging.info(f"compute_scene_bounds_worker start")
+ ret = joblib.Parallel(n_jobs=6, prefer="threads")(joblib.delayed(compute_scene_bounds_worker)(*arg) for arg in args)
+ logging.info(f"compute_scene_bounds_worker done")
+
+ pcd_all = None
+ for r in ret:
+ if r is None or len(r[0]) == 0:
+ continue
+ if pcd_all is None:
+ pcd_all = toOpen3dCloud(r[0], r[1])
+ else:
+ pcd_all += toOpen3dCloud(r[0], r[1])
+
+ pcd = pcd_all.voxel_down_sample(eps / 5)
+
+ pts = np.asarray(pcd.points).copy()
+
+ def make_tf(translation_cvcam, sc_factor):
+ tf = np.eye(4)
+ tf[:3, 3] = translation_cvcam
+ tf1 = np.eye(4)
+ tf1[:3, :3] *= sc_factor
+ tf = tf1 @ tf
+ return tf
+
+ if translation_cvcam is None:
+ translation_cvcam, sc_factor, keep_mask = compute_translation_scales(pts, cluster=cluster, eps=eps,
+ min_samples=min_samples)
+ tf = make_tf(translation_cvcam, sc_factor)
+ else:
+ tf = make_tf(translation_cvcam, sc_factor)
+ tmp = copy.deepcopy(pcd)
+ tmp.transform(tf)
+ tmp_pts = np.asarray(tmp.points)
+ keep_mask = (np.abs(tmp_pts) < 1).all(axis=-1)
+
+ logging.info(f"compute_translation_scales done")
+
+ pcd = toOpen3dCloud(pts[keep_mask], np.asarray(pcd.colors)[keep_mask])
+ pcd_real_scale = copy.deepcopy(pcd)
+
+ with open(f'{base_dir}/normalization.yml', 'w') as ff:
+ tmp = {
+ 'translation_cvcam': translation_cvcam.tolist(),
+ 'sc_factor': float(sc_factor),
+ }
+ yaml.dump(tmp, ff)
+
+ pcd.transform(tf)
+
+ return sc_factor, translation_cvcam, pcd_real_scale, pcd
+
+
+BAD_DEPTH = 99
+BAD_COLOR = 128
+
+
+def mask_and_normalize_data(rgbs, depths, masks, poses, sc_factor, translation):
+ '''
+ @rgbs: np array (N,H,W,3)
+ @depths: (N,H,W)
+ @masks: (N,H,W)
+ @normal_maps: (N,H,W,3)
+ @poses: (N,4,4)
+ '''
+ depths[depths < 0.1] = BAD_DEPTH
+ if masks is not None:
+ rgbs[masks == 0] = BAD_COLOR
+ depths[masks == 0] = BAD_DEPTH
+ masks = masks[..., None]
+
+ rgbs = (rgbs / 255.0).astype(np.float32)
+ depths *= sc_factor
+ depths = depths[..., None]
+ poses[:, :3, 3] += translation
+ poses[:, :3, 3] *= sc_factor
+
+ return rgbs, depths, masks, poses
+
+
+def get_voxel_pts(voxel_size):
+ one_dim = np.linspace(-1, 1, int(np.ceil(2.0 / voxel_size)) + 1)
+ x, y, z = np.meshgrid(one_dim, one_dim, one_dim) # [N, N, N]
+ pts = np.stack([x, y, z], axis=-1) # [N, N, N, 3]
+ return pts
+
+
+def sdf_voxel_from_mesh(mesh, voxel_size): # mesh is already scaled to be in [-1, 1]
+ pts = get_voxel_pts(voxel_size)
+
+ t = time.time()
+ sdf = mesh_to_sdf(mesh, pts.reshape(-1, 3), sign_method='depth') # , sign_method='depth')
+ sdf = sdf.reshape(pts.shape[:-1])
+
+ return pts, sdf
+
+
+class VoxelSDF:
+ def __init__(self, sdf): # sdf is by default from a grid [-1, 1]^3
+ self.sdf_grid = torch.FloatTensor(sdf).unsqueeze(0).unsqueeze(0).cuda()
+
+ def query(self, xyz):
+ zxy = xyz[..., [2, 0, 1]]
+ sdf = F.grid_sample(self.sdf_grid, zxy.unsqueeze(0).unsqueeze(0).unsqueeze(0),
+ mode='bilinear', padding_mode='border', align_corners=False)
+ return sdf.reshape(-1)
+
+
+def extract_mesh(voxel_sdf, voxel_size=0.0099, isolevel=0):
+ x_min, x_max = -1, 1
+ y_min, y_max = -1, 1
+ z_min, z_max = -1, 1
+ tx = np.arange(x_min + 0.5 * voxel_size, x_max, voxel_size)
+ ty = np.arange(y_min + 0.5 * voxel_size, y_max, voxel_size)
+ tz = np.arange(z_min + 0.5 * voxel_size, z_max, voxel_size)
+ N = len(tx)
+ query_pts = torch.tensor(
+ np.stack(np.meshgrid(tx, ty, tz, indexing='ij'), -1).astype(np.float32).reshape(-1, 3)).float().cuda()
+
+ sigma = voxel_sdf.query(query_pts.reshape(-1, 3)).reshape(N, N, N).data.cpu().numpy()
+
+ from skimage import measure
+ try:
+ vertices, triangles, normals, values = measure.marching_cubes(sigma, isolevel)
+ except Exception as e:
+ print(f"ERROR Marching Cubes {e}")
+ return None
+
+ # Rescale and translate
+ voxel_size_ndc = np.array([tx[-1] - tx[0], ty[-1] - ty[0], tz[-1] - tz[0]]) / np.array(
+ [[tx.shape[0] - 1, ty.shape[0] - 1, tz.shape[0] - 1]])
+ offset = np.array([tx[0], ty[0], tz[0]])
+ vertices[:, :3] = voxel_size_ndc.reshape(1, 3) * vertices[:, :3] + offset.reshape(1, 3)
+
+ # Create mesh
+ mesh = trimesh.Trimesh(vertices, triangles, process=False)
+ return mesh
+
+
+def transform_mesh_to_world(mesh, sc_factor, translation):
+ nerf_scale = sc_factor
+ nerf_trans = translation
+ inv_scale = 1.0 / nerf_scale
+ inv_trans = (-nerf_scale * nerf_trans).reshape(1, 3)
+
+ vertices_raw = mesh.vertices.copy()
+ vertices_raw = inv_scale * (vertices_raw + inv_trans.reshape(1, 3))
+ ori_mesh = trimesh.Trimesh(vertices_raw, mesh.faces, process=False)
+ return ori_mesh
+
+
+class DepthFuser:
+ def __init__(self, K, c2w, depth, mask, trunc, near=0.01, far=10):
+ self.w2c = torch.linalg.inv(c2w) # [V, 4, 4], w2c
+ self.depth = torch.tensor(depth).to(K.device) # [V, H, W]
+ self.mask = torch.tensor(mask).to(K.device)
+ self.near = near
+ self.far = far
+ self.K = K
+ self.trunc = trunc
+ self.V, self.H, self.W = depth.shape
+
+ def query(self, pts): # pts [N, 3]
+ with torch.no_grad():
+ cam_pts = torch.matmul(self.w2c[:, :3, :3].unsqueeze(0), pts.unsqueeze(1).unsqueeze(-1)).squeeze(
+ -1) # [N, V, 3]
+ cam_pts = cam_pts + self.w2c[:, :3, 3].unsqueeze(0) # [N, V, 3]
+
+ cam_depth = -cam_pts[..., 2]
+
+ projection = torch.matmul(self.K[:2, :2].unsqueeze(0).unsqueeze(0), # [1, 1, 2, 2]
+ (cam_pts[..., :2] / torch.clip(-cam_pts[..., 2:3], min=1e-8)).unsqueeze(
+ -1)).squeeze(-1)
+ projection = projection + self.K[:2, 2].unsqueeze(0).unsqueeze(0) # [N, V, 2]
+
+ pixel = torch.round(projection).long()
+
+ valid_pixel = torch.logical_and(
+ torch.logical_and(pixel[..., 0] >= 0, pixel[..., 0] < self.W),
+ torch.logical_and(pixel[..., 1] >= 0, pixel[..., 1] < self.H))
+
+ py = self.H - 1 - torch.clamp(pixel[..., 1], 0, self.H - 1)
+ px = torch.clamp(pixel[..., 0], 0, self.W - 1)
+
+ view_idx = torch.arange(0, self.V).long().to(px.device).reshape(1, -1)
+ depth = self.depth[view_idx, py, px]
+ mask = self.mask[view_idx, py, px]
+
+ valid_depth = torch.logical_and(depth > self.near, depth < self.far)
+
+ before_depth = cam_depth <= depth + self.trunc
+
+ valid = torch.logical_and(torch.logical_and(valid_pixel, mask), valid_depth)
+
+ observed = torch.logical_and(valid, before_depth) # [N, V]
+ observed = torch.any(observed, dim=1) # [N]
+
+ return observed
+
+
+class VoxelVisibility:
+ def __init__(self, visibility):
+ self.grid = torch.FloatTensor(visibility).unsqueeze(0).unsqueeze(0).cuda()
+
+ def query(self, xyz):
+ zxy = xyz[..., [2, 0, 1]]
+ visibility = F.grid_sample(self.grid, zxy.unsqueeze(0).unsqueeze(0).unsqueeze(0),
+ mode='nearest', padding_mode='zeros', align_corners=False)
+ return visibility.bool().reshape(-1)
diff --git a/utils/nerf_utils.py b/utils/nerf_utils.py
new file mode 100644
index 0000000..554a9d4
--- /dev/null
+++ b/utils/nerf_utils.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+import torch
+import numpy as np
+
+to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
+
+
+def get_camera_rays_np(H, W, K):
+ """Get ray origins, directions from a pinhole camera."""
+ i, j = np.meshgrid(np.arange(W, dtype=np.float32),
+ np.arange(H, dtype=np.float32), indexing='xy')
+ dirs = np.stack([(i - K[0, 2]) / K[0, 0], -(j - K[1, 2]) / K[1, 1], -np.ones_like(i)], axis=-1)
+ return dirs
+
+
+def get_pixel_coords_np(H, W, K):
+ """Get ray origins, directions from a pinhole camera."""
+ i, j = np.meshgrid(np.arange(W, dtype=np.float32),
+ np.arange(H, dtype=np.float32), indexing='xy')
+ coords = np.stack([i, H - j - 1], axis=-1) # assume K[1, 2] * 2 == H, -(j - K[1, 2]) = -j + H - K[1, 2]
+ return coords
+
+
+def get_masks(z_vals, target_d, truncation, cfg, dir_norm=None):
+ valid_depth_mask = (target_d >= cfg['near'] * cfg['sc_factor']) & (target_d <= cfg['far'] * cfg['sc_factor'])
+ front_mask = (z_vals < target_d - truncation)
+ back_mask = (z_vals > target_d + truncation * cfg['neg_trunc_ratio'])
+
+ sdf_mask = (1.0 - front_mask.float()) * (1.0 - back_mask.float()) * valid_depth_mask
+
+ return front_mask.bool(), sdf_mask.bool()
+
+
+def get_sdf_loss(z_vals, target_d, predicted_sdf, truncation, cfg, return_mask=False, rays_d=None):
+ dir_norm = rays_d.norm(dim=-1, keepdim=True)
+ front_mask, sdf_mask = get_masks(z_vals, target_d, truncation, cfg, dir_norm=dir_norm)
+
+ mask = (target_d > cfg['far'] * cfg['sc_factor']) & (predicted_sdf < cfg['fs_sdf'])
+ fs_loss = ((predicted_sdf - cfg['fs_sdf']) * mask) ** 2
+
+ mask = front_mask & (target_d <= cfg['far'] * cfg['sc_factor']) & (predicted_sdf < 1)
+ empty_loss = torch.abs(predicted_sdf - 1) * mask
+
+ sdf_loss = ((z_vals + predicted_sdf * truncation) * sdf_mask - target_d * sdf_mask) ** 2
+
+ if return_mask:
+ return empty_loss, fs_loss, sdf_loss, front_mask, sdf_mask
+ return empty_loss, fs_loss, sdf_loss
+
+
+def ray_box_intersection_batch(origins, dirs, bounds):
+ '''
+ @origins: (N,3) origin and directions. In the same coordinate frame as the bounding box
+ @bounds: (2,3) xyz_min and max
+ '''
+ if not torch.is_tensor(origins):
+ origins = torch.tensor(origins)
+ dirs = torch.tensor(dirs)
+ if not torch.is_tensor(bounds):
+ bounds = torch.tensor(bounds)
+
+ dirs = dirs / (torch.norm(dirs, dim=-1, keepdim=True) + 1e-10)
+ inv_dirs = 1 / dirs
+ bounds = bounds[None].expand(len(dirs), -1, -1) # (N,2,3)
+
+ sign = torch.zeros((len(dirs), 3)).long().to(dirs.device) # (N,3)
+ sign[:, 0] = (inv_dirs[:, 0] < 0)
+ sign[:, 1] = (inv_dirs[:, 1] < 0)
+ sign[:, 2] = (inv_dirs[:, 2] < 0)
+
+ tmin = (torch.gather(bounds[..., 0], dim=1, index=sign[:, 0].reshape(-1, 1)).reshape(-1) - origins[:,
+ 0]) * inv_dirs[:,
+ 0] # (N)
+ tmin[tmin < 0] = 0
+ tmax = (torch.gather(bounds[..., 0], dim=1, index=1 - sign[:, 0].reshape(-1, 1)).reshape(-1) - origins[:,
+ 0]) * inv_dirs[:, 0]
+ tymin = (torch.gather(bounds[..., 1], dim=1, index=sign[:, 1].reshape(-1, 1)).reshape(-1) - origins[:,
+ 1]) * inv_dirs[:, 1]
+ tymin[tymin < 0] = 0
+ tymax = (torch.gather(bounds[..., 1], dim=1, index=1 - sign[:, 1].reshape(-1, 1)).reshape(-1) - origins[:,
+ 1]) * inv_dirs[:, 1]
+
+ ishit = torch.ones(len(dirs)).bool().to(dirs.device)
+ ishit[(tmin > tymax) | (tymin > tmax)] = 0
+ tmin[tymin > tmin] = tymin[tymin > tmin]
+ tmax[tymax < tmax] = tymax[tymax < tmax]
+
+ tzmin = (torch.gather(bounds[..., 2], dim=1, index=sign[:, 2].reshape(-1, 1)).reshape(-1) - origins[:,
+ 2]) * inv_dirs[:, 2]
+ tzmin[tzmin < 0] = 0
+ tzmax = (torch.gather(bounds[..., 2], dim=1, index=1 - sign[:, 2].reshape(-1, 1)).reshape(-1) - origins[:,
+ 2]) * inv_dirs[:, 2]
+
+ ishit[(tmin > tzmax) | (tzmin > tmax)] = 0
+ tmin[tzmin > tmin] = tzmin[tzmin > tmin] # (N)
+ tmax[tzmax < tmax] = tzmax[tzmax < tmax]
+
+ tmin[ishit == 0] = -1
+ tmax[ishit == 0] = -1
+
+ return tmin, tmax
+
diff --git a/utils/py_utils.py b/utils/py_utils.py
new file mode 100644
index 0000000..72edaa0
--- /dev/null
+++ b/utils/py_utils.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+
+import torch
+import numpy as np
+import importlib
+import logging
+
+
+def print_composite(data, beg=""):
+ if isinstance(data, dict):
+ print(f'{beg} dict, size = {len(data)}')
+ for key, value in data.items():
+ print(f' {beg}{key}:')
+ print_composite(value, beg + " ")
+ elif isinstance(data, list):
+ print(f'{beg} list, len = {len(data)}')
+ for i, item in enumerate(data):
+ print(f' {beg}item {i}')
+ print_composite(item, beg + " ")
+ elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor):
+ print(f'{beg} array of size {data.shape}')
+ else:
+ print(f'{beg} {data}')
+
+
+def list_to_array(l):
+ if isinstance(l, list):
+ return np.stack([list_to_array(x) for x in l], axis=0)
+ elif isinstance(l, str):
+ return np.array(float(l))
+ elif isinstance(l, float):
+ return np.array(l)
+
+
+def set_logging_format():
+ importlib.reload(logging)
+ FORMAT = '[%(filename)s:%(lineno)d %(funcName)s()] %(message)s'
+ logging.basicConfig(level=logging.INFO, format=FORMAT)
+
+
+def set_logging_file(file_path):
+ FORMAT = '[%(filename)s:%(lineno)d %(funcName)s()] %(message)s'
+ logging.basicConfig(level=logging.INFO, format=FORMAT,
+ handlers=[
+ logging.FileHandler(file_path),
+ logging.StreamHandler()
+ ])
+
+
+def set_seed(random_seed):
+ import torch, random
+ np.random.seed(random_seed)
+ random.seed(random_seed)
+ torch.manual_seed(random_seed)
+ torch.cuda.manual_seed_all(random_seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+