-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathpoint_utils.py
139 lines (112 loc) · 3.63 KB
/
point_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torch.autograd import Function
import torch.nn as nn
import sys
import pointlib as pl
class KNNOperation(Function):
@staticmethod
def forward(ctx, pointsa, pointsb, knn):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Parameters
----------
points : torch.Tensor
(B, N, 3) tensor of features to group
n : int
the number of nearest neighbour
Returns
-------
group : torch.Tensor
(B, npoints, nsample, 3) tensor
idx : torch.Tensor
(B, npoints, nsample) tensor
"""
nnidx = pl.knn_points(pointsa.contiguous(), pointsb.contiguous(), knn)
return nnidx
knn_operation = KNNOperation.apply
class FurthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz, npoint):
# type: (Any, torch.Tensor, int) -> torch.Tensor
r"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
minimum distance
Parameters
----------
xyz : torch.Tensor
(B, N, 3) tensor where N > npoint
npoint : int32
number of features in the sampled set
Returns
-------
torch.Tensor
(B, npoint) tensor containing the set
"""
return pl.furthest_point_sampling(xyz.contiguous(), npoint)
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
class GatherOperation(Function):
@staticmethod
def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Parameters
----------
features : torch.Tensor
(B, C, N) tensor
idx : torch.Tensor
(B, npoint) tensor of the features to gather
Returns
-------
torch.Tensor
(B, C, npoint) tensor
"""
_, C, N = features.size()
ctx.for_backwards = (idx, C, N)
return pl.gather_points(features.contiguous(), idx.contiguous())
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
grad_features = pl.gather_points_grad(grad_out.contiguous(), idx, N)
return grad_features, None
gather_operation = GatherOperation.apply
class GroupingOperation(Function):
@staticmethod
def forward(ctx, features, idx):
# type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Parameters
----------
features : torch.Tensor
(B, C, N) tensor of features to group
idx : torch.Tensor
(B, npoint, nsample) tensor containing the indicies of features to group with
Returns
-------
torch.Tensor
(B, C, npoint, nsample) tensor
"""
B, nfeatures, nsample = idx.size()
_, C, N = features.size()
ctx.for_backwards = (idx, N)
return pl.group_points(features.contiguous(), idx)
@staticmethod
def backward(ctx, grad_out):
# type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
r"""
Parameters
----------
grad_out : torch.Tensor
(B, C, npoint, nsample) tensor of the gradients of the output from forward
Returns
-------
torch.Tensor
(B, C, N) gradient of the features
None
"""
idx, N = ctx.for_backwards
grad_features = pl.group_points_grad(grad_out.contiguous(), idx, N)
return grad_features, None
grouping_operation = GroupingOperation.apply