Skip to content

Commit

Permalink
【HEU】[Paddle Tensor 第二期 API鲁棒性增强] paddle.linalg.qr 支持 0-size tensor 与…
Browse files Browse the repository at this point in the history
… 复数类型 (#70481)

* 0size

* support complex

* add grad m > n

* fix

* fix test

* fix

* fix

* fix DCU

* fix m<n

* fix coverage

* fix
  • Loading branch information
fangfangssj authored Jan 21, 2025
1 parent 079d457 commit d3e8070
Show file tree
Hide file tree
Showing 7 changed files with 967 additions and 318 deletions.
10 changes: 8 additions & 2 deletions paddle/phi/kernels/cpu/qr_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h"

PD_REGISTER_KERNEL(qr_grad, CPU, ALL_LAYOUT, phi::QrGradKernel, float, double) {
}
PD_REGISTER_KERNEL(qr_grad,
CPU,
ALL_LAYOUT,
phi::QrGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
247 changes: 192 additions & 55 deletions paddle/phi/kernels/cpu/qr_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,187 @@
#include <Eigen/Dense>

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/diagonal_kernel.h"
#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"

namespace phi {

template <class T, class Context>
static DenseTensor Fill(const Context& ctx,
std::vector<int64_t> shape,
T fill_value) {
DenseTensor ret;
ret.Resize(common::make_ddim(shape));
ctx.template Alloc<T>(&ret);
funcs::SetConstant<Context, T>()(ctx, &ret, fill_value);
return ret;
}

template <class T, class Context>
static DenseTensor identity_matrix(const Context& ctx, common::DDim shape) {
DenseTensor M =
Fill<T, Context>(ctx, common::vectorize<int64_t>(shape), T(0));
size_t rank = M.dims().size();
int64_t M_diag_len = std::min(M.dims()[rank - 1], M.dims()[rank - 2]);
std::vector<int64_t> M_diag_shape;
for (size_t i = 0; i < rank - 2; ++i) {
M_diag_shape.push_back(M.dims()[i]);
}
M_diag_shape.push_back(M_diag_len);
DenseTensor M_diag = Fill<T, Context>(
ctx, common::vectorize<int64_t>(make_ddim(M_diag_shape)), T(1));
M = FillDiagonalTensor<T, Context>(ctx, M, M_diag, 0, rank - 2, rank - 1);
return M;
}

template <typename T, typename Context>
struct QrFunctor {
void operator()(const Context& ctx,
const DenseTensor& x,
bool compute_q,
bool reduced_mode,
DenseTensor* q,
DenseTensor* r) {
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = static_cast<int>(x_dims[x_rank - 2]);
int n = static_cast<int>(x_dims[x_rank - 1]);
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int64_t batch_size = static_cast<int64_t>(x.numel() / (m * n));
int x_stride = m * n;
int q_stride = m * k;
int r_stride = k * n;
auto* x_data = x.data<phi::dtype::Real<T>>();
T* q_data = nullptr;
if (compute_q) {
q_data = ctx.template Alloc<phi::dtype::Real<T>>(
q, batch_size * m * k * sizeof(phi::dtype::Real<T>));
}
auto* r_data = ctx.template Alloc<phi::dtype::Real<T>>(
r, batch_size * k * n * sizeof(phi::dtype::Real<T>));

// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
const T* x_matrix_ptr = x_data + i * x_stride;
T* r_matrix_ptr = r_data + i * r_stride;
using EigenDynamicMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
auto x_matrix = Eigen::Map<const EigenDynamicMatrix>(x_matrix_ptr, m, n);
Eigen::HouseholderQR<EigenDynamicMatrix> qr(x_matrix);
if (reduced_mode) {
auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n);
auto r_matrix_view =
qr_top_matrix.template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
} else {
auto r_matrix_view =
qr.matrixQR().template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
}

if (compute_q) {
T* q_matrix_ptr = q_data + i * q_stride;
if (reduced_mode) {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
} else {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, m);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
}
}
}
}
};

template <typename T, typename Context>
struct QrFunctor<phi::dtype::complex<T>, Context> {
void operator()(const Context& ctx,
const DenseTensor& x,
bool compute_q,
bool reduced_mode,
DenseTensor* q,
DenseTensor* r) {
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = static_cast<int>(x_dims[x_rank - 2]);
int n = static_cast<int>(x_dims[x_rank - 1]);
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int batch_size = static_cast<int>(x.numel() / (m * n));
int x_stride = m * n;
int q_stride = m * k;
int r_stride = k * n;
auto* x_data = x.data<phi::dtype::complex<T>>();
phi::dtype::complex<T>* q_data = nullptr;
if (compute_q) {
q_data = ctx.template Alloc<phi::dtype::complex<T>>(
q, batch_size * m * k * sizeof(phi::dtype::complex<T>));
}
auto* r_data = ctx.template Alloc<phi::dtype::complex<T>>(
r, batch_size * k * n * sizeof(phi::dtype::complex<T>));

// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
const phi::dtype::complex<T>* x_matrix_ptr = x_data + i * x_stride;
phi::dtype::complex<T>* r_matrix_ptr = r_data + i * r_stride;
using EigenDynamicMatrix = Eigen::Matrix<std::complex<T>,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>;
auto x_matrix = Eigen::Map<const EigenDynamicMatrix>(
reinterpret_cast<const std::complex<T>*>(x_matrix_ptr), m, n);
Eigen::HouseholderQR<EigenDynamicMatrix> qr(x_matrix);
if (reduced_mode) {
auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n);
auto r_matrix_view =
qr_top_matrix.template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr,
r_matrix.data(),
r_matrix.size() * sizeof(phi::dtype::complex<T>));
} else {
auto r_matrix_view =
qr.matrixQR().template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr,
r_matrix.data(),
r_matrix.size() * sizeof(phi::dtype::complex<T>));
}

if (compute_q) {
phi::dtype::complex<T>* q_matrix_ptr = q_data + i * q_stride;
if (reduced_mode) {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr,
q_matrix.data(),
q_matrix.size() * sizeof(phi::dtype::complex<T>));
} else {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, m);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr,
q_matrix.data(),
q_matrix.size() * sizeof(phi::dtype::complex<T>));
}
}
}
}
};

template <typename T, typename Context>
void QrKernel(const Context& ctx,
const DenseTensor& x,
Expand All @@ -32,65 +207,27 @@ void QrKernel(const Context& ctx,
bool compute_q = false;
bool reduced_mode = false;
std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode);
auto numel = x.numel();
PADDLE_ENFORCE_GT(
numel, 0, errors::PreconditionNotMet("The input of QR is empty."));
auto x_dims = x.dims();
int x_rank = x_dims.size();
int m = static_cast<int>(x_dims[x_rank - 2]);
int n = static_cast<int>(x_dims[x_rank - 1]);
int min_mn = std::min(m, n);
int k = reduced_mode ? min_mn : m;
int batch_size = static_cast<int>(numel / (m * n));
int x_stride = m * n;
int q_stride = m * k;
int r_stride = k * n;
auto* x_data = x.data<phi::dtype::Real<T>>();
T* q_data = nullptr;
if (compute_q) {
q_data = ctx.template Alloc<phi::dtype::Real<T>>(
q, batch_size * m * k * sizeof(phi::dtype::Real<T>));
}
auto* r_data = ctx.template Alloc<phi::dtype::Real<T>>(
r, batch_size * k * n * sizeof(phi::dtype::Real<T>));

// Implement QR by calling Eigen
for (int i = 0; i < batch_size; ++i) {
const T* x_matrix_ptr = x_data + i * x_stride;
T* r_matrix_ptr = r_data + i * r_stride;
using EigenDynamicMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
auto x_matrix = Eigen::Map<const EigenDynamicMatrix>(x_matrix_ptr, m, n);
Eigen::HouseholderQR<EigenDynamicMatrix> qr(x_matrix);
if (reduced_mode) {
auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n);
auto r_matrix_view =
qr_top_matrix.template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
if (x.numel() == 0) {
if (q->numel() == 0) {
q->Resize(q->dims());
} else {
auto r_matrix_view =
qr.matrixQR().template triangularView<Eigen::Upper>();
auto r_matrix = EigenDynamicMatrix(r_matrix_view);
memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T));
}

if (compute_q) {
T* q_matrix_ptr = q_data + i * q_stride;
if (reduced_mode) {
auto q_matrix =
qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
} else {
auto q_matrix = qr.householderQ() * EigenDynamicMatrix::Identity(m, m);
q_matrix.transposeInPlace();
memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T));
}
*q = identity_matrix<T, Context>(ctx, q->dims());
}
r->Resize(r->dims());
ctx.template Alloc<T>(q);
ctx.template Alloc<T>(r);
return;
}
QrFunctor<T, Context>()(ctx, x, compute_q, reduced_mode, q, r);
}

} // namespace phi

PD_REGISTER_KERNEL(qr, CPU, ALL_LAYOUT, phi::QrKernel, float, double) {}
PD_REGISTER_KERNEL(qr,
CPU,
ALL_LAYOUT,
phi::QrKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
11 changes: 11 additions & 0 deletions paddle/phi/kernels/gpu/qr_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,16 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h"

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(qr_grad, GPU, ALL_LAYOUT, phi::QrGradKernel, float, double) {
}
#else
PD_REGISTER_KERNEL(qr_grad,
GPU,
ALL_LAYOUT,
phi::QrGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
Loading

0 comments on commit d3e8070

Please sign in to comment.