diff --git a/src/dataset/impl_dataset.rs b/src/dataset/impl_dataset.rs index 6b4d19952..b0706580d 100644 --- a/src/dataset/impl_dataset.rs +++ b/src/dataset/impl_dataset.rs @@ -30,6 +30,7 @@ impl DatasetBase { targets, weights: Array1::zeros(0), feature_names: Vec::new(), + target_names: Vec::new(), } } @@ -81,13 +82,14 @@ impl DatasetBase { /// Updates the records of a dataset /// /// This function overwrites the records in a dataset. It also invalidates the weights and - /// feature names. + /// feature/target names. pub fn with_records(self, records: T) -> DatasetBase { DatasetBase { records, targets: self.targets, weights: Array1::zeros(0), feature_names: Vec::new(), + target_names: Vec::new(), } } @@ -100,6 +102,7 @@ impl DatasetBase { targets, weights: self.weights, feature_names: self.feature_names, + target_names: self.target_names, } } @@ -118,6 +121,15 @@ impl DatasetBase { self } + + /// Updates the target names of a dataset + pub fn with_target_names>(mut self, names: Vec) -> DatasetBase { + let target_names = names.into_iter().map(|x| x.into()).collect(); + + self.target_names = target_names; + + self + } } impl Dataset { @@ -153,6 +165,7 @@ impl> DatasetBase { targets, weights, feature_names, + target_names, .. } = self; @@ -163,9 +176,17 @@ impl> DatasetBase { targets: targets.map(fnc), weights, feature_names, + target_names, } } + /// Returns target names + /// + /// A target name gives a human-readable string describing the purpose of a single target. + pub fn target_names(&self) -> &[String] { + &self.target_names + } + /// Return the number of targets in the dataset /// /// # Example @@ -226,6 +247,7 @@ where DatasetBase::new(records, targets) .with_feature_names(self.feature_names.clone()) .with_weights(self.weights.clone()) + .with_target_names(self.target_names.clone()) } /// Iterate over features @@ -299,11 +321,13 @@ where }; let dataset1 = DatasetBase::new(records_first, targets_first) .with_weights(first_weights) - .with_feature_names(self.feature_names.clone()); + .with_feature_names(self.feature_names.clone()) + .with_target_names(self.target_names.clone()); let dataset2 = DatasetBase::new(records_second, targets_second) .with_weights(second_weights) - .with_feature_names(self.feature_names.clone()); + .with_feature_names(self.feature_names.clone()) + .with_target_names(self.target_names.clone()); (dataset1, dataset2) } @@ -349,7 +373,8 @@ where label, DatasetBase::new(self.records().view(), targets) .with_feature_names(self.feature_names.clone()) - .with_weights(self.weights.clone()), + .with_weights(self.weights.clone()) + .with_target_names(self.target_names.clone()), ) }) .collect()) @@ -405,6 +430,7 @@ impl, I: Dimension> From> targets: empty_targets, weights: Array1::zeros(0), feature_names: Vec::new(), + target_names: Vec::new(), } } } @@ -421,6 +447,7 @@ where targets: rec_tar.1, weights: Array1::zeros(0), feature_names: Vec::new(), + target_names: Vec::new(), } } } diff --git a/src/dataset/impl_targets.rs b/src/dataset/impl_targets.rs index 36692d5c5..34f8967b4 100644 --- a/src/dataset/impl_targets.rs +++ b/src/dataset/impl_targets.rs @@ -231,6 +231,7 @@ where weights: Array1::from(weights), targets, feature_names: self.feature_names.clone(), + target_names: self.target_names.clone(), } } } diff --git a/src/dataset/iter.rs b/src/dataset/iter.rs index d1608f9ab..691fc3de4 100644 --- a/src/dataset/iter.rs +++ b/src/dataset/iter.rs @@ -77,18 +77,24 @@ where if self.target_or_feature && self.dataset.nfeatures() <= self.idx { return None; } - let mut records = self.dataset.records.view(); let mut targets = self.dataset.targets.as_targets(); let feature_names; + let target_names; let weights = self.dataset.weights.clone(); if !self.target_or_feature { // This branch should only run for 2D targets targets.collapse_axis(Axis(1), self.idx); feature_names = self.dataset.feature_names.clone(); + if self.dataset.target_names.is_empty() { + target_names = Vec::new(); + } else { + target_names = vec![self.dataset.target_names[self.idx].clone()]; + } } else { records.collapse_axis(Axis(1), self.idx); + target_names = self.dataset.target_names.clone(); if self.dataset.feature_names.len() == records.len_of(Axis(1)) { feature_names = vec![self.dataset.feature_names[self.idx].clone()]; } else { @@ -103,6 +109,7 @@ where targets, weights, feature_names, + target_names, }; Some(dataset_view) diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 1a221805e..c59e0d200 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -164,6 +164,7 @@ impl Deref for Pr { /// * `targets`: a two-/one-dimension matrix with dimensionality (nsamples, ntargets) /// * `weights`: optional weights for each sample with dimensionality (nsamples) /// * `feature_names`: optional descriptive feature names with dimensionality (nfeatures) +/// * `target_names`: optional descriptive target names with dimensionality (ntargets) /// /// # Trait bounds /// @@ -180,6 +181,7 @@ where pub weights: Array1, feature_names: Vec, + target_names: Vec, } /// Targets with precomputed, counted labels @@ -343,6 +345,19 @@ mod tests { assert!(dataset.into_single_target().targets.shape() == [10]); } + #[test] + fn set_target_name() { + let dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![0., 1.]) + .with_target_names(vec!["test"]); + assert_eq!(dataset.target_names, vec!["test"]); + } + + #[test] + fn empty_target_name() { + let dataset = Dataset::new(array![[1., 2.], [1., 2.]], array![[0., 1.], [2., 3.]]); + assert_eq!(dataset.target_names, Vec::::new()); + } + #[test] fn dataset_implements_required_methods() { let mut rng = SmallRng::seed_from_u64(42); @@ -540,7 +555,7 @@ mod tests { let dataset = Dataset::new( array![[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]], array![[1, 2], [3, 4], [5, 6]], - ); + ).with_target_names(vec!["a", "b"]); let res = dataset .target_iter() @@ -549,6 +564,13 @@ mod tests { assert_eq!(res, &[array![1, 3, 5], array![2, 4, 6]]); + let mut iter = dataset.target_iter(); + let first = iter.next(); + let second = iter.next(); + + assert_eq!(vec!["a"], first.unwrap().target_names()); + assert_eq!(vec!["b"], second.unwrap().target_names()); + let res = dataset .feature_iter() .map(|x| x.records)