Skip to content

Commit

Permalink
result class
Browse files Browse the repository at this point in the history
  • Loading branch information
Polkas committed Oct 25, 2023
1 parent 53959b0 commit 47182ed
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 18 deletions.
108 changes: 90 additions & 18 deletions src/cat2cat/cat2cat_ml.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,100 @@
from pandas import DataFrame, concat
from numpy import arange, repeat, setdiff1d, in1d, intersect1d, sum, NaN, mean
from numpy import repeat, setdiff1d, in1d, sum, NaN, nanmean, isnan, round

from sklearn.model_selection import train_test_split

from cat2cat.mappings import get_mappings, get_freqs, cat_apply_freq
from cat2cat.dataclass import cat2cat_data, cat2cat_mappings, cat2cat_ml
from cat2cat.cat2cat_utils import dummy_c2c
from cat2cat.mappings import get_mappings
from cat2cat.dataclass import cat2cat_mappings, cat2cat_ml

from typing import Optional, Any, Dict
from typing import Any, Dict

__all__ = ["cat2cat_ml_run"]


class cat2cat_ml_run_class:
def __init__(self, res) -> Dict:
class cat2cat_ml_run_results:
"""The class to represent the results of the cat2cat_ml_run function call
Args:
res (Dict): raw results from the cat2cat_ml_run function call
mappings (cat2cat_mappings): dataclass with mappings related arguments.
Please check out the `cat2cat.dataclass.cat2cat_mappings` for more information.
ml (cat2cat_ml): dataclass with ml related arguments.
Please check out the `cat2cat.dataclass.cat2cat_ml` for more information.
Returns:
cat2cat_ml_run_results class instance with the following attributes:
res (Dict): raw results from the cat2cat_ml_run function call
mean_acc (Dict): mean accuracy for each model
percent_failed (Dict): percent of failed models for each model
percent_better (Dict): percent of better models over most frequent category solution for each model
mappings (cat2cat_mappings): initial mappings dataclass with mappings related arguments.
ml (cat2cat_ml): initial ml dataclass with ml related arguments.
Methods:
get_raw: get raw results
"""

def __init__(self, res, mappings, ml) -> Dict:
self.res = res
self.mappings = mappings
self.ml = ml
self.models_names = [type(m).__name__ for m in self.ml.models]

mean_acc = dict()
percent_failed = dict()
percent_better = dict()

mean_acc["naive"] = round(
nanmean(
[self.res.get(g, {"naive": NaN}).get("naive") for g in self.res.keys()]
),
3,
)
mean_acc["most_freq"] = round(
nanmean(
[self.res.get(g, {"freq": NaN}).get("freq") for g in self.res.keys()]
),
2,
)
for m in self.models_names:
vals = [
self.res.get(g, {"acc": {m: NaN}}).get("acc").get(m, NaN)
for g in self.res.keys()
]
mean_acc[m] = round(nanmean(vals), 3)
percent_failed[m] = round(sum(isnan(vals)) / len(vals) * 100, 3)
percent_better[m] = round(
sum(vals > mean_acc["most_freq"]) / len(vals) * 100, 3
)

def __str__(self) -> str:
str(self.res)
self.mean_acc = mean_acc
self.percent_failed = percent_failed
self.percent_better = percent_better

def get_raw(self) -> Dict:
"""Get raw results"""
return self.res

def __repr__(self) -> str:
str(self.res)
res = ""
for k, v in self.mean_acc.items():
res += "Accuracy {}: {}".format(k, v) + "\n"
res += "\n"
for k, v in self.percent_failed.items():
res += "Percent of failed {}: {}".format(k, v) + "\n"
res += "\n"
for k, v in self.percent_better.items():
res += (
"Percent of better {} over most frequent category solution: {}".format(
k, v
)
+ "\n"
)
res += "Features: {}".format(self.ml.features) + "\n"
res += "Test sample size: {}".format(self.ml.test_size) + "\n"
return res


def cat2cat_ml_run(
mappings: cat2cat_mappings, ml: cat2cat_ml, **kwargs: Any
) -> cat2cat_ml_run_class:
) -> cat2cat_ml_run_results:
"""Automatic mapping in a panel dataset - cat2cat procedure
Args:
Expand All @@ -35,6 +104,7 @@ def cat2cat_ml_run(
Please check out the `cat2cat.dataclass.cat2cat_ml` for more information.
**kwargs: additional arguments passed to the `cat2cat_ml_run` function.
min_match (float): minimum share of categories from the base period that have to be matched in the mapping table. Between 0 and 1. Default 0.8.
test_size (float): share of the data used for testing. Between 0 and 1. Default 0.2.
Returns:
cat2cat_ml_run_class
Expand Down Expand Up @@ -87,7 +157,8 @@ def cat2cat_ml_run(
), "The mapping table does not cover all categories in the data. Please check the direction in the mapping table."

features = ml.features
methods = ml.models
models = ml.models
models_names = [type(m).__name__ for m in models]

train_g = {n: g for n, g in ml.data[features + [ml.cat_var]].groupby(ml.cat_var)}

Expand All @@ -99,7 +170,7 @@ def cat2cat_ml_run(
res[g_name] = {
"ncat": len(matched_cat),
"naive": 1 / len(matched_cat),
"acc": dict(zip(methods, repeat(NaN, len(methods)))),
"acc": dict(zip(models_names, repeat(NaN, len(models_names)))),
"freq": NaN,
}
data_small_g_list = list()
Expand All @@ -122,24 +193,25 @@ def cat2cat_ml_run(
X_train, X_test, y_train, y_test = train_test_split(
data_small_g[features],
data_small_g[ml.cat_var],
test_size=0.2,
test_size=kwargs.get("test_size", 0.2),
random_state=42,
)

gcounts = y_train.value_counts()
gfreq_max = gcounts.index[0]
res[g_name]["freq"] = mean(gfreq_max == y_test)
res[g_name]["freq"] = nanmean(gfreq_max == y_test)

if X_test.shape[0] == 0 | X_train.shape[0] < 5:
continue

for m in methods:
for m in models:
ml_name = type(m).__name__
m.fit(X_train, y_train)
res[g_name]["acc"][m] = m.score(X_test, y_test)
res[g_name]["acc"][ml_name] = m.score(X_test, y_test)
except:
continue

return res
return cat2cat_ml_run_results(res, mappings, ml)


def _cat2cat_ml(
Expand Down
7 changes: 7 additions & 0 deletions tests/test_cat2cat_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@
)
cat2cat_ml_run(mappings=mappings, ml=ml)


def test_cat2cat_ml_run_repr():
pass


def test_cat2cat_ml_run_get_raw():
pass

0 comments on commit 47182ed

Please sign in to comment.