Skip to content

Commit

Permalink
removed class wrapper in cdb utils and fixed class set up in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-sutton-1992 committed Dec 15, 2023
1 parent 7cdd208 commit fe9ef66
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 112 deletions.
203 changes: 100 additions & 103 deletions medcat/utils/cdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,114 +7,111 @@
logger = logging.getLogger(__name__) # separate logger from the package-level one


class cdb_utils(object):
def merge_cdb(cdb1: "CDB",
cdb2: "CDB",
overwrite_training: int = 0,
full_build: bool = False):
"""Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
`addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`
@staticmethod
def merge_cdb(cdb1: "CDB",
cdb2: "CDB",
overwrite_training: int = 0,
full_build: bool = False):
"""Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed.
`addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build`
Args:
cdb1 (medcat.cdb.CDB):
The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as
cui2preferred_name), this cdb values will be prioritised over cdb2.
cdb2 (medcat.cdb.CDB):
The second medcat cdb to merge.
overwrite_training (int):
Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2
full_build (bool):
Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description"
"""
config = deepcopy(cdb1.config)
cdb = CDB(config)

Args:
cdb1 (medcat.cdb.CDB):
The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as
cui2preferred_name), this cdb values will be prioritised over cdb2.
cdb2 (medcat.cdb.CDB):
The second medcat cdb to merge.
overwrite_training (int):
Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2
full_build (bool):
Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description"
"""
config = deepcopy(cdb1.config)
cdb = CDB(config)
# Copy CDB 1 - as all settings from CDB 1 will be carried over
cdb.cui2names = deepcopy(cdb1.cui2names)
cdb.cui2snames = deepcopy(cdb1.cui2snames)
cdb.cui2count_train = deepcopy(cdb1.cui2count_train)
cdb.cui2info = deepcopy(cdb1.cui2info)
cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors)
cdb.cui2tags = deepcopy(cdb1.cui2tags)
cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids)
cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name)
cdb.name2cuis = deepcopy(cdb1.name2cuis)
cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status)
cdb.name2count_train = deepcopy(cdb1.name2count_train)
cdb.name_isupper = deepcopy(cdb1.name_isupper)
if full_build:
cdb.addl_info = deepcopy(cdb1.addl_info)

# Copy CDB 1 - as all settings from CDB 1 will be carried over
cdb.cui2names = deepcopy(cdb1.cui2names)
cdb.cui2snames = deepcopy(cdb1.cui2snames)
cdb.cui2count_train = deepcopy(cdb1.cui2count_train)
cdb.cui2info = deepcopy(cdb1.cui2info)
cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors)
cdb.cui2tags = deepcopy(cdb1.cui2tags)
cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids)
cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name)
cdb.name2cuis = deepcopy(cdb1.name2cuis)
cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status)
cdb.name2count_train = deepcopy(cdb1.name2count_train)
cdb.name_isupper = deepcopy(cdb1.name_isupper)
if full_build:
cdb.addl_info = deepcopy(cdb1.addl_info)
# handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name
for cui in cdb2.cui2names:
names = dict()
for name in cdb2.cui2names[cui]:
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
# For addl_info check cui2original_names as they MUST be added
ontologies = set()
description = ''
to_build = False
if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
to_build = True
if 'cui2ontologies' in cdb2.addl_info:
ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
if 'cui2description' in cdb2.addl_info:
description = cdb2.addl_info['cui2description'][cui]
cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
if cui in cdb1.cui2names:
if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train):
if overwrite_training == 2 and cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
else:
cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
weights = [0, 1]
else:
norm = cdb.cui2count_train[cui]
weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
for s in contexts:
cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
if cui in cdb1.cui2tags:
cdb.cui2tags[cui].append(cdb2.cui2tags[cui])
if cui in cdb1.cui2type_ids:
cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui])
else:
if cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2names[cui]
if cui in cdb2.cui2info:
cdb.cui2info[cui] = cdb2.cui2info[cui]
if cui in cdb2.cui2context_vectors:
cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui]
if cui in cdb2.cui2tags:
cdb.cui2tags[cui] = cdb2.cui2tags[cui]
if cui in cdb2.cui2type_ids:
cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]

# handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name
for cui in cdb2.cui2names:
names = dict()
for name in cdb2.cui2names[cui]:
names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)}
name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A'
# For addl_info check cui2original_names as they MUST be added
ontologies = set()
description = ''
to_build = False
if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']):
to_build = True
if 'cui2ontologies' in cdb2.addl_info:
ontologies.update(cdb2.addl_info['cui2ontologies'][cui])
if 'cui2description' in cdb2.addl_info:
description = cdb2.addl_info['cui2description'][cui]
cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status,
type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build)
if cui in cdb1.cui2names:
if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train):
if overwrite_training == 2 and cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2count_train[cui]
else:
cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0)
if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]):
if overwrite_training == 2 and cui in cdb2.cui2context_vectors:
weights = [0, 1]
else:
norm = cdb.cui2count_train[cui]
weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)]
contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short
for s in contexts:
cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300))))
if cui in cdb1.cui2tags:
cdb.cui2tags[cui].append(cdb2.cui2tags[cui])
if cui in cdb1.cui2type_ids:
cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui])
if overwrite_training != 1:
for name in cdb2.name2cuis:
if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
if name in cdb1.name2count_train and name in cdb2.name2count_train:
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
else:
if cui in cdb2.cui2count_train:
cdb.cui2count_train[cui] = cdb2.cui2names[cui]
if cui in cdb2.cui2info:
cdb.cui2info[cui] = cdb2.cui2info[cui]
if cui in cdb2.cui2context_vectors:
cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui]
if cui in cdb2.cui2tags:
cdb.cui2tags[cui] = cdb2.cui2tags[cui]
if cui in cdb2.cui2type_ids:
cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui]

if overwrite_training != 1:
for name in cdb2.name2cuis:
if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs
if name in cdb1.name2count_train and name in cdb2.name2count_train:
cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason
else:
if name in cdb2.name2count_train:
cdb.name2count_train[name] = cdb2.name2count_train[name]
if name in cdb2.name2count_train:
cdb.name2count_train[name] = cdb2.name2count_train[name]

# snames
cdb.snames = cdb1.snames.union(cdb2.snames)
# snames
cdb.snames = cdb1.snames.union(cdb2.snames)

# vocab, adding counts if they occur in both
cdb.vocab = deepcopy(cdb1.vocab)
if overwrite_training != 1:
for word in cdb2.vocab:
if word in cdb.vocab and overwrite_training == 0:
cdb.vocab[word] += cdb2.vocab[word]
else:
cdb.vocab[word] = cdb2.vocab[word]
# vocab, adding counts if they occur in both
cdb.vocab = deepcopy(cdb1.vocab)
if overwrite_training != 1:
for word in cdb2.vocab:
if word in cdb.vocab and overwrite_training == 0:
cdb.vocab[word] += cdb2.vocab[word]
else:
cdb.vocab[word] = cdb2.vocab[word]

return cdb
return cdb
18 changes: 9 additions & 9 deletions tests/utils/test_cdb_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import unittest
import numpy as np
from tests.helper import ForCDBMerging
from medcat.utils.cdb_utils import cdb_utils
from medcat.utils.cdb_utils import merge_cdb


class CDBMergeTests(unittest.TestCase):
@classmethod
def setUp(cls) -> None:

def setUp(self) -> None:
to_merge = ForCDBMerging()
cls.cdb1 = to_merge.cdb1
cls.cdb2 = to_merge.cdb2
cls.merged_cdb = cdb_utils.merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2)
cls.overwrite_cdb = cdb_utils.merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2, overwrite_training=2, full_build=True)
cls.zeroes = np.zeros(shape=(1,300))
cls.ones = np.ones(shape=(1,300))
self.cdb1 = to_merge.cdb1
self.cdb2 = to_merge.cdb2
self.merged_cdb = merge_cdb(cdb1=self.cdb1, cdb2=self.cdb2)
self.overwrite_cdb = merge_cdb(cdb1=self.cdb1, cdb2=self.cdb2, overwrite_training=2, full_build=True)
self.zeroes = np.zeros(shape=(1,300))
self.ones = np.ones(shape=(1,300))

def test_merge_inserts(self):
self.assertIn("test", self.merged_cdb.cui2names["C0006826"])
Expand Down

0 comments on commit fe9ef66

Please sign in to comment.