forked from fedbiomed/fedbiomed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_data_manager.py
98 lines (74 loc) · 3.99 KB
/
test_data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import unittest
import pandas as pd
from torch.utils.data import Dataset
from fedbiomed.common.data import DataManager
from fedbiomed.common.data._torch_data_manager import TorchDataManager
from fedbiomed.common.data._sklearn_data_manager import SkLearnDataManager
from fedbiomed.common.exceptions import FedbiomedDataManagerError
from fedbiomed.common.constants import TrainingPlans
class TestDataManager(unittest.TestCase):
class CustomDataset(Dataset):
""" Create PyTorch Dataset for test purposes """
def __init__(self):
self.X_train = []
self.Y_train = []
def __len__(self):
return len(self.Y_train)
def __getitem__(self, idx):
return self.X_train[idx], self.Y_train[idx]
def setUp(self):
pass
def tearDown(self):
pass
def test_data_manager_01_load(self):
""" Testing __getattr__ method of DataManager """
# Test passing invalid argument
with self.assertRaises(FedbiomedDataManagerError):
data_manager = DataManager(dataset='invalid-argument')
data_manager.load(tp_type=TrainingPlans.TorchTrainingPlan)
# Test passing another invalid argument
with self.assertRaises(FedbiomedDataManagerError):
DataManager(dataset=12)
data_manager.load(tp_type=TrainingPlans.TorchTrainingPlan)
# Test passing dataset as list
with self.assertRaises(FedbiomedDataManagerError):
data_manager = DataManager(dataset=[12, 12, 12, 12])
data_manager.load(tp_type=TrainingPlans.TorchTrainingPlan)
# Test passing PyTorch Dataset while training plan is SkLearn
with self.assertRaises(FedbiomedDataManagerError):
data_manager = DataManager(dataset=TestDataManager.CustomDataset())
data_manager.load(tp_type=TrainingPlans.SkLearnTrainingPlan)
# Test Torch Dataset Scenario
data_manager = DataManager(dataset=TestDataManager.CustomDataset())
data_manager.load(tp_type=TrainingPlans.TorchTrainingPlan)
self.assertIsInstance(data_manager._data_manager_instance, TorchDataManager)
# Test SkLearn Scenario
data_manager = DataManager(dataset=pd.DataFrame([[1, 2, 3], [1, 2, 3]]), target=pd.Series([1, 2]))
data_manager.load(tp_type=TrainingPlans.SkLearnTrainingPlan)
self.assertIsInstance(data_manager._data_manager_instance, SkLearnDataManager)
# Test auto PyTorch dataset creation
data_manager = DataManager(dataset=pd.DataFrame([[1, 2, 3], [1, 2, 3]]), target=pd.Series([1, 2]))
data_manager.load(tp_type=TrainingPlans.TorchTrainingPlan)
self.assertIsInstance(data_manager._data_manager_instance, TorchDataManager)
# Test if inputs are not supported by SkLearnTrainingPlan
data_manager = DataManager(dataset=['non-pd-or-numpy'], target=['non-pd-or-numpy'])
with self.assertRaises(FedbiomedDataManagerError):
data_manager.load(tp_type=TrainingPlans.SkLearnTrainingPlan)
# Test undefined training plan
data_manager = DataManager(dataset=pd.DataFrame([[1, 2, 3], [1, 2, 3]]), target=pd.Series([1, 2]))
with self.assertRaises(FedbiomedDataManagerError):
data_manager.load(tp_type='NanaNone')
def test_data_manager_01___getattr___(self):
""" Test __getattr__ magic method of DataManager """
data_manager = DataManager(dataset=pd.DataFrame([[1, 2, 3], [1, 2, 3]]), target=pd.Series([1, 2]))
data_manager.load(tp_type=TrainingPlans.TorchTrainingPlan)
try:
load = data_manager.__getattr__('load')
dataset = data_manager.__getattr__('dataset')
except Exception as e:
self.assertTrue(False, f'Error while calling __getattr__ method of DataManager {str(e)}')
# Test attribute error tyr/catch block
with self.assertRaises(FedbiomedDataManagerError):
data_manager.__getattr__('toto')
if __name__ == '__main__': # pragma: no cover
unittest.main()