-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
141 lines (119 loc) · 4.86 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import os
import cv2
import json
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps
from torchvision import transforms as T
import argparse
from rays import get_rays, get_ray_directions
class NerfData(Dataset):
def __init__(self, data_dir, width,height,split='train',):
self.images_dir = data_dir
# self.scale = scale
self.width = width
self.height = height
self.split = split
self._read_meta(data_dir, split)
# Needed for alpha compositing the points on ray
self.white_back = True
# For train the sample is ray and corresponding pixel value
# for test and val return all rays and image
def __len__(self):
if self.split == 'train':
return len(self.all_images)
elif self.split == 'val':
return 8
else:
return len(os.listdir(os.path.join(self.images_dir,self.split)))
def _read_meta(self, path, split):
self.data = json.load(
open(os.path.join(path, f'transforms_{split}.json')))['frames']
self.camera_angle_x = json.load(open(os.path.join(path, f'transforms_{split}.json')))[
'camera_angle_x']
self.focal_length = 0.5*self.width/np.tan(0.5*self.camera_angle_x)
# Intrinsics camera matrix
self.K = np.eye(3)
self.K[0, 0] = self.focal_length
self.K[1, 1] = self.focal_length
self.K[0, 2] = self.width/2
self.K[1, 2] = self.height/2
self.near = 2.0
self.far = 6.0
self.bounds = np.array([self.near, self.far])
self.pinhole_extrinsics = np.array(
[self.height, self.width, self.focal_length])
"""
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
https://ksimek.github.io/2013/08/13/intrinsic/
"""
# ray direction for all pixels will be same
self.directions = get_ray_directions(self.height, self.width, self.focal_length)
self.transform = T.ToTensor()
# cache all train data together. val and test data can be generate per image
if self.split == 'train':
self.all_rays = []
self.all_images = []
self.all_valid_masks = []
for t, frame in enumerate(self.data):
pose = np.array(frame['transform_matrix'])[:3, :4]
c2w = torch.FloatTensor(pose)
image_path = os.path.join(self.images_dir, f"{frame['file_path']}.png")
img = Image.open(image_path)
"""
https://www.linkedin.com/pulse/afternoon-debugging-e-commerce-image-processing-nikhil-rasiwasia/
Basically all images are 4 channel with alpha being the final one
to read it correctly we need to blend A to RGB
Note: We dont need alpha at all
"""
img = img.resize((self.width, self.height), Image.LANCZOS)
# reading with PIL and then doing transform is giving alpha channel as well
img = self.transform(img)
valid_mask = (img[-1] > 0).flatten() # valid color area. HxW
img = img.view(4, -1).permute(1, 0) # (HxW,4) RGBA
img = img[:, :3]*img[:, -1:]+(1-img[:, -1:]) # Blend A to RGB
rays_o, rays_d = get_rays(self.directions, c2w)
# rays_t = t*torch.ones(len(rays_o), 1)
rays = torch.cat([rays_o, rays_d, self.near*torch.ones_like(rays_o[:, :1]),\
self.far*torch.ones_like(rays_o[:, :1])],axis=1) # (h*w,8)
self.all_rays.append(rays)
self.all_images.append(img)
self.all_valid_masks.append(valid_mask)
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.data)*h*w,3)
self.all_images = torch.cat(self.all_images, 0) # (len(self.data)*h*w,3)
# self.all_valid_masks = torch.cat(self.all_valid_masks,0)
# print(self.all_images.shape)
def __getitem__(self, index):
if self.split == 'train':
sample = {'rays': self.all_rays[index],
'images': self.all_images[index],
}
else:
frame = self.data[index]
image_path = os.path.join(self.images_dir, f"{frame['file_path']}.png")
img = Image.open(image_path)
img = img.resize((self.width, self.height), Image.LANCZOS)
img = self.transform(img)
valid_mask = (img[-1] > 0).flatten() # valid color area. HxW
img = img.view(4, -1).permute(1, 0) # (HxW,4) RGBA
img = img[:, :3]*img[:, -1:]+(1-img[:, -1:])
pose = np.array(frame['transform_matrix'])[:3, :4]
c2w = torch.FloatTensor(pose)
rays_o, rays_d = get_rays(self.directions, c2w)
rays = torch.cat([rays_o, rays_d, self.near*torch.ones_like(rays_o[:, :1]),\
self.far*torch.ones_like(rays_o[:, :1])],axis=1) # (H,W,8)
sample = {'rays': rays,
'images': img, 'c2w': c2w, 'valid_mask': valid_mask}
return sample
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str, default='pinecone_dr.xml')
parser.add_argument('--split', type=str, default='pinecone_dr.xml')
args = parser.parse_args()
train_data = NerfData(args.image_dir,400,400,args.split)
print(train_data.__len__())
print(train_data.__getitem__(42)['images'].squeeze().shape)
img = train_data.__getitem__(42)['images']
print(img.shape)