Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split AsTargets into AsSingleTargets and AsMultiTargets #203

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f955154
split AsTargets into AsSingleTargets and AsMultiTargets
PABannier Feb 20, 2022
365a116
impl targets
PABannier Feb 20, 2022
b7a6e67
added iter implementation
PABannier Feb 20, 2022
86b40cd
fix metrics
PABannier Feb 20, 2022
ad148cb
fix prelude
PABannier Feb 20, 2022
d28484c
impl datasets
PABannier Feb 20, 2022
d19f45f
remove useless impl
PABannier Feb 20, 2022
f609700
fix metrics
PABannier Feb 20, 2022
c7dab14
made AsMultiTargets a supertrait for AsSingleTargets
PABannier Feb 20, 2022
d396611
updated impl_dataset
PABannier Feb 20, 2022
01678a8
fix supertrait associated type
PABannier Feb 20, 2022
a06998d
added impls
PABannier Feb 20, 2022
b2c6fba
WIP fixing errors
PABannier Feb 20, 2022
5cd01d7
pass comments
PABannier Feb 21, 2022
47a1bb5
fix implementation
PABannier Feb 21, 2022
0e0afc7
changed AsTargets to AsSingleTargets in linfa-bayes
PABannier Feb 21, 2022
ceeb298
changed AsTargets to AsSingleTargets in linfa-elasticnet
PABannier Feb 21, 2022
bb75942
changed AsTargets to AsSingleTargets in linfa-kernel
PABannier Feb 21, 2022
e1c6ab0
changed AsTargets to AsSingleTargets in linfa-linear
PABannier Feb 21, 2022
ad33a4a
added impl for &T with AsSingleTargets as trait bounds
PABannier Feb 21, 2022
e2b2f3d
same
PABannier Feb 21, 2022
f42b2bf
added dimension generics to target for Dataset and DatasetView
PABannier Feb 22, 2022
7ead65c
fix linfa kernel
PABannier Feb 22, 2022
b43d410
changed AsTargets to AsSingleTargets in linfa-svm
PABannier Feb 22, 2022
6d69d4a
commented out IntoTargets explicit conversion for Ix1 to Ix2
PABannier Feb 22, 2022
feb3c75
removed IntoTargets trait
PABannier Feb 23, 2022
0ebc59e
fix metrics_clustering
PABannier Feb 23, 2022
be64d10
impl FromTargetArray for Array1
PABannier Feb 27, 2022
ed4cfe6
changed impl
PABannier Feb 27, 2022
53ebb9c
made cross_validate generic and remove cross_validate_multi
PABannier Feb 28, 2022
7c6b9cf
made iter_fold generic
PABannier Feb 28, 2022
8e80ec6
partial fix of tests
PABannier Feb 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::{
super::traits::{Predict, PredictInplace},
iter::{ChunksIter, DatasetIter, Iter},
AsTargets, AsTargetsMut, CountedTargets, Dataset, DatasetBase, DatasetView, Float,
FromTargetArray, Label, Labels, Records, Result,
AsMultiTargets, AsMultiTargetsMut, AsSingleTargets, AsSingleTargetsMut, CountedTargets,
Dataset, DatasetBase, DatasetView, Float, FromTargetArray, Label, Labels, Records, Result,
};
use crate::traits::Fit;
use ndarray::{
Expand Down Expand Up @@ -123,7 +123,7 @@ impl<R: Records, S> DatasetBase<R, S> {
}
}

impl<L, R: Records, T: AsTargets<Elem = L>> DatasetBase<R, T> {
impl<L, R: Records, T: AsSingleTargets<Elem = L>> DatasetBase<R, T> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YuhanLiin The implementation for T: AsMultiTargets<Elem = L> would be very similar. What do you recommend to avoid duplicating very similar code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AsMultiTargets instead of AsSingleTarget as the bound for these dataset impls. AsMultiTarget is more general than AsSingleTarget, and even without a blanket impl, types that implement AsSingleTarget probably also implement AsMultiTarget.

/// Map targets with a function `f`
///
/// # Example
Expand Down
112 changes: 98 additions & 14 deletions src/dataset/impl_targets.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,54 @@
use std::collections::HashMap;

use super::{
AsProbabilities, AsTargets, AsTargetsMut, CountedTargets, DatasetBase, FromTargetArray, Label,
Labels, Pr, Records,
AsMultiTargets, AsMultiTargetsMut, AsProbabilities, AsSingleTargets, AsSingleTargetsMut,
CountedTargets, DatasetBase, FromTargetArray, Label, Labels, Pr, Records,
};
use ndarray::{
Array1, Array2, ArrayBase, ArrayView2, ArrayViewMut2, Axis, CowArray, Data, DataMut, Dimension,
Ix1, Ix2, Ix3, OwnedRepr, ViewRepr,
};

impl<'a, L, S: Data<Elem = L>> AsTargets for ArrayBase<S, Ix1> {
impl<'a, L, S: Data<Elem = L>> AsSingleTargets for ArrayBase<S, Ix1> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<L> {
self.view().insert_axis(Axis(1))
}
}

impl<'a, L, S: Data<Elem = L>> AsSingleTargets for ArrayBase<S, Ix2> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<L> {
self.view()
}
}
PABannier marked this conversation as resolved.
Show resolved Hide resolved

impl<'a, L, S: Data<Elem = L>> AsMultiTargets for ArrayBase<S, Ix1> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<L> {
self.view().insert_axis(Axis(1))
}

fn ntargets(&self) -> usize {
1
}
}
PABannier marked this conversation as resolved.
Show resolved Hide resolved

impl<'a, L, S: Data<Elem = L>> AsMultiTargets for ArrayBase<S, Ix2> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<L> {
self.view()
}

fn ntargets(&self) -> usize {
self.len_of(Axis(1))
}
}

impl<'a, L: Clone + 'a, S: Data<Elem = L>> FromTargetArray<'a, L> for ArrayBase<S, Ix2> {
type Owned = ArrayBase<OwnedRepr<L>, Ix2>;
type View = ArrayBase<ViewRepr<&'a L>, Ix2>;
Expand All @@ -30,57 +62,109 @@ impl<'a, L: Clone + 'a, S: Data<Elem = L>> FromTargetArray<'a, L> for ArrayBase<
}
}

impl<L, S: DataMut<Elem = L>> AsTargetsMut for ArrayBase<S, Ix1> {
impl<L, S: DataMut<Elem = L>> AsSingleTargetsMut for ArrayBase<S, Ix1> {
type Elem = L;

fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
self.view_mut().insert_axis(Axis(1))
}
}

impl<L, S: Data<Elem = L>> AsTargets for ArrayBase<S, Ix2> {
impl<L, S: DataMut<Elem = L>> AsSingleTargetsMut for ArrayBase<S, Ix2> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<L> {
self.view()
fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
self.view_mut()
}
}
PABannier marked this conversation as resolved.
Show resolved Hide resolved

impl<L, S: DataMut<Elem = L>> AsTargetsMut for ArrayBase<S, Ix2> {
impl<L, S: DataMut<Elem = L>> AsMultiTargetsMut for ArrayBase<S, Ix1> {
type Elem = L;

fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
self.view_mut().insert_axis(Axis(1))
}

fn ntargets(&self) -> usize {
1
}
}

impl<L, S: DataMut<Elem = L>> AsMultiTargetsMut for ArrayBase<S, Ix2> {
type Elem = L;

fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
self.view_mut()
}

fn ntargets(&self) -> usize {
self.len_of(Axis(1))
}
}

impl<T: AsTargets> AsTargets for &T {
impl<T: AsSingleTargets> AsSingleTargets for &T {
type Elem = T::Elem;

fn as_multi_targets(&self) -> ArrayView2<Self::Elem> {
(*self).as_multi_targets()
}
}

impl<L: Label, T: AsTargets<Elem = L>> AsTargets for CountedTargets<L, T> {
impl<T: AsMultiTargets> AsMultiTargets for &T {
type Elem = T::Elem;

fn as_multi_targets(&self) -> ArrayView2<Self::Elem> {
(*self).as_multi_targets()
}

fn ntargets(&self) -> usize {
(*self).ntargets()
}
}

impl<L: Label, T: AsSingleTargets<Elem = L>> AsSingleTargets for CountedTargets<L, T> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<L> {
self.targets.as_multi_targets()
}
}

impl<L: Label, T: AsTargetsMut<Elem = L>> AsTargetsMut for CountedTargets<L, T> {
impl<L: Label, T: AsMultiTargets<Elem = L>> AsMultiTargets for CountedTargets<L, T> {
type Elem = L;

fn as_multi_targets(&self) -> ArrayView2<Self::Elem> {
self.targets.as_multi_targets()
}

fn ntargets(&self) -> usize {
self.targets.ntargets()
}
}

impl<L: Label, T: AsSingleTargetsMut<Elem = L>> AsSingleTargetsMut for CountedTargets<L, T> {
type Elem = L;

fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
self.targets.as_multi_targets_mut()
}
}

impl<L: Label, T: AsMultiTargetsMut<Elem = L>> AsMultiTargetsMut for CountedTargets<L, T> {
type Elem = L;

fn as_multi_targets_mut(&mut self) -> ArrayViewMut2<'_, Self::Elem> {
self.targets.as_multi_targets_mut()
}

fn ntargets(&self) -> usize {
self.targets.ntargets()
}
}

impl<'a, L: Label + 'a, T> FromTargetArray<'a, L> for CountedTargets<L, T>
where
T: AsTargets<Elem = L> + FromTargetArray<'a, L>,
T: AsSingleTargets<Elem = L> + FromTargetArray<'a, L>,
PABannier marked this conversation as resolved.
Show resolved Hide resolved
T::Owned: Labels<Elem = L>,
T::View: Labels<Elem = L>,
{
Expand Down Expand Up @@ -152,7 +236,7 @@ impl<L: Label, S: Data<Elem = L>, I: Dimension> Labels for ArrayBase<S, I> {
}

/// Counted labels can act as labels
impl<L: Label, T: AsTargets<Elem = L>> Labels for CountedTargets<L, T> {
impl<L: Label, T: AsSingleTargets<Elem = L>> Labels for CountedTargets<L, T> {
PABannier marked this conversation as resolved.
Show resolved Hide resolved
type Elem = L;

fn label_count(&self) -> Vec<HashMap<L, usize>> {
Expand All @@ -163,7 +247,7 @@ impl<L: Label, T: AsTargets<Elem = L>> Labels for CountedTargets<L, T> {
impl<F: Copy, L: Copy + Label, D, T> DatasetBase<ArrayBase<D, Ix2>, T>
where
D: Data<Elem = F>,
T: AsTargets<Elem = L>,
T: AsSingleTargets<Elem = L>,
PABannier marked this conversation as resolved.
Show resolved Hide resolved
{
/// Transforms the input dataset by keeping only those samples whose label appears in `labels`.
///
Expand Down
6 changes: 3 additions & 3 deletions src/dataset/iter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{AsTargets, DatasetBase, DatasetView, FromTargetArray, Records};
use super::{AsMultiTargets, AsSingleTargets, DatasetBase, DatasetView, FromTargetArray, Records};
use ndarray::{s, ArrayBase, ArrayView1, ArrayView2, Axis, Data, Ix2};
use std::marker::PhantomData;

Expand Down Expand Up @@ -63,7 +63,7 @@ impl<'a, 'b: 'a, R: Records, T> DatasetIter<'a, 'b, R, T> {
impl<'a, 'b: 'a, F: 'a, L: 'a, D, T> Iterator for DatasetIter<'a, 'b, ArrayBase<D, Ix2>, T>
where
D: Data<Elem = F>,
T: AsTargets<Elem = L> + FromTargetArray<'a, L>,
T: AsMultiTargets<Elem = L> + FromTargetArray<'a, L>,
{
type Item = DatasetView<'a, F, L>;

Expand Down Expand Up @@ -136,7 +136,7 @@ impl<'a, 'b: 'a, F, T> ChunksIter<'a, 'b, F, T> {

impl<'a, 'b: 'a, F, E: 'b, T> Iterator for ChunksIter<'a, 'b, F, T>
where
T: AsTargets<Elem = E> + FromTargetArray<'b, E>,
T: AsMultiTargets<Elem = E> + FromTargetArray<'b, E>,
{
type Item = DatasetBase<ArrayView2<'a, F>, T::View>;

Expand Down
27 changes: 24 additions & 3 deletions src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ pub trait Records: Sized {
fn nfeatures(&self) -> usize;
}

/// Return a reference to single or multiple target variables
pub trait AsTargets {
/// Return a reference to single target variable
pub trait AsSingleTargets {
type Elem;

/// Returns a view on targets as two-dimensional array
Expand All @@ -222,6 +222,17 @@ pub trait AsTargets {
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait should only have the method as_single_target() and should only be implemented for 1D datasets. Same with the mutable version of the trait.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The as_single_target method shouldn't be fallible, so it won't have a default impl. Otherwise if we mark the wrong type with this trait we'd get panics. This trait should only be implemented on "single-target" types, such as Array1.

}

/// Return a reference to a multi-target variable
pub trait AsMultiTargets {
type Elem;

/// Convert to a multi-target
fn as_multi_targets(&self) -> ArrayView2<Self::Elem>;

/// Returns the number of targets
fn ntargets(&self) -> usize;
}

/// Helper trait to construct counted labels
///
/// This is implemented for objects which can act as targets and created from a target matrix. For
Expand All @@ -236,7 +247,7 @@ pub trait FromTargetArray<'a, F> {
fn new_targets_view(targets: ArrayView2<'a, F>) -> Self::View;
}

pub trait AsTargetsMut {
pub trait AsSingleTargetsMut {
type Elem;

/// Returns a mutable view on targets as two-dimensional array
Expand All @@ -254,6 +265,16 @@ pub trait AsTargetsMut {
}
}

pub trait AsMultiTargetsMut {
type Elem;

/// Convert to a multi-target
fn as_multi_targets_mut(&mut self) -> Result<ArrayViewMut2<Self::Elem>>;

/// Returns the number of targets
fn ntargets(&self) -> usize;
}

/// Convert to probability matrix
///
/// Some algorithms are working with probabilities. Targets which allow an implicit conversion into
Expand Down
14 changes: 7 additions & 7 deletions src/metrics_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::fmt;
use ndarray::prelude::*;
use ndarray::Data;

use crate::dataset::{AsTargets, DatasetBase, Label, Labels, Pr, Records};
use crate::dataset::{AsSingleTargets, DatasetBase, Label, Labels, Pr, Records};
use crate::error::{Error, Result};

/// Return tuple of class index for each element of prediction and ground_truth
Expand Down Expand Up @@ -264,7 +264,7 @@ pub trait ToConfusionMatrix<A, T> {
impl<L: Label, S, T> ToConfusionMatrix<L, ArrayBase<S, Ix1>> for T
where
S: Data<Elem = L>,
T: AsTargets<Elem = L> + Labels<Elem = L>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
fn confusion_matrix(&self, ground_truth: ArrayBase<S, Ix1>) -> Result<ConfusionMatrix<L>> {
self.confusion_matrix(&ground_truth)
Expand All @@ -274,7 +274,7 @@ where
impl<L: Label, S, T> ToConfusionMatrix<L, &ArrayBase<S, Ix1>> for T
where
S: Data<Elem = L>,
T: AsTargets<Elem = L> + Labels<Elem = L>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
fn confusion_matrix(&self, ground_truth: &ArrayBase<S, Ix1>) -> Result<ConfusionMatrix<L>> {
let targets = self.try_single_target()?;
Expand Down Expand Up @@ -307,16 +307,16 @@ impl<L: Label, R, R2, T, T2> ToConfusionMatrix<L, &DatasetBase<R, T>> for Datase
where
R: Records,
R2: Records,
T: AsTargets<Elem = L>,
T2: AsTargets<Elem = L> + Labels<Elem = L>,
T: AsSingleTargets<Elem = L>,
T2: AsSingleTargets<Elem = L> + Labels<Elem = L>,
{
fn confusion_matrix(&self, ground_truth: &DatasetBase<R, T>) -> Result<ConfusionMatrix<L>> {
self.targets()
.confusion_matrix(ground_truth.try_single_target()?)
}
}

impl<L: Label, S: Data<Elem = L>, T: AsTargets<Elem = L> + Labels<Elem = L>, R: Records>
impl<L: Label, S: Data<Elem = L>, T: AsSingleTargets<Elem = L> + Labels<Elem = L>, R: Records>
ToConfusionMatrix<L, &DatasetBase<R, T>> for ArrayBase<S, Ix1>
{
fn confusion_matrix(&self, ground_truth: &DatasetBase<R, T>) -> Result<ConfusionMatrix<L>> {
Expand Down Expand Up @@ -477,7 +477,7 @@ impl<D: Data<Elem = Pr>> BinaryClassification<&[bool]> for ArrayBase<D, Ix1> {
}
}

impl<R: Records, R2: Records, T: AsTargets<Elem = bool>, T2: AsTargets<Elem = Pr>>
impl<R: Records, R2: Records, T: AsSingleTargets<Elem = bool>, T2: AsSingleTargets<Elem = Pr>>
BinaryClassification<&DatasetBase<R, T>> for DatasetBase<R2, T2>
{
fn roc(&self, y: &DatasetBase<R, T>) -> Result<ReceiverOperatingCharacteristic> {
Expand Down
11 changes: 8 additions & 3 deletions src/metrics_clustering.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Common metrics for clustering
use crate::dataset::{AsTargets, DatasetBase, Label, Labels, Records};
use crate::dataset::{AsSingleTargets, DatasetBase, Label, Labels, Records};
use crate::error::{Error, Result};
use crate::Float;
use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
Expand Down Expand Up @@ -62,8 +62,13 @@ impl<F: Float> DistanceCount<F> {
}
}

impl<'a, F: Float, L: 'a + Label, D: Data<Elem = F>, T: AsTargets<Elem = L> + Labels<Elem = L>>
SilhouetteScore<F> for DatasetBase<ArrayBase<D, Ix2>, T>
impl<
'a,
F: Float,
L: 'a + Label,
D: Data<Elem = F>,
T: AsSingleTargets<Elem = L> + Labels<Elem = L>,
> SilhouetteScore<F> for DatasetBase<ArrayBase<D, Ix2>, T>
{
fn silhouette_score(&self) -> Result<F> {
if self.ntargets() > 1 {
PABannier marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading