Skip to content

Commit

Permalink
CU-8697qfvzz: Add tests regarding training meta-cats during supervise…
Browse files Browse the repository at this point in the history
…d training
  • Loading branch information
mart-r committed Jan 29, 2025
1 parent c84dcde commit 6ec9abb
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re
with self.subTest(f'CUI: {filtered_cui}'):
self.assertTrue(filtered_cui in self.undertest.config.linking.filters.cuis)

def _test_train_sup_with_meta_cat(self, train_meta_cats: bool):
# def side_effect(doc, *args, **kwargs):
# raise ValueError()
# # return doc
meta_cat = _get_meta_cat(self.meta_cat_dir)
cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
with patch.object(MetaCAT, "train_raw") as mock_train:
with patch.object(MetaCAT, "__call__", side_effect=lambda doc: doc):
cat.train_supervised_raw(get_fixed_meta_cat_data(), never_terminate=True,
train_meta_cats=train_meta_cats)
if train_meta_cats:
mock_train.assert_called()
else:
mock_train.assert_not_called()

def test_train_supervised_does_not_train_meta_cat_by_default(self):
self._test_train_sup_with_meta_cat(False)

def test_train_supervised_can_train_meta_cats(self):
self._test_train_sup_with_meta_cat(True)

def test_train_supervised_no_leak_extra_cui_filters(self):
self.test_train_supervised_does_not_retain_MCT_filters_default(extra_cui_filter={'C123', 'C111'})

Expand Down Expand Up @@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self):
CAT.load_model_pack(self.temp_dir.name)


META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")


def _get_meta_cat(meta_cat_dir):
config = ConfigMetaCAT()
config.general["category_name"] = "Status"
Expand All @@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir):
embeddings=None,
config=config)
os.makedirs(meta_cat_dir, exist_ok=True)
json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json")
json_path = META_CAT_JSON_PATH
meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir)
return meta_cat


def get_fixed_meta_cat_data(path: str = META_CAT_JSON_PATH):
with open(path) as f:
data = json.load(f)
for proj_num, project in enumerate(data['projects']):
if 'name' not in project:
project['name'] = f"Proj_{proj_num}"
if 'cuis' not in project:
project['cuis'] = ''
if 'id' not in project:
project['id'] = f'P{proj_num}'
for doc in project['documents']:
if 'entities' in doc and 'annotations' not in doc:
ents = doc.pop("entities")
doc['annotations'] = list(ents.values())
for ann in doc['annotations']:
if 'pretty_name' in ann and 'value' not in ann:
ann['value'] = ann.pop('pretty_name')
return data


class TestLoadingOldWeights(unittest.TestCase):
cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"..", "examples", "cdb_old_broken_weights_in_config.dat")
Expand Down

0 comments on commit 6ec9abb

Please sign in to comment.