Skip to content

Commit

Permalink
Set default values for regularization and outlier removal so user isn…
Browse files Browse the repository at this point in the history
…'t forced to provide those parameters
  • Loading branch information
ooples committed Dec 15, 2024
1 parent 6bad169 commit d7053f7
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/Models/RegularizationOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ public class RegularizationOptions
{
public RegularizationType Type { get; set; } = RegularizationType.None;
public double Strength { get; set; } = 0.0;
public double L1Ratio { get; set; } = 0.5; // Only used for ElasticNet
public double L1Ratio { get; set; } = 0.5;
}
9 changes: 7 additions & 2 deletions src/OutlierRemoval/IQROutlierRemoval.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ public class IQROutlierRemoval<T> : IOutlierRemoval<T>
private readonly T _iqrMultiplier;
private readonly INumericOperations<T> _numOps;

public IQROutlierRemoval(T iqrMultiplier)
public IQROutlierRemoval(T? iqrMultiplier = default)
{
_iqrMultiplier = iqrMultiplier;
_numOps = MathHelper.GetNumericOperations<T>();
_iqrMultiplier = iqrMultiplier ?? GetDefaultMultiplier();
}

public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
Expand Down Expand Up @@ -45,4 +45,9 @@ public IQROutlierRemoval(T iqrMultiplier)

return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
}

private T GetDefaultMultiplier()
{
return _numOps.FromDouble(1.5);
}
}
9 changes: 7 additions & 2 deletions src/OutlierRemoval/MADOutlierRemoval.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ public class MADOutlierRemoval<T> : IOutlierRemoval<T>
private readonly T _threshold;
private readonly INumericOperations<T> _numOps;

public MADOutlierRemoval(T threshold)
public MADOutlierRemoval(T? threshold = default)
{
_threshold = threshold;
_numOps = MathHelper.GetNumericOperations<T>();
_threshold = threshold ?? GetDefaultThreshold();
}

public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
Expand Down Expand Up @@ -43,4 +43,9 @@ public MADOutlierRemoval(T threshold)

return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
}

private T GetDefaultThreshold()
{
return _numOps.FromDouble(3.5);
}
}
12 changes: 7 additions & 5 deletions src/OutlierRemoval/ThresholdOutlierRemoval.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
namespace AiDotNet.OutlierRemoval;

/// <summary>
/// Removes outliers from the data using the threshold method. This method is not recommended for data sets with less than 15 data points.
/// </summary>
public class ThresholdOutlierRemoval<T> : IOutlierRemoval<T>
{
private readonly T _threshold;
private readonly INumericOperations<T> _numOps;

public ThresholdOutlierRemoval(T threshold)
public ThresholdOutlierRemoval(T? threshold = default)
{
_threshold = threshold;
_numOps = MathHelper.GetNumericOperations<T>();
_threshold = threshold ?? GetDefaultThreshold();
}

public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
Expand Down Expand Up @@ -42,4 +39,9 @@ public ThresholdOutlierRemoval(T threshold)

return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
}

private T GetDefaultThreshold()
{
return _numOps.FromDouble(3.0);
}
}
9 changes: 7 additions & 2 deletions src/OutlierRemoval/ZScoreOutlierRemoval.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ public class ZScoreOutlierRemoval<T> : IOutlierRemoval<T>
private readonly T _threshold;
private readonly INumericOperations<T> _numOps;

public ZScoreOutlierRemoval(T threshold)
public ZScoreOutlierRemoval(T? threshold = default)
{
_threshold = threshold;
_numOps = MathHelper.GetNumericOperations<T>();
_threshold = threshold ?? FindDefaultThreshold();
}

public (Matrix<T> CleanedInputs, Vector<T> CleanedOutputs) RemoveOutliers(Matrix<T> inputs, Vector<T> outputs)
Expand Down Expand Up @@ -41,4 +41,9 @@ public ZScoreOutlierRemoval(T threshold)

return (new Matrix<T>(cleanedInputs, _numOps), new Vector<T>(cleanedOutputs, _numOps));
}

private T FindDefaultThreshold()
{
return _numOps.FromDouble(3.0);
}
}
6 changes: 3 additions & 3 deletions src/Regularization/ElasticRegularization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ public class ElasticNetRegularization<T> : IRegularization<T>
private readonly INumericOperations<T> _numOps;
private readonly RegularizationOptions _options;

public ElasticNetRegularization(INumericOperations<T> numOps, RegularizationOptions options)
public ElasticNetRegularization(RegularizationOptions? options = null)
{
_numOps = numOps;
_options = options;
_numOps = MathHelper.GetNumericOperations<T>();
_options = options ?? new RegularizationOptions();
}

public Matrix<T> RegularizeMatrix(Matrix<T> matrix)
Expand Down
7 changes: 3 additions & 4 deletions src/Regularization/L1Regularization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ public class L1Regularization<T> : IRegularization<T>
private readonly INumericOperations<T> _numOps;
private readonly RegularizationOptions _options;

public L1Regularization(INumericOperations<T> numOps, RegularizationOptions options)
public L1Regularization(RegularizationOptions? options = null)
{
_numOps = numOps;
_options = options;
_numOps = MathHelper.GetNumericOperations<T>();
_options = options ?? new RegularizationOptions();
}

public Matrix<T> RegularizeMatrix(Matrix<T> matrix)
{
// L1 regularization doesn't modify the matrix directly
return matrix;
}

Expand Down
6 changes: 3 additions & 3 deletions src/Regularization/L2Regularization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ public class L2Regularization<T> : IRegularization<T>
private readonly INumericOperations<T> _numOps;
private readonly RegularizationOptions _options;

public L2Regularization(INumericOperations<T> numOps, RegularizationOptions options)
public L2Regularization(RegularizationOptions? options = null)
{
_numOps = numOps;
_options = options;
_numOps = MathHelper.GetNumericOperations<T>();
_options = options ?? new RegularizationOptions();
}

public Matrix<T> RegularizeMatrix(Matrix<T> matrix)
Expand Down

0 comments on commit d7053f7

Please sign in to comment.