diff --git a/algorithms/linfa-elasticnet/src/algorithm.rs b/algorithms/linfa-elasticnet/src/algorithm.rs index 85a16f5e6..ee759144c 100644 --- a/algorithms/linfa-elasticnet/src/algorithm.rs +++ b/algorithms/linfa-elasticnet/src/algorithm.rs @@ -1,12 +1,19 @@ use approx::{abs_diff_eq, abs_diff_ne}; -use linfa::dataset::AsSingleTargets; -use ndarray::{s, Array1, ArrayBase, ArrayView1, ArrayView2, Axis, CowArray, Data, Ix1, Ix2}; +use ndarray::{ + s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, CowArray, Data, Dimension, Ix1, Ix2, +}; use ndarray_linalg::{Inverse, Lapack}; use linfa::traits::{Fit, PredictInplace}; -use linfa::{dataset::Records, DatasetBase, Float}; +use linfa::{ + dataset::{AsMultiTargets, AsSingleTargets, AsTargets, Records}, + DatasetBase, Float, +}; -use super::{hyperparams::ElasticNetValidParams, ElasticNet, ElasticNetError, Result}; +use super::{ + hyperparams::{ElasticNetValidParams, MultiTaskElasticNetValidParams}, + ElasticNet, ElasticNetError, MultiTaskElasticNet, Result, +}; impl Fit, T, ElasticNetError> for ElasticNetValidParams where @@ -54,6 +61,63 @@ where } } +impl Fit, T, ElasticNetError> for MultiTaskElasticNetValidParams +where + F: Float + Lapack, + T: AsMultiTargets, + D: Data, +{ + type Object = MultiTaskElasticNet; + + /// Fit a multi-task Elastic Net model given a feature matrix `x` and a target + /// matrix `y`. + /// + /// The feature matrix `x` must have shape `(n_samples, n_features)` + /// + /// The target variable `y` must have shape `(n_samples, n_tasks)` + /// + /// Returns a `FittedMultiTaskElasticNet` object which contains the fitted + /// parameters and can be used to `predict` values of the target variables + /// for new feature values. + fn fit(&self, dataset: &DatasetBase, T>) -> Result { + let targets = dataset.targets().as_multi_targets(); + let nsamples = dataset.nsamples(); + let ntasks = targets.ncols(); + + let mut intercept = Array1::::zeros(ntasks); + let mut y = Array2::::zeros((nsamples, ntasks)); + + for t in 0..ntasks { + let (intercept_t, y_t) = + compute_intercept(self.with_intercept(), targets.slice(s![.., t])); + intercept[t] = intercept_t; + y.slice_mut(s![.., t]).assign(&y_t); + } + + let (hyperplane, duality_gap, n_steps) = block_coordinate_descent( + dataset.records().view(), + y.view(), + self.tolerance(), + self.max_iterations(), + self.l1_ratio(), + self.penalty(), + ); + + let y_est = dataset.records().dot(&hyperplane) + &intercept; + + // try to calculate the variance + let variance = variance_params(dataset, y_est); + + Ok(MultiTaskElasticNet { + hyperplane, + intercept, + duality_gap, + n_steps, + variance, + }) + } +} + impl> PredictInplace, Array1> for ElasticNet { /// Given an input matrix `X`, with shape `(n_samples, n_features)`, /// `predict` returns the target variable according to elastic net @@ -73,6 +137,28 @@ impl> PredictInplace, Array1> f } } +impl> PredictInplace, Array2> + for MultiTaskElasticNet +{ + /// Given an input matrix `X`, with shape `(n_samples, n_features)`, + /// `predict` returns the target variable according to elastic net + /// learned from the training data distribution. + fn predict_inplace(&self, x: &ArrayBase, y: &mut Array2) { + assert_eq!( + x.nrows(), + y.nrows(), + "The number of data points must match the number of output targets." + ); + + *y = x.dot(&self.hyperplane) + &self.intercept; + } + + fn default_target(&self, x: &ArrayBase) -> Array2 { + // TODO: fix, should be (x.nrows(), y.ncols()) + Array2::zeros((x.nrows(), x.nrows())) + } +} + /// View the fitted parameters and make predictions with a fitted /// elastic net model impl ElasticNet { @@ -128,6 +214,58 @@ impl ElasticNet { } } +/// View the fitted parameters and make predictions with a fitted +/// elastic net model +impl MultiTaskElasticNet { + /// Get the fitted hyperplane + pub fn hyperplane(&self) -> &Array2 { + &self.hyperplane + } + + /// Get the fitted intercept, [0., ..., 0.] if no intercept was fitted + /// Note that there are as many intercepts as tasks + pub fn intercept(&self) -> &Array1 { + &self.intercept + } + + /// Get the number of steps taken in optimization algorithm + pub fn n_steps(&self) -> u32 { + self.n_steps + } + + /// Get the duality gap at the end of the optimization algorithm + pub fn duality_gap(&self) -> F { + self.duality_gap + } + + /// Calculate the Z score + pub fn z_score(&self) -> Result> { + self.variance + .as_ref() + .map(|variance| { + ndarray::Zip::from(&self.hyperplane) + .and_broadcast(variance) + .map_collect(|a, b| *a / b.sqrt()) + }) + .map_err(|err| err.clone()) + } + + /// Calculate the confidence level + pub fn confidence_95th(&self) -> Result> { + // the 95th percentile of our confidence level + let p = F::cast(1.645); + + self.variance + .as_ref() + .map(|variance| { + ndarray::Zip::from(&self.hyperplane) + .and_broadcast(variance) + .map_collect(|a, b| (*a - p * b.sqrt(), *a + p * b.sqrt())) + }) + .map_err(|err| err.clone()) + } +} + fn coordinate_descent<'a, F: Float>( x: ArrayView2<'a, F>, y: ArrayView1<'a, F>, @@ -151,25 +289,24 @@ fn coordinate_descent<'a, F: Float>( while n_steps < max_steps { let mut w_max = F::zero(); let mut d_w_max = F::zero(); - for ii in 0..n_features { - if abs_diff_eq!(norm_cols_x[ii], F::zero()) { + for j in 0..n_features { + if abs_diff_eq!(norm_cols_x[j], F::zero()) { continue; } - let w_ii = w[ii]; - let x_slc: ArrayView1 = x.slice(s![.., ii]); - if abs_diff_ne!(w_ii, F::zero()) { - // FIXME: direct addition with loop might be faster as it does not have to allocate - r += &(&x_slc * w_ii); + let old_w_j = w[j]; + let x_j: ArrayView1 = x.slice(s![.., j]); + if abs_diff_ne!(old_w_j, F::zero()) { + r.scaled_add(old_w_j, &x_j); } - let tmp: F = x_slc.dot(&r); - w[ii] = tmp.signum() * F::max(tmp.abs() - n_samples * l1_ratio * penalty, F::zero()) - / (norm_cols_x[ii] + n_samples * (F::one() - l1_ratio) * penalty); - if abs_diff_ne!(w[ii], F::zero()) { - r -= &(&x_slc * w[ii]); + let tmp: F = x_j.dot(&r); + w[j] = tmp.signum() * F::max(tmp.abs() - n_samples * l1_ratio * penalty, F::zero()) + / (norm_cols_x[j] + n_samples * (F::one() - l1_ratio) * penalty); + if abs_diff_ne!(w[j], F::zero()) { + r.scaled_add(-w[j], &x_j); } - let d_w_ii = (w[ii] - w_ii).abs(); - d_w_max = F::max(d_w_max, d_w_ii); - w_max = F::max(w_max, w[ii].abs()); + let d_w_j = (w[j] - old_w_j).abs(); + d_w_max = F::max(d_w_max, d_w_j); + w_max = F::max(w_max, w[j].abs()); } n_steps += 1; @@ -185,6 +322,83 @@ fn coordinate_descent<'a, F: Float>( (w, gap, n_steps) } +fn block_coordinate_descent<'a, F: Float>( + x: ArrayView2<'a, F>, + y: ArrayView2<'a, F>, + tol: F, + max_steps: u32, + l1_ratio: F, + penalty: F, +) -> (Array2, F, u32) { + let n_samples = F::cast(x.shape()[0]); + let n_features = x.shape()[1]; + let n_tasks = y.shape()[1]; + // the parameters of the model + let mut w = Array2::::zeros((n_features, n_tasks)); + // the residuals: `Y - XW` (since W=0, this is just `Y` for now), + // the residuals are updated during the algorithm as the parameters change + let mut r = y.to_owned(); + let mut n_steps = 0u32; + let norm_cols_x = x.map_axis(Axis(0), |col| col.dot(&col)); + let mut gap = F::one() + tol; + let d_w_tol = tol; + let tol = tol * y.iter().map(|&y_ij| y_ij * y_ij).sum(); + while n_steps < max_steps { + let mut w_max = F::zero(); + let mut d_w_max = F::zero(); + for j in 0..n_features { + if abs_diff_eq!(norm_cols_x[j], F::zero()) { + continue; + } + let old_w_j: ArrayView1 = w.slice(s![j, ..]); + let x_j: ArrayView1 = x.slice(s![.., j]); + let norm_old_w_j = old_w_j.dot(&old_w_j).sqrt(); + if abs_diff_ne!(norm_old_w_j, F::zero()) { + for i in 0..x.shape()[0] { + r.slice_mut(s![i, ..]).scaled_add(x_j[i], &old_w_j); + } + } + let tmp = x_j.dot(&r); + w.slice_mut(s![j, ..]).assign( + &(block_soft_thresholding(tmp.view(), n_samples * l1_ratio * penalty) + / (norm_cols_x[j] + n_samples * (F::one() - l1_ratio) * penalty)), + ); + let norm_w_j = w.slice(s![j, ..]).dot(&w.slice(s![j, ..])).sqrt(); + if abs_diff_ne!(norm_w_j, F::zero()) { + for i in 0..x.shape()[0] { + for t in 0..n_tasks { + r[[i, t]] -= x_j[i] * w[[j, t]]; + } + } + } + let d_w_j = (norm_w_j - norm_old_w_j).abs(); + d_w_max = F::max(d_w_max, d_w_j); + w_max = F::max(w_max, norm_w_j); + } + n_steps += 1; + + if n_steps == max_steps - 1 || abs_diff_eq!(w_max, F::zero()) || d_w_max / w_max < d_w_tol { + // We've hit one potential stopping criteria + // check duality gap for ultimate stopping criterion + gap = duality_gap_mtl(x.view(), y.view(), w.view(), r.view(), l1_ratio, penalty); + if gap < tol { + break; + } + } + } + + (w, gap, n_steps) +} + +fn block_soft_thresholding<'a, F: Float>(x: ArrayView1<'a, F>, threshold: F) -> Array1 { + let norm_x = x.dot(&x).sqrt(); + if norm_x < threshold { + return Array1::::zeros(x.len()); + } + let scale = F::one() - threshold / norm_x; + &x * scale +} + fn duality_gap<'a, F: Float>( x: ArrayView2<'a, F>, y: ArrayView1<'a, F>, @@ -209,28 +423,73 @@ fn duality_gap<'a, F: Float>( } else { (F::one(), r_norm2) }; - let l1_norm = w.fold(F::zero(), |sum, w_i| sum + w_i.abs()); + let l1_norm = w.iter().map(|w_i| w_i.abs()).sum(); gap += l1_reg * l1_norm - const_ * r.dot(&y) + half * l2_reg * (F::one() + const_ * const_) * w_norm2; gap } -fn variance_params, D: Data>( +fn duality_gap_mtl<'a, F: Float>( + x: ArrayView2<'a, F>, + y: ArrayView2<'a, F>, + w: ArrayView2<'a, F>, + r: ArrayView2<'a, F>, + l1_ratio: F, + penalty: F, +) -> F { + let half = F::cast(0.5); + let n_samples = F::cast(x.shape()[0]); + let l1_reg = l1_ratio * penalty * n_samples; + let l2_reg = (F::one() - l1_ratio) * penalty * n_samples; + let xta = x.t().dot(&r) - &w * l2_reg; + + let dual_norm_xta = xta + .map_axis(Axis(1), |x| x.dot(&x).sqrt()) + .fold(F::zero(), |max_norm, &nrm| max_norm.max(nrm)); + let r_norm2 = r.iter().map(|rij| rij.powi(2)).sum(); + let w_norm2 = w.iter().map(|wij| wij.powi(2)).sum(); + let (const_, mut gap) = if dual_norm_xta > l1_reg { + let const_ = l1_reg / dual_norm_xta; + let a_norm2 = r_norm2 * const_ * const_; + (const_, half * (r_norm2 + a_norm2)) + } else { + (F::one(), r_norm2) + }; + let rty = r.t().dot(&y); + let trace_rty = rty.diag().sum(); + let l21_norm = w.map_axis(Axis(1), |wj| (wj.dot(&wj)).sqrt()).sum(); + gap += l1_reg * l21_norm - const_ * trace_rty + + half * l2_reg * (F::one() + const_ * const_) * w_norm2; + gap +} + +fn variance_params, D: Data>( ds: &DatasetBase, T>, - y_est: Array1, + y_est: T, ) -> Result> { let nfeatures = ds.nfeatures(); let nsamples = ds.nsamples(); - // try to convert targets into a single target - let target = ds.as_single_targets(); + let target = ds.targets().as_targets(); + let ndim = target.ndim(); + + let ntasks: usize = match ndim { + 1 => 1, + 2 => *target.shape().last().unwrap(), + _ => { + return Err(ElasticNetError::IncorrectTargetShape); + } + }; + + let y_est = y_est.as_targets(); // check that we have enough samples if nsamples < nfeatures + 1 { return Err(ElasticNetError::NotEnoughSamples); } - let var_target = (&target - &y_est).mapv(|x| x * x).sum() / F::cast(nsamples - nfeatures); + let var_target = + (&target - &y_est).mapv(|x| x * x).sum() / F::cast(ntasks * (nsamples - nfeatures)); let inv_cov = ds.records().t().dot(ds.records()).inv(); @@ -257,9 +516,9 @@ pub fn compute_intercept( #[cfg(test)] mod tests { - use super::{coordinate_descent, ElasticNet}; + use super::{block_coordinate_descent, coordinate_descent, ElasticNet, MultiTaskElasticNet}; use approx::assert_abs_diff_eq; - use ndarray::{array, s, Array, Array1, Array2}; + use ndarray::{array, s, Array, Array1, Array2, Axis}; use ndarray_rand::rand::SeedableRng; use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; @@ -282,6 +541,17 @@ mod tests { squared_error(x, y, intercept, beta) + lambda * elastic_net_penalty(beta, alpha) } + fn elastic_net_multi_task_objective( + x: &Array2, + y: &Array2, + intercept: &Array1, + beta: &Array2, + alpha: f64, + lambda: f64, + ) -> f64 { + squared_error_mtl(x, y, intercept, beta) + lambda * elastic_net_mtl_penalty(beta, alpha) + } + fn squared_error(x: &Array2, y: &Array1, intercept: f64, beta: &Array1) -> f64 { let mut resid = -x.dot(beta); resid -= intercept; @@ -294,6 +564,20 @@ mod tests { result } + fn squared_error_mtl( + x: &Array2, + y: &Array2, + intercept: &Array1, + beta: &Array2, + ) -> f64 { + let mut resid = x.dot(beta); + resid = &resid * -1.; + resid = &resid - intercept + y; + let mut datafit = resid.iter().map(|rij| rij.powi(2)).sum(); + datafit /= 2.0 * x.shape()[0] as f64; + datafit + } + fn elastic_net_penalty(beta: &Array1, alpha: f64) -> f64 { let mut penalty = 0.0; for beta_j in beta { @@ -302,6 +586,14 @@ mod tests { penalty } + fn elastic_net_mtl_penalty(beta: &Array2, alpha: f64) -> f64 { + let frob_norm: f64 = beta.iter().map(|beta_ij| beta_ij.powi(2)).sum(); + let l21_norm = beta + .map_axis(Axis(1), |beta_j| (beta_j.dot(&beta_j)).sqrt()) + .sum(); + (1.0 - alpha) / 2.0 * frob_norm + alpha * l21_norm + } + #[test] fn elastic_net_penalty_works() { let beta = array![-2.0, 1.0]; @@ -319,6 +611,31 @@ mod tests { assert_abs_diff_eq!(elastic_net_penalty(&beta2, 0.0), 0.0); } + #[test] + fn elastic_net_mtl_penalty_works() { + let beta = array![[-2.0, 1.0, 3.0], [3.0, 1.5, -1.7]]; + assert_abs_diff_eq!( + elastic_net_mtl_penalty(&beta, 0.7), + 9.472383565516601, + epsilon = 1e-12 + ); + assert_abs_diff_eq!( + elastic_net_mtl_penalty(&beta, 1.0), + 7.501976522166574, + epsilon = 1e-12 + ); + assert_abs_diff_eq!( + elastic_net_mtl_penalty(&beta, 0.2), + 12.756395304433315, + epsilon = 1e-12 + ); + + let beta2 = array![[0., 0.], [0., 0.], [0., 0.]]; + assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 0.8), 0.0); + assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 1.2), 0.0); + assert_abs_diff_eq!(elastic_net_mtl_penalty(&beta2, 0.8), 0.0); + } + #[test] fn squared_error_works() { let x = array![[2.0, 1.0], [-1.0, 2.0]]; @@ -327,6 +644,26 @@ mod tests { assert_abs_diff_eq!(squared_error(&x, &y, 0.0, &beta), 0.25); } + #[test] + fn squared_error_mtl_works() { + let x = array![[1.2, 2.3], [-1.3, 0.3], [-1.3, 0.1]]; + let y = array![ + [0.2, 1.0, 0.0, 1.], + [-0.3, 0.7, 0.1, 2.], + [-0.3, 0.7, 2.3, 3.] + ]; + let beta = array![[2.3, 4.5, 1.2, -3.4], [1.2, -3.4, 0.7, -1.2]]; + assert_abs_diff_eq!( + squared_error_mtl(&x, &y, &array![0., 0., 0., 0.], &beta), + 41.66298333333333 + ); + let intercept = array![1., 3., 2., 0.3]; + assert_abs_diff_eq!( + squared_error_mtl(&x, &y, &intercept, &beta), + 29.059983333333335 + ); + } + #[test] fn coordinate_descent_lowers_objective() { let x = array![[1.0, 0.0], [0.0, 1.0]]; @@ -341,6 +678,22 @@ mod tests { assert!(objective_start > objective_end); } + #[test] + fn block_coordinate_descent_lowers_objective() { + let x = array![[1.0, 0., -0.3, 3.2], [0.3, 1.2, -0.6, 1.2]]; + let y = array![[0.3, -1.2, 0.7], [1.4, -3.2, 0.2]]; + let beta = array![[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]; + let intercept = array![0., 0., 0.]; + let alpha = 0.4; + let lambda = 0.002; + let objective_start = + elastic_net_multi_task_objective(&x, &y, &intercept, &beta, alpha, lambda); + let opt_result = block_coordinate_descent(x.view(), y.view(), 1e-4, 3, alpha, lambda); + let objective_end = + elastic_net_multi_task_objective(&x, &y, &intercept, &opt_result.0, alpha, lambda); + assert!(objective_start > objective_end); + } + #[test] fn lasso_zero_works() { let dataset = Dataset::from((array![[0.], [0.], [0.]], array![0., 0., 0.])); @@ -355,6 +708,20 @@ mod tests { assert_abs_diff_eq!(model.hyperplane(), &array![0.]); } + #[test] + fn mtl_lasso_zero_works() { + let dataset = Dataset::from((array![[0.], [0.], [0.]], array![[0.], [0.], [0.]])); + + let model = MultiTaskElasticNet::params() + .l1_ratio(1.0) + .penalty(0.1) + .fit(&dataset) + .unwrap(); + + assert_abs_diff_eq!(model.intercept(), &array![0.]); + assert_abs_diff_eq!(model.hyperplane(), &array![[0.]]); + } + #[test] fn lasso_toy_example_works() { // Test Lasso on a toy example for various values of alpha. @@ -389,6 +756,95 @@ mod tests { assert_abs_diff_eq!(model.duality_gap(), 0.0); } + #[test] + fn multitask_lasso_toy_example_works() { + // Test MultiTaskLasso on a toy example for various values of alpha. + // When validating this against sklearn notice that sklearn divides it + // against n_samples. + let dataset = Dataset::new( + array![[-1.0], [0.0], [1.0]], + array![[-1.0, 1.0], [0.0, -1.5], [1.0, 1.3]], + ); + + // no intercept fitting + let t = array![[2.0], [3.0], [4.0]]; + let model = MultiTaskElasticNet::lasso() + .with_intercept(false) + .penalty(0.01) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., 0.]); + assert_abs_diff_eq!( + model.hyperplane(), + &array![[0.9851659, 0.1477748]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!( + model.predict(&t), + array![ + [1.9703319, 0.2955497], + [2.9554978, 0.4433246], + [3.9406638, 0.5910995] + ], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0); + + // input for prediction + let t = array![[2.0], [3.0], [4.0]]; + let model = MultiTaskElasticNet::lasso() + .penalty(1e-8) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., -1.5]); + assert_abs_diff_eq!(model.hyperplane(), &array![[0., 2.65]], epsilon = 1e-6); + assert_abs_diff_eq!( + model.predict(&t), + array![[0., 3.79999998], [0., 6.44999996], [0., 9.09999995]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0); + + let model = MultiTaskElasticNet::lasso() + .penalty(0.1) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., -1.4]); + assert_abs_diff_eq!(model.hyperplane(), &array![[0., 2.5]], epsilon = 1e-6); + assert_abs_diff_eq!( + model.predict(&t), + &array![[0., 3.6], [0., 6.1], [0., 8.6]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0); + + let model = MultiTaskElasticNet::lasso() + .penalty(0.5) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., -1.]); + assert_abs_diff_eq!(model.hyperplane(), &array![[0., 1.9]], epsilon = 1e-6); + assert_abs_diff_eq!( + model.predict(&t), + &array![[0., 2.8], [0., 4.7], [0., 6.6]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0); + + let model = MultiTaskElasticNet::lasso() + .penalty(1.0) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., -0.5]); + assert_abs_diff_eq!(model.hyperplane(), &array![[0.0, 1.15]], epsilon = 1e-6); + assert_abs_diff_eq!( + model.predict(&t), + &array![[0., 1.8], [0., 2.95], [0., 4.1]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0); + } + #[test] fn elastic_net_toy_example_works() { let dataset = Dataset::new(array![[-1.0], [0.0], [1.0]], array![-1.0, 0.0, 1.0]); @@ -426,6 +882,88 @@ mod tests { assert_abs_diff_eq!(model.duality_gap(), 0.0); } + #[test] + fn multitask_elasticnet_toy_example_works() { + // Test MultiTaskElasticNet on a toy example for various values of alpha + // and l1_ratio. When validating this against sklearn notice that sklearn + // divides it against n_samples. + let dataset = Dataset::new( + array![[-1.0], [0.0], [1.0]], + array![[-1.0, 1.0], [0.0, -1.5], [1.0, 1.3]], + ); + + // no intercept fitting + let t = array![[2.0], [3.0], [4.0]]; + let model = MultiTaskElasticNet::params() + .with_intercept(false) + .l1_ratio(0.3) + .penalty(0.1) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., 0.]); + assert_abs_diff_eq!( + model.hyperplane(), + &array![[0.86470395, 0.12970559]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!( + model.predict(&t), + array![ + [1.7294079, 0.25941118], + [2.59411185, 0.38911678], + [3.4588158, 0.51882237] + ], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12); + + // input for prediction + let t = array![[2.0], [3.0], [4.0]]; + let model = MultiTaskElasticNet::params() + .l1_ratio(0.3) + .penalty(0.1) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., 0.26666666], epsilon = 1e-6); + assert_abs_diff_eq!( + model.hyperplane(), + &array![[0.86470395, 0.12970559]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!( + model.predict(&t), + array![ + [1.7294079, 0.52607785], + [2.59411185, 0.65578344], + [3.4588158, 0.78548904] + ], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12); + + let model = MultiTaskElasticNet::params() + .l1_ratio(0.5) + .penalty(0.1) + .fit(&dataset) + .unwrap(); + assert_abs_diff_eq!(model.intercept(), &array![0., 0.2666666], epsilon = 1e-6); + assert_abs_diff_eq!( + model.hyperplane(), + &array![[0.861237, 0.12918555]], + epsilon = 1e-6 + ); + assert_abs_diff_eq!( + model.predict(&t), + &array![ + [1.722474, 0.52503777], + [2.583711, 0.65422332], + [3.44494799, 0.78340887] + ], + epsilon = 1e-6 + ); + assert_abs_diff_eq!(model.duality_gap(), 0.0, epsilon = 1e-12); + } + #[test] fn elastic_net_2d_toy_example_works() { let dataset = Dataset::new(array![[1.0, 0.0], [0.0, 1.0]], array![3.0, 2.0]); @@ -522,7 +1060,7 @@ mod tests { [-5.514554978810590376e-03,5.068011873981870252e-02,-1.590626280073640167e-02,-6.764228304218700139e-02,4.934129593323050011e-02,7.916527725369119917e-02,-2.867429443567860031e-02,3.430885887772629900e-02,-1.811826730789670159e-02,4.448547856271539702e-02], [4.170844488444359899e-02,5.068011873981870252e-02,-1.590626280073640167e-02,1.728186074811709910e-02,-3.734373413344069942e-02,-1.383981589779990050e-02,-2.499265663159149983e-02,-1.107951979964190078e-02,-4.687948284421659950e-02,1.549073015887240078e-02], [-4.547247794002570037e-02,-4.464163650698899782e-02,3.906215296718960200e-02,1.215130832538269907e-03,1.631842733640340160e-02,1.528299104862660025e-02,-2.867429443567860031e-02,2.655962349378539894e-02,4.452837402140529671e-02,-2.593033898947460017e-02], - [-4.547247794002570037e-02,-4.464163650698899782e-02,-7.303030271642410587e-02,-8.141376581713200000e-02,8.374011738825870577e-02,2.780892952020790065e-02,1.738157847891100005e-01,-3.949338287409189657e-02,-4.219859706946029777e-03,3.064409414368320182e-03] + [-4.547247794002570037e-02,-4.464163650698899782e-02,-7.303030271642410587e-02,-8.141376581713200000e-02,8.374011738825870577e-02,2.780892952020790065e-02,1.738157847891100005e-01,-3.949338287409189657e-02,-4.219859706946029777e-03,3.064409414368320182e-03] ]; #[rustfmt::skip] let y = array![2.33e+02, 9.1e+01, 1.11e+02, 1.52e+02, 1.2e+02, 6.70e+01, 3.1e+02, 9.4e+01, 1.83e+02, 6.6e+01, 1.73e+02, 7.2e+01, 4.9e+01, 6.4e+01, 4.8e+01, 1.78e+02, 1.04e+02, 1.32e+02, 2.20e+02, 5.7e+01]; diff --git a/algorithms/linfa-elasticnet/src/error.rs b/algorithms/linfa-elasticnet/src/error.rs index 9b44f258c..30fbe566b 100644 --- a/algorithms/linfa-elasticnet/src/error.rs +++ b/algorithms/linfa-elasticnet/src/error.rs @@ -25,6 +25,8 @@ pub enum ElasticNetError { InvalidPenalty(f32), #[error("invalid tolerance {0}")] InvalidTolerance(f32), + #[error("the target can either be a vector (ndim=1) or a matrix (ndim=2)")] + IncorrectTargetShape, #[error(transparent)] BaseCrate(#[from] linfa::Error), } diff --git a/algorithms/linfa-elasticnet/src/hyperparams.rs b/algorithms/linfa-elasticnet/src/hyperparams.rs index 8cb40c7d9..abd0f8e3c 100644 --- a/algorithms/linfa-elasticnet/src/hyperparams.rs +++ b/algorithms/linfa-elasticnet/src/hyperparams.rs @@ -15,7 +15,7 @@ use super::Result; /// A verified hyper-parameter set ready for the estimation of a ElasticNet regression model /// /// See [`ElasticNetParams`](crate::ElasticNetParams) for more informations. -pub struct ElasticNetValidParams { +pub struct ElasticNetValidParamsBase { penalty: F, l1_ratio: F, with_intercept: bool, @@ -23,7 +23,10 @@ pub struct ElasticNetValidParams { tolerance: F, } -impl ElasticNetValidParams { +pub type ElasticNetValidParams = ElasticNetValidParamsBase; +pub type MultiTaskElasticNetValidParams = ElasticNetValidParamsBase; + +impl ElasticNetValidParamsBase { pub fn penalty(&self) -> F { self.penalty } @@ -54,6 +57,14 @@ impl ElasticNetValidParams { /// + 0.5 * penalty * (1 - l1_ratio) * ||w||^2_2 /// ``` /// +/// The multi-task version (Y becomes a measurement matrix) is also supported and +/// solves the following objective function: +/// ```ignore +/// 1 / (2 * n_samples) * || Y - XW ||^2_F +/// + penalty * l1_ratio * ||W||_2,1 +/// + 0.5 * penalty * (1 - l1_ratio) * ||W||^2_F +/// ``` +/// /// The parameter set can be verified into a /// [`ElasticNetValidParams`](crate::hyperparams::ElasticNetValidParams) by calling /// [ParamGuard::check](Self::check). It is also possible to directly fit a model with @@ -104,16 +115,21 @@ impl ElasticNetValidParams { /// let model = checked_params.fit(&ds)?; /// # Ok::<(), ElasticNetError>(()) /// ``` -pub struct ElasticNetParams(ElasticNetValidParams); +pub struct ElasticNetParamsBase( + ElasticNetValidParamsBase, +); + +pub type ElasticNetParams = ElasticNetParamsBase; +pub type MultiTaskElasticNetParams = ElasticNetParamsBase; -impl Default for ElasticNetParams { +impl Default for ElasticNetParamsBase { fn default() -> Self { Self::new() } } /// Configure and fit a Elastic Net model -impl ElasticNetParams { +impl ElasticNetParamsBase { /// Create default elastic net hyper parameters /// /// By default, an intercept will be fitted. To disable fitting an @@ -122,8 +138,8 @@ impl ElasticNetParams { /// To additionally normalize the feature matrix before fitting, call /// `fit_intercept_and_normalize()` before calling `fit()`. The feature /// matrix will not be normalized by default. - pub fn new() -> ElasticNetParams { - Self(ElasticNetValidParams { + pub fn new() -> ElasticNetParamsBase { + Self(ElasticNetValidParamsBase { penalty: F::one(), l1_ratio: F::cast(0.5), with_intercept: true, @@ -178,8 +194,8 @@ impl ElasticNetParams { } } -impl ParamGuard for ElasticNetParams { - type Checked = ElasticNetValidParams; +impl ParamGuard for ElasticNetParamsBase { + type Checked = ElasticNetValidParamsBase; type Error = ElasticNetError; /// Validate the hyper parameters diff --git a/algorithms/linfa-elasticnet/src/lib.rs b/algorithms/linfa-elasticnet/src/lib.rs index 789f27c86..d2865093a 100644 --- a/algorithms/linfa-elasticnet/src/lib.rs +++ b/algorithms/linfa-elasticnet/src/lib.rs @@ -1,7 +1,7 @@ #![doc = include_str!("../README.md")] use linfa::Float; -use ndarray::Array1; +use ndarray::{Array1, Array2}; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; @@ -11,11 +11,11 @@ mod error; mod hyperparams; pub use error::{ElasticNetError, Result}; -pub use hyperparams::{ElasticNetParams, ElasticNetValidParams}; +pub use hyperparams::{ElasticNetParams, ElasticNetValidParams, MultiTaskElasticNetParams}; #[cfg_attr( feature = "serde", - derive(Serialize, Deserialize), + derive(Serialize, Deserialize, Debug, Clone, PartialEq), serde(crate = "serde_crate") )] /// Elastic Net model @@ -65,3 +65,44 @@ impl ElasticNet { ElasticNetParams::new().l1_ratio(F::one()) } } + +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize, Debug, Clone, PartialEq), + serde(crate = "serde_crate") +)] +/// MultiTask Elastic Net model +/// +/// This struct contains the parameters of a fitted multi-task elastic net model. This includes the +/// coefficients (a 2-dimensional array), (optionally) intercept (a 1-dimensional array), duality gaps +/// and the number of steps needed in the computation. +/// +/// ## Model implementation +/// +/// The block coordinate descent is widely used to solve generalized linear models optimization problems, +/// like Group Lasso, MultiTask Ridge or MultiTask Lasso. It cycles through a group of parameters and update +/// the groups separately, holding all the others fixed. The optimization routine stops when a criterion is +/// satisfied (dual sub-optimality gap or change in coefficients). +pub struct MultiTaskElasticNet { + hyperplane: Array2, + intercept: Array1, + duality_gap: F, + n_steps: u32, + variance: Result>, +} + +impl MultiTaskElasticNet { + pub fn params() -> MultiTaskElasticNetParams { + MultiTaskElasticNetParams::new() + } + + /// Create a multi-task ridge only model + pub fn ridge() -> MultiTaskElasticNetParams { + MultiTaskElasticNetParams::new().l1_ratio(F::zero()) + } + + /// Create a multi-task Lasso only model + pub fn lasso() -> MultiTaskElasticNetParams { + MultiTaskElasticNetParams::new().l1_ratio(F::one()) + } +}