Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split AsTargets into AsSingleTargets and AsMultiTargets #203

Closed
wants to merge 32 commits into from

Conversation

PABannier
Copy link
Contributor

1) What was done before? linfa offers one trait AsTargets which supports both the single-task and the multi-task case.

2) What does this PR change? The goal of this PR is to split AsTargets into AsSingleTargets and AsMultiTargets, add the corresponding implementations and integrate it to linfa codebase.

3) Why is it better? As explained in #195 , splitting AsTargets into AsSingleTargets and AsMultiTargets would be greatly beneficial to implement multi-task models.

@PABannier PABannier marked this pull request as draft February 20, 2022 15:03
@@ -123,7 +123,7 @@ impl<R: Records, S> DatasetBase<R, S> {
}
}

impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
impl<L, R: Records, T: AsSingleTargets<Elem = L>> DatasetBase<R, T> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YuhanLiin The implementation for T: AsMultiTargets<Elem = L> would be very similar. What do you recommend to avoid duplicating very similar code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AsMultiTargets instead of AsSingleTarget as the bound for these dataset impls. AsMultiTarget is more general than AsSingleTarget, and even without a blanket impl, types that implement AsSingleTarget probably also implement AsMultiTarget.

@@ -222,6 +222,17 @@ pub trait AsTargets {
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait should only have the method as_single_target() and should only be implemented for 1D datasets. Same with the mutable version of the trait.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The as_single_target method shouldn't be fallible, so it won't have a default impl. Otherwise if we mark the wrong type with this trait we'd get panics. This trait should only be implemented on "single-target" types, such as Array1.

src/dataset/mod.rs Outdated Show resolved Hide resolved
src/dataset/mod.rs Outdated Show resolved Hide resolved
@PABannier
Copy link
Contributor Author

PABannier commented Feb 21, 2022

@YuhanLiin I'm changing AsTargets into AsSingleTargets and AsMultiTargets in linfa crates. I'm having a lot of compiler errors in the tests because of the object DatasetView, which by definition use a 2d array as target. Since 2d arrays do not implement the AsSingleTargets trait, and we try to pass a DatasetView object to a function fit that expects a struct that implements the AsSingleTargets, we have compiler errors. For instance, see GaussianNB tests.

@PABannier
Copy link
Contributor Author

PABannier commented Feb 21, 2022

@YuhanLiin I'd be tempted to create a SingleTargetDataset and MultiTargetDataset types, but the naming is cumbersome and this change has far-reaching consequences on the whole codebase. One option would be to keep Dataset as the default option with 2d array as the target, and to create a SingleDataset type which takes in a 1d array as the target. But in this case, to be consistent with the naming, we should rename AsMultiTargets into AsTargets as you suggested above.

WDYT?

@YuhanLiin
Copy link
Collaborator

Try this: change the definition of Dataset into pub type Dataset<D, T, I = Ix2> = DatasetBase<ArrayBase<OwnedRepr<D>, Ix2>, ArrayBase<OwnedRepr<T>, I>>; Do the same for DatasetView. Adding an extra type param for the dimensionality of targets allows the type to be used for both 1D and 2D targets, and the default value of Ix2 minimizes breakage. After that all you need to do is change the single-target dataset examples (such as GaussianNB) to use 1D instead of 2D arrays.

impl_regression!(ArrayView2<'_, f32>, ArrayView2<'_, f32>, f32);
impl_regression!(ArrayView2<'_, f64>, ArrayView2<'_, f64>, f64);
// impl_regression!(Array2<f32>, Array2<f32>, f32);
// impl_regression!(Array2<f64>, Array2<f64>, f64);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented out these tests, since the macro calls try_single_target. Should we remove them?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure those are impls used by the rest of the library, not tests. This fitting algorithm is clearly single-target, so just remove the impls with 2D targets and leave the other ones.

self.insert_axis(Axis(1))
}
}
// impl<F, D: Data<Elem = F>> IntoTargets<ArrayBase<D, Ix2>> for ArrayBase<D, Ix1> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this impl anymore. It converts Array1d into Array2d in Dataset, which creates conflicting implementations in the tests.

Copy link
Contributor Author

@PABannier PABannier Feb 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has a lot of consequences. A lot of 1d targets were implicitly converted to Array2d, which is expected for functions like cross_validate to work. For DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix2>>, there are cross_validate and cross_validate_multi. IMHO, we should remove cross_validate_multi from the impl and leave cross_validate, then we implement cross_validate for DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix1>>. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That whole trait should be removed. The fact that it has into as a method is a red flag, since that makes it a worse version of Into. From what I understand the problem you're having currently is that, because most Dataset methods were implemented on 2D targets, they won't the compile now that we have datasets with 1D targets. For methods that are applicable for both 1D and 2D targets (most methods are like this), place them in a generic impl block bound by AsMultiTargets, since as_multi_targets normalizes 1D and 2D targets into 2D arrays. Methods applicable only for 1D targets should be implemented for AsSingleTarget if possible; otherwise implement for ArrayBase<S, Ix1>. Methods applicable only for 2D targets should be implemented for AsMultiTargets if possible; otherwise implement for ArrayBase<S, Ix2>. This strategy should minimize caller breakage.

For cross_validate and cross_validate_multi, I'm pretty sure that cross_validate applies to both 1D and 2D targets, while cross_validate_multi applies only to 2D and works a bit differently. multi allows the evaluation function to differentiate between different targets, since its evaluation function takes Array2 instead of Array1. However, it's possible to express the semantics of 2D cross_validate in terms of cross_validate_multi, so we don't really need 2 functions. All we really need is one generic cross_validate function that covers both 1D and 2D targets, and it'll behave like cross_validate_multi in the 2D case. This should be easily doable by making the evaluation closure generic on the target dimensionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is well noted. I'm left confused by the FromTargetArray trait. It is not implemented for Array1d. However, bootstrap, bootstrap_samples, etc. are implemented with the FromTargetArray and AsMultiTargets traits. I only see one implementation of FromTargetArray for CountedTargets but none for Array1d and Array2d. I don't know how to deal with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FromTargetArray acts like the reverse of AsMultipleTargets. It takes a 2D target array and converts it back into the original "target type", such as CountedTargets. If implemented on Array2 it converts the target array into itself (identity operation pretty much). If implemented on CountedTargets it wraps the target array into a CountedTargets. Typically, a function like bootstrap takes some targets, calls as_multi_targets on it, manipulates the 2D target array, then calls from_target to turn it back into the original target type. This doesn't work on 1D targets right now, since FromTargetArray is not implemented on Array1. The impl should turn the 2D target array (which we know only has one target) into a 1D target array. Something like:

impl<'a, L: Clone + 'a, S: Data<Elem = L>> FromTargetArray<'a, L> for ArrayBase<S, Ix1> {
    type Owned = Array1<L>;
    type View = ArrayView1<'a, L>;

    fn new_targets(targets: Array2<L>) -> Self::Owned {
        // something like targets.reshape(targets.nrows()).unwrap()
    }

    fn new_targets_view(targets: ArrayView2<'a, L>) -> Self::View {
        // same here
    }
}

@PABannier
Copy link
Contributor Author

@YuhanLiin Hey. Could you have a look at what I've done? I'm running into a lot of issues and it feels like it's never ending. I have errors stating Unsatisfied trait bounds. I don't have a holistic view of the crate and it's becoming increasingly hard to fix those issues. I'd be glad if you could help in fixing those issues.

Comment on lines -761 to -806
#[test]
fn test_st_cv_mt_all_correct() {
let records =
Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
let targets = array![[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]];
let mut dataset: Dataset<f64, f64> = (records, targets).into();
let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }];
let acc = dataset
.cross_validate_multi(5, &params, |_pred, _truth| Ok(array![5., 6.]))
.unwrap();
assert_eq!(acc.dim(), (params.len(), dataset.ntargets()));
assert_eq!(acc, array![[5., 6.], [5., 6.]])
}
#[test]
fn test_st_cv_mt_one_incorrect() {
let records =
Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
let mut dataset: Dataset<f64, f64> = (records, targets).into();
// second one should throw an error
let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 0 }];
let err = dataset
.cross_validate_multi(5, &params, |_pred, _truth| Ok(array![5.]))
.unwrap_err();
assert_eq!(err.to_string(), "invalid parameter 0".to_string());
}

#[test]
fn test_st_cv_mt_incorrect_eval() {
let records =
Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap();
let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap();
let mut dataset: Dataset<f64, f64> = (records, targets).into();
// second one should throw an error
let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 1 }];
let err = dataset
.cross_validate_multi(5, &params, |_pred, _truth| {
if false {
Ok(array![0f32])
} else {
Err(Error::Parameters("eval".to_string()))
}
})
.unwrap_err();
assert_eq!(err.to_string(), "invalid parameter eval".to_string());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to adapt these tests to use cross_validate_multi?

C: Fn(
&ArrayView2<E>,
&ArrayView2<E>,
) -> std::result::Result<Array1<FACC>, crate::error::Error>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant in my comment was that since cross_validate can be expressed in terms of cross_validate_multi, we should only have one cross validation API that behaves like cross_validate_multi. This means cross_validate should behave the same as the old cross_validate_multi, and it should return a 2D array instead of 1D.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually old cross_validate_multi returns a 1d array, while the previous implementation of cross_validate returns a scalar

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signature of cross_validate_multi on master is

pub fn cross_validate_multi<O, ER, M, FACC, C>(
        &'a mut self,
        k: usize,
        parameters: &[M],
        eval: C,
    ) -> std::result::Result<Array2<FACC>, ER>

The return is actually a 2D array

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok my bad, i got confused

@YuhanLiin
Copy link
Collaborator

After looking at the change in-depth, I'd like to propose a rework to the target traits (sorry that I'm suggesting this at this point in the PR, but I do believe it will make everything significantly easier). The main thorn in our side has been the fact that as_multi_targets() yields ArrayView2 from both 1D and 2D targets, which is bad when we want to turn the view back into the original target type. This was the problem with FromTargetArray, map_targets(), and a bunch of other functions. I propose changing the target traits as such:

trait AsTargets {
  type Elem;
  type Ix: Dimensions;
  fn as_targets(&self) -> ArrayView<Self::Elem, Self::Ix>;
  fn ntargets(&self) -> usize { ... }
}

impl AsTargets for ArrayBase<S, Ix1> {
  type Ix = Ix1;
  ...
}
impl AsTargets for ArrayBase<S, Ix2> {
  type Ix = Ix2;
  ...
}

trait AsSingleTarget: AsTargets<Ix = Ix1> {}

// Might not need this
trait AsMultiTargets: AsTargets<Ix = Ix2> {}

If we then change most of the trait bounds to use AsTargets instead of AsMultiTargets, it allows 1D targets to return 1D views in generic contexts, which solves the problem described above. Since this is a pretty big change I'd be willing to make the change.

As for the errors I see on your branch, they come from tfidf_validation.rs and whitening.rs. For tfidf you need load_set() to return Array1 instead of Array2, since Gaussian Naive Bayes now expects a 1D array. This can be done by changing the from_shape_vec() call in load_set(). For whitening the issue is due to map_targets() and datasets::winequality() always returning 2D arrays. The problem with map_targets() would be solved with the above rework, which would also allow you to change winequality() to return a 1D target.

@PABannier
Copy link
Contributor Author

@YuhanLiin no worries, the chance is possibly very disruptive so if a solution can significantly alleviate the work load, it is more than welcome. If you could give it shot, that'd be great. I'll try to give it a shot myself today, and push if it is any successful.

@YuhanLiin
Copy link
Collaborator

Just let me know how it goes, so we don't duplicate work.

@PABannier
Copy link
Contributor Author

@YuhanLiin I tried to start from scratch again and implement what you just said. It does seem to be a less breaking change, but I still run into some issues which I don't know how to solve. I let you take the lead on what you proposed to implement.

@YuhanLiin
Copy link
Collaborator

I reworked the PR in #206. The refactoring of the main linfa crate is entirely complete. Can you fix any remaining failures and update the docs in the main linfa crate?

@YuhanLiin
Copy link
Collaborator

Closed by #206

@YuhanLiin YuhanLiin closed this Mar 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants