This repository has been archived by the owner on Jan 8, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrel_pos_cuda.cpp
91 lines (76 loc) · 4.53 KB
/
rel_pos_cuda.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
// CUDA forward declarations
torch::Tensor relative_positioning_forward_2d_cuda(
torch::Tensor logits, torch::Tensor r_h, torch::Tensor r_w, torch::Tensor mask,
const int h_q, const int w_q, const int h_k, const int w_k, const bool use_mask);
std::vector<torch::Tensor> relative_positioning_backward_2d_cuda(
torch::Tensor grad_out, const int h_q, const int w_q, const int h_k, const int w_k);
torch::Tensor relative_positioning_forward_3d_cuda(
torch::Tensor logits, torch::Tensor r_t, torch::Tensor r_h, torch::Tensor r_w, torch::Tensor mask,
const int t_q, const int h_q, const int w_q, const int t_k, const int h_k, const int w_k, const bool use_mask);
std::vector<torch::Tensor> relative_positioning_backward_3d_cuda(
torch::Tensor grad_out, const int t_q, const int h_q, const int w_q, const int t_k, const int h_k, const int w_k);
torch::Tensor fuse_all_cuda(
torch::Tensor q, torch::Tensor k, torch::Tensor rh, torch::Tensor rw,
torch::Tensor uk, torch::Tensor uh, torch::Tensor uw, torch::Tensor m,
const int h_q, const int w_q, const int h_k, const int w_k, const int num_heads);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor relative_positioning_forward_2d(
torch::Tensor logits, torch::Tensor r_h, torch::Tensor r_w, torch::Tensor mask,
const int h_q, const int w_q, const int h_k, const int w_k, const bool use_mask) {
CHECK_CUDA(logits); // logits is N(==B*Nh), H_q*W_q, H_k*W_k
CHECK_CUDA(r_h); // r_h is N(==B*Nh), H_q*W_q, H_k+H_q-1
CHECK_CUDA(r_w); // r_w is N(==B*Nh), H_q*W_q, W_k+W_q-1
CHECK_CUDA(mask); // mask is a bool tensor of N(==B*Nh), H_q*W_q, H_k*W_k or H_q*W_q, H_k*W_k OR 1, 1, 1
return relative_positioning_forward_2d_cuda(logits, r_h, r_w, mask, h_q, w_q, h_k, w_k, use_mask);
}
std::vector<torch::Tensor> relative_positioning_backward_2d(
torch::Tensor grad_out, const int h_q, const int w_q, const int h_k, const int w_k) {
CHECK_INPUT(grad_out);
auto grads = relative_positioning_backward_2d_cuda(grad_out, h_q, w_q, h_k, w_k);
grads.insert(grads.begin(), grad_out.clone());
return grads;
}
torch::Tensor relative_positioning_forward_3d(
torch::Tensor logits, torch::Tensor r_t, torch::Tensor r_h, torch::Tensor r_w, torch::Tensor mask,
const int t_q, const int h_q, const int w_q, const int t_k, const int h_k, const int w_k, const bool use_mask) {
CHECK_CUDA(logits); // logits is B*Nh, Tq*Hq*Wq, Tq*Hk*Wk
CHECK_CUDA(r_t); // r_t is B*Nh, Tq*Hq*Wq, Tk+Tq-1
CHECK_CUDA(r_h); // r_h is B*Nh, Tq*Hq*Wq, Hk+Hq-1
CHECK_CUDA(r_w); // r_w is B*Nh, Tq*Hq*Wq, Wk+Wq-1
CHECK_CUDA(mask); // mask is a bool tensor of {B*Nh, Tq*Hq*Wq, Tk*Hk*Wk} or {Tq*Hq*Wq, Tk*Hk*Wk} or {1, 1, 1}
return relative_positioning_forward_3d_cuda(logits, r_t, r_h, r_w, mask, t_q, h_q, w_q, t_k, h_k, w_k, use_mask);
}
std::vector<torch::Tensor> relative_positioning_backward_3d(
torch::Tensor grad_out, const int t_q, const int h_q, const int w_q, const int t_k, const int h_k, const int w_k) {
CHECK_INPUT(grad_out);
auto grads = relative_positioning_backward_3d_cuda(grad_out, t_q, h_q, w_q, t_k, h_k, w_k);
grads.insert(grads.begin(), grad_out.clone());
return grads;
}
torch::Tensor fuse_all(
torch::Tensor q, torch::Tensor k, torch::Tensor rh, torch::Tensor rw,
torch::Tensor uk, torch::Tensor uh, torch::Tensor uw, torch::Tensor m,
const int h_q, const int w_q, const int h_k, const int w_k, const int num_heads){
CHECK_INPUT(q);
CHECK_INPUT(k);
CHECK_INPUT(rh);
CHECK_INPUT(rw);
CHECK_INPUT(uk);
CHECK_INPUT(uh);
CHECK_INPUT(uw);
CHECK_INPUT(m);
return fuse_all_cuda(q, k, rh, rw, uk, uh, uw, m, h_q, w_q, h_k, w_k, num_heads);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_2d", &relative_positioning_forward_2d, "adds 2d relative positioning logits to the main logits (cuda, forward)");
m.def("backward_2d", &relative_positioning_backward_2d, "adds 2d relative positioning logits to the main logits (cuda, backward)");
m.def("forward_3d", &relative_positioning_forward_3d, "adds 3d relative positioning logits to the main logits (cuda, forward)");
m.def("backward_3d", &relative_positioning_backward_3d, "adds 3d relative positioning logits to the main logits (cuda, backward)");
m.def("fuse_all_2d", &fuse_all, "calculates logits (cuda, forward)");
}