Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Multi-Task ElasticNet support #238

Merged
merged 47 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
81d0f1e
added block coordinate descent function
PABannier Jan 19, 2022
7e8408e
added duality_gap_mtl computation
PABannier Jan 19, 2022
1b69b8e
ENH cd pass to be consistent with bcd
PABannier Jan 19, 2022
e5b76e7
added prox operator for MTL Enet
PABannier Jan 19, 2022
196a2a5
added helper functions for tests
PABannier Jan 19, 2022
dc91b62
fix failing CD tests
PABannier Jan 19, 2022
3bb5333
working ent mtl penalties
PABannier Jan 19, 2022
f9f9959
bcd lower objective test pass
PABannier Jan 19, 2022
e23d876
added MultiTaskEnet struct
PABannier Jan 19, 2022
f316c95
added MTENET documentation
PABannier Jan 19, 2022
1ae6522
added API MTENET
PABannier Jan 19, 2022
f13a150
added variance, z-score, conf interval for multitask ENET
PABannier Jan 19, 2022
b347963
added multi-task estimators
PABannier Jan 22, 2022
80d9a02
added tests for MTL
PABannier Jan 22, 2022
3d981ef
pass comments
PABannier Jan 23, 2022
36f4b70
CLN files
PABannier Jan 23, 2022
15b280b
cleaner implementation
PABannier Jan 23, 2022
d81a5cd
cln tests
PABannier Jan 23, 2022
a21bc1a
changed map into fold
PABannier Feb 19, 2022
7f32afc
added tests for Enet and MTL
PABannier Feb 19, 2022
6c56e06
added incorrect target shape
PABannier Feb 19, 2022
a07fb39
WIP: made variance params generic over the number of tasks
PABannier Feb 19, 2022
75023d9
added z_score and confidence_95th for MTL
PABannier Feb 20, 2022
ba3a574
map instead of fold
PABannier Feb 20, 2022
114fc03
fix confidence interval and z-score
PABannier Feb 20, 2022
30744c0
converted back fold to map
PABannier Feb 21, 2022
3b5f37d
pass comments
PABannier Mar 3, 2022
c7de3e2
Merge branch 'master' of https://github.com/PABannier/linfa into mtl_…
PABannier Mar 20, 2022
bc1ad06
WIP make compute_variance generic over the dimension
PABannier Mar 20, 2022
b09e37b
Fix compiler errors
YuhanLiin Aug 13, 2022
a3cc0b4
Merge branch 'master' into mtl_elastic_net
YuhanLiin Aug 14, 2022
b7667d6
Fix multi enet tests
YuhanLiin Aug 14, 2022
d992c2c
Make compute_intercept generic
YuhanLiin Aug 14, 2022
d707040
Revert "Make compute_intercept generic"
YuhanLiin Aug 14, 2022
fa5ee8c
Replace for loops in block_coordinate_descent with general_mat_mul calls
YuhanLiin Aug 14, 2022
693cb85
Bring back generic compute_intercept
YuhanLiin Aug 14, 2022
1d2a68d
Replace manual norm calculations with norm trait calls
YuhanLiin Aug 14, 2022
8a0779d
Add docs and derives to multi task types
YuhanLiin Aug 14, 2022
6f6f026
Add example for multitask_elasticnet
YuhanLiin Aug 14, 2022
4cfbbd5
Remove bad derives
YuhanLiin Aug 14, 2022
5b125f9
Address review comments
YuhanLiin Aug 27, 2022
7e7b63c
Merge branch 'master' into mtl_elastic_net
YuhanLiin Aug 27, 2022
9f7d40a
Fix CI issues
YuhanLiin Aug 27, 2022
66895b6
Rename shape() calls to nrows and ncols
YuhanLiin Sep 2, 2022
f70f0b7
Merge branch 'master' into mtl_elastic_net
YuhanLiin Nov 12, 2022
c18c4db
Fix multitask elasticnet example
YuhanLiin Nov 12, 2022
9551892
Fix docs
YuhanLiin Nov 12, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion algorithms/linfa-elasticnet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ thiserror = "1.0"
linfa = { version = "0.6.0", path = "../.." }

[dev-dependencies]
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes"] }
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes", "linnerud"] }
ndarray-rand = "0.14"
rand_xoshiro = "0.6"
2 changes: 1 addition & 1 deletion algorithms/linfa-elasticnet/examples/elasticnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ fn main() -> Result<()> {
// load Diabetes dataset
let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90);

// train pure LASSO model with 0.1 penalty
// train pure LASSO model with 0.3 penalty
let model = ElasticNet::params()
.penalty(0.3)
.l1_ratio(1.0)
Expand Down
24 changes: 24 additions & 0 deletions algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use linfa::prelude::*;
use linfa_elasticnet::{MultiTaskElasticNet, Result};

fn main() -> Result<()> {
// load Diabetes dataset
let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.90);

// train pure LASSO model with 0.1 penalty
let model = MultiTaskElasticNet::params()
.penalty(0.1)
.l1_ratio(1.0)
.fit(&train)?;

println!("intercept: {}", model.intercept());
println!("params: {}", model.hyperplane());

println!("z score: {:?}", model.z_score());

// validate
let y_est = model.predict(&valid);
println!("predicted variance: {}", valid.r2(&y_est)?);

Ok(())
}
Loading