Skip to content

Commit

Permalink
fix providing metrics and exceptions, add saving intermediate steps
Browse files Browse the repository at this point in the history
  • Loading branch information
kingagla committed Mar 14, 2024
1 parent c4a8f2d commit 7321d7e
Showing 1 changed file with 92 additions and 26 deletions.
118 changes: 92 additions & 26 deletions src/stylo_metrix/stylo_metrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import os
import re
from pathlib import Path
from typing import List, Union

import numpy as np
import pandas as pd
Expand All @@ -29,26 +32,47 @@


class StyloMetrix(BaseEstimator, TransformerMixin):
"""Class for counting linguistic metrics.
Args:
lang (str): Language. One of ['de','en', 'pl', 'ru', 'ukr']
nlp (spacy.language.Language, optional): Language model to use from spacy. Defaults to None. If None predefined models are used.
metrics (Union[MetricGroup, List[str], None], optional): List of Metrics to use. Defaults to None. If not defined all available metrics will be used.
exceptions (Union[MetricGroup, List[str], None], optional): List of Metrics to remove. Defaults to None. If not defined no metric will be removed.
debug (bool, optional): Should debug be collected?. Defaults to False.
save_path (Union[str, Path], optional): Path to save result and between steps. Defaults to None.
output_name (str, optional): Filename for SM output. Defaults to "sm_output".
debug_name (str, optional): Filename for SM debug. Defaults to "sm_debug".
save_step (int, optional): Define after how many steps StyloMetrix should be saved. Defaults to None. If not defined, no intermediate steps are saved.
nlp_customization (_type_, optional): NLP customization. Defaults to None.
"""

def __init__(
self,
lang,
nlp=None,
metrics=None,
exceptions=None,
debug=False,
save_path=None,
lang: str,
nlp: spacy.language.Language = None,
metrics: Union[MetricGroup, List[str], None] = None,
exceptions: Union[MetricGroup, List[str], None] = None,
debug: bool = False,
save_path: Union[str, Path] = None,
output_name: str = "sm_output",
debug_name: str = "sm_debug",
save_step: int = None,
nlp_customization=None,
):
super().__init__()
self._debug = debug
self._customization = nlp_customization
self._save_step = save_step
self.output_name = output_name
self.debug_name = debug_name
self._init_nlp(lang, nlp)
self._init_metrics(metrics, exceptions, self.nlp)
self._set_pipeline()
if save_path:
if not os.path.exists(save_path):
raise Exception(f"Path {save_path} is not exists")
self._save_path = save_path
self._save_path = Path(save_path)
self._save_path.mkdir(parents=True, exist_ok=True)
self._file_number = self._determine_number()

def fit(self, texts, outputs=None):
return self
Expand Down Expand Up @@ -94,36 +118,47 @@ def transform(self, texts):
debugs_content = doc_content + metric_debugs
values.append(values_content)
debugs.append(debugs_content)

if self._save_path and self._save_step and (i % (self._save_step - 1) == 0):
values_temp = pd.DataFrame(values, columns=columns + m_columns)
debugs_temp = pd.DataFrame(debugs, columns=columns + m_columns)

self._save(values_temp, self.output_name + f"{self._file_number}_temp")
self._save(debugs_temp, self.debug_name + f"{self._file_number}_temp")
columns = columns + m_columns

values = pd.DataFrame(values, columns=columns)
debugs = pd.DataFrame(debugs, columns=columns)

output_name = "sm_output"
debug_name = "sm_debug"

if self._debug:
if self._save_path:
deb_num = self._save(values, self._save_path, output_name)
self._save(values, self._save_path, debug_name, deb_num)
self._save(values, self.output_name + f"{self._file_number}")
self._save(debugs, self.debug_name + f"{self._file_number}")
return values, debugs
else:
if self._save_path:
self._save(values, self._save_path, output_name)
self._save(values)
return values

def _save(self, value, path, base_name, number=None):
if not number:
number = 1
for file in os.listdir(path):
if file.startswith(base_name):
n = int(file[len(base_name) :].replace(".csv", ""))
if n >= number:
number = n + 1
file_name = f"{base_name}{number}.csv"
file_path = os.path.join(path, file_name)
def _save(self, value, base_name):
file_name = f"{base_name}.csv"
file_path = os.path.join(self._save_path, file_name)
value.to_csv(file_path)
print(f"File saved in location: {file_path}")
if not file_path.replace(".csv", "").endswith("_temp"):
Path(file_path.replace(".csv", "_temp.csv")).unlink(missing_ok=True)

def _determine_number(self):
number = 1
for file in os.listdir(self._save_path):
if file.startswith(self.output_name):
n = file[len(self.output_name) :].replace(".csv", "")
if re.findall("\d+", n):
n = int(re.findall("\d+", n)[0])
else:
n = 1
if n >= number:
number = n + 1
return number

def set_debug(self, debug):
Expand All @@ -134,12 +169,18 @@ def set_nlp_customization(self, nlp_customization):

def _init_metrics(self, metrics, exceptions, nlp):
base_metrics = MetricGroup()

if not metrics:
base_metrics += self._lang.get_metrics()
elif not isinstance(metrics, MetricGroup):
base_metrics += self._list_to_metric_group(metrics)

else:
for metric in metrics:
base_metrics += metric
if exceptions:
if exceptions and not isinstance(exceptions, MetricGroup):
base_metrics -= self._list_to_metric_group(exceptions)
elif exceptions and isinstance(exceptions, MetricGroup):
for exception in exceptions:
base_metrics -= exception
for metric in base_metrics:
Expand All @@ -161,3 +202,28 @@ def _init_nlp(self, lang, nlp):

self.nlp = nlp
self._lang = lang

def _list_to_metric_group(self, metrics):
defined_metrics = []
metrics_df = pd.DataFrame(
{
"category": [
metric.category.__name__
for metric in self._lang.get_metrics().metrics
],
"metric_name": [
metric.__name__ for metric in self._lang.get_metrics().metrics
],
"metric": list(self._lang.get_metrics().metrics),
}
)
for metric in metrics:
if metric in metrics_df["category"].values:
defined_metrics += metrics_df.loc[
metrics_df["category"] == metric, "metric"
].tolist()
elif metric in metrics_df["metric_name"].values:
defined_metrics += metrics_df.loc[
metrics_df["metric_name"] == metric, "metric"
].tolist()
return MetricGroup(metrics=defined_metrics)

0 comments on commit 7321d7e

Please sign in to comment.