-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdatasets.py
42 lines (31 loc) · 1.43 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import torch
from torchvision import datasets, transforms
from shapenet import ShapeNet
class AxisScaling(object):
def __init__(self, interval=(0.75, 1.25), jitter=True):
assert isinstance(interval, tuple)
self.interval = interval
self.jitter = jitter
def __call__(self, surface, point):
scaling = torch.rand(1, 3) * 0.5 + 0.75
# print(scaling)
surface = surface * scaling
point = point * scaling
scale = (1 / torch.abs(surface).max().item()) * 0.999999
surface *= scale
point *= scale
if self.jitter:
surface += 0.005 * torch.randn_like(surface)
surface.clamp_(min=-1, max=1)
return surface, point
def build_shape_surface_occupancy_dataset(split, args):
if split == 'train':
transform = AxisScaling((0.75, 1.25), True)
return ShapeNet(args.data_path, split=split, transform=transform, sampling=True, num_samples=1024, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
elif split == 'val':
return ShapeNet(args.data_path, split=split, transform=None, sampling=False, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
else:
return ShapeNet(args.data_path, split=split, transform=None, sampling=False, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
if __name__ == '__main__':
pass