Skip to content

Commit

Permalink
Adding Multi-Task ElasticNet support (#238)
Browse files Browse the repository at this point in the history
* added block coordinate descent function

* added duality_gap_mtl computation

* ENH cd pass to be consistent with bcd

* added prox operator for MTL Enet

* added helper functions for tests

* working ent mtl penalties

* bcd lower objective test pass

* added MultiTaskEnet struct

* added MTENET documentation

* added API MTENET

* added variance, z-score, conf interval for multitask ENET

* added multi-task estimators

* added tests for MTL

* added tests for Enet and MTL

* WIP: made variance params generic over the number of tasks

* added z_score and confidence_95th for MTL

* WIP make compute_variance generic over the dimension

* Replace for loops in block_coordinate_descent with general_mat_mul calls

* Bring back generic compute_intercept

* Replace manual norm calculations with norm trait calls

* Add docs and derives to multi task types

* Add example for multitask_elasticnet

* Rename shape() calls to nrows and ncols

Co-authored-by: Pierre-Antoine Bannier <[email protected]>
  • Loading branch information
YuhanLiin and PABannier authored Nov 12, 2022
1 parent 44b244c commit 21357e2
Show file tree
Hide file tree
Showing 7 changed files with 754 additions and 67 deletions.
2 changes: 1 addition & 1 deletion algorithms/linfa-elasticnet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ thiserror = "1.0"
linfa = { version = "0.6.0", path = "../.." }

[dev-dependencies]
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes"] }
linfa-datasets = { version = "0.6.0", path = "../../datasets", features = ["diabetes", "linnerud"] }
ndarray-rand = "0.14"
rand_xoshiro = "0.6"
2 changes: 1 addition & 1 deletion algorithms/linfa-elasticnet/examples/elasticnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ fn main() -> Result<()> {
// load Diabetes dataset
let (train, valid) = linfa_datasets::diabetes().split_with_ratio(0.90);

// train pure LASSO model with 0.1 penalty
// train pure LASSO model with 0.3 penalty
let model = ElasticNet::params()
.penalty(0.3)
.l1_ratio(1.0)
Expand Down
24 changes: 24 additions & 0 deletions algorithms/linfa-elasticnet/examples/multitask_elasticnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use linfa::prelude::*;
use linfa_elasticnet::{MultiTaskElasticNet, Result};

fn main() -> Result<()> {
// load Diabetes dataset
let (train, valid) = linfa_datasets::linnerud().split_with_ratio(0.80);

// train pure LASSO model with 0.1 penalty
let model = MultiTaskElasticNet::params()
.penalty(0.1)
.l1_ratio(1.0)
.fit(&train)?;

println!("intercept: {}", model.intercept());
println!("params: {}", model.hyperplane());

println!("z score: {:?}", model.z_score());

// validate
let y_est = model.predict(&valid);
println!("predicted variance: {}", y_est.r2(&valid)?);

Ok(())
}
Loading

0 comments on commit 21357e2

Please sign in to comment.