From 7acf68686edbee6c052d28d38edcf97b668e7e94 Mon Sep 17 00:00:00 2001 From: Apostolos Chalkis Date: Tue, 9 Jul 2024 19:16:46 -0600 Subject: [PATCH] implement vaidya minimizer and hessian computation --- .../inscribed_ellipsoid_rounding.hpp | 3 +- .../preprocess/rounding_util_functions.hpp | 39 ++++++++++++++++--- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/include/preprocess/inscribed_ellipsoid_rounding.hpp b/include/preprocess/inscribed_ellipsoid_rounding.hpp index 37468b9d9..4b027f61b 100644 --- a/include/preprocess/inscribed_ellipsoid_rounding.hpp +++ b/include/preprocess/inscribed_ellipsoid_rounding.hpp @@ -26,7 +26,8 @@ compute_inscribed_ellipsoid(Custom_MT A, VT b, VT const& x0, { return max_inscribed_ellipsoid(A, b, x0, maxiter, tol, reg); } else if constexpr (ellipsoid_type == EllipsoidType::LOG_BARRIER || - ellipsoid_type == EllipsoidType::VOLUMETRIC_BARRIER) + ellipsoid_type == EllipsoidType::VOLUMETRIC_BARRIER || + ellipsoid_type == EllipsoidType::VAIDYA_BARRIER) { return barrier_center_ellipsoid_linear_ineq(A, b, x0); } else diff --git a/include/preprocess/rounding_util_functions.hpp b/include/preprocess/rounding_util_functions.hpp index 6b94ace8b..7aa5cf5b0 100644 --- a/include/preprocess/rounding_util_functions.hpp +++ b/include/preprocess/rounding_util_functions.hpp @@ -22,7 +22,8 @@ enum EllipsoidType { MAX_ELLIPSOID = 1, LOG_BARRIER = 2, - VOLUMETRIC_BARRIER = 3 + VOLUMETRIC_BARRIER = 3, + VAIDYA_BARRIER = 4 }; template @@ -345,7 +346,8 @@ std::tuple init_step() if constexpr (BarrierType == EllipsoidType::LOG_BARRIER) { return {NT(1), NT(0.99)}; - } else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER) + } else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER || + BarrierType == EllipsoidType::VAIDYA_BARRIER) { return {NT(0.5), NT(0.4)}; } else { @@ -362,21 +364,43 @@ void get_barrier_hessian_grad(MT const& A, MT const& A_trans, VT const& b, b_Ax.noalias() = b - Ax; VT s = b_Ax.cwiseInverse(); VT s_sq = s.cwiseProduct(s); + VT sigma; // Hessian of the log-barrier function update_Atrans_Diag_A(H, A_trans, A, s_sq.asDiagonal()); + + if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER || + BarrierType == EllipsoidType::VAIDYA_BARRIER) + { + // Computing sigma(x)_i = (a_i^T H^{-1} a_i) / (b_i - a_i^Tx)^2 + MT_dense HA = solve_mat(llt, H, A_trans, obj_val); + MT_dense aiHai = HA.transpose().cwiseProduct(A); + sigma = (aiHai.rowwise().sum()).cwiseProduct(s_sq); + } + if constexpr (BarrierType == EllipsoidType::LOG_BARRIER) { grad.noalias() = A_trans * s; } else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER) { - // Computing sigma(x)_i = (a_i^T H^{-1} a_i) / (b_i - a_i^Tx)^2 - MT_dense HA = solve_mat(llt, H, A_trans, obj_val); - MT_dense aiHai = HA.transpose().cwiseProduct(A); - VT sigma = (aiHai.rowwise().sum()).cwiseProduct(s_sq); // Gradient of the volumetric barrier function grad.noalias() = A_trans * (s.cwiseProduct(sigma)); // Hessian of the volumetric barrier function update_Atrans_Diag_A(H, A_trans, A, s_sq.cwiseProduct(sigma).asDiagonal()); + } else if constexpr (BarrierType == EllipsoidType::VAIDYA_BARRIER) + { + const int m = b.size(), d = x.size(); + // Weighted gradient of the log barrier function + grad.noalias() = A_trans * s; + grad *= NT(d) / NT(m); + // Add the gradient of the volumetric function + grad.noalias() += A_trans * (s.cwiseProduct(sigma)); + // Weighted Hessian of the log barrier function + H *= NT(d) / NT(m); + // Add the Hessian of the volumetric function + MT Hvol(d, d); + update_Atrans_Diag_A(Hvol, A_trans, A, s_sq.cwiseProduct(sigma).asDiagonal()); + H += Hvol; + obj_val -= s.array().log().sum(); } else { static_assert(AssertBarrierFalseType::value, "Barrier type is not supported."); @@ -393,6 +417,9 @@ void get_step_next_iteration(NT const obj_val_prev, NT const obj_val, } else if constexpr (BarrierType == EllipsoidType::VOLUMETRIC_BARRIER) { step_iter *= (obj_val_prev <= obj_val - tol_obj) ? NT(0.9) : NT(0.999); + } else if constexpr (BarrierType == EllipsoidType::VAIDYA_BARRIER) + { + step_iter *= NT(0.999); } else { static_assert(AssertBarrierFalseType::value, "Barrier type is not supported.");