-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] Add error modeling to CADET-Process
- Loading branch information
1 parent
514cb45
commit 9a4caf8
Showing
4 changed files
with
310 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
=============================================== | ||
Error Modeling (:mod:`CADETProcess.errorModel`) | ||
=============================================== | ||
.. currentmodule:: CADETProcess.errorModel | ||
A module to define error models in CADET-Process. | ||
.. autosummary:: | ||
:toctree: generated/ | ||
""" | ||
|
||
from .distribution import * | ||
from .variator import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from scipy.stats import norm, uniform, expon, binom, poisson | ||
from typing import Optional, Union | ||
import numpy as np | ||
|
||
|
||
class DistributionBase: | ||
""" | ||
Base class for all distributions. | ||
Handles common functionality for sampling, mean, and variance using `scipy.stats` | ||
distributions. | ||
""" | ||
|
||
def __init__(self, dist): | ||
""" | ||
Initialize the base distribution. | ||
Parameters | ||
---------- | ||
dist : scipy.stats distribution | ||
A scipy.stats frozen distribution object. | ||
""" | ||
self._dist = dist | ||
|
||
def sample( | ||
self, | ||
size: Optional[Union[int, tuple]] = None | ||
) -> Union[float, np.ndarray]: | ||
""" | ||
Generate random samples from the distribution. | ||
Parameters | ||
---------- | ||
size : int or tuple of ints, optional | ||
Number of samples to generate. Defaults to None for a single sample. | ||
Returns | ||
------- | ||
ndarray or scalar | ||
Random sample(s) from the distribution. | ||
""" | ||
return self._dist.rvs(size=size) | ||
|
||
def mean(self) -> float: | ||
""" | ||
Compute the mean of the distribution. | ||
Returns | ||
------- | ||
float | ||
The mean of the distribution. | ||
""" | ||
return self._dist.mean() | ||
|
||
def var(self) -> float: | ||
""" | ||
Compute the variance of the distribution. | ||
Returns | ||
------- | ||
float | ||
The variance of the distribution. | ||
""" | ||
return self._dist.var() | ||
|
||
|
||
class NormalDistribution(DistributionBase): | ||
"""Represents a normal (Gaussian) distribution.""" | ||
|
||
def __init__(self, mu: float, sigma: float): | ||
""" | ||
Initialize the normal distribution. | ||
Parameters | ||
---------- | ||
mu : float | ||
Mean of the distribution. | ||
sigma : float | ||
Standard deviation of the distribution. | ||
Raises | ||
------ | ||
ValueError | ||
If sigma is negative. | ||
""" | ||
if sigma < 0: | ||
raise ValueError("Sigma must be non-negative.") | ||
super().__init__(norm(loc=mu, scale=sigma)) | ||
|
||
|
||
class UniformDistribution(DistributionBase): | ||
"""Represents a uniform distribution.""" | ||
|
||
def __init__(self, lb: float, ub: float): | ||
""" | ||
Initialize the uniform distribution. | ||
Parameters | ||
---------- | ||
lb : float | ||
Lower bound of the distribution. | ||
ub : float | ||
Upper bound of the distribution. | ||
Raises | ||
------ | ||
ValueError | ||
If lb is not less than ub. | ||
""" | ||
if lb >= ub: | ||
raise ValueError("Lower bound must be less than upper bound.") | ||
super().__init__(uniform(loc=lb, scale=ub - lb)) | ||
|
||
|
||
class ExponentialDistribution(DistributionBase): | ||
"""Represents an exponential distribution.""" | ||
|
||
def __init__(self, lambda_: float): | ||
""" | ||
Initialize the exponential distribution. | ||
Parameters | ||
---------- | ||
lambda_ : float | ||
The rate parameter (inverse of mean). | ||
Raises | ||
------ | ||
ValueError | ||
If lambda_ is non-positive. | ||
""" | ||
if lambda_ <= 0: | ||
raise ValueError("Lambda must be positive.") | ||
super().__init__(expon(scale=1 / lambda_)) | ||
|
||
|
||
class BinomialDistribution(DistributionBase): | ||
"""Represents a binomial distribution.""" | ||
|
||
def __init__(self, n: int, p: float): | ||
""" | ||
Initialize the binomial distribution. | ||
Parameters | ||
---------- | ||
n : int | ||
Number of trials. | ||
p : float | ||
Probability of success in each trial. | ||
Raises | ||
------ | ||
ValueError | ||
If n is not positive or p is not in [0, 1]. | ||
""" | ||
if n <= 0: | ||
raise ValueError("Number of trials (n) must be positive.") | ||
if not (0 <= p <= 1): | ||
raise ValueError("Probability (p) must be between 0 and 1.") | ||
super().__init__(binom(n=n, p=p)) | ||
|
||
|
||
class PoissonDistribution(DistributionBase): | ||
"""Represents a Poisson distribution.""" | ||
|
||
def __init__(self, lambda_: float): | ||
""" | ||
Initialize the Poisson distribution. | ||
Parameters | ||
---------- | ||
lambda_ : float | ||
The rate parameter (mean and variance). | ||
Raises | ||
------ | ||
ValueError | ||
If lambda_ is non-positive. | ||
""" | ||
if lambda_ <= 0: | ||
raise ValueError("Lambda must be positive.") | ||
super().__init__(poisson(mu=lambda_)) | ||
|
||
|
||
# %% | ||
|
||
# Create instances of different distributions | ||
normal = NormalDistribution(mu=0, sigma=1) | ||
uniform = UniformDistribution(lb=0, ub=10) | ||
exponential = ExponentialDistribution(lambda_=2) | ||
binomial = BinomialDistribution(n=10, p=0.5) | ||
poisson = PoissonDistribution(lambda_=3) | ||
|
||
# Example usage | ||
print("Normal Samples:", normal.sample(size=5)) | ||
print("Normal Mean:", normal.mean()) | ||
print("Normal Variance:", normal.var()) | ||
|
||
print("Uniform Samples:", uniform.sample(size=5)) | ||
print("Exponential Samples:", exponential.sample(size=5)) | ||
print("Binomial Samples:", binomial.sample(size=5)) | ||
print("Poisson Samples:", poisson.sample(size=5)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Dict | ||
|
||
from CADETProcess.errorModel.distribution import DistributionBase | ||
|
||
|
||
class Variator: | ||
""" | ||
A class to manage variations in a model using error distributions. | ||
Attributes | ||
---------- | ||
model : Any | ||
The model object to which variations will be applied. | ||
error_distributions : dict | ||
A dictionary to store registered error distributions indexed by name. | ||
""" | ||
|
||
def __init__(self, model): | ||
""" | ||
Initialize the Variator with a model. | ||
Parameters | ||
---------- | ||
model : Any | ||
The model object to which variations will be applied. | ||
""" | ||
self.model = model | ||
self.error_distributions: Dict[str, VariatedVariable] = {} | ||
|
||
def add_error(self, name: str, parameter_path: str, distribution: DistributionBase): | ||
""" | ||
Register an error distribution for a specific parameter in the model. | ||
Parameters | ||
---------- | ||
name : str | ||
The unique name for the error being registered. | ||
parameter_path : str | ||
The path to the parameter in the model that the error affects. | ||
Example: 'layer1.weights[0][0]' | ||
distribution : DistributionBase | ||
The distribution object defining the error. | ||
Raises | ||
------ | ||
ValueError | ||
If the name is already registered. | ||
""" | ||
if name in self.error_distributions: | ||
raise ValueError(f"Error with name '{name}' is already registered.") | ||
|
||
self.error_distributions[name] = VariatedVariable(parameter_path, distribution) | ||
|
||
|
||
class VariatedVariable: | ||
""" | ||
Encapsulates a parameter path and its associated error distribution. | ||
Attributes | ||
---------- | ||
parameter_path : str | ||
The path to the parameter in the model. | ||
distribution : DistributionBase | ||
The distribution defining the error for this parameter. | ||
""" | ||
|
||
def __init__(self, parameter_path: str, distribution: DistributionBase): | ||
""" | ||
Initialize the VariatedVariable. | ||
Parameters | ||
---------- | ||
parameter_path : str | ||
The path to the parameter in the model. | ||
distribution : DistributionBase | ||
The distribution defining the error for this parameter. | ||
Raises | ||
------ | ||
TypeError | ||
If distribution is not an instance of DistributionBase. | ||
""" | ||
if not isinstance(distribution, DistributionBase): | ||
raise TypeError( | ||
f"Expected distribution to be an instance of DistributionBase, got {type(distribution).__name__}." | ||
) | ||
self.parameter_path = parameter_path | ||
self.distribution = distribution | ||
|
||
def __repr__(self): | ||
return f"VariatedVariable(parameter_path='{self.parameter_path}', distribution={self.distribution})" |