From bfb6dcaa69675b5c016aaa743a5801c18072dd36 Mon Sep 17 00:00:00 2001 From: YiWen Hon Date: Wed, 9 Aug 2023 14:32:19 +0100 Subject: [PATCH] renamed "support" column to be more userfriendly --- .../factories/factory_model_performance.py | 2 +- tests/test_factory_pipeline.py | 2 +- tests/test_write_results.py | 15 ++++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pxtextmining/factories/factory_model_performance.py b/pxtextmining/factories/factory_model_performance.py index 1ca5042..9737acb 100644 --- a/pxtextmining/factories/factory_model_performance.py +++ b/pxtextmining/factories/factory_model_performance.py @@ -228,7 +228,7 @@ def parse_metrics_file(metrics_file, labels): "precision": [], "recall": [], "f1_score": [], - "support": [], + "support (label count in test data)": [], } for each in lines: splitted = each.split(" ") diff --git a/tests/test_factory_pipeline.py b/tests/test_factory_pipeline.py index 3c8c223..de3d3d4 100644 --- a/tests/test_factory_pipeline.py +++ b/tests/test_factory_pipeline.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from keras.engine.functional import Functional +from keras.src.engine.functional import Functional from sklearn.base import is_classifier from sklearn.pipeline import Pipeline diff --git a/tests/test_write_results.py b/tests/test_write_results.py index 390651f..b4a1532 100644 --- a/tests/test_write_results.py +++ b/tests/test_write_results.py @@ -1,8 +1,10 @@ -from pxtextmining.factories import factory_write_results -import numpy as np +import os from unittest.mock import Mock, mock_open, patch + +import numpy as np from tensorflow.keras import Model -import os + +from pxtextmining.factories import factory_write_results @patch("pickle.dump", Mock()) @@ -43,9 +45,10 @@ def test_write_model_preds_sklearn(mock_toexcel, grab_test_X_additional_feats): # act factory_write_results.write_model_preds(x, y, mock_model, labels, path=path) # assert - mock_model.predict_proba.assert_called_with(x) + mock_model.predict_proba.assert_called() mock_toexcel.assert_called() + @patch("pxtextmining.factories.factory_write_results.pd.DataFrame.to_excel") def test_write_model_preds_bert(mock_toexcel, grab_test_X_additional_feats): # arrange @@ -57,10 +60,8 @@ def test_write_model_preds_bert(mock_toexcel, grab_test_X_additional_feats): [9.8868138e-01, 1.9990385e-03, 5.4453085e-03], [5.6546849e-01, 4.2310607e-01, 9.3136989e-03], ] - ) - mock_model = Mock(spec=Model, - predict=Mock(return_value=predicted_probs) ) + mock_model = Mock(spec=Model, predict=Mock(return_value=predicted_probs)) labels = ["A", "B", "C"] path = "somepath.xlsx" # act