Skip to content

Commit

Permalink
work in progress almost done
Browse files Browse the repository at this point in the history
  • Loading branch information
rwijtvliet committed Aug 16, 2024
1 parent 922b070 commit dfebb52
Show file tree
Hide file tree
Showing 4 changed files with 1,315 additions and 14 deletions.
33 changes: 19 additions & 14 deletions portfolyo/tools/wavg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Mapping, overload
from typing import Iterable, Mapping, Optional, overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,28 +43,33 @@
# --> Otherwise, if values are identical --> result is that value
# --> Otherwise, result is NaN.

# Mapping of values ot weights:
# - If mapping is unclear: raise Error.
# - If there more values than weights: no worries; value apparently not needed.
# - If there more weights than values: raise Error.

RESULT_IF_WEIGHTSUM0_VALUESNOTUNIFORM = np.nan


@overload
def general(
fr: pd.Series, weights: Iterable | Mapping | pd.Series = None, axis: int = 0
fr: pd.Series, weights: Optional[Iterable | Mapping | pd.Series], axis: int = 0
) -> float:
...


@overload
def general(
fr: pd.DataFrame,
weights: Iterable | Mapping | pd.Series | pd.DataFrame = None,
weights: Optional[Iterable | Mapping | pd.Series | pd.DataFrame],
axis: int = 0,
) -> pd.Series:
...


def general(
fr: pd.Series | pd.DataFrame,
weights: Iterable | Mapping | pd.Series | pd.DataFrame = None,
weights: Optional[Iterable | Mapping | pd.Series | pd.DataFrame] = None,
axis: int = 0,
) -> float | tools_unit.Q_ | pd.Series:
"""
Expand Down Expand Up @@ -92,14 +97,12 @@ def general(
return dataframe(fr, weights, axis)
elif isinstance(fr, pd.Series):
return series(fr, weights)
else:
raise TypeError(
f"Parameter ``fr`` must be Series or DataFrame; got {type(fr)}."
)
raise TypeError(f"Parameter ``fr`` must be Series or DataFrame; got {type(fr)}.")


def series(
s: pd.Series, weights: Iterable | Mapping | pd.Series = None
s: pd.Series,
weights: Optional[Iterable | Mapping | pd.Series] = None,
) -> float | tools_unit.Q_:
"""
Weighted average of series.
Expand Down Expand Up @@ -142,7 +145,7 @@ def series(
replaceable = s.isna() & (weights == 0.0)
s[replaceable] = 0.0

# If we arrive here, ``s`` only has NaN on locations where weight != 0.
# If we arrive here, if ``s`` contains NaN, it is on locations where weight != 0.

# Check if ALL weights are 0.
# In that case, the result is NaN.
Expand Down Expand Up @@ -174,7 +177,7 @@ def series(

def dataframe(
df: pd.DataFrame,
weights: Iterable | Mapping | pd.Series | pd.DataFrame = None,
weights: Optional[Iterable | Mapping | pd.Series | pd.DataFrame] = None,
axis: int = 0,
) -> pd.Series:
"""
Expand Down Expand Up @@ -444,14 +447,14 @@ def weights_as_series(weights: Iterable | Mapping, refindex: Iterable) -> pd.Ser
elif isinstance(weights, Mapping):
weights = pd.Series(weights)
elif isinstance(weights, Iterable):
weights = pd.Series(weights, refindex)
weights = pd.Series(weights, refindex) # will only work if same length
else:
raise TypeError("``weights`` must be iterable or mapping.")
# Step 2: avoid Series of Quantity-objects (convert to pint-series instead).
return tools_unit.avoid_frame_of_objects(weights)


def values_areuniform(series: pd.Series, mask: Iterable = None) -> bool:
def values_areuniform(series: pd.Series, mask: Optional[Iterable] = None) -> bool:
"""Return True if all values in series are same. If mask is provided, only compare
values where the mask is True. If there are no values to compare, return True."""
values = series[mask].values if mask is not None else series.values
Expand All @@ -463,7 +466,9 @@ def values_areuniform(series: pd.Series, mask: Iterable = None) -> bool:
return True


def concatseries(series: Iterable[pd.Series], refindex: Iterable = None) -> pd.Series:
def concatseries(
series: Iterable[pd.Series], refindex: Optional[Iterable]
) -> pd.Series:
"""Concatenate some series, and try to make it a pint-series if possible."""
dtypes = set()
for s in series:
Expand Down
86 changes: 86 additions & 0 deletions tests/tools/test_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest


@pytest.fixture(
scope="session",
params=[
{
"val": "a2",
"wei": "a1",
"exp": "a1",
"axis": "a0",
},
{
"val": "b2",
"wei": "b2",
"exp": "b1",
"axis": "b1",
},
{
"val": "c3",
"wei": "c2",
"exp": "c1",
"axis": "c0",
},
{
"val": "d2",
"wei": "d2",
"exp": "d1",
"axis": "d1",
},
],
)
def complexcases(request):
return request.param


@pytest.fixture(
scope="session",
params=[
{
"val": "e1",
"wei": "e1",
"exp": "e0",
},
{
"val": "f1",
"wei": "f1",
"exp": "f0",
},
{
"val": "g1",
"wei": "g1",
"exp": "g0",
},
{
"val": "h1",
"wei": "h0",
"exp": "h2",
},
],
)
def easycases(request):
return request.param


@pytest.fixture
def val2d(easycases, complexcases):
allcases = [easycases] + [complexcases]
values = [cas["val"] for cas in allcases]
return [val for val in values if val[1] == "2"]


@pytest.fixture
def val1d(easycases, complexcases):
allcases = [*easycases, *complexcases]
return [case["val"] for case in allcases if case["val"][1] == "1"]


@pytest.fixture
def val0d(easycases, complexcases):
allcases = [*easycases, *complexcases]
return [case["val"] for case in allcases if case["val"][1] == "0"]


def test_two_d_cases(val2d):
print(val2d)
28 changes: 28 additions & 0 deletions tests/tools/test_test2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# I have a pytest integer fixture fix1 which is parametrized with the values 1,2,3, and another pytest integer fixture fix2 which is parametrized with the values 10,11, 20. I want to create a fixture fix3 which returns every value for fix1 and for fix2, one after the other, as long as the value is odd. So, fix3 should return the integer values 1, then 3, then 11. How do I do that?

import pytest


# Define fixture for `fix1`
@pytest.fixture(params=[[1], [2], [3]])
def alist(request):
return request.param


# Define fixture for `fix2`
@pytest.fixture(params=[9, 10, 11])
def anelement(request):
return request.param


def test_1(anelement):
print(f"test_1: {anelement:=}")


@pytest.fixture
def longer_list(alist, anelement):
alist.append(anelement)


def test_listcontent(longer_list):
print("---".join(longer_list))
Loading

0 comments on commit dfebb52

Please sign in to comment.