Skip to content

Commit

Permalink
Merge pull request #47 from MatthewSZhang/minibatch-doc
Browse files Browse the repository at this point in the history
DOC debug increase n_random
  • Loading branch information
MatthewSZhang authored Feb 21, 2025
2 parents 392f028 + 01ddab3 commit a633884
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/plot_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sklearn.linear_model import LogisticRegression

data, labels = load_iris(return_X_y=True)
baseline_lr = LogisticRegression(max_iter=110).fit(data, labels)
baseline_lr = LogisticRegression(max_iter=1000).fit(data, labels)

# %%
# Random data pruning
Expand All @@ -40,7 +40,7 @@
def _random_pruning(X, y, n_samples_to_select: int, random_state: int):
rng = np.random.default_rng(random_state)
ids_random = rng.choice(y.size, n_samples_to_select, replace=False)
pruned_lr = LogisticRegression(max_iter=110).fit(X[ids_random], y[ids_random])
pruned_lr = LogisticRegression(max_iter=1000).fit(X[ids_random], y[ids_random])
return pruned_lr.coef_, pruned_lr.intercept_


Expand Down Expand Up @@ -72,9 +72,9 @@ def _fastcan_pruning(
).fit(X)
atoms = kmeans.cluster_centers_
ids_fastcan = minibatch(
X.T, atoms.T, n_samples_to_select, batch_size=batch_size, tol=1e-9, verbose=0
X.T, atoms.T, n_samples_to_select, batch_size=batch_size, verbose=0
)
pruned_lr = LogisticRegression(max_iter=110).fit(X[ids_fastcan], y[ids_fastcan])
pruned_lr = LogisticRegression(max_iter=1000).fit(X[ids_fastcan], y[ids_fastcan])
print(atoms[-1], ids_fastcan[-10:])
return pruned_lr.coef_, pruned_lr.intercept_

Expand Down Expand Up @@ -112,4 +112,4 @@ def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int):
plt.show()


plot_box(data, labels, baseline_lr, n_samples_to_select=100, n_random=10)
plot_box(data, labels, baseline_lr, n_samples_to_select=100, n_random=100)

0 comments on commit a633884

Please sign in to comment.