From d3e80707f8af45788ea938b4a1527f30ec209e0c Mon Sep 17 00:00:00 2001 From: fangfangssj <99968055+fangfangssj@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:18:48 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90HEU=E3=80=91[Paddle=20Tensor=20?= =?UTF-8?q?=E7=AC=AC=E4=BA=8C=E6=9C=9F=20API=E9=B2=81=E6=A3=92=E6=80=A7?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA]=20paddle.linalg.qr=20=E6=94=AF=E6=8C=81=200?= =?UTF-8?q?-size=20tensor=20=E4=B8=8E=20=E5=A4=8D=E6=95=B0=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=20(#70481)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 0size * support complex * add grad m > n * fix * fix test * fix * fix * fix DCU * fix m, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/qr_kernel.cc b/paddle/phi/kernels/cpu/qr_kernel.cc index 194906ae1dc346..9d3ac1990ebe57 100644 --- a/paddle/phi/kernels/cpu/qr_kernel.cc +++ b/paddle/phi/kernels/cpu/qr_kernel.cc @@ -17,12 +17,187 @@ #include #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/diagonal_kernel.h" +#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h" namespace phi { +template +static DenseTensor Fill(const Context& ctx, + std::vector shape, + T fill_value) { + DenseTensor ret; + ret.Resize(common::make_ddim(shape)); + ctx.template Alloc(&ret); + funcs::SetConstant()(ctx, &ret, fill_value); + return ret; +} + +template +static DenseTensor identity_matrix(const Context& ctx, common::DDim shape) { + DenseTensor M = + Fill(ctx, common::vectorize(shape), T(0)); + size_t rank = M.dims().size(); + int64_t M_diag_len = std::min(M.dims()[rank - 1], M.dims()[rank - 2]); + std::vector M_diag_shape; + for (size_t i = 0; i < rank - 2; ++i) { + M_diag_shape.push_back(M.dims()[i]); + } + M_diag_shape.push_back(M_diag_len); + DenseTensor M_diag = Fill( + ctx, common::vectorize(make_ddim(M_diag_shape)), T(1)); + M = FillDiagonalTensor(ctx, M, M_diag, 0, rank - 2, rank - 1); + return M; +} + +template +struct QrFunctor { + void operator()(const Context& ctx, + const DenseTensor& x, + bool compute_q, + bool reduced_mode, + DenseTensor* q, + DenseTensor* r) { + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = static_cast(x_dims[x_rank - 2]); + int n = static_cast(x_dims[x_rank - 1]); + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int64_t batch_size = static_cast(x.numel() / (m * n)); + int x_stride = m * n; + int q_stride = m * k; + int r_stride = k * n; + auto* x_data = x.data>(); + T* q_data = nullptr; + if (compute_q) { + q_data = ctx.template Alloc>( + q, batch_size * m * k * sizeof(phi::dtype::Real)); + } + auto* r_data = ctx.template Alloc>( + r, batch_size * k * n * sizeof(phi::dtype::Real)); + + // Implement QR by calling Eigen + for (int i = 0; i < batch_size; ++i) { + const T* x_matrix_ptr = x_data + i * x_stride; + T* r_matrix_ptr = r_data + i * r_stride; + using EigenDynamicMatrix = + Eigen::Matrix; + auto x_matrix = Eigen::Map(x_matrix_ptr, m, n); + Eigen::HouseholderQR qr(x_matrix); + if (reduced_mode) { + auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); + auto r_matrix_view = + qr_top_matrix.template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } else { + auto r_matrix_view = + qr.matrixQR().template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } + + if (compute_q) { + T* q_matrix_ptr = q_data + i * q_stride; + if (reduced_mode) { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } else { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, m); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } + } + } + } +}; + +template +struct QrFunctor, Context> { + void operator()(const Context& ctx, + const DenseTensor& x, + bool compute_q, + bool reduced_mode, + DenseTensor* q, + DenseTensor* r) { + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = static_cast(x_dims[x_rank - 2]); + int n = static_cast(x_dims[x_rank - 1]); + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = static_cast(x.numel() / (m * n)); + int x_stride = m * n; + int q_stride = m * k; + int r_stride = k * n; + auto* x_data = x.data>(); + phi::dtype::complex* q_data = nullptr; + if (compute_q) { + q_data = ctx.template Alloc>( + q, batch_size * m * k * sizeof(phi::dtype::complex)); + } + auto* r_data = ctx.template Alloc>( + r, batch_size * k * n * sizeof(phi::dtype::complex)); + + // Implement QR by calling Eigen + for (int i = 0; i < batch_size; ++i) { + const phi::dtype::complex* x_matrix_ptr = x_data + i * x_stride; + phi::dtype::complex* r_matrix_ptr = r_data + i * r_stride; + using EigenDynamicMatrix = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; + auto x_matrix = Eigen::Map( + reinterpret_cast*>(x_matrix_ptr), m, n); + Eigen::HouseholderQR qr(x_matrix); + if (reduced_mode) { + auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); + auto r_matrix_view = + qr_top_matrix.template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, + r_matrix.data(), + r_matrix.size() * sizeof(phi::dtype::complex)); + } else { + auto r_matrix_view = + qr.matrixQR().template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, + r_matrix.data(), + r_matrix.size() * sizeof(phi::dtype::complex)); + } + + if (compute_q) { + phi::dtype::complex* q_matrix_ptr = q_data + i * q_stride; + if (reduced_mode) { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, + q_matrix.data(), + q_matrix.size() * sizeof(phi::dtype::complex)); + } else { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, m); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, + q_matrix.data(), + q_matrix.size() * sizeof(phi::dtype::complex)); + } + } + } + } +}; + template void QrKernel(const Context& ctx, const DenseTensor& x, @@ -32,65 +207,27 @@ void QrKernel(const Context& ctx, bool compute_q = false; bool reduced_mode = false; std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); - auto numel = x.numel(); - PADDLE_ENFORCE_GT( - numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); - auto x_dims = x.dims(); - int x_rank = x_dims.size(); - int m = static_cast(x_dims[x_rank - 2]); - int n = static_cast(x_dims[x_rank - 1]); - int min_mn = std::min(m, n); - int k = reduced_mode ? min_mn : m; - int batch_size = static_cast(numel / (m * n)); - int x_stride = m * n; - int q_stride = m * k; - int r_stride = k * n; - auto* x_data = x.data>(); - T* q_data = nullptr; - if (compute_q) { - q_data = ctx.template Alloc>( - q, batch_size * m * k * sizeof(phi::dtype::Real)); - } - auto* r_data = ctx.template Alloc>( - r, batch_size * k * n * sizeof(phi::dtype::Real)); - - // Implement QR by calling Eigen - for (int i = 0; i < batch_size; ++i) { - const T* x_matrix_ptr = x_data + i * x_stride; - T* r_matrix_ptr = r_data + i * r_stride; - using EigenDynamicMatrix = - Eigen::Matrix; - auto x_matrix = Eigen::Map(x_matrix_ptr, m, n); - Eigen::HouseholderQR qr(x_matrix); - if (reduced_mode) { - auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); - auto r_matrix_view = - qr_top_matrix.template triangularView(); - auto r_matrix = EigenDynamicMatrix(r_matrix_view); - memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + if (x.numel() == 0) { + if (q->numel() == 0) { + q->Resize(q->dims()); } else { - auto r_matrix_view = - qr.matrixQR().template triangularView(); - auto r_matrix = EigenDynamicMatrix(r_matrix_view); - memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); - } - - if (compute_q) { - T* q_matrix_ptr = q_data + i * q_stride; - if (reduced_mode) { - auto q_matrix = - qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); - q_matrix.transposeInPlace(); - memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); - } else { - auto q_matrix = qr.householderQ() * EigenDynamicMatrix::Identity(m, m); - q_matrix.transposeInPlace(); - memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); - } + *q = identity_matrix(ctx, q->dims()); } + r->Resize(r->dims()); + ctx.template Alloc(q); + ctx.template Alloc(r); + return; } + QrFunctor()(ctx, x, compute_q, reduced_mode, q, r); } } // namespace phi -PD_REGISTER_KERNEL(qr, CPU, ALL_LAYOUT, phi::QrKernel, float, double) {} +PD_REGISTER_KERNEL(qr, + CPU, + ALL_LAYOUT, + phi::QrKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/qr_grad_kernel.cu b/paddle/phi/kernels/gpu/qr_grad_kernel.cu index 9f59ee53c1bb88..59a4d0b5aeb413 100644 --- a/paddle/phi/kernels/gpu/qr_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/qr_grad_kernel.cu @@ -16,5 +16,16 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/qr_grad_kernel_impl.h" +#ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(qr_grad, GPU, ALL_LAYOUT, phi::QrGradKernel, float, double) { } +#else +PD_REGISTER_KERNEL(qr_grad, + GPU, + ALL_LAYOUT, + phi::QrGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} +#endif diff --git a/paddle/phi/kernels/gpu/qr_kernel.cu b/paddle/phi/kernels/gpu/qr_kernel.cu index 39f2826eed7f60..e153093f937504 100644 --- a/paddle/phi/kernels/gpu/qr_kernel.cu +++ b/paddle/phi/kernels/gpu/qr_kernel.cu @@ -22,10 +22,13 @@ #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/kernels/diagonal_kernel.h" +#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h" @@ -39,132 +42,103 @@ namespace phi { template static DenseTensor Fill(const Context& ctx, - std::vector shape, - float fill_value) { + std::vector shape, + T fill_value) { DenseTensor ret; ret.Resize(common::make_ddim(shape)); ctx.template Alloc(&ret); - funcs::SetConstant()(ctx, &ret, T(fill_value)); + funcs::SetConstant()(ctx, &ret, fill_value); return ret; } +template +static DenseTensor identity_matrix(const Context& ctx, common::DDim shape) { + DenseTensor M = + Fill(ctx, common::vectorize(shape), T(0)); + size_t rank = M.dims().size(); + int64_t M_diag_len = std::min(M.dims()[rank - 1], M.dims()[rank - 2]); + std::vector M_diag_shape; + for (size_t i = 0; i < rank - 2; ++i) { + M_diag_shape.push_back(M.dims()[i]); + } + M_diag_shape.push_back(M_diag_len); + DenseTensor M_diag = Fill( + ctx, common::vectorize(make_ddim(M_diag_shape)), T(1)); + M = FillDiagonalTensor(ctx, M, M_diag, 0, rank - 2, rank - 1); + return M; +} + template -void QrKernel(const Context& ctx, - const DenseTensor& x, - const std::string& mode, - DenseTensor* q, - DenseTensor* r) { - bool compute_q; - bool reduced_mode; - std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); - auto numel = x.numel(); - PADDLE_ENFORCE_GT( - numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); - auto x_dims = x.dims(); - int x_rank = x_dims.size(); - int m = x_dims[x_rank - 2]; - int n = x_dims[x_rank - 1]; - int min_mn = std::min(m, n); - int k = reduced_mode ? min_mn : m; - int batch_size = numel / (m * n); - int qr_stride = m * n; - int tau_stride = min_mn; - - if (compute_q) { +struct QrFunctor { + void operator()(const Context& ctx, + const DenseTensor& x, + bool compute_q, + bool reduced_mode, + DenseTensor* q, + DenseTensor* r) { + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int64_t batch_size = static_cast(x.numel() / (m * n)); + int qr_stride = m * n; + int tau_stride = min_mn; + + if (compute_q) { + ctx.template Alloc>( + q, batch_size * m * k * sizeof(phi::dtype::Real)); + } ctx.template Alloc>( - q, batch_size * m * k * sizeof(phi::dtype::Real)); - } - ctx.template Alloc>( - r, batch_size * k * n * sizeof(phi::dtype::Real)); - - // Note: allocate temporary tensors because of lacking in-place operatios. - // Prepare qr - DenseTensor qr; - ctx.template Alloc>( - &qr, size_t(batch_size * m * n * sizeof(phi::dtype::Real))); - // BatchedGeqrf performs computation in-place and 'qr' must be a copy of - // input - phi::Copy(ctx, x, ctx.GetPlace(), false, &qr); - - // Prepare tau - auto tau_dims_vec = common::vectorize(x_dims); - tau_dims_vec.pop_back(); - tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; - DenseTensor tau = Fill(ctx, tau_dims_vec, 0); - - // Transpose 'qr' to conform the column-major order - auto tmp_qr = TransposeLast2Dim(ctx, qr); - phi::Copy(ctx, tmp_qr, qr.place(), false, &qr); - auto qr_data = ctx.template Alloc>(&qr); - auto tau_data = ctx.template Alloc>(&tau); - - BatchedGeqrf( - ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride); - - if (reduced_mode) { - auto trans_qr = TransposeLast2Dim(ctx, qr); - auto sliced_qr = Slice( - ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}); - auto tmp_r = TrilTriu(ctx, sliced_qr, 0, false); - // Transpose 'tmp_r' to restore the original row-major order - phi::Copy(ctx, tmp_r, r->place(), false, r); - } else { - auto trans_qr = TransposeLast2Dim(ctx, qr); - auto tmp_r = TrilTriu(ctx, trans_qr, 0, false); - // Transpose 'tmp_r' to restore the original row-major order - phi::Copy(ctx, tmp_r, r->place(), false, r); - } + r, batch_size * k * n * sizeof(phi::dtype::Real)); + + // Note: allocate temporary tensors because of lacking in-place operatios. + // Prepare qr + DenseTensor qr; + ctx.template Alloc>( + &qr, size_t(batch_size * m * n * sizeof(phi::dtype::Real))); + // BatchedGeqrf performs computation in-place and 'qr' must be a copy of + // input + phi::Copy(ctx, x, ctx.GetPlace(), false, &qr); + + // Prepare tau + auto tau_dims_vec = common::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + DenseTensor tau = Fill(ctx, tau_dims_vec, T(0)); + + // Transpose 'qr' to conform the column-major order + auto tmp_qr = TransposeLast2Dim(ctx, qr); + phi::Copy(ctx, tmp_qr, qr.place(), false, &qr); + auto qr_data = ctx.template Alloc>(&qr); + auto tau_data = ctx.template Alloc>(&tau); + + BatchedGeqrf( + ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride); - if (compute_q) { - // Perform QRGQR for Q using the result from GEQRF - // Transpose 'q' to restore the original row-major order if (reduced_mode) { - BatchedOrgqr(ctx, - batch_size, - m, - min_mn, - min_mn, - qr_data, - m, - tau_data, - qr_stride, - tau_stride); - auto trans_q = TransposeLast2Dim(ctx, qr); - auto sliced_q = Slice( - ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}); - phi::Copy(ctx, sliced_q, q->place(), false, q); + auto trans_qr = TransposeLast2Dim(ctx, qr); + auto sliced_qr = Slice( + ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}); + auto tmp_r = TrilTriu(ctx, sliced_qr, 0, false); + // Transpose 'tmp_r' to restore the original row-major order + phi::Copy(ctx, tmp_r, r->place(), false, r); } else { - if (m > n) { - auto new_qr_dims_vec = common::vectorize(x_dims); - new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; - DenseTensor new_qr = Fill(ctx, new_qr_dims_vec, 0); - auto new_qr_data = ctx.template Alloc>(&new_qr); - auto new_qr_stride = m * m; - for (int i = 0; i < batch_size; ++i) { - memory_utils::Copy(ctx.GetPlace(), - (new_qr_data + i * new_qr_stride), - ctx.GetPlace(), - (qr_data + i * qr_stride), - qr_stride * sizeof(phi::dtype::Real), - ctx.stream()); - } + auto trans_qr = TransposeLast2Dim(ctx, qr); + auto tmp_r = TrilTriu(ctx, trans_qr, 0, false); + // Transpose 'tmp_r' to restore the original row-major order + phi::Copy(ctx, tmp_r, r->place(), false, r); + } + + if (compute_q) { + // Perform QRGQR for Q using the result from GEQRF + // Transpose 'q' to restore the original row-major order + if (reduced_mode) { BatchedOrgqr(ctx, batch_size, m, - m, min_mn, - new_qr_data, - m, - tau_data, - new_qr_stride, - tau_stride); - auto trans_q = TransposeLast2Dim(ctx, new_qr); - phi::Copy(ctx, trans_q, q->place(), false, q); - } else { - BatchedOrgqr(ctx, - batch_size, - m, - m, min_mn, qr_data, m, @@ -173,11 +147,209 @@ void QrKernel(const Context& ctx, tau_stride); auto trans_q = TransposeLast2Dim(ctx, qr); auto sliced_q = Slice( - ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}); + ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}); phi::Copy(ctx, sliced_q, q->place(), false, q); + } else { + if (m > n) { + auto new_qr_dims_vec = common::vectorize(x_dims); + new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; + DenseTensor new_qr = Fill(ctx, new_qr_dims_vec, T(0)); + auto new_qr_data = ctx.template Alloc>(&new_qr); + auto new_qr_stride = m * m; + for (int i = 0; i < batch_size; ++i) { + memory_utils::Copy(ctx.GetPlace(), + (new_qr_data + i * new_qr_stride), + ctx.GetPlace(), + (qr_data + i * qr_stride), + qr_stride * sizeof(phi::dtype::Real), + ctx.stream()); + } + BatchedOrgqr(ctx, + batch_size, + m, + m, + min_mn, + new_qr_data, + m, + tau_data, + new_qr_stride, + tau_stride); + auto trans_q = TransposeLast2Dim(ctx, new_qr); + phi::Copy(ctx, trans_q, q->place(), false, q); + } else { + BatchedOrgqr(ctx, + batch_size, + m, + m, + min_mn, + qr_data, + m, + tau_data, + qr_stride, + tau_stride); + auto trans_q = TransposeLast2Dim(ctx, qr); + auto sliced_q = Slice( + ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}); + phi::Copy(ctx, sliced_q, q->place(), false, q); + } } } } +}; + +template +struct QrFunctor, Context> { + void operator()(const Context& ctx, + const DenseTensor& x, + bool compute_q, + bool reduced_mode, + DenseTensor* q, + DenseTensor* r) { + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = x.numel() / (m * n); + int qr_stride = m * n; + int tau_stride = min_mn; + if (compute_q) { + ctx.template Alloc>( + q, batch_size * m * k * sizeof(phi::dtype::complex)); + } + ctx.template Alloc>( + r, batch_size * k * n * sizeof(phi::dtype::complex)); + // Note: allocate temporary tensors because of lacking in-place operatios. + // Prepare qr + DenseTensor qr; + ctx.template Alloc>( + &qr, size_t(batch_size * m * n * sizeof(phi::dtype::complex))); + // BatchedGeqrf performs computation in-place and 'qr' must be a copy of + // input + phi::Copy(ctx, x, ctx.GetPlace(), false, &qr); + // Prepare tau + auto tau_dims_vec = common::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + DenseTensor tau = + Fill, Context>(ctx, tau_dims_vec, T(0)); + // Transpose 'qr' to conform the column-major order + auto tmp_qr = TransposeLast2Dim, Context>(ctx, qr); + phi::Copy(ctx, tmp_qr, qr.place(), false, &qr); + auto qr_data = ctx.template Alloc>(&qr); + auto tau_data = ctx.template Alloc>(&tau); + BatchedGeqrf>( + ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, tau_stride); + if (reduced_mode) { + auto trans_qr = + TransposeLast2Dim, Context>(ctx, qr); + auto sliced_qr = Slice, Context>( + ctx, trans_qr, {trans_qr.dims().size() - 2}, {0}, {min_mn}); + auto tmp_r = + TrilTriu, Context>(ctx, sliced_qr, 0, false); + // Transpose 'tmp_r' to restore the original row-major order + phi::Copy(ctx, tmp_r, r->place(), false, r); + } else { + auto trans_qr = + TransposeLast2Dim, Context>(ctx, qr); + auto tmp_r = + TrilTriu, Context>(ctx, trans_qr, 0, false); + // Transpose 'tmp_r' to restore the original row-major order + phi::Copy(ctx, tmp_r, r->place(), false, r); + } + if (compute_q) { + // Perform QRGQR for Q using the result from GEQRF + // Transpose 'q' to restore the original row-major order + if (reduced_mode) { + BatchedOrgqr>(ctx, + batch_size, + m, + min_mn, + min_mn, + qr_data, + m, + tau_data, + qr_stride, + tau_stride); + auto trans_q = + TransposeLast2Dim, Context>(ctx, qr); + auto sliced_q = Slice, Context>( + ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {min_mn}); + phi::Copy(ctx, sliced_q, q->place(), false, q); + } else { + if (m > n) { + auto new_qr_dims_vec = common::vectorize(x_dims); + new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; + DenseTensor new_qr = + Fill, Context>(ctx, new_qr_dims_vec, T(0)); + auto new_qr_data = + ctx.template Alloc>(&new_qr); + auto new_qr_stride = m * m; + for (int i = 0; i < batch_size; ++i) { + memory_utils::Copy(ctx.GetPlace(), + (new_qr_data + i * new_qr_stride), + ctx.GetPlace(), + (qr_data + i * qr_stride), + qr_stride * sizeof(phi::dtype::complex), + ctx.stream()); + } + BatchedOrgqr>(ctx, + batch_size, + m, + m, + min_mn, + new_qr_data, + m, + tau_data, + new_qr_stride, + tau_stride); + auto trans_q = + TransposeLast2Dim, Context>(ctx, new_qr); + phi::Copy(ctx, trans_q, q->place(), false, q); + } else { + BatchedOrgqr>(ctx, + batch_size, + m, + m, + min_mn, + qr_data, + m, + tau_data, + qr_stride, + tau_stride); + auto trans_q = + TransposeLast2Dim, Context>(ctx, qr); + auto sliced_q = Slice, Context>( + ctx, trans_q, {trans_q.dims().size() - 1}, {0}, {m}); + phi::Copy(ctx, sliced_q, q->place(), false, q); + } + } + } + } +}; + +template +void QrKernel(const Context& ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* q, + DenseTensor* r) { + bool compute_q; + bool reduced_mode; + std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); + if (x.numel() == 0) { + if (q->numel() == 0) { + q->Resize(q->dims()); + } else { + *q = identity_matrix(ctx, q->dims()); + } + r->Resize(r->dims()); + ctx.template Alloc(q); + ctx.template Alloc(r); + return; + } + QrFunctor()(ctx, x, compute_q, reduced_mode, q, r); } #ifdef PADDLE_WITH_HIP @@ -335,6 +507,120 @@ void BatchedGeqrf(const GPUContext& dev_ctx, } } +template <> +void BatchedGeqrf>( + const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + phi::dtype::complex* a, + int lda, + phi::dtype::complex* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf_bufferSize( + handle, m, n, reinterpret_cast(a), lda, &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(common::make_ddim({lwork})); + phi::dtype::complex* workspace_ptr = + dev_ctx.template Alloc>(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(common::make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + phi::dtype::complex* a_working_ptr = &a[i * a_stride]; + phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCgeqrf( + handle, + m, + n, + reinterpret_cast(a_working_ptr), + lda, + reinterpret_cast(tau_working_ptr), + reinterpret_cast(workspace_ptr), + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory_utils::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + common::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedGeqrf>( + const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + phi::dtype::complex* a, + int lda, + phi::dtype::complex* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf_bufferSize( + handle, m, n, reinterpret_cast(a), lda, &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(common::make_ddim({lwork})); + phi::dtype::complex* workspace_ptr = + dev_ctx.template Alloc>(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(common::make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + phi::dtype::complex* a_working_ptr = &a[i * a_stride]; + phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZgeqrf( + handle, + m, + n, + reinterpret_cast(a_working_ptr), + lda, + reinterpret_cast(tau_working_ptr), + reinterpret_cast(workspace_ptr), + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory_utils::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + common::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + template <> void BatchedOrgqr(const GPUContext& dev_ctx, int batch_size, @@ -446,8 +732,151 @@ void BatchedOrgqr(const GPUContext& dev_ctx, "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); } } + +template <> +void BatchedOrgqr>( + const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + phi::dtype::complex* a, + int lda, + phi::dtype::complex* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr_bufferSize( + handle, + m, + n, + k, + reinterpret_cast(a), + lda, + reinterpret_cast(tau), + &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(common::make_ddim({lwork})); + phi::dtype::complex* workspace_ptr = + dev_ctx.template Alloc>(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(common::make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + phi::dtype::complex* a_working_ptr = &a[i * a_stride]; + phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnCungqr( + handle, + m, + n, + k, + reinterpret_cast(a_working_ptr), + lda, + reinterpret_cast(tau_working_ptr), + reinterpret_cast(workspace_ptr), + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory_utils::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + common::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +template <> +void BatchedOrgqr>( + const GPUContext& dev_ctx, + int batch_size, + int m, + int n, + int k, + phi::dtype::complex* a, + int lda, + phi::dtype::complex* tau, + int a_stride, + int tau_stride) { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr_bufferSize( + handle, + m, + n, + k, + reinterpret_cast(a), + lda, + reinterpret_cast(tau), + &lwork)); + + DenseTensor workspace = DenseTensor(); + workspace.Resize(common::make_ddim({lwork})); + phi::dtype::complex* workspace_ptr = + dev_ctx.template Alloc>(&workspace); + + DenseTensor info = DenseTensor(); + info.Resize(common::make_ddim({1})); + int* info_d = dev_ctx.template Alloc(&info); + + for (int i = 0; i < batch_size; ++i) { + phi::dtype::complex* a_working_ptr = &a[i * a_stride]; + phi::dtype::complex* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cusolverDnZungqr( + handle, + m, + n, + k, + reinterpret_cast(a_working_ptr), + lda, + reinterpret_cast(tau_working_ptr), + reinterpret_cast(workspace_ptr), + lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory_utils::Copy(phi::CPUPlace(), + &info_h, + dev_ctx.GetPlace(), + info_d, + sizeof(int), + dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, + 0, + common::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} #endif } // namespace phi +#ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(qr, GPU, ALL_LAYOUT, phi::QrKernel, float, double) {} +#else +PD_REGISTER_KERNEL(qr, + GPU, + ALL_LAYOUT, + phi::QrKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} +#endif diff --git a/paddle/phi/kernels/impl/qr_grad_kernel_impl.h b/paddle/phi/kernels/impl/qr_grad_kernel_impl.h index e015909d6e7b56..de6b63efbfd703 100644 --- a/paddle/phi/kernels/impl/qr_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/qr_grad_kernel_impl.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/phi/common/complex.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" @@ -20,9 +21,12 @@ #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/concat_kernel.h" +#include "paddle/phi/kernels/diagonal_kernel.h" #include "paddle/phi/kernels/elementwise_add_kernel.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/fill_diagonal_tensor_kernel.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/matmul_kernel.h" @@ -61,7 +65,7 @@ void QrGradKernel(const Context& ctx, const DenseTensor& dR = r_grad; DenseTensor& dA = *x_grad; - ctx.template Alloc>(&dA); + ctx.template Alloc(&dA); phi::funcs::SetConstant()(ctx, &dA, T(0)); bool compute_q, reduced; @@ -91,15 +95,17 @@ void QrGradKernel(const Context& ctx, const DenseTensor& A UNUSED, const DenseTensor& Q, const DenseTensor& R) -> DenseTensor { - // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable - // Programming Tensor Networks. - // https://arxiv.org/abs/1903.09650 Section 3. QR factorization + // Roberts, D., & Roberts, L. (2020). QR and LQ Decomposition Matrix + // Backpropagation Algorithms for Square, Wide, and Deep Matrices and Their + // Software Implementation. https://arxiv.org/abs/2009.10071v4 // dR^H DenseTensor R_term; if (dR.initialized()) { - R_term = - Matmul(ctx, R, TransposeLast2Dim(ctx, dR)); + R_term = Matmul( + ctx, + R, + TransposeLast2Dim(ctx, Conj(ctx, dR))); } else { R_term = Fill(ctx, common::vectorize(R.dims()), 0); } @@ -107,19 +113,55 @@ void QrGradKernel(const Context& ctx, // dQ^H * Q DenseTensor Q_term; if (dQ.initialized()) { - Q_term = - Matmul(ctx, TransposeLast2Dim(ctx, dQ), Q); + Q_term = Matmul( + ctx, + TransposeLast2Dim(ctx, Conj(ctx, dQ)), + Q); } else { Q_term = Fill(ctx, common::vectorize(R.dims()), 0); } DenseTensor M_tmp1 = Subtract(ctx, R_term, Q_term); - + DenseTensor M; +#ifdef PADDLE_WITH_HIP // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity DenseTensor M_tril_0 = TrilTriu(ctx, M_tmp1, 0, true); DenseTensor M_tril_1 = TrilTriu(ctx, M_tmp1, -1, true); - DenseTensor M = Add( + M = Add( ctx, M_tril_0, TransposeLast2Dim(ctx, M_tril_1)); +#else + if (std::is_same>::value || + std::is_same>::value) { + DenseTensor M_tril_tmp = TrilTriu(ctx, M_tmp1, -1, true); + DenseTensor M_tril = + Add(ctx, + M_tril_tmp, + TransposeLast2Dim( + ctx, Conj(ctx, M_tril_tmp))); + + size_t rank = M_tmp1.dims().size(); + DenseTensor M_diag_tmp = + Diagonal(ctx, M_tmp1, 0, rank - 2, rank - 1); + DenseTensor M_diag_real = Real(ctx, M_diag_tmp); + DenseTensor M_diag_imag = Fill, Context>( + ctx, common::vectorize(M_diag_real.dims()), 0); + + DenseTensor M_diag; + M_diag.Resize(M_diag_real.dims()); + ctx.template Alloc(&M_diag); + phi::ComplexKernel>( + ctx, M_diag_real, M_diag_imag, &M_diag); + + M = FillDiagonalTensor( + ctx, M_tril, M_diag, 0, rank - 2, rank - 1); + } else { + // Compute M = (tril(M) + tril(M).mH()) * 0.5 Identity + DenseTensor M_tril_0 = TrilTriu(ctx, M_tmp1, 0, true); + DenseTensor M_tril_1 = TrilTriu(ctx, M_tmp1, -1, true); + M = Add( + ctx, M_tril_0, TransposeLast2Dim(ctx, M_tril_1)); + } +#endif DenseTensor rhs_term; if (dQ.initialized()) { @@ -151,23 +193,25 @@ void QrGradKernel(const Context& ctx, auto Y = Slice(ctx, A, {A.dims().size() - 1}, {m}, {n}); auto U = Slice(ctx, R, {R.dims().size() - 1}, {0}, {m}); - DenseTensor dY, dX, dV, dR_tmp, dQ_prime; + DenseTensor dY, dX, dV, dU, dQ_prime; if (dR.initialized()) { dV = Slice(ctx, dR, {dR.dims().size() - 1}, {m}, {n}); - dR_tmp = Slice(ctx, dR, {dR.dims().size() - 1}, {0}, {m}); + dU = Slice(ctx, dR, {dR.dims().size() - 1}, {0}, {m}); // Y * dV^H - dQ_prime = - Matmul(ctx, Y, TransposeLast2Dim(ctx, dV)); + dQ_prime = Matmul( + ctx, + Y, + TransposeLast2Dim(ctx, Conj(ctx, dV))); } else { dV = Fill(ctx, common::vectorize(Y.dims()), 0); dQ_prime = Fill(ctx, common::vectorize(Q.dims()), 0); } if (dQ.initialized()) { - dQ_prime = Add(ctx, dQ_prime, dQ); + dQ_prime = Add(ctx, dQ, dQ_prime); } - dX = m_gt_n_case(ctx, dQ_prime, dR_tmp, A, Q, U); + dX = m_gt_n_case(ctx, dQ_prime, dU, A, Q, U); dY = Matmul(ctx, Q, dV); // Concatenate dX and dY to get dA. auto dA_tmp = Concat(ctx, {&dX, &dY}, -1); diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 2c6508200ed1ae..de3296efb7d6ff 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3387,7 +3387,7 @@ def qr( Args: x (Tensor): The input tensor. Its shape should be `[..., M, N]`, where ... is zero or more batch dimensions. M and N can be arbitrary - positive number. The data type of x should be float32 or float64. + positive number. The data type of x supports float, double, complex64, complex128. mode (str, optional): A flag to control the behavior of qr. Suppose x's shape is `[..., M, N]` and denoting `K = min(M, N)`: If mode = "reduced", qr op will return reduced Q and R matrices, @@ -3429,7 +3429,9 @@ def qr( else: return q, r else: - check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'qr') + check_variable_and_dtype( + x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'qr' + ) check_type(mode, 'mode', str, 'qr') helper = LayerHelper('qr', **locals()) q = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_qr_op.py b/test/legacy_test/test_qr_op.py index 3d53b02012fcde..8ec5413cde55c0 100644 --- a/test/legacy_test/test_qr_op.py +++ b/test/legacy_test/test_qr_op.py @@ -13,11 +13,11 @@ # limitations under the License. import itertools -import os import unittest import numpy as np from op_test import OpTest +from utils import dygraph_guard, static_guard import paddle from paddle import base, static @@ -26,14 +26,14 @@ class TestQrOp(OpTest): def setUp(self): - paddle.enable_static() - self.python_api = paddle.linalg.qr - np.random.seed(7) - self.op_type = "qr" - a, q, r = self.get_input_and_output() - self.inputs = {"X": a} - self.attrs = {"mode": self.get_mode()} - self.outputs = {"Q": q, "R": r} + with static_guard(): + self.python_api = paddle.linalg.qr + np.random.seed(7) + self.op_type = "qr" + a, q, r = self.get_input_and_output() + self.inputs = {"X": a} + self.attrs = {"mode": self.get_mode()} + self.outputs = {"Q": q, "R": r} def get_dtype(self): return "float64" @@ -44,31 +44,20 @@ def get_mode(self): def get_shape(self): return (11, 11) + def _get_places(self): + places = [] + places.append(base.CPUPlace()) + if core.is_compiled_with_cuda(): + places.append(base.CUDAPlace(0)) + return places + def get_input_and_output(self): dtype = self.get_dtype() shape = self.get_shape() mode = self.get_mode() assert mode != "r", "Cannot be backward in r mode." a = np.random.rand(*shape).astype(dtype) - m = a.shape[-2] - n = a.shape[-1] - min_mn = min(m, n) - if mode == "reduced": - k = min_mn - else: - k = m - q_shape = list(a.shape[:-2]) - q_shape.extend([m, k]) - r_shape = list(a.shape[:-2]) - r_shape.extend([k, n]) - q = np.zeros(q_shape).astype(dtype) - r = np.zeros(r_shape).astype(dtype) - batch_size = a.size // (a.shape[-1] * a.shape[-2]) - for i in range(batch_size): - coord = np.unravel_index(i, a.shape[:-2]) - tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) - q[coord] = tmp_q - r[coord] = tmp_r + q, r = np.linalg.qr(a, mode=mode) return a, q, r def test_check_output(self): @@ -120,52 +109,88 @@ def get_shape(self): return (2, 10, 12) +@unittest.skipIf( + core.is_compiled_with_xpu(), + "Skip XPU for complex dtype is not fully supported", +) +class TestQrOpcomplex(TestQrOp): + def get_input_and_output(self): + dtype = self.get_dtype() + shape = self.get_shape() + mode = self.get_mode() + assert mode != "r", "Cannot be backward in r mode." + a_real = np.random.rand(*shape).astype(dtype) + a_imag = np.random.rand(*shape).astype(dtype) + a = a_real + 1j * a_imag + q, r = np.linalg.qr(a, mode=mode) + return a, q, r + + +@unittest.skipIf( + core.is_compiled_with_xpu(), + "Skip XPU for complex dtype is not fully supported", +) +class TestQrOpcomplexCase1(TestQrOpcomplex): + def get_shape(self): + return (16, 15) + + +@unittest.skipIf( + core.is_compiled_with_xpu(), + "Skip XPU for complex dtype is not fully supported", +) +class TestQrOpcomplexCase2(TestQrOpcomplex): + def get_shape(self): + return (3, 16, 15) + + +@unittest.skipIf( + core.is_compiled_with_xpu(), + "Skip XPU for complex dtype is not fully supported", +) +class TestQrOpcomplexCase3(TestQrOpcomplex): + def get_shape(self): + return (12, 15) + + +@unittest.skipIf( + core.is_compiled_with_xpu(), + "Skip XPU for complex dtype is not fully supported", +) +class TestQrOpcomplexCase4(TestQrOpcomplex): + def get_shape(self): + return (3, 12, 15) + + class TestQrAPI(unittest.TestCase): def test_dygraph(self): - paddle.disable_static() - np.random.seed(7) - def run_qr_dygraph(shape, mode, dtype): if dtype == "float32": np_dtype = np.float32 elif dtype == "float64": np_dtype = np.float64 - a = np.random.rand(*shape).astype(np_dtype) - m = a.shape[-2] - n = a.shape[-1] - min_mn = min(m, n) - if mode == "reduced" or mode == "r": - k = min_mn + elif dtype == "complex64": + np_dtype = np.complex64 + elif dtype == "complex128": + np_dtype = np.complex128 + if np.issubdtype(np_dtype, np.complexfloating): + a_dtype = np.float32 if np_dtype == np.complex64 else np.float64 + a_real = np.random.rand(*shape).astype(a_dtype) + a_imag = np.random.rand(*shape).astype(a_dtype) + a = a_real + 1j * a_imag else: - k = m - np_q_shape = list(a.shape[:-2]) - np_q_shape.extend([m, k]) - np_r_shape = list(a.shape[:-2]) - np_r_shape.extend([k, n]) - np_q = np.zeros(np_q_shape).astype(np_dtype) - np_r = np.zeros(np_r_shape).astype(np_dtype) + a = np.random.rand(*shape).astype(np_dtype) places = [] - if ( - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() - in ['1', 'true', 'on'] - or not core.is_compiled_with_cuda() - ): - places.append(base.CPUPlace()) + places.append('cpu') if core.is_compiled_with_cuda(): - places.append(base.CUDAPlace(0)) + places.append('gpu') for place in places: - batch_size = a.size // (a.shape[-1] * a.shape[-2]) - for i in range(batch_size): - coord = np.unravel_index(i, a.shape[:-2]) - if mode == "r": - tmp_r = np.linalg.qr(a[coord], mode=mode) - np_r[coord] = tmp_r - else: - tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) - np_q[coord] = tmp_q - np_r[coord] = tmp_r + if mode == "r": + np_r = np.linalg.qr(a, mode=mode) + else: + np_q, np_r = np.linalg.qr(a, mode=mode) - x = paddle.to_tensor(a, dtype=dtype) + x = paddle.to_tensor(a, dtype=dtype, place=place) if mode == "r": r = paddle.linalg.qr(x, mode=mode) np.testing.assert_allclose(r, np_r, rtol=1e-05, atol=1e-05) @@ -174,74 +199,64 @@ def run_qr_dygraph(shape, mode, dtype): np.testing.assert_allclose(q, np_q, rtol=1e-05, atol=1e-05) np.testing.assert_allclose(r, np_r, rtol=1e-05, atol=1e-05) - tensor_shapes = [ - (3, 5), - (5, 5), - (5, 3), # 2-dim Tensors - (2, 3, 5), - (3, 5, 5), - (4, 5, 3), # 3-dim Tensors - (2, 5, 3, 5), - (3, 5, 5, 5), - (4, 5, 5, 3), # 4-dim Tensors - ] - modes = ["reduced", "complete", "r"] - dtypes = ["float32", "float64"] - for tensor_shape, mode, dtype in itertools.product( - tensor_shapes, modes, dtypes - ): - run_qr_dygraph(tensor_shape, mode, dtype) + with dygraph_guard(): + np.random.seed(7) + tensor_shapes = [ + (0, 3), + (3, 5), + (5, 5), + (5, 3), # 2-dim Tensors + (0, 3, 5), + (4, 0, 5), + (5, 4, 0), + (2, 3, 5), + (3, 5, 5), + (4, 5, 3), # 3-dim Tensors + (0, 5, 3, 5), + (2, 5, 3, 5), + (3, 5, 5, 5), + (4, 5, 5, 3), # 4-dim Tensors + ] + modes = ["reduced", "complete", "r"] + dtypes = ["float32", "float64", 'complex64', 'complex128'] + for tensor_shape, mode, dtype in itertools.product( + tensor_shapes, modes, dtypes + ): + run_qr_dygraph(tensor_shape, mode, dtype) def test_static(self): - paddle.enable_static() - np.random.seed(7) - def run_qr_static(shape, mode, dtype): if dtype == "float32": np_dtype = np.float32 elif dtype == "float64": np_dtype = np.float64 - a = np.random.rand(*shape).astype(np_dtype) - m = a.shape[-2] - n = a.shape[-1] - min_mn = min(m, n) - if mode == "reduced" or mode == "r": - k = min_mn + elif dtype == "complex64": + np_dtype = np.complex64 + elif dtype == "complex128": + np_dtype = np.complex128 + if np.issubdtype(np_dtype, np.complexfloating): + a_dtype = np.float32 if np_dtype == np.complex64 else np.float64 + a_real = np.random.rand(*shape).astype(a_dtype) + a_imag = np.random.rand(*shape).astype(a_dtype) + a = a_real + 1j * a_imag else: - k = m - np_q_shape = list(a.shape[:-2]) - np_q_shape.extend([m, k]) - np_r_shape = list(a.shape[:-2]) - np_r_shape.extend([k, n]) - np_q = np.zeros(np_q_shape).astype(np_dtype) - np_r = np.zeros(np_r_shape).astype(np_dtype) + a = np.random.rand(*shape).astype(np_dtype) places = [] - if ( - os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower() - in ['1', 'true', 'on'] - or not core.is_compiled_with_cuda() - ): - places.append(base.CPUPlace()) + places.append(paddle.CPUPlace()) if core.is_compiled_with_cuda(): - places.append(base.CUDAPlace(0)) + places.append(paddle.CUDAPlace(0)) for place in places: with static.program_guard(static.Program(), static.Program()): - batch_size = a.size // (a.shape[-1] * a.shape[-2]) - for i in range(batch_size): - coord = np.unravel_index(i, a.shape[:-2]) - if mode == "r": - tmp_r = np.linalg.qr(a[coord], mode=mode) - np_r[coord] = tmp_r - else: - tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) - np_q[coord] = tmp_q - np_r[coord] = tmp_r + if mode == "r": + np_r = np.linalg.qr(a, mode=mode) + else: + np_q, np_r = np.linalg.qr(a, mode=mode) x = paddle.static.data( name="input", shape=shape, dtype=dtype ) if mode == "r": r = paddle.linalg.qr(x, mode=mode) - exe = base.Executor(place) + exe = base.Executor(place=place) fetches = exe.run( feed={"input": a}, fetch_list=[r], @@ -251,7 +266,7 @@ def run_qr_static(shape, mode, dtype): ) else: q, r = paddle.linalg.qr(x, mode=mode) - exe = base.Executor(place) + exe = base.Executor(place=place) fetches = exe.run( feed={"input": a}, fetch_list=[q, r], @@ -263,23 +278,28 @@ def run_qr_static(shape, mode, dtype): fetches[1], np_r, rtol=1e-05, atol=1e-05 ) - tensor_shapes = [ - (3, 5), - (5, 5), - (5, 3), # 2-dim Tensors - (2, 3, 5), - (3, 5, 5), - (4, 5, 3), # 3-dim Tensors - (2, 5, 3, 5), - (3, 5, 5, 5), - (4, 5, 5, 3), # 4-dim Tensors - ] - modes = ["reduced", "complete", "r"] - dtypes = ["float32", "float64"] - for tensor_shape, mode, dtype in itertools.product( - tensor_shapes, modes, dtypes - ): - run_qr_static(tensor_shape, mode, dtype) + with static_guard(): + np.random.seed(7) + tensor_shapes = [ + (0, 3), + (3, 5), + (5, 5), + (5, 3), # 2-dim Tensors + (0, 3, 5), + (4, 0, 5), + (5, 4, 0), + (4, 5, 3), # 3-dim Tensors + (0, 5, 3, 5), + (2, 5, 3, 5), + (3, 5, 5, 5), + (4, 5, 5, 3), # 4-dim Tensors + ] + modes = ["reduced", "complete", "r"] + dtypes = ["float32", "float64", 'complex64', 'complex128'] + for tensor_shape, mode, dtype in itertools.product( + tensor_shapes, modes, dtypes + ): + run_qr_static(tensor_shape, mode, dtype) if __name__ == "__main__":