From 7321d7e9958a5ea1b62e58dc6ce8a3f5146d5330 Mon Sep 17 00:00:00 2001 From: kingagl Date: Thu, 14 Mar 2024 13:29:00 +0100 Subject: [PATCH] fix providing metrics and exceptions, add saving intermediate steps --- src/stylo_metrix/stylo_metrix.py | 118 ++++++++++++++++++++++++------- 1 file changed, 92 insertions(+), 26 deletions(-) diff --git a/src/stylo_metrix/stylo_metrix.py b/src/stylo_metrix/stylo_metrix.py index 5dca861..deb720c 100644 --- a/src/stylo_metrix/stylo_metrix.py +++ b/src/stylo_metrix/stylo_metrix.py @@ -14,6 +14,9 @@ # along with this program. If not, see . import os +import re +from pathlib import Path +from typing import List, Union import numpy as np import pandas as pd @@ -29,26 +32,47 @@ class StyloMetrix(BaseEstimator, TransformerMixin): + """Class for counting linguistic metrics. + + Args: + lang (str): Language. One of ['de','en', 'pl', 'ru', 'ukr'] + nlp (spacy.language.Language, optional): Language model to use from spacy. Defaults to None. If None predefined models are used. + metrics (Union[MetricGroup, List[str], None], optional): List of Metrics to use. Defaults to None. If not defined all available metrics will be used. + exceptions (Union[MetricGroup, List[str], None], optional): List of Metrics to remove. Defaults to None. If not defined no metric will be removed. + debug (bool, optional): Should debug be collected?. Defaults to False. + save_path (Union[str, Path], optional): Path to save result and between steps. Defaults to None. + output_name (str, optional): Filename for SM output. Defaults to "sm_output". + debug_name (str, optional): Filename for SM debug. Defaults to "sm_debug". + save_step (int, optional): Define after how many steps StyloMetrix should be saved. Defaults to None. If not defined, no intermediate steps are saved. + nlp_customization (_type_, optional): NLP customization. Defaults to None. + """ + def __init__( self, - lang, - nlp=None, - metrics=None, - exceptions=None, - debug=False, - save_path=None, + lang: str, + nlp: spacy.language.Language = None, + metrics: Union[MetricGroup, List[str], None] = None, + exceptions: Union[MetricGroup, List[str], None] = None, + debug: bool = False, + save_path: Union[str, Path] = None, + output_name: str = "sm_output", + debug_name: str = "sm_debug", + save_step: int = None, nlp_customization=None, ): super().__init__() self._debug = debug self._customization = nlp_customization + self._save_step = save_step + self.output_name = output_name + self.debug_name = debug_name self._init_nlp(lang, nlp) self._init_metrics(metrics, exceptions, self.nlp) self._set_pipeline() if save_path: - if not os.path.exists(save_path): - raise Exception(f"Path {save_path} is not exists") - self._save_path = save_path + self._save_path = Path(save_path) + self._save_path.mkdir(parents=True, exist_ok=True) + self._file_number = self._determine_number() def fit(self, texts, outputs=None): return self @@ -94,36 +118,47 @@ def transform(self, texts): debugs_content = doc_content + metric_debugs values.append(values_content) debugs.append(debugs_content) + + if self._save_path and self._save_step and (i % (self._save_step - 1) == 0): + values_temp = pd.DataFrame(values, columns=columns + m_columns) + debugs_temp = pd.DataFrame(debugs, columns=columns + m_columns) + + self._save(values_temp, self.output_name + f"{self._file_number}_temp") + self._save(debugs_temp, self.debug_name + f"{self._file_number}_temp") columns = columns + m_columns values = pd.DataFrame(values, columns=columns) debugs = pd.DataFrame(debugs, columns=columns) - output_name = "sm_output" - debug_name = "sm_debug" - if self._debug: if self._save_path: - deb_num = self._save(values, self._save_path, output_name) - self._save(values, self._save_path, debug_name, deb_num) + self._save(values, self.output_name + f"{self._file_number}") + self._save(debugs, self.debug_name + f"{self._file_number}") return values, debugs else: if self._save_path: - self._save(values, self._save_path, output_name) + self._save(values) return values - def _save(self, value, path, base_name, number=None): - if not number: - number = 1 - for file in os.listdir(path): - if file.startswith(base_name): - n = int(file[len(base_name) :].replace(".csv", "")) - if n >= number: - number = n + 1 - file_name = f"{base_name}{number}.csv" - file_path = os.path.join(path, file_name) + def _save(self, value, base_name): + file_name = f"{base_name}.csv" + file_path = os.path.join(self._save_path, file_name) value.to_csv(file_path) print(f"File saved in location: {file_path}") + if not file_path.replace(".csv", "").endswith("_temp"): + Path(file_path.replace(".csv", "_temp.csv")).unlink(missing_ok=True) + + def _determine_number(self): + number = 1 + for file in os.listdir(self._save_path): + if file.startswith(self.output_name): + n = file[len(self.output_name) :].replace(".csv", "") + if re.findall("\d+", n): + n = int(re.findall("\d+", n)[0]) + else: + n = 1 + if n >= number: + number = n + 1 return number def set_debug(self, debug): @@ -134,12 +169,18 @@ def set_nlp_customization(self, nlp_customization): def _init_metrics(self, metrics, exceptions, nlp): base_metrics = MetricGroup() + if not metrics: base_metrics += self._lang.get_metrics() + elif not isinstance(metrics, MetricGroup): + base_metrics += self._list_to_metric_group(metrics) + else: for metric in metrics: base_metrics += metric - if exceptions: + if exceptions and not isinstance(exceptions, MetricGroup): + base_metrics -= self._list_to_metric_group(exceptions) + elif exceptions and isinstance(exceptions, MetricGroup): for exception in exceptions: base_metrics -= exception for metric in base_metrics: @@ -161,3 +202,28 @@ def _init_nlp(self, lang, nlp): self.nlp = nlp self._lang = lang + + def _list_to_metric_group(self, metrics): + defined_metrics = [] + metrics_df = pd.DataFrame( + { + "category": [ + metric.category.__name__ + for metric in self._lang.get_metrics().metrics + ], + "metric_name": [ + metric.__name__ for metric in self._lang.get_metrics().metrics + ], + "metric": list(self._lang.get_metrics().metrics), + } + ) + for metric in metrics: + if metric in metrics_df["category"].values: + defined_metrics += metrics_df.loc[ + metrics_df["category"] == metric, "metric" + ].tolist() + elif metric in metrics_df["metric_name"].values: + defined_metrics += metrics_df.loc[ + metrics_df["metric_name"] == metric, "metric" + ].tolist() + return MetricGroup(metrics=defined_metrics)