From 81f2493e6aa497c2546498471f84ce773d60448b Mon Sep 17 00:00:00 2001 From: Badr MOUFAD Date: Sun, 24 Apr 2022 12:03:14 +0200 Subject: [PATCH] add ``ConvergenceWarning`` in ``do_line_search`` --- celer/PN_logreg.pyx | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/celer/PN_logreg.pyx b/celer/PN_logreg.pyx index 646a33c2..83473f98 100644 --- a/celer/PN_logreg.pyx +++ b/celer/PN_logreg.pyx @@ -400,11 +400,12 @@ cpdef void do_line_search( floating[::1, :] X, floating[:] X_data, int[:] X_indices, int[:] X_indptr, int MAX_BACKTRACK_ITR, floating[:] y, floating[:] exp_Xw, floating[:] low_exp_Xw, - floating[:] aux, int[:] is_positive_label) nogil: + floating[:] aux, int[:] is_positive_label): cdef int i, ind, backtrack_itr cdef floating deriv cdef floating step_size = 1. + cdef floating atol = 1e-7 cdef int n_samples = y.shape[0] fcopy(&n_samples, &exp_Xw[0], &inc, &low_exp_Xw[0], &inc) @@ -417,15 +418,18 @@ cpdef void do_line_search( deriv = compute_derivative( w, WS, delta_w, X_delta_w, alpha, aux, step_size, y) - if deriv < 1e-7: + if deriv < atol: break else: step_size = step_size / 2. for i in range(n_samples): exp_Xw[i] = sqrt(exp_Xw[i] * low_exp_Xw[i]) else: - pass - # TODO what do we do in this case? + warnings.warn( + 'Line search failed to converge ' + f'deriv {deriv:.2e}, atol {atol:.2e}', + ConvergenceWarning + ) # a suitable step size is found, perform step: for ind in range(WS.shape[0]):