Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP isolate convergence warning in LogReg #225

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ celer.egg-info
# build
build

# virtual environment
venv
mathurinm marked this conversation as resolved.
Show resolved Hide resolved

# cache
.pytest_cache
Expand Down
1 change: 1 addition & 0 deletions celer/homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def celer_path(X, y, pb, eps=1e-3, n_alphas=100, alphas=None,
n_iters[t] = len(sol[2])

results = alphas, coefs, dual_gaps
# results = alphas, coefs, sol[2] # for isolate_conv_warn
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
if return_thetas:
results += (thetas,)
if return_n_iter:
Expand Down
25 changes: 25 additions & 0 deletions celer/tests/conv_warning/inspect_conv_warn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""For debbuging purposes."""

import numpy as np
from numpy.linalg import norm
from sklearn.utils.estimator_checks import check_estimator

from celer.dropin_sklearn import LogisticRegression
from celer.utils.testing import build_dataset


np.random.seed(1409)
X, y = build_dataset(
n_samples=30, n_features=60, sparse_X=True)
y = np.sign(y)
alpha_max = norm(X.T.dot(y), ord=np.inf) / 2
C = 20. / alpha_max

tol = 1e-4
clf1 = LogisticRegression(C=C, tol=tol, verbose=0)

generator = check_estimator(clf1, generate_only=True)
generator = list(generator)

for i, (estimator, check_estimator) in enumerate(generator[37:]):
check_estimator(estimator)
66 changes: 66 additions & 0 deletions celer/tests/conv_warning/isolate_conv_warn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
from numpy.linalg import norm
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

from celer.dropin_sklearn import LogisticRegression
from celer.tests.conv_warning.logs.dumped_data import DICT_DATA

import pickle


# to be excuted only once
# before, uncomment results in homotopy.py
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
def simulate_LogReg_mul_alphas():
check_props = DICT_DATA['check_fit_check_is_fitted']
enc = LabelEncoder()

X = check_props['X']
y = check_props['y']
y_ind = enc.fit_transform(y)

C = check_props['C']
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
alpha_max = norm(X.T.dot(y_ind), ord=np.inf)
C_max = 1 / alpha_max
tol = 1e-14

arr_C = np.linspace(C_max, C, num=5)[1:] # skip first value

dict_gaps = {}
for current_C in arr_C:
clf = LogisticRegression(C=C, tol=tol, max_iter=100, verbose=0)
_, gaps = clf.path(
X, 2 * y_ind - 1, np.array([current_C]), solver="celer-pn")

current_alpha = 1 / current_C
plot_name = f'{current_alpha/alpha_max:.2e}'
dict_gaps[plot_name] = gaps
# save logs
filename = 'celer/tests/conv_warning/logs/gaps_for_mul_alphas.pkl'
with open(filename, 'wb') as f:
pickle.dump(dict_gaps, f)


# # simulate
# simulate_LogReg_mul_alphas()

# load
filename = 'celer/tests/conv_warning/logs/gaps_for_mul_alphas.pkl'
with open(filename, 'rb') as f:
dict_gaps = pickle.load(f)

# plot
fig, ax = plt.subplots()

for plot_name, gaps in dict_gaps.items():
ax.semilogy(gaps, label=plot_name, marker='.')

# set layout
plt.title("LogReg for different alpha")
ax.set_xlabel("iterations")
ax.set_ylabel("dual gap")

plt.grid()
plt.legend(title="Fraction of alpha_max")

plt.show()
204 changes: 204 additions & 0 deletions celer/tests/conv_warning/logs/dumped_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import numpy as np


DICT_DATA = {}

# check_fit_idempotent
tol = 1e-4
C = 5.3648013574915625
X = np.array([[99.30543214, 99.85036546],
[98.82687659, 101.94362119],
[99.260437, 101.5430146],
[99.5444675, 100.01747916],
[99.49034782, 99.5619257],
[99.73199663, 100.8024564],
[100.3130677, 99.14590426],
[101.18802979, 100.31694261],
[98.96575716, 100.68159452],
[101.8831507, 98.65224094],
[99.59682305, 101.22244507],
[99.32566734, 100.03183056],
[100.04575852, 99.81281615],
[101.48825219, 101.89588918],
[100.15494743, 100.37816252],
[98.95144703, 98.57998206],
[101.86755799, 99.02272212],
[100.72909056, 100.12898291],
[99.65208785, 100.15634897],
[99.97181777, 100.42833187],
[99.36567791, 99.63725883],
[99.89678115, 100.4105985],
[100.46566244, 98.46375631],
[99.25524518, 99.17356146],
[100.92085882, 100.31872765],
[100.12691209, 100.40198936],
[101.53277921, 101.46935877],
[101.12663592, 98.92006849],
[99.61267318, 99.69769725],
[100.17742614, 99.59821906],
[98.74720464, 100.77749036],
[99.09270164, 100.0519454],
[99.13877431, 101.91006495],
[97.44701018, 100.6536186],
[99.10453344, 100.3869025],
[99.12920285, 99.42115034],
[99.90154748, 99.33652171],
[98.99978465, 98.4552289],
[98.85253135, 99.56217996],
[100.20827498, 100.97663904],
[99.96071718, 98.8319065],
[101.49407907, 99.79484174],
[99.50196755, 101.92953205],
[100.40234164, 99.31518991],
[98.29372981, 101.9507754],
[100.1666735, 100.63503144],
[99.32753955, 99.64044684],
[100.14404357, 101.45427351],
[99.3563816, 97.77659685],
[98.50874241, 100.4393917],
[99.36415392, 100.67643329],
[102.38314477, 100.94447949],
[100.44386323, 100.33367433],
[100.94942081, 100.08755124],
[100.85683061, 99.34897441],
[99.58638102, 99.25254519],
[99.08717777, 101.11701629],
[98.729515, 100.96939671],
[99.48919486, 98.81936782],
[100.57659082, 99.79170124],
[100.06651722, 100.3024719],
[98.36980165, 100.46278226],
[100.37642553, 98.90059921],
[98.83485016, 100.90082649],
[98.70714309, 100.26705087],
[100.76103773, 100.12167502],
[100.39600671, 98.90693849],
[101.13940068, 98.76517418],
[98.89561666, 100.05216508],
[98.38610215, 99.78725972],
[99.68844747, 100.05616534],
[101.76405235, 100.40015721],
[100.97873798, 102.2408932],
[100.94725197, 99.84498991],
[102.26975462, 98.54563433],
[100.95008842, 99.84864279],
[101.17877957, 99.82007516],
[100.52327666, 99.82845367],
[100.77179055, 100.82350415],
[98.68409259, 99.5384154]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0,
1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0,
1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0])

DICT_DATA['check_fit_idempotent'] = {'C': C, 'X': X, 'y': y, 'tol': tol}

# check_fit_check_is_fitted
C = 5.3648013574915625
X = np.array([[100.49671415, 99.8617357],
[100.64768854, 101.52302986],
[99.76584663, 99.76586304],
[101.57921282, 100.76743473],
[99.53052561, 100.54256004],
[99.53658231, 99.53427025],
[100.24196227, 98.08671976],
[98.27508217, 99.43771247],
[98.98716888, 100.31424733],
[99.09197592, 98.5876963],
[101.46564877, 99.7742237],
[100.0675282, 98.57525181],
[99.45561728, 100.11092259],
[98.84900642, 100.37569802],
[99.39936131, 99.70830625],
[99.39829339, 101.85227818],
[99.98650278, 98.94228907],
[100.82254491, 98.77915635],
[100.2088636, 98.04032988],
[98.67181395, 100.19686124],
[100.73846658, 100.17136828],
[99.88435172, 99.6988963],
[98.52147801, 99.28015579],
[99.53936123, 101.05712223],
[100.34361829, 98.23695984],
[100.32408397, 99.61491772],
[99.323078, 100.61167629],
[101.03099952, 100.93128012],
[99.16078248, 99.69078762],
[100.33126343, 100.97554513],
[99.52082576, 99.81434102],
[98.89366503, 98.80379338],
[100.81252582, 101.35624003],
[99.92798988, 101.0035329],
[100.36163603, 99.35488025],
[100.36139561, 101.53803657],
[99.96417396, 101.56464366],
[97.3802549, 100.8219025],
[100.08704707, 99.70099265],
[100.09176078, 98.01243109],
[99.78032811, 100.35711257],
[101.47789404, 99.48172978],
[99.1915064, 99.49824296],
[100.91540212, 100.32875111],
[99.4702398, 100.51326743],
[100.09707755, 100.96864499],
[99.29794691, 99.67233785],
[99.60789185, 98.53648505],
[100.29612028, 100.26105527],
[100.00511346, 99.76541287],
[98.58462926, 99.57935468],
[99.65728548, 99.19772273],
[99.83871429, 100.40405086],
[101.8861859, 100.17457781],
[100.25755039, 99.92555408],
[98.08122878, 99.97348612],
[100.06023021, 102.46324211],
[99.80763904, 100.30154734],
[99.96528823, 98.83132196],
[101.14282281, 100.75193303],
[100.79103195, 99.09061255],
[101.40279431, 98.59814894],
[100.58685709, 102.19045563],
[99.00946367, 99.43370227],
[100.09965137, 99.49652435],
[98.44933657, 100.06856297],
[98.93769629, 100.47359243],
[99.08057577, 101.54993441],
[99.21674671, 99.67793848],
[100.81351722, 98.76913568],
[100.22745993, 101.30714275],
[98.39251677, 100.18463386],
[100.25988279, 100.78182287],
[98.76304929, 98.67954339],
[100.52194157, 100.29698467],
[100.25049285, 100.34644821],
[99.31997528, 100.2322537],
[100.29307247, 99.28564858],
[101.86577451, 100.47383292],
[98.8086965, 100.65655361],
[99.02531833, 100.7870846],
[101.15859558, 99.17931768],
[100.96337613, 100.41278093],
[100.82206016, 101.89679298],
[99.75461188, 99.24626384],
[99.11048557, 99.18418972],
[99.92289829, 100.34115197],
[100.2766908, 100.82718325],
[100.01300189, 101.45353408],
[99.73534317, 102.72016917],
[100.62566735, 99.14284244],
[98.9291075, 100.48247242],
[99.77653721, 100.71400049],
[100.47323762, 99.92717109],
[99.15320628, 98.48515278],
[99.55348505, 100.85639879],
[100.21409374, 98.75426122],
[100.17318093, 100.38531738],
[99.11614256, 100.15372511],
[100.05820872, 98.8570297]])
y = np.array([1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0,
0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0,
0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0,
0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1])

DICT_DATA['check_fit_check_is_fitted'] = {'C': C, 'X': X, 'y': y, 'tol': tol}
Binary file not shown.