-
-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split AsTargets
into AsSingleTargets
and AsMultiTargets
#203
Conversation
src/dataset/impl_dataset.rs
Outdated
@@ -123,7 +123,7 @@ impl<R: Records, S> DatasetBase<R, S> { | |||
} | |||
} | |||
|
|||
impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> { | |||
impl<L, R: Records, T: AsSingleTargets<Elem = L>> DatasetBase<R, T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@YuhanLiin The implementation for T: AsMultiTargets<Elem = L>
would be very similar. What do you recommend to avoid duplicating very similar code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use AsMultiTargets
instead of AsSingleTarget
as the bound for these dataset impls. AsMultiTarget
is more general than AsSingleTarget
, and even without a blanket impl, types that implement AsSingleTarget
probably also implement AsMultiTarget
.
src/dataset/mod.rs
Outdated
@@ -222,6 +222,17 @@ pub trait AsTargets { | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This trait should only have the method as_single_target()
and should only be implemented for 1D datasets. Same with the mutable version of the trait.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The as_single_target
method shouldn't be fallible, so it won't have a default impl. Otherwise if we mark the wrong type with this trait we'd get panics. This trait should only be implemented on "single-target" types, such as Array1
.
@YuhanLiin I'm changing |
@YuhanLiin I'd be tempted to create a WDYT? |
Try this: change the definition of |
impl_regression!(ArrayView2<'_, f32>, ArrayView2<'_, f32>, f32); | ||
impl_regression!(ArrayView2<'_, f64>, ArrayView2<'_, f64>, f64); | ||
// impl_regression!(Array2<f32>, Array2<f32>, f32); | ||
// impl_regression!(Array2<f64>, Array2<f64>, f64); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commented out these tests, since the macro calls try_single_target
. Should we remove them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure those are impls used by the rest of the library, not tests. This fitting algorithm is clearly single-target, so just remove the impls with 2D targets and leave the other ones.
src/dataset/impl_dataset.rs
Outdated
self.insert_axis(Axis(1)) | ||
} | ||
} | ||
// impl<F, D: Data<Elem = F>> IntoTargets<ArrayBase<D, Ix2>> for ArrayBase<D, Ix1> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this impl anymore. It converts Array1d into Array2d in Dataset
, which creates conflicting implementations in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change has a lot of consequences. A lot of 1d targets were implicitly converted to Array2d, which is expected for functions like cross_validate
to work. For DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix2>>
, there are cross_validate
and cross_validate_multi
. IMHO, we should remove cross_validate_multi
from the impl and leave cross_validate
, then we implement cross_validate
for DatasetBase<ArrayBase<D, Ix2>, ArrayBase<S, Ix1>>
. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That whole trait should be removed. The fact that it has into
as a method is a red flag, since that makes it a worse version of Into
. From what I understand the problem you're having currently is that, because most Dataset
methods were implemented on 2D targets, they won't the compile now that we have datasets with 1D targets. For methods that are applicable for both 1D and 2D targets (most methods are like this), place them in a generic impl block bound by AsMultiTargets
, since as_multi_targets
normalizes 1D and 2D targets into 2D arrays. Methods applicable only for 1D targets should be implemented for AsSingleTarget
if possible; otherwise implement for ArrayBase<S, Ix1>
. Methods applicable only for 2D targets should be implemented for AsMultiTargets
if possible; otherwise implement for ArrayBase<S, Ix2>
. This strategy should minimize caller breakage.
For cross_validate
and cross_validate_multi
, I'm pretty sure that cross_validate
applies to both 1D and 2D targets, while cross_validate_multi
applies only to 2D and works a bit differently. multi
allows the evaluation function to differentiate between different targets, since its evaluation function takes Array2
instead of Array1
. However, it's possible to express the semantics of 2D cross_validate
in terms of cross_validate_multi
, so we don't really need 2 functions. All we really need is one generic cross_validate
function that covers both 1D and 2D targets, and it'll behave like cross_validate_multi
in the 2D case. This should be easily doable by making the evaluation closure generic on the target dimensionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is well noted. I'm left confused by the FromTargetArray
trait. It is not implemented for Array1d. However, bootstrap
, bootstrap_samples
, etc. are implemented with the FromTargetArray
and AsMultiTargets
traits. I only see one implementation of FromTargetArray
for CountedTargets
but none for Array1d and Array2d. I don't know how to deal with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FromTargetArray
acts like the reverse of AsMultipleTargets
. It takes a 2D target array and converts it back into the original "target type", such as CountedTargets
. If implemented on Array2
it converts the target array into itself (identity operation pretty much). If implemented on CountedTargets
it wraps the target array into a CountedTargets
. Typically, a function like bootstrap
takes some targets, calls as_multi_targets
on it, manipulates the 2D target array, then calls from_target
to turn it back into the original target type. This doesn't work on 1D targets right now, since FromTargetArray
is not implemented on Array1
. The impl should turn the 2D target array (which we know only has one target) into a 1D target array. Something like:
impl<'a, L: Clone + 'a, S: Data<Elem = L>> FromTargetArray<'a, L> for ArrayBase<S, Ix1> {
type Owned = Array1<L>;
type View = ArrayView1<'a, L>;
fn new_targets(targets: Array2<L>) -> Self::Owned {
// something like targets.reshape(targets.nrows()).unwrap()
}
fn new_targets_view(targets: ArrayView2<'a, L>) -> Self::View {
// same here
}
}
@YuhanLiin Hey. Could you have a look at what I've done? I'm running into a lot of issues and it feels like it's never ending. I have errors stating |
#[test] | ||
fn test_st_cv_mt_all_correct() { | ||
let records = | ||
Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); | ||
let targets = array![[1., 1.], [2., 2.], [3., 3.], [4., 4.], [5., 5.]]; | ||
let mut dataset: Dataset<f64, f64> = (records, targets).into(); | ||
let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 2 }]; | ||
let acc = dataset | ||
.cross_validate_multi(5, ¶ms, |_pred, _truth| Ok(array![5., 6.])) | ||
.unwrap(); | ||
assert_eq!(acc.dim(), (params.len(), dataset.ntargets())); | ||
assert_eq!(acc, array![[5., 6.], [5., 6.]]) | ||
} | ||
#[test] | ||
fn test_st_cv_mt_one_incorrect() { | ||
let records = | ||
Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); | ||
let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap(); | ||
let mut dataset: Dataset<f64, f64> = (records, targets).into(); | ||
// second one should throw an error | ||
let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 0 }]; | ||
let err = dataset | ||
.cross_validate_multi(5, ¶ms, |_pred, _truth| Ok(array![5.])) | ||
.unwrap_err(); | ||
assert_eq!(err.to_string(), "invalid parameter 0".to_string()); | ||
} | ||
|
||
#[test] | ||
fn test_st_cv_mt_incorrect_eval() { | ||
let records = | ||
Array2::from_shape_vec((5, 2), vec![1., 1., 2., 2., 3., 3., 4., 4., 5., 5.]).unwrap(); | ||
let targets = Array1::from_shape_vec(5, vec![1., 2., 3., 4., 5.]).unwrap(); | ||
let mut dataset: Dataset<f64, f64> = (records, targets).into(); | ||
// second one should throw an error | ||
let params = vec![MockFittable { mock_var: 1 }, MockFittable { mock_var: 1 }]; | ||
let err = dataset | ||
.cross_validate_multi(5, ¶ms, |_pred, _truth| { | ||
if false { | ||
Ok(array![0f32]) | ||
} else { | ||
Err(Error::Parameters("eval".to_string())) | ||
} | ||
}) | ||
.unwrap_err(); | ||
assert_eq!(err.to_string(), "invalid parameter eval".to_string()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to adapt these tests to use cross_validate_multi
?
C: Fn( | ||
&ArrayView2<E>, | ||
&ArrayView2<E>, | ||
) -> std::result::Result<Array1<FACC>, crate::error::Error>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant in my comment was that since cross_validate
can be expressed in terms of cross_validate_multi
, we should only have one cross validation API that behaves like cross_validate_multi
. This means cross_validate
should behave the same as the old cross_validate_multi
, and it should return a 2D array instead of 1D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually old cross_validate_multi
returns a 1d array, while the previous implementation of cross_validate
returns a scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of cross_validate_multi
on master is
pub fn cross_validate_multi<O, ER, M, FACC, C>(
&'a mut self,
k: usize,
parameters: &[M],
eval: C,
) -> std::result::Result<Array2<FACC>, ER>
The return is actually a 2D array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok my bad, i got confused
After looking at the change in-depth, I'd like to propose a rework to the target traits (sorry that I'm suggesting this at this point in the PR, but I do believe it will make everything significantly easier). The main thorn in our side has been the fact that trait AsTargets {
type Elem;
type Ix: Dimensions;
fn as_targets(&self) -> ArrayView<Self::Elem, Self::Ix>;
fn ntargets(&self) -> usize { ... }
}
impl AsTargets for ArrayBase<S, Ix1> {
type Ix = Ix1;
...
}
impl AsTargets for ArrayBase<S, Ix2> {
type Ix = Ix2;
...
}
trait AsSingleTarget: AsTargets<Ix = Ix1> {}
// Might not need this
trait AsMultiTargets: AsTargets<Ix = Ix2> {} If we then change most of the trait bounds to use As for the errors I see on your branch, they come from |
@YuhanLiin no worries, the chance is possibly very disruptive so if a solution can significantly alleviate the work load, it is more than welcome. If you could give it shot, that'd be great. I'll try to give it a shot myself today, and push if it is any successful. |
Just let me know how it goes, so we don't duplicate work. |
@YuhanLiin I tried to start from scratch again and implement what you just said. It does seem to be a less breaking change, but I still run into some issues which I don't know how to solve. I let you take the lead on what you proposed to implement. |
I reworked the PR in #206. The refactoring of the main |
Closed by #206 |
1) What was done before?
linfa
offers one traitAsTargets
which supports both the single-task and the multi-task case.2) What does this PR change? The goal of this PR is to split
AsTargets
intoAsSingleTargets
andAsMultiTargets
, add the corresponding implementations and integrate it tolinfa
codebase.3) Why is it better? As explained in #195 , splitting
AsTargets
intoAsSingleTargets
andAsMultiTargets
would be greatly beneficial to implement multi-task models.