Skip to content

Commit

Permalink
WIP: Unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
srcansiz committed Mar 25, 2022
1 parent 9a91807 commit fa1ccbe
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
2 changes: 2 additions & 0 deletions fedbiomed/common/training_plans/_fedbiosklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def __init__(self, model_args: dict = {}):
if not isinstance(model_args, dict):
model_args = {}

print('###', model_args)

if 'model' not in model_args:
msg = ErrorNumbers.FB303.value + ": SKLEARN model not provided"
logger.critical(msg)
Expand Down
26 changes: 15 additions & 11 deletions tests/test_fedbiosklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,59 @@
from sklearn.linear_model import SGDRegressor

from fedbiomed.common.training_plans import SGDSkLearnModel
from fedbiomed.common.exceptions import FedbiomedTrainingPlanError
from fedbiomed.common.exceptions import FedbiomedTrainingPlanError


class TestModel(SGDSkLearnModel):
"""
What is it ?
"""

def __init__(self, model_args: dict = {}):
super(TestModel).__init__()

def adhoc(self):
print('adhoc')


class TestFedbiosklearn(unittest.TestCase):

def setUp(self):
self.model = SGDRegressor(max_iter=1000, tol=1e-3)


def tearDown(self):
pass

def test_init(self):
kw = {'toto': 'le' , 'lelo':'la', 'max_iter':7000, 'tol': 0.3456, 'n_features': 10, 'model': 'SGDRegressor' }
kw = {'toto': 'le', 'lelo': 'la', 'max_iter': 7000, 'tol': 0.3456, 'n_features': 10, 'model': 'SGDRegressor'}
fbsk = SGDSkLearnModel(kw)
m = fbsk.get_model()
p = m.get_params()
self.assertEqual(p['max_iter'] , 7000)
self.assertEqual(p['tol'], 0.3456 )
self.assertTrue( np.allclose(m.coef_, np.zeros(10)) )
self.assertEqual(p['max_iter'], 7000)
self.assertEqual(p['tol'], 0.3456)
self.assertTrue(np.allclose(m.coef_, np.zeros(10)))
self.assertIsNone(p.get('lelo'))
self.assertIsNone(p.get('toto'))
self.assertIsNone(p.get('model'))

def test_not_implemented_method(self):
kw = {'toto': 'le', 'lelo': 'la', 'max_iter': 7000, 'tol': 0.3456, 'n_features': 10, 'model': 'SGDRegressor'}
t = TestModel(kw)
self.assertRaises(FedbiomedTrainingPlanError,lambda: t.training_data())
self.assertRaises(FedbiomedTrainingPlanError, lambda: t.training_data())

def test_save_and_load(self):
randomfile = tempfile.NamedTemporaryFile()

skm = SGDSkLearnModel({'max_iter': 1000, 'tol':1e-3, 'n_features': 5, 'model': 'SGDRegressor'})
skm = SGDSkLearnModel({'max_iter': 1000, 'tol': 1e-3, 'n_features': 5, 'model': 'SGDRegressor'})
skm.save(randomfile.name)

self.assertTrue(os.path.exists(randomfile.name) and os.path.getsize(randomfile.name) > 0 )
self.assertTrue(os.path.exists(randomfile.name) and os.path.getsize(randomfile.name) > 0)

m = skm.load(randomfile.name)

self.assertEqual(m.max_iter,1000)
self.assertEqual(m.max_iter, 1000)
self.assertEqual(m.tol, 0.001)


if __name__ == '__main__': # pragma: no cover
unittest.main()
unittest.main()
2 changes: 0 additions & 2 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def test_basic_json(self):
loop1 = js.serialize_msg(js.deserialize_msg(msg))
self.assertEqual(loop1, msg)

loop2 = js.deserialize_msg(js.serialize_msg(msg))
self.assertEqual(loop1, msg)
pass

def test_errnum_json(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/testsupport/fake_training_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def training_routine(self, **kwargs):
"""
time.sleep(FakeModel.SLEEPING_TIME)

def testing_routine(self, **kwargs):
pass


def after_training_params(self) -> List[int]:
"""Fakes `after_training_params` method of TrainingPlan classes.
Originally used to get the parameters after training is performed.
Expand Down

0 comments on commit fa1ccbe

Please sign in to comment.